Compare commits
57 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e3f6ba3fc1 | |||
| f442955c6e | |||
| ff9a387618 | |||
| fb8722e9ab | |||
| 512044c5ea | |||
| 03c3f69aa5 | |||
| f20aba3e87 | |||
| 6c85fcd899 | |||
| 085e9cba36 | |||
| 919ee1aee3 | |||
| 9cda45701c | |||
| c678e8a445 | |||
| ccf2c31188 | |||
| d1342d7464 | |||
| 7b10e4ae65 | |||
| 3c0531bc50 | |||
| a8e47978c6 | |||
| 50e18ee698 | |||
| 4b17fa2a2e | |||
| d45199a2f1 | |||
| 061163142d | |||
| 5780776c8a | |||
| f19421e27c | |||
| 9a0cc463ee | |||
| ef4e373a65 | |||
| 1b4af6b7ef | |||
| 69cdc25746 | |||
| cfd6ec7465 | |||
| ea77fdc4b4 | |||
| 1082c46afa | |||
| ba2ba9019f | |||
| fa4c0e5e2e | |||
| b793debd9d | |||
| 377057126c | |||
| 5937e11d85 | |||
| 9c1d4e3be1 | |||
| 7ea065c507 | |||
| 7a7a487396 | |||
| 4efb4db9d0 | |||
| 639fd12a20 | |||
| 69a9828f4d | |||
| 11d22e0e80 | |||
| 9a38fab5ae | |||
| cb8e61ed2f | |||
| 8e53cd959e | |||
| 359b605f4b | |||
| 6febc08bfc | |||
| 9a2eaed002 | |||
| 0c71189abe | |||
| 58d2b10a2e | |||
| 255c5742aa | |||
| 4524d43279 | |||
| b6dc0b75f4 | |||
| 966a2ff8df | |||
| 201da97dd0 | |||
| 4423097b23 | |||
| 60d1b81023 |
@@ -25,7 +25,7 @@ jobs:
|
||||
group: aws-g6e-4xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
|
||||
@@ -61,7 +61,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
@@ -107,7 +107,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -178,7 +178,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -222,7 +222,7 @@ jobs:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
@@ -270,7 +270,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-minimum-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -333,7 +333,7 @@ jobs:
|
||||
additional_deps: ["peft"]
|
||||
- backend: "gguf"
|
||||
test_location: "gguf"
|
||||
additional_deps: ["peft"]
|
||||
additional_deps: ["peft", "kernels"]
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
additional_deps: []
|
||||
@@ -344,7 +344,7 @@ jobs:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "20gb" --ipc host --gpus 0
|
||||
options: --shm-size "20gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
@@ -396,7 +396,7 @@ jobs:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "20gb" --ipc host --gpus 0
|
||||
options: --shm-size "20gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
name: Fast PR tests for Modular
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths:
|
||||
- "src/diffusers/modular_pipelines/**.py"
|
||||
- "src/diffusers/models/modeling_utils.py"
|
||||
- "src/diffusers/models/model_loading_utils.py"
|
||||
- "src/diffusers/pipelines/pipeline_utils.py"
|
||||
- "src/diffusers/pipeline_loading_utils.py"
|
||||
- "src/diffusers/loaders/lora_base.py"
|
||||
- "src/diffusers/loaders/lora_pipeline.py"
|
||||
- "src/diffusers/loaders/peft.py"
|
||||
- "tests/modular_pipelines/**.py"
|
||||
- ".github/**.yml"
|
||||
- "utils/**.py"
|
||||
- "setup.py"
|
||||
push:
|
||||
branches:
|
||||
- ci-*
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_support_list.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Fast PyTorch Modular Pipeline CPU tests
|
||||
framework: pytorch_pipelines
|
||||
runner: aws-highmemory-32-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_modular_pipelines
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on:
|
||||
group: ${{ matrix.config.runner }}
|
||||
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/modular_pipelines
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
@@ -183,7 +183,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -253,7 +253,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
@@ -109,7 +109,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -167,7 +167,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -210,7 +210,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-xformers-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -252,7 +252,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
@@ -107,7 +107,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -163,7 +163,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-minimum-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -222,7 +222,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -265,7 +265,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-xformers-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -307,7 +307,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: ${{ github.event.inputs.docker_image }}
|
||||
options: --gpus 0 --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
|
||||
steps:
|
||||
- name: Validate test files input
|
||||
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
group: "${{ github.event.inputs.runner_type }}"
|
||||
container:
|
||||
image: ${{ github.event.inputs.docker_image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --gpus 0 --privileged
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --gpus all --privileged
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
|
||||
@@ -366,6 +366,8 @@
|
||||
title: PixArtTransformer2DModel
|
||||
- local: api/models/prior_transformer
|
||||
title: PriorTransformer
|
||||
- local: api/models/qwenimage_transformer2d
|
||||
title: QwenImageTransformer2DModel
|
||||
- local: api/models/sana_transformer2d
|
||||
title: SanaTransformer2DModel
|
||||
- local: api/models/sd3_transformer2d
|
||||
@@ -418,6 +420,8 @@
|
||||
title: AutoencoderKLMagvit
|
||||
- local: api/models/autoencoderkl_mochi
|
||||
title: AutoencoderKLMochi
|
||||
- local: api/models/autoencoderkl_qwenimage
|
||||
title: AutoencoderKLQwenImage
|
||||
- local: api/models/autoencoder_kl_wan
|
||||
title: AutoencoderKLWan
|
||||
- local: api/models/consistency_decoder_vae
|
||||
@@ -554,6 +558,8 @@
|
||||
title: PixArt-α
|
||||
- local: api/pipelines/pixart_sigma
|
||||
title: PixArt-Σ
|
||||
- local: api/pipelines/qwenimage
|
||||
title: QwenImage
|
||||
- local: api/pipelines/sana
|
||||
title: Sana
|
||||
- local: api/pipelines/sana_sprint
|
||||
|
||||
@@ -30,6 +30,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
|
||||
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
|
||||
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
|
||||
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen)
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
<Tip>
|
||||
@@ -105,6 +106,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
|
||||
|
||||
## QwenImageLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin
|
||||
|
||||
## LoraBaseMixin
|
||||
|
||||
[[autodoc]] loaders.lora_base.LoraBaseMixin
|
||||
@@ -0,0 +1,35 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLQwenImage
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLQwenImage
|
||||
|
||||
vae = AutoencoderKLQwenImage.from_pretrained("Qwen/QwenImage-20B", subfolder="vae")
|
||||
```
|
||||
|
||||
## AutoencoderKLQwenImage
|
||||
|
||||
[[autodoc]] AutoencoderKLQwenImage
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -0,0 +1,28 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# QwenImageTransformer2DModel
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
transformer = QwenImageTransformer2DModel.from_pretrained("Qwen/QwenImage-20B", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## QwenImageTransformer2DModel
|
||||
|
||||
[[autodoc]] QwenImageTransformer2DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -0,0 +1,92 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# QwenImage
|
||||
|
||||
Qwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese.
|
||||
|
||||
Check out the model card [here](https://huggingface.co/Qwen/Qwen-Image) to learn more.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## LoRA for faster inference
|
||||
|
||||
Use a LoRA from `lightx2v/Qwen-Image-Lightning` to speed up inference by reducing the
|
||||
number of steps. Refer to the code snippet below:
|
||||
|
||||
<details>
|
||||
<summary>Code</summary>
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
|
||||
import torch
|
||||
import math
|
||||
|
||||
ckpt_id = "Qwen/Qwen-Image"
|
||||
|
||||
# From
|
||||
# https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
|
||||
scheduler_config = {
|
||||
"base_image_seq_len": 256,
|
||||
"base_shift": math.log(3), # We use shift=3 in distillation
|
||||
"invert_sigmas": False,
|
||||
"max_image_seq_len": 8192,
|
||||
"max_shift": math.log(3), # We use shift=3 in distillation
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 1.0,
|
||||
"shift_terminal": None, # set shift_terminal to None
|
||||
"stochastic_sampling": False,
|
||||
"time_shift_type": "exponential",
|
||||
"use_beta_sigmas": False,
|
||||
"use_dynamic_shifting": True,
|
||||
"use_exponential_sigmas": False,
|
||||
"use_karras_sigmas": False,
|
||||
}
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
ckpt_id, scheduler=scheduler, torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
pipe.load_lora_weights(
|
||||
"lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.0.safetensors"
|
||||
)
|
||||
|
||||
prompt = "a tiny astronaut hatching from an egg on the moon, Ultra HD, 4K, cinematic composition."
|
||||
negative_prompt = " "
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=1024,
|
||||
height=1024,
|
||||
num_inference_steps=8,
|
||||
true_cfg_scale=1.0,
|
||||
generator=torch.manual_seed(0),
|
||||
).images[0]
|
||||
image.save("qwen_fewsteps.png")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## QwenImagePipeline
|
||||
|
||||
[[autodoc]] QwenImagePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## QwenImagePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
|
||||
@@ -29,6 +29,7 @@
|
||||
You can find all the original Wan2.1 checkpoints under the [Wan-AI](https://huggingface.co/Wan-AI) organization.
|
||||
|
||||
The following Wan models are supported in Diffusers:
|
||||
|
||||
- [Wan 2.1 T2V 1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)
|
||||
- [Wan 2.1 T2V 14B](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers)
|
||||
- [Wan 2.1 I2V 14B - 480P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers)
|
||||
@@ -36,6 +37,9 @@ The following Wan models are supported in Diffusers:
|
||||
- [Wan 2.1 FLF2V 14B - 720P](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers)
|
||||
- [Wan 2.1 VACE 1.3B](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B-diffusers)
|
||||
- [Wan 2.1 VACE 14B](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B-diffusers)
|
||||
- [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)
|
||||
- [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers)
|
||||
- [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)
|
||||
|
||||
> [!TIP]
|
||||
> Click on the Wan2.1 models in the right sidebar for more examples of video generation.
|
||||
@@ -327,6 +331,8 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
|
||||
|
||||
- Try lower `shift` values (`2.0` to `5.0`) for lower resolution videos and higher `shift` values (`7.0` to `12.0`) for higher resolution images.
|
||||
|
||||
- Wan 2.1 and 2.2 support using [LightX2V LoRAs](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Lightx2v) to speed up inference. Using them on Wan 2.2 is slightly more involed. Refer to [this code snippet](https://github.com/huggingface/diffusers/pull/12040#issuecomment-3144185272) to learn more.
|
||||
|
||||
## WanPipeline
|
||||
|
||||
[[autodoc]] WanPipeline
|
||||
|
||||
+75
-102
@@ -12,183 +12,156 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Installation
|
||||
|
||||
🤗 Diffusers is tested on Python 3.8+, PyTorch 1.7.0+, and Flax. Follow the installation instructions below for the deep learning library you are using:
|
||||
Diffusers is tested on Python 3.8+, PyTorch 1.4+, and Flax 0.4.1+. Follow the installation instructions for the deep learning library you're using, [PyTorch](https://pytorch.org/get-started/locally/) or [Flax](https://flax.readthedocs.io/en/latest/).
|
||||
|
||||
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions
|
||||
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions
|
||||
|
||||
## Install with pip
|
||||
|
||||
You should install 🤗 Diffusers in a [virtual environment](https://docs.python.org/3/library/venv.html).
|
||||
If you're unfamiliar with Python virtual environments, take a look at this [guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
A virtual environment makes it easier to manage different projects and avoid compatibility issues between dependencies.
|
||||
|
||||
Create a virtual environment with Python or [uv](https://docs.astral.sh/uv/) (refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), a fast Rust-based Python package and project manager.
|
||||
|
||||
<hfoptions id="install">
|
||||
<hfoption id="uv">
|
||||
Create a [virtual environment](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) for easier management of separate projects and to avoid compatibility issues between dependencies. Use [uv](https://docs.astral.sh/uv/), a Rust-based Python package and project manager, to create a virtual environment and install Diffusers.
|
||||
|
||||
```bash
|
||||
uv venv my-env
|
||||
source my-env/bin/activate
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Python">
|
||||
Install Diffusers with one of the following methods.
|
||||
|
||||
<hfoptions id="install">
|
||||
<hfoption id="pip">
|
||||
|
||||
PyTorch only supports Python 3.8 - 3.11 on Windows.
|
||||
|
||||
```bash
|
||||
python -m venv my-env
|
||||
source my-env/bin/activate
|
||||
uv pip install diffusers["torch"] transformers
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
You should also install 🤗 Transformers because 🤗 Diffusers relies on its models.
|
||||
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
PyTorch only supports Python 3.8 - 3.11 on Windows. Install Diffusers with uv.
|
||||
|
||||
```bash
|
||||
uv install diffusers["torch"] transformers
|
||||
```
|
||||
|
||||
You can also install Diffusers with pip.
|
||||
|
||||
```bash
|
||||
pip install diffusers["torch"] transformers
|
||||
```
|
||||
|
||||
</pt>
|
||||
<jax>
|
||||
|
||||
Install Diffusers with uv.
|
||||
Use the command below for Flax.
|
||||
|
||||
```bash
|
||||
uv pip install diffusers["flax"] transformers
|
||||
```
|
||||
|
||||
You can also install Diffusers with pip.
|
||||
|
||||
```bash
|
||||
pip install diffusers["flax"] transformers
|
||||
```
|
||||
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
## Install with conda
|
||||
|
||||
After activating your virtual environment, with `conda` (maintained by the community):
|
||||
</hfoption>
|
||||
<hfoption id="conda">
|
||||
|
||||
```bash
|
||||
conda install -c conda-forge diffusers
|
||||
```
|
||||
|
||||
## Install from source
|
||||
</hfoption>
|
||||
<hfoption id="source">
|
||||
|
||||
Before installing 🤗 Diffusers from source, make sure you have PyTorch and 🤗 Accelerate installed.
|
||||
A source install installs the `main` version instead of the latest `stable` version. The `main` version is useful for staying updated with the latest changes but it may not always be stable. If you run into a problem, open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and we will try to resolve it as soon as possible.
|
||||
|
||||
To install 🤗 Accelerate:
|
||||
Make sure [Accelerate](https://huggingface.co/docs/accelerate/index) is installed.
|
||||
|
||||
```bash
|
||||
pip install accelerate
|
||||
uv pip install accelerate
|
||||
```
|
||||
|
||||
Then install 🤗 Diffusers from source:
|
||||
Install Diffusers from source with the command below.
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/diffusers
|
||||
uv pip install git+https://github.com/huggingface/diffusers
|
||||
```
|
||||
|
||||
This command installs the bleeding edge `main` version rather than the latest `stable` version.
|
||||
The `main` version is useful for staying up-to-date with the latest developments.
|
||||
For instance, if a bug has been fixed since the last official release but a new release hasn't been rolled out yet.
|
||||
However, this means the `main` version may not always be stable.
|
||||
We strive to keep the `main` version operational, and most issues are usually resolved within a few hours or a day.
|
||||
If you run into a problem, please open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) so we can fix it even sooner!
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Editable install
|
||||
|
||||
You will need an editable install if you'd like to:
|
||||
An editable install is recommended for development workflows or if you're using the `main` version of the source code. A special link is created between the cloned repository and the Python library paths. This avoids reinstalling a package after every change.
|
||||
|
||||
* Use the `main` version of the source code.
|
||||
* Contribute to 🤗 Diffusers and need to test changes in the code.
|
||||
Clone the repository and install Diffusers with the following commands.
|
||||
|
||||
Clone the repository and install 🤗 Diffusers with the following commands:
|
||||
<hfoptions id="editable">
|
||||
<hfoption id="PyTorch">
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
uv pip install -e ".[torch]"
|
||||
```
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
</hfoption>
|
||||
<hfoption id="Flax">
|
||||
|
||||
```bash
|
||||
pip install -e ".[torch]"
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
uv pip install -e ".[flax]"
|
||||
```
|
||||
</pt>
|
||||
<jax>
|
||||
```bash
|
||||
pip install -e ".[flax]"
|
||||
```
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
These commands will link the folder you cloned the repository to and your Python library paths.
|
||||
Python will now look inside the folder you cloned to in addition to the normal library paths.
|
||||
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.10/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<Tip warning={true}>
|
||||
> [!WARNING]
|
||||
> You must keep the `diffusers` folder if you want to keep using the library with the editable install.
|
||||
|
||||
You must keep the `diffusers` folder if you want to keep using the library.
|
||||
|
||||
</Tip>
|
||||
|
||||
Now you can easily update your clone to the latest version of 🤗 Diffusers with the following command:
|
||||
Update your cloned repository to the latest version of Diffusers with the command below.
|
||||
|
||||
```bash
|
||||
cd ~/diffusers/
|
||||
git pull
|
||||
```
|
||||
|
||||
Your Python environment will find the `main` version of 🤗 Diffusers on the next run.
|
||||
|
||||
## Cache
|
||||
|
||||
Model weights and files are downloaded from the Hub to a cache which is usually your home directory. You can change the cache location by specifying the `HF_HOME` or `HUGGINFACE_HUB_CACHE` environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
|
||||
Model weights and files are downloaded from the Hub to a cache, which is usually your home directory. Change the cache location with the [HF_HOME](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhome) or [HF_HUB_CACHE](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhubcache) environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
|
||||
|
||||
Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `1` and 🤗 Diffusers will only load previously downloaded files in the cache.
|
||||
<hfoptions id="cache">
|
||||
<hfoption id="env variable">
|
||||
|
||||
```bash
|
||||
export HF_HOME="/path/to/your/cache"
|
||||
export HF_HUB_CACHE="/path/to/your/hub/cache"
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="from_pretrained">
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
cache_dir="/path/to/your/cache"
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Cached files allow you to use Diffusers offline. Set the [HF_HUB_OFFLINE](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhuboffline) environment variable to `1` to prevent Diffusers from connecting to the internet.
|
||||
|
||||
```shell
|
||||
export HF_HUB_OFFLINE=1
|
||||
```
|
||||
|
||||
For more details about managing and cleaning the cache, take a look at the [caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
|
||||
For more details about managing and cleaning the cache, take a look at the [Understand caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
|
||||
|
||||
## Telemetry logging
|
||||
|
||||
Our library gathers telemetry information during [`~DiffusionPipeline.from_pretrained`] requests.
|
||||
The data gathered includes the version of 🤗 Diffusers and PyTorch/Flax, the requested model or pipeline class,
|
||||
and the path to a pretrained checkpoint if it is hosted on the Hugging Face Hub.
|
||||
Diffusers gathers telemetry information during [`~DiffusionPipeline.from_pretrained`] requests.
|
||||
The data gathered includes the Diffusers and PyTorch/Flax version, the requested model or pipeline class,
|
||||
and the path to a pretrained checkpoint if it is hosted on the Hub.
|
||||
|
||||
This usage data helps us debug issues and prioritize new features.
|
||||
Telemetry is only sent when loading models and pipelines from the Hub,
|
||||
and it is not collected if you're loading local files.
|
||||
|
||||
We understand that not everyone wants to share additional information,and we respect your privacy.
|
||||
You can disable telemetry collection by setting the `HF_HUB_DISABLE_TELEMETRY` environment variable from your terminal:
|
||||
Opt-out and disable telemetry collection with the [HF_HUB_DISABLE_TELEMETRY](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhubdisabletelemetry) environment variable.
|
||||
|
||||
On Linux/MacOS:
|
||||
<hfoptions id="telemetry">
|
||||
<hfoption id="Linux/macOS">
|
||||
|
||||
```bash
|
||||
export HF_HUB_DISABLE_TELEMETRY=1
|
||||
```
|
||||
|
||||
On Windows:
|
||||
</hfoption>
|
||||
<hfoption id="Windows">
|
||||
|
||||
```bash
|
||||
set HF_HUB_DISABLE_TELEMETRY=1
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
@@ -53,6 +53,16 @@ image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
|
||||
image.save("flux-gguf.png")
|
||||
```
|
||||
|
||||
## Using Optimized CUDA Kernels with GGUF
|
||||
|
||||
Optimized CUDA kernels can accelerate GGUF quantized model inference by approximately 10%. This functionality requires a compatible GPU with `torch.cuda.get_device_capability` greater than 7 and the kernels library:
|
||||
|
||||
```shell
|
||||
pip install -U kernels
|
||||
```
|
||||
|
||||
Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels when available. Note that CUDA kernels may introduce minor numerical differences compared to the original GGUF implementation, potentially causing subtle visual variations in generated images. To disable CUDA kernel usage, set the environment variable `DIFFUSERS_GGUF_CUDA_KERNELS=false`.
|
||||
|
||||
## Supported Quantization Types
|
||||
|
||||
- BF16
|
||||
@@ -67,3 +77,44 @@ image.save("flux-gguf.png")
|
||||
- Q5_K
|
||||
- Q6_K
|
||||
|
||||
## Convert to GGUF
|
||||
|
||||
Use the Space below to convert a Diffusers checkpoint into the GGUF format for inference.
|
||||
run conversion:
|
||||
|
||||
<iframe
|
||||
src="https://diffusers-internal-dev-diffusers-to-gguf.hf.space"
|
||||
frameborder="0"
|
||||
width="850"
|
||||
height="450"
|
||||
></iframe>
|
||||
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
|
||||
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/sayakpaul/different-lora-from-civitai/blob/main/flux_dev_diffusers-q4_0.gguf"
|
||||
)
|
||||
transformer = FluxTransformer2DModel.from_single_file(
|
||||
ckpt_path,
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
config="black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
|
||||
image.save("flux-gguf.png")
|
||||
```
|
||||
|
||||
When using Diffusers format GGUF checkpoints, it's a must to provide the model `config` path. If the
|
||||
model config resides in a `subfolder`, that needs to be specified, too.
|
||||
@@ -0,0 +1,136 @@
|
||||
# DreamBooth training example for Qwen Image
|
||||
|
||||
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
|
||||
|
||||
The `train_dreambooth_lora_qwen_image.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [Qwen Image](https://huggingface.co/Qwen/Qwen-Image).
|
||||
|
||||
|
||||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/dreambooth` folder and run
|
||||
```bash
|
||||
pip install -r requirements_sana.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
|
||||
|
||||
|
||||
### Dog toy example
|
||||
|
||||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
|
||||
|
||||
Now, we can launch training using:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="Qwen/Qwen-Image"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-sana-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_sana.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="bf16" \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--use_8bit_adam \
|
||||
--learning_rate=2e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
For using `push_to_hub`, make you're logged into your Hugging Face account:
|
||||
|
||||
```bash
|
||||
hf auth login
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
## Notes
|
||||
|
||||
Additionally, we welcome you to explore the following CLI arguments:
|
||||
|
||||
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
|
||||
* `--max_sequence_length`: Maximum sequence length to use for text embeddings.
|
||||
|
||||
We provide several options for optimizing memory optimization:
|
||||
|
||||
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
|
||||
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
|
||||
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
|
||||
|
||||
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwenimage) of the `QwenImagePipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
|
||||
|
||||
## Using quantization
|
||||
|
||||
You can quantize the base model with [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/index) to reduce memory usage. To do so, pass a JSON file path to `--bnb_quantization_config_path`. This file should hold the configuration to initialize `BitsAndBytesConfig`. Below is an example JSON file:
|
||||
|
||||
```json
|
||||
{
|
||||
"load_in_4bit": true,
|
||||
"bnb_4bit_quant_type": "nf4"
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,248 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRAQwenImage(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "photo"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_qwen_image.py"
|
||||
transformer_layer_type = "transformer_blocks.0.attn.to_k"
|
||||
|
||||
def test_dreambooth_lora_qwen(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lora_layers {self.transformer_layer_type}
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names. In this test, we only params of
|
||||
# transformer.transformer_blocks.0.attn.to_k should be in the state dict
|
||||
starts_with_transformer = all(
|
||||
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_qwen_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_qwen_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
def test_dreambooth_lora_with_metadata(self):
|
||||
# Use a `lora_alpha` that is different from `rank`.
|
||||
lora_alpha = 8
|
||||
rank = 4
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_alpha={lora_alpha}
|
||||
--rank={rank}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(state_dict_file))
|
||||
|
||||
# Check if the metadata was properly serialized.
|
||||
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
if raw:
|
||||
raw = json.loads(raw)
|
||||
|
||||
loaded_lora_alpha = raw["transformer.lora_alpha"]
|
||||
self.assertTrue(loaded_lora_alpha == lora_alpha)
|
||||
loaded_lora_rank = raw["transformer.r"]
|
||||
self.assertTrue(loaded_lora_rank == rank)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -116,7 +116,7 @@ _deps = [
|
||||
"librosa",
|
||||
"numpy",
|
||||
"parameterized",
|
||||
"peft>=0.15.0",
|
||||
"peft>=0.17.0",
|
||||
"protobuf>=3.20.3,<4",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
|
||||
@@ -139,6 +139,7 @@ else:
|
||||
"AutoGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"ClassifierFreeZeroStarGuidance",
|
||||
"FrequencyDecoupledGuidance",
|
||||
"PerturbedAttentionGuidance",
|
||||
"SkipLayerGuidance",
|
||||
"SmoothedEnergyGuidance",
|
||||
@@ -174,6 +175,7 @@ else:
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
"AutoencoderKLMochi",
|
||||
"AutoencoderKLQwenImage",
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
"AutoencoderKLWan",
|
||||
"AutoencoderOobleck",
|
||||
@@ -215,6 +217,7 @@ else:
|
||||
"OmniGenTransformer2DModel",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"QwenImageTransformer2DModel",
|
||||
"SanaControlNetModel",
|
||||
"SanaTransformer2DModel",
|
||||
"SD3ControlNetModel",
|
||||
@@ -486,6 +489,7 @@ else:
|
||||
"PixArtAlphaPipeline",
|
||||
"PixArtSigmaPAGPipeline",
|
||||
"PixArtSigmaPipeline",
|
||||
"QwenImagePipeline",
|
||||
"ReduxImageEncoder",
|
||||
"SanaControlNetPipeline",
|
||||
"SanaPAGPipeline",
|
||||
@@ -801,6 +805,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
FrequencyDecoupledGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
@@ -832,6 +837,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLQwenImage,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
@@ -873,6 +879,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
OmniGenTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
QwenImageTransformer2DModel,
|
||||
SanaControlNetModel,
|
||||
SanaTransformer2DModel,
|
||||
SD3ControlNetModel,
|
||||
@@ -1119,6 +1126,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PixArtAlphaPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
PixArtSigmaPipeline,
|
||||
QwenImagePipeline,
|
||||
ReduxImageEncoder,
|
||||
SanaControlNetPipeline,
|
||||
SanaPAGPipeline,
|
||||
|
||||
@@ -23,7 +23,7 @@ deps = {
|
||||
"librosa": "librosa",
|
||||
"numpy": "numpy",
|
||||
"parameterized": "parameterized",
|
||||
"peft": "peft>=0.15.0",
|
||||
"peft": "peft>=0.17.0",
|
||||
"protobuf": "protobuf>=3.20.3,<4",
|
||||
"pytest": "pytest",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
|
||||
@@ -22,6 +22,7 @@ if is_torch_available():
|
||||
from .auto_guidance import AutoGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
|
||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||
@@ -32,6 +33,7 @@ if is_torch_available():
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
FrequencyDecoupledGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
|
||||
@@ -0,0 +1,327 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from ..utils import is_kornia_available
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
_CAN_USE_KORNIA = is_kornia_available()
|
||||
|
||||
|
||||
if _CAN_USE_KORNIA:
|
||||
from kornia.geometry import pyrup as upsample_and_blur_func
|
||||
from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func
|
||||
else:
|
||||
upsample_and_blur_func = None
|
||||
build_laplacian_pyramid_func = None
|
||||
|
||||
|
||||
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
|
||||
(Algorithm 2).
|
||||
"""
|
||||
# v0 shape: [B, ...]
|
||||
# v1 shape: [B, ...]
|
||||
# Assume first dim is a batch dim and all other dims are channel or "spatial" dims
|
||||
all_dims_but_first = list(range(1, len(v0.shape)))
|
||||
if upcast_to_double:
|
||||
dtype = v0.dtype
|
||||
v0, v1 = v0.double(), v1.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
|
||||
v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
if upcast_to_double:
|
||||
v0_parallel = v0_parallel.to(dtype)
|
||||
v0_orthogonal = v0_orthogonal.to(dtype)
|
||||
return v0_parallel, v0_orthogonal
|
||||
|
||||
|
||||
def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
|
||||
(Algorihtm 2).
|
||||
"""
|
||||
# pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...]
|
||||
img = pyramid[-1]
|
||||
for i in range(len(pyramid) - 2, -1, -1):
|
||||
img = upsample_and_blur_func(img) + pyramid[i]
|
||||
return img
|
||||
|
||||
|
||||
class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
"""
|
||||
Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713
|
||||
|
||||
FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation
|
||||
quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both
|
||||
conditional and unconditional data, and use a combination of the two during inference. (If you want more details on
|
||||
how CFG works, you can check out the CFG guider.)
|
||||
|
||||
FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components
|
||||
using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space
|
||||
separately for the low- and high-frequency components with different guidance scales. Finally, the inverse
|
||||
frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images)
|
||||
to form the final FDG prediction.
|
||||
|
||||
For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample
|
||||
diversity and realistic color composition, while using high guidance scales for high-frequency components enhances
|
||||
sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for
|
||||
the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an
|
||||
example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper).
|
||||
|
||||
As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen
|
||||
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in
|
||||
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
||||
|
||||
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
|
||||
Args:
|
||||
guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
|
||||
The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
|
||||
frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
|
||||
values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
|
||||
image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
|
||||
lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
|
||||
descending order).
|
||||
guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
|
||||
`guidance_scales`.
|
||||
parallel_weights (`float` or `List[float]`, *optional*):
|
||||
Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
|
||||
set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
|
||||
(that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
|
||||
recommended. If a list is supplied, it should be the same length as `guidance_scales`.
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float` or `List[float]`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
|
||||
should be the same length as `guidance_scales`.
|
||||
stop (`float` or `List[float]`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
|
||||
should be the same length as `guidance_scales`.
|
||||
guidance_rescale_space (`str`, defaults to `"data"`):
|
||||
Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in
|
||||
`"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
|
||||
speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
|
||||
will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
|
||||
upcast_to_double (`bool`, defaults to `True`):
|
||||
Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
|
||||
float64 when performing guidance. This may result in better performance at the cost of increased runtime.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
|
||||
guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
|
||||
parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
|
||||
use_original_formulation: bool = False,
|
||||
start: Union[float, List[float], Tuple[float]] = 0.0,
|
||||
stop: Union[float, List[float], Tuple[float]] = 1.0,
|
||||
guidance_rescale_space: str = "data",
|
||||
upcast_to_double: bool = True,
|
||||
):
|
||||
if not _CAN_USE_KORNIA:
|
||||
raise ImportError(
|
||||
"The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which "
|
||||
"it depends is not available in the current environment. You can install `kornia` with `pip install "
|
||||
"kornia`."
|
||||
)
|
||||
|
||||
# Set start to earliest start for any freq component and stop to latest stop for any freq component
|
||||
min_start = start if isinstance(start, float) else min(start)
|
||||
max_stop = stop if isinstance(stop, float) else max(stop)
|
||||
super().__init__(min_start, max_stop)
|
||||
|
||||
self.guidance_scales = guidance_scales
|
||||
self.levels = len(guidance_scales)
|
||||
|
||||
if isinstance(guidance_rescale, float):
|
||||
self.guidance_rescale = [guidance_rescale] * self.levels
|
||||
elif len(guidance_rescale) == self.levels:
|
||||
self.guidance_rescale = guidance_rescale
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as "
|
||||
f"`guidance_scales` ({len(self.guidance_scales)})"
|
||||
)
|
||||
# Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after
|
||||
# transforming from frequency space back to data space)
|
||||
if guidance_rescale_space not in ["data", "freq"]:
|
||||
raise ValueError(
|
||||
f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`."
|
||||
)
|
||||
self.guidance_rescale_space = guidance_rescale_space
|
||||
|
||||
if parallel_weights is None:
|
||||
# Use normal CFG shift (equal weights for parallel and orthogonal components)
|
||||
self.parallel_weights = [1.0] * self.levels
|
||||
elif isinstance(parallel_weights, float):
|
||||
self.parallel_weights = [parallel_weights] * self.levels
|
||||
elif len(parallel_weights) == self.levels:
|
||||
self.parallel_weights = parallel_weights
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as "
|
||||
f"`guidance_scales` ({len(self.guidance_scales)})"
|
||||
)
|
||||
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.upcast_to_double = upcast_to_double
|
||||
|
||||
if isinstance(start, float):
|
||||
self.guidance_start = [start] * self.levels
|
||||
elif len(start) == self.levels:
|
||||
self.guidance_start = start
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`start` has length {len(start)} but should have the same length as `guidance_scales` "
|
||||
f"({len(self.guidance_scales)})"
|
||||
)
|
||||
if isinstance(stop, float):
|
||||
self.guidance_stop = [stop] * self.levels
|
||||
elif len(stop) == self.levels:
|
||||
self.guidance_stop = stop
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` "
|
||||
f"({len(self.guidance_scales)})"
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_fdg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
# Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions.
|
||||
pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels)
|
||||
pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels)
|
||||
|
||||
# From high frequencies to low frequencies, following the paper implementation
|
||||
pred_guided_pyramid = []
|
||||
parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale)
|
||||
for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters):
|
||||
if self._is_fdg_enabled_for_level(level):
|
||||
# Get the cond/uncond preds (in freq space) at the current frequency level
|
||||
pred_cond_freq = pred_cond_pyramid[level]
|
||||
pred_uncond_freq = pred_uncond_pyramid[level]
|
||||
|
||||
shift = pred_cond_freq - pred_uncond_freq
|
||||
|
||||
# Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
|
||||
if not math.isclose(parallel_weight, 1.0):
|
||||
shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
|
||||
shift = parallel_weight * shift_parallel + shift_orthogonal
|
||||
|
||||
# Apply CFG update for the current frequency level
|
||||
pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq
|
||||
pred = pred + guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale)
|
||||
|
||||
# Add the current FDG guided level to the FDG prediction pyramid
|
||||
pred_guided_pyramid.append(pred)
|
||||
else:
|
||||
# Add the current pred_cond_pyramid level as the "non-FDG" prediction
|
||||
pred_guided_pyramid.append(pred_cond_freq)
|
||||
|
||||
# Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
|
||||
pred = build_image_from_pyramid(pred_guided_pyramid)
|
||||
|
||||
# If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value
|
||||
# across all freq levels
|
||||
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_fdg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_fdg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales)
|
||||
else:
|
||||
is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_fdg_enabled_for_level(self, level: int) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.guidance_start[level] * self._num_inference_steps)
|
||||
skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scales[level], 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scales[level], 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
@@ -133,6 +133,7 @@ def _register_attention_processors_metadata():
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
|
||||
# FluxAttnProcessor
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=FluxAttnProcessor,
|
||||
@@ -153,6 +154,7 @@ def _register_transformer_blocks_metadata():
|
||||
)
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanTransformerBlock
|
||||
|
||||
# BasicTransformerBlock
|
||||
@@ -255,6 +257,15 @@ def _register_transformer_blocks_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# QwenImage
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=QwenImageTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
|
||||
@@ -95,7 +95,7 @@ class ModuleGroup:
|
||||
self.offload_to_disk_path = offload_to_disk_path
|
||||
self._is_offloaded_to_disk = False
|
||||
|
||||
if self.offload_to_disk_path:
|
||||
if self.offload_to_disk_path is not None:
|
||||
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
|
||||
self.group_id = group_id if group_id is not None else str(id(self))
|
||||
short_hash = _compute_group_hash(self.group_id)
|
||||
@@ -115,6 +115,12 @@ class ModuleGroup:
|
||||
else:
|
||||
self.cpu_param_dict = self._init_cpu_param_dict()
|
||||
|
||||
self._torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
else torch.cuda
|
||||
)
|
||||
|
||||
def _init_cpu_param_dict(self):
|
||||
cpu_param_dict = {}
|
||||
if self.stream is None:
|
||||
@@ -138,112 +144,76 @@ class ModuleGroup:
|
||||
|
||||
@contextmanager
|
||||
def _pinned_memory_tensors(self):
|
||||
pinned_dict = {}
|
||||
try:
|
||||
for param, tensor in self.cpu_param_dict.items():
|
||||
if not tensor.is_pinned():
|
||||
pinned_dict[param] = tensor.pin_memory()
|
||||
else:
|
||||
pinned_dict[param] = tensor
|
||||
|
||||
pinned_dict = {
|
||||
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
|
||||
for param, tensor in self.cpu_param_dict.items()
|
||||
}
|
||||
yield pinned_dict
|
||||
|
||||
finally:
|
||||
pinned_dict = None
|
||||
|
||||
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
|
||||
def _transfer_tensor_to_device(self, tensor, source_tensor):
|
||||
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream and current_stream is not None:
|
||||
tensor.data.record_stream(current_stream)
|
||||
if self.record_stream:
|
||||
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
|
||||
|
||||
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
|
||||
def _process_tensors_from_modules(self, pinned_memory=None):
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
source = pinned_memory[param] if pinned_memory else param.data
|
||||
self._transfer_tensor_to_device(param, source, current_stream)
|
||||
self._transfer_tensor_to_device(param, source)
|
||||
for buffer in group_module.buffers():
|
||||
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
||||
self._transfer_tensor_to_device(buffer, source, current_stream)
|
||||
self._transfer_tensor_to_device(buffer, source)
|
||||
|
||||
for param in self.parameters:
|
||||
source = pinned_memory[param] if pinned_memory else param.data
|
||||
self._transfer_tensor_to_device(param, source, current_stream)
|
||||
self._transfer_tensor_to_device(param, source)
|
||||
|
||||
for buffer in self.buffers:
|
||||
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
||||
self._transfer_tensor_to_device(buffer, source, current_stream)
|
||||
|
||||
def _onload_from_disk(self, current_stream):
|
||||
if self.stream is not None:
|
||||
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
|
||||
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
|
||||
|
||||
with self._pinned_memory_tensors() as pinned_memory:
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
|
||||
|
||||
self.cpu_param_dict.clear()
|
||||
|
||||
else:
|
||||
onload_device = (
|
||||
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
||||
)
|
||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
tensor_obj.data = loaded_tensors[key]
|
||||
|
||||
def _onload_from_memory(self, current_stream):
|
||||
if self.stream is not None:
|
||||
with self._pinned_memory_tensors() as pinned_memory:
|
||||
self._process_tensors_from_modules(pinned_memory, current_stream)
|
||||
else:
|
||||
self._process_tensors_from_modules(None, current_stream)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def onload_(self):
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
else torch.cuda
|
||||
)
|
||||
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
|
||||
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
|
||||
|
||||
if self.offload_to_disk_path:
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
if self.stream is not None:
|
||||
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
|
||||
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
|
||||
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
tensor_obj.data.record_stream(current_stream)
|
||||
else:
|
||||
# Load directly to the target device (synchronous)
|
||||
onload_device = (
|
||||
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
||||
)
|
||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
tensor_obj.data = loaded_tensors[key]
|
||||
return
|
||||
self._transfer_tensor_to_device(buffer, source)
|
||||
|
||||
def _onload_from_disk(self):
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
|
||||
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
||||
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
|
||||
|
||||
with context:
|
||||
if self.offload_to_disk_path:
|
||||
self._onload_from_disk(current_stream)
|
||||
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
|
||||
device = str(self.onload_device) if self.stream is None else "cpu"
|
||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
|
||||
|
||||
if self.stream is not None:
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
pinned_tensor = loaded_tensors[key].pin_memory()
|
||||
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
tensor_obj.data.record_stream(current_stream)
|
||||
else:
|
||||
self._onload_from_memory(current_stream)
|
||||
onload_device = (
|
||||
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
||||
)
|
||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
tensor_obj.data = loaded_tensors[key]
|
||||
|
||||
def _onload_from_memory(self):
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
|
||||
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
||||
with context:
|
||||
if self.stream is not None:
|
||||
with self._pinned_memory_tensors() as pinned_memory:
|
||||
self._process_tensors_from_modules(pinned_memory)
|
||||
else:
|
||||
self._process_tensors_from_modules(None)
|
||||
|
||||
def _offload_to_disk(self):
|
||||
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
||||
@@ -264,14 +234,10 @@ class ModuleGroup:
|
||||
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
|
||||
|
||||
def _offload_to_memory(self):
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
else torch.cuda
|
||||
)
|
||||
if self.stream is not None:
|
||||
if not self.record_stream:
|
||||
torch_accelerator_module.current_stream().synchronize()
|
||||
self._torch_accelerator_module.current_stream().synchronize()
|
||||
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
@@ -279,18 +245,25 @@ class ModuleGroup:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for buffer in self.buffers:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
group_module.to(self.offload_device, non_blocking=False)
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
param.data = param.data.to(self.offload_device, non_blocking=False)
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def onload_(self):
|
||||
r"""Onloads the group of parameters to the onload_device."""
|
||||
if self.offload_to_disk_path is not None:
|
||||
self._onload_from_disk()
|
||||
else:
|
||||
self._onload_from_memory()
|
||||
|
||||
@torch.compiler.disable()
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
r"""Offloads the group of parameters to the offload_device."""
|
||||
if self.offload_to_disk_path:
|
||||
self._offload_to_disk()
|
||||
else:
|
||||
@@ -307,11 +280,9 @@ class GroupOffloadingHook(ModelHook):
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(
|
||||
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
|
||||
) -> None:
|
||||
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
|
||||
self.group = group
|
||||
self.next_group = next_group
|
||||
self.next_group: Optional[ModuleGroup] = None
|
||||
self.config = config
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
@@ -331,9 +302,23 @@ class GroupOffloadingHook(ModelHook):
|
||||
if self.group.onload_leader == module:
|
||||
if self.group.onload_self:
|
||||
self.group.onload_()
|
||||
if self.next_group is not None and not self.next_group.onload_self:
|
||||
|
||||
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
|
||||
if should_onload_next_group:
|
||||
self.next_group.onload_()
|
||||
|
||||
should_synchronize = (
|
||||
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
|
||||
)
|
||||
if should_synchronize:
|
||||
# If this group didn't onload itself, it means it was asynchronously onloaded by the
|
||||
# previous group. We need to synchronize the side stream to ensure parameters
|
||||
# are completely loaded to proceed with forward pass. Without this, uninitialized
|
||||
# weights will be used in the computation, leading to incorrect results
|
||||
# Also, we should only do this synchronization if we don't already do it from the sync call in
|
||||
# self.next_group.onload_, hence the `not should_onload_next_group` check.
|
||||
self.group.stream.synchronize()
|
||||
|
||||
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
|
||||
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
|
||||
return args, kwargs
|
||||
@@ -459,8 +444,8 @@ class LayerExecutionTrackerHook(ModelHook):
|
||||
|
||||
def apply_group_offloading(
|
||||
module: torch.nn.Module,
|
||||
onload_device: torch.device,
|
||||
offload_device: torch.device = torch.device("cpu"),
|
||||
onload_device: Union[str, torch.device],
|
||||
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
||||
offload_type: Union[str, GroupOffloadingType] = "block_level",
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
@@ -546,6 +531,8 @@ def apply_group_offloading(
|
||||
```
|
||||
"""
|
||||
|
||||
onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
|
||||
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
|
||||
offload_type = GroupOffloadingType(offload_type)
|
||||
|
||||
stream = None
|
||||
@@ -633,7 +620,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
|
||||
# Apply group offloading hooks to the module groups
|
||||
for i, group in enumerate(matched_module_groups):
|
||||
for group_module in group.modules:
|
||||
_apply_group_offloading_hook(group_module, group, None, config=config)
|
||||
_apply_group_offloading_hook(group_module, group, config=config)
|
||||
|
||||
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
|
||||
# when the forward pass of this module is called. This is because the top-level module is not
|
||||
@@ -662,9 +649,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
|
||||
group_id=f"{module.__class__.__name__}_unmatched_group",
|
||||
)
|
||||
if config.stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
_apply_group_offloading_hook(module, unmatched_group, config=config)
|
||||
else:
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
||||
|
||||
|
||||
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
@@ -693,7 +680,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
|
||||
onload_self=True,
|
||||
group_id=name,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None, config=config)
|
||||
_apply_group_offloading_hook(submodule, group, config=config)
|
||||
modules_with_group_offloading.add(name)
|
||||
|
||||
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
|
||||
@@ -740,7 +727,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
|
||||
onload_self=True,
|
||||
group_id=name,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None, config=config)
|
||||
_apply_group_offloading_hook(parent_module, group, config=config)
|
||||
|
||||
if config.stream is not None:
|
||||
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
|
||||
@@ -762,13 +749,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
|
||||
onload_self=True,
|
||||
group_id=_GROUP_ID_LAZY_LEAF,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
||||
|
||||
|
||||
def _apply_group_offloading_hook(
|
||||
module: torch.nn.Module,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
*,
|
||||
config: GroupOffloadingConfig,
|
||||
) -> None:
|
||||
@@ -777,14 +763,13 @@ def _apply_group_offloading_hook(
|
||||
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||||
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||||
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||||
hook = GroupOffloadingHook(group, next_group, config=config)
|
||||
hook = GroupOffloadingHook(group, config=config)
|
||||
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||||
|
||||
|
||||
def _apply_lazy_group_offloading_hook(
|
||||
module: torch.nn.Module,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
*,
|
||||
config: GroupOffloadingConfig,
|
||||
) -> None:
|
||||
@@ -793,7 +778,7 @@ def _apply_lazy_group_offloading_hook(
|
||||
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||||
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||||
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||||
hook = GroupOffloadingHook(group, next_group, config=config)
|
||||
hook = GroupOffloadingHook(group, config=config)
|
||||
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||||
|
||||
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
|
||||
|
||||
|
||||
def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module):
|
||||
module_list_with_transformer_blocks = []
|
||||
for name, submodule in module.named_modules():
|
||||
name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS)
|
||||
is_modulelist = isinstance(submodule, torch.nn.ModuleList)
|
||||
if name_endswith_identifier and is_modulelist:
|
||||
module_list_with_transformer_blocks.append((name, submodule))
|
||||
return module_list_with_transformer_blocks
|
||||
|
||||
|
||||
def _get_identifiable_attention_layers_in_module(module: torch.nn.Module):
|
||||
attention_layers = []
|
||||
for name, submodule in module.named_modules():
|
||||
if isinstance(submodule, _ATTENTION_CLASSES):
|
||||
attention_layers.append((name, submodule))
|
||||
return attention_layers
|
||||
|
||||
|
||||
def _get_identifiable_feedforward_layers_in_module(module: torch.nn.Module):
|
||||
feedforward_layers = []
|
||||
for name, submodule in module.named_modules():
|
||||
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
||||
feedforward_layers.append((name, submodule))
|
||||
return feedforward_layers
|
||||
@@ -79,6 +79,7 @@ if is_torch_available():
|
||||
"WanLoraLoaderMixin",
|
||||
"HiDreamImageLoraLoaderMixin",
|
||||
"SkyReelsV2LoraLoaderMixin",
|
||||
"QwenImageLoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = [
|
||||
@@ -118,6 +119,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LTXVideoLoraLoaderMixin,
|
||||
Lumina2LoraLoaderMixin,
|
||||
Mochi1LoraLoaderMixin,
|
||||
QwenImageLoraLoaderMixin,
|
||||
SanaLoraLoaderMixin,
|
||||
SD3LoraLoaderMixin,
|
||||
SkyReelsV2LoraLoaderMixin,
|
||||
|
||||
@@ -817,7 +817,11 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
# has both `peft` and non-peft state dict.
|
||||
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
|
||||
if has_peft_state_dict:
|
||||
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
|
||||
state_dict = {
|
||||
k.replace("lora_down.weight", "lora_A.weight").replace("lora_up.weight", "lora_B.weight"): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("transformer.")
|
||||
}
|
||||
return state_dict
|
||||
|
||||
# Another weird one.
|
||||
@@ -1974,6 +1978,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
bias_key_theirs = original_key.removesuffix(f".{lora_up_key}.weight") + ".diff_b"
|
||||
if bias_key_theirs in original_state_dict:
|
||||
bias_key = converted_key.removesuffix(".weight") + ".bias"
|
||||
converted_state_dict[bias_key] = original_state_dict.pop(bias_key_theirs)
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
diff = all(".diff" in k for k in original_state_dict)
|
||||
@@ -2069,3 +2077,39 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
|
||||
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
|
||||
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
converted_state_dict = {}
|
||||
all_keys = list(state_dict.keys())
|
||||
down_key = ".lora_down.weight"
|
||||
up_key = ".lora_up.weight"
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
return scale_down, scale_up
|
||||
|
||||
for k in all_keys:
|
||||
if k.endswith(down_key):
|
||||
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
|
||||
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
|
||||
alpha_key = k.replace(down_key, ".alpha")
|
||||
|
||||
down_weight = state_dict.pop(k)
|
||||
up_weight = state_dict.pop(k.replace(down_key, up_key))
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[diffusers_down_key] = down_weight * scale_down
|
||||
converted_state_dict[diffusers_up_key] = up_weight * scale_up
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
|
||||
|
||||
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
||||
return converted_state_dict
|
||||
|
||||
@@ -49,6 +49,7 @@ from .lora_conversion_utils import (
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_ltxv_lora_to_diffusers,
|
||||
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
||||
_convert_non_diffusers_qwen_lora_to_diffusers,
|
||||
_convert_non_diffusers_wan_lora_to_diffusers,
|
||||
_convert_xlabs_flux_lora_to_diffusers,
|
||||
_maybe_map_sgm_blocks_to_diffusers,
|
||||
@@ -6538,6 +6539,351 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`QwenImageTransformer2DModel`]. Specific to [`QwenImagePipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
return_lora_metadata (`bool`, *optional*, defaults to False):
|
||||
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
||||
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
|
||||
if has_alphas_in_sd:
|
||||
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
||||
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
||||
dict is loaded into `self.transformer`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->QwenImageTransformer2DModel
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
transformer (`QwenImageTransformer2DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
metadata (`dict`):
|
||||
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
||||
from the state dict.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata: Optional[dict] = None,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the transformer.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
||||
|
||||
@@ -61,6 +61,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
|
||||
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
@@ -319,7 +320,9 @@ class PeftAdapterMixin:
|
||||
# it to None
|
||||
incompatible_keys = None
|
||||
else:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
inject_adapter_in_model(
|
||||
lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
|
||||
)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
if self._prepare_lora_hotswap_kwargs is not None:
|
||||
|
||||
@@ -153,9 +153,17 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"QwenImageTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": lambda x: x,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
|
||||
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
|
||||
|
||||
|
||||
def _get_single_file_loadable_mapping_class(cls):
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
@@ -381,19 +389,23 @@ class FromOriginalModelMixin:
|
||||
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
||||
diffusers_model_config.update(model_kwargs)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model = cls.from_config(diffusers_model_config)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
|
||||
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
else:
|
||||
diffusers_format_checkpoint = checkpoint
|
||||
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||
|
||||
@@ -60,6 +60,7 @@ if is_accelerate_available():
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
CHECKPOINT_KEY_NAMES = {
|
||||
"v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
|
||||
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
|
||||
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
|
||||
|
||||
@@ -38,6 +38,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
|
||||
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
||||
@@ -88,6 +89,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
@@ -126,6 +128,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLQwenImage,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
@@ -177,6 +180,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
OmniGenTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
QwenImageTransformer2DModel,
|
||||
SanaTransformer2DModel,
|
||||
SD3Transformer2DModel,
|
||||
SkyReelsV2Transformer3DModel,
|
||||
|
||||
@@ -8,6 +8,7 @@ from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_kl_wan import AutoencoderKLWan
|
||||
from .autoencoder_oobleck import AutoencoderOobleck
|
||||
|
||||
@@ -90,7 +90,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
|
||||
shift_factor: Optional[float] = None,
|
||||
latents_mean: Optional[Tuple[float]] = None,
|
||||
latents_std: Optional[Tuple[float]] = None,
|
||||
force_upcast: float = True,
|
||||
force_upcast: bool = True,
|
||||
use_quant_conv: bool = True,
|
||||
use_post_quant_conv: bool = True,
|
||||
mid_block_add_attention: bool = True,
|
||||
|
||||
@@ -168,7 +168,9 @@ class CosmosPatchEmbed3d(nn.Module):
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p = self.patch_size
|
||||
|
||||
hidden_states = torch.reshape(batch_size, num_channels, num_frames // p, p, height // p, p, width // p, p)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, num_channels, num_frames // p, p, height // p, p, width // p, p
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4).contiguous()
|
||||
return hidden_states
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -913,38 +913,21 @@ def patchify(x, patch_size):
|
||||
if patch_size == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
# x shape: [batch_size, channels, height, width]
|
||||
batch_size, channels, height, width = x.shape
|
||||
|
||||
# Ensure height and width are divisible by patch_size
|
||||
if height % patch_size != 0 or width % patch_size != 0:
|
||||
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
|
||||
|
||||
# Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size]
|
||||
x = x.view(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)
|
||||
|
||||
# Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size]
|
||||
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
|
||||
x = x.view(batch_size, channels * patch_size * patch_size, height // patch_size, width // patch_size)
|
||||
|
||||
elif x.dim() == 5:
|
||||
# x shape: [batch_size, channels, frames, height, width]
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
|
||||
# Ensure height and width are divisible by patch_size
|
||||
if height % patch_size != 0 or width % patch_size != 0:
|
||||
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
|
||||
|
||||
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
|
||||
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
|
||||
|
||||
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
|
||||
x = x.permute(0, 1, 4, 6, 2, 3, 5).contiguous()
|
||||
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
|
||||
|
||||
else:
|
||||
if x.dim() != 5:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
# x shape: [batch_size, channels, frames, height, width]
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
|
||||
# Ensure height and width are divisible by patch_size
|
||||
if height % patch_size != 0 or width % patch_size != 0:
|
||||
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
|
||||
|
||||
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
|
||||
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
|
||||
|
||||
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
|
||||
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
|
||||
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
|
||||
|
||||
return x
|
||||
|
||||
@@ -953,29 +936,18 @@ def unpatchify(x, patch_size):
|
||||
if patch_size == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
# x shape: [b, (c * patch_size * patch_size), h, w]
|
||||
batch_size, c_patches, height, width = x.shape
|
||||
channels = c_patches // (patch_size * patch_size)
|
||||
if x.dim() != 5:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
|
||||
batch_size, c_patches, frames, height, width = x.shape
|
||||
channels = c_patches // (patch_size * patch_size)
|
||||
|
||||
# Reshape to [b, c, patch_size, patch_size, h, w]
|
||||
x = x.view(batch_size, channels, patch_size, patch_size, height, width)
|
||||
# Reshape to [b, c, patch_size, patch_size, f, h, w]
|
||||
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
|
||||
|
||||
# Rearrange to [b, c, h * patch_size, w * patch_size]
|
||||
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
|
||||
x = x.view(batch_size, channels, height * patch_size, width * patch_size)
|
||||
|
||||
elif x.dim() == 5:
|
||||
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
|
||||
batch_size, c_patches, frames, height, width = x.shape
|
||||
channels = c_patches // (patch_size * patch_size)
|
||||
|
||||
# Reshape to [b, c, patch_size, patch_size, f, h, w]
|
||||
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
|
||||
|
||||
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
|
||||
x = x.permute(0, 1, 4, 5, 2, 6, 3).contiguous()
|
||||
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
|
||||
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
|
||||
x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
|
||||
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
|
||||
|
||||
return x
|
||||
|
||||
@@ -1044,7 +1016,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
patch_size: Optional[int] = None,
|
||||
scale_factor_temporal: Optional[int] = 4,
|
||||
scale_factor_spatial: Optional[int] = 8,
|
||||
clip_output: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -1244,10 +1215,11 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
if self.config.clip_output:
|
||||
out = torch.clamp(out, min=-1.0, max=1.0)
|
||||
if self.config.patch_size is not None:
|
||||
out = unpatchify(out, patch_size=self.config.patch_size)
|
||||
|
||||
out = torch.clamp(out, min=-1.0, max=1.0)
|
||||
|
||||
self.clear_cache()
|
||||
if not return_dict:
|
||||
return (out,)
|
||||
|
||||
@@ -30,6 +30,7 @@ if is_torch_available():
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
from .transformer_qwenimage import QwenImageTransformer2DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
|
||||
@@ -384,7 +384,7 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
|
||||
@@ -0,0 +1,628 @@
|
||||
# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
Args
|
||||
timesteps (torch.Tensor):
|
||||
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
||||
embedding_dim (int):
|
||||
the dimension of the output.
|
||||
flip_sin_to_cos (bool):
|
||||
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
||||
downscale_freq_shift (float):
|
||||
Controls the delta between frequencies between dimensions
|
||||
scale (float):
|
||||
Scaling factor applied to the embeddings.
|
||||
max_period (int):
|
||||
Controls the maximum frequency of the embeddings
|
||||
Returns
|
||||
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent).to(timesteps.dtype)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def apply_rotary_emb_qwen(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Used for flux, cogvideox, hunyuan-dit
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
||||
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
else:
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(1)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timestep, hidden_states):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
|
||||
|
||||
conditioning = timesteps_emb
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
class QwenEmbedRope(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(1024)
|
||||
neg_index = torch.arange(1024).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat(
|
||||
[
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
self.neg_freqs = torch.cat(
|
||||
[
|
||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
self.rope_cache = {}
|
||||
|
||||
# 是否使用 scale rope
|
||||
self.scale_rope = scale_rope
|
||||
|
||||
def rope_params(self, index, dim, theta=10000):
|
||||
"""
|
||||
Args:
|
||||
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
"""
|
||||
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
||||
txt_length: [bs] a list of 1 integers representing the length of the text
|
||||
"""
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
frame, height, width = video_fhw
|
||||
rope_key = f"{frame}_{height}_{width}"
|
||||
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
||||
vid_freqs = self.rope_cache[rope_key]
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2)
|
||||
else:
|
||||
max_vid_index = max(height, width)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
|
||||
class QwenDoubleStreamAttnProcessor2_0:
|
||||
"""
|
||||
Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
|
||||
implements joint attention computation where text and image streams are processed together.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor, # Image stream
|
||||
encoder_hidden_states: torch.FloatTensor = None, # Text stream
|
||||
encoder_hidden_states_mask: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if encoder_hidden_states is None:
|
||||
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
|
||||
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
# Compute QKV for image stream (sample projections)
|
||||
img_query = attn.to_q(hidden_states)
|
||||
img_key = attn.to_k(hidden_states)
|
||||
img_value = attn.to_v(hidden_states)
|
||||
|
||||
# Compute QKV for text stream (context projections)
|
||||
txt_query = attn.add_q_proj(encoder_hidden_states)
|
||||
txt_key = attn.add_k_proj(encoder_hidden_states)
|
||||
txt_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
# Reshape for multi-head attention
|
||||
img_query = img_query.unflatten(-1, (attn.heads, -1))
|
||||
img_key = img_key.unflatten(-1, (attn.heads, -1))
|
||||
img_value = img_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
|
||||
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
|
||||
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
# Apply QK normalization
|
||||
if attn.norm_q is not None:
|
||||
img_query = attn.norm_q(img_query)
|
||||
if attn.norm_k is not None:
|
||||
img_key = attn.norm_k(img_key)
|
||||
if attn.norm_added_q is not None:
|
||||
txt_query = attn.norm_added_q(txt_query)
|
||||
if attn.norm_added_k is not None:
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
# Apply RoPE
|
||||
if image_rotary_emb is not None:
|
||||
img_freqs, txt_freqs = image_rotary_emb
|
||||
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
|
||||
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
|
||||
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
|
||||
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
|
||||
|
||||
# Concatenate for joint attention
|
||||
# Order: [text, image]
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
|
||||
# Compute joint attention
|
||||
joint_hidden_states = dispatch_attention_fn(
|
||||
joint_query,
|
||||
joint_key,
|
||||
joint_value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
joint_hidden_states = joint_hidden_states.flatten(2, 3)
|
||||
joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
|
||||
|
||||
# Split attention outputs back
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
|
||||
|
||||
# Apply output projections
|
||||
img_attn_output = attn.to_out[0](img_attn_output)
|
||||
if len(attn.to_out) > 1:
|
||||
img_attn_output = attn.to_out[1](img_attn_output) # dropout
|
||||
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output)
|
||||
|
||||
return img_attn_output, txt_attn_output
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class QwenImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
# Image processing modules
|
||||
self.img_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
|
||||
)
|
||||
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None, # Enable cross attention for joint computation
|
||||
added_kv_proj_dim=dim, # Enable added KV projections for text stream
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=QwenDoubleStreamAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
)
|
||||
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
# Text processing modules
|
||||
self.txt_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
|
||||
)
|
||||
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
# Text doesn't need separate attention - it's handled by img_attn joint computation
|
||||
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def _modulate(self, x, mod_params):
|
||||
"""Apply modulation to input tensor"""
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Get modulation parameters for both streams
|
||||
img_mod_params = self.img_mod(temb) # [B, 6*dim]
|
||||
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
|
||||
|
||||
# Split modulation parameters for norm1 and norm2
|
||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
||||
|
||||
# Process image stream - norm1 + modulation
|
||||
img_normed = self.img_norm1(hidden_states)
|
||||
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
||||
|
||||
# Process text stream - norm1 + modulation
|
||||
txt_normed = self.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
||||
|
||||
# Use QwenAttnProcessor2_0 for joint attention computation
|
||||
# This directly implements the DoubleStreamLayerMegatron logic:
|
||||
# 1. Computes QKV for both streams
|
||||
# 2. Applies QK normalization and RoPE
|
||||
# 3. Concatenates and runs joint attention
|
||||
# 4. Splits results back to separate streams
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
attn_output = self.attn(
|
||||
hidden_states=img_modulated, # Image stream (will be processed as "sample")
|
||||
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
|
||||
img_attn_output, txt_attn_output = attn_output
|
||||
|
||||
# Apply attention gates and add residual (like in Megatron)
|
||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
|
||||
# Process image stream - norm2 + MLP
|
||||
img_normed2 = self.img_norm2(hidden_states)
|
||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
||||
img_mlp_output = self.img_mlp(img_modulated2)
|
||||
hidden_states = hidden_states + img_gate2 * img_mlp_output
|
||||
|
||||
# Process text stream - norm2 + MLP
|
||||
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
||||
txt_mlp_output = self.txt_mlp(txt_modulated2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
||||
|
||||
# Clip to prevent overflow for fp16
|
||||
if encoder_hidden_states.dtype == torch.float16:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
"""
|
||||
The Transformer model introduced in Qwen.
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`):
|
||||
Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, defaults to `64`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `None`):
|
||||
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
||||
num_layers (`int`, defaults to `60`):
|
||||
The number of layers of dual stream DiT blocks to use.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of dimensions to use for each attention head.
|
||||
num_attention_heads (`int`, defaults to `24`):
|
||||
The number of attention heads to use.
|
||||
joint_attention_dim (`int`, defaults to `3584`):
|
||||
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
||||
`encoder_hidden_states`).
|
||||
guidance_embeds (`bool`, defaults to `False`):
|
||||
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
||||
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions to use for the rotary positional embeddings.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["QwenImageTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 64,
|
||||
out_channels: Optional[int] = 16,
|
||||
num_layers: int = 60,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 3584,
|
||||
guidance_embeds: bool = False, # TODO: this should probably be removed
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
||||
|
||||
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
||||
|
||||
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
|
||||
|
||||
self.img_in = nn.Linear(in_channels, self.inner_dim)
|
||||
self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
QwenImageTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
encoder_hidden_states_mask: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
||||
txt_seq_lens: Optional[List[int]] = None,
|
||||
guidance: torch.Tensor = None, # TODO: this should probably be removed
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`QwenTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
|
||||
Mask of the input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
temb = (
|
||||
self.time_text_embed(timestep, hidden_states)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
encoder_hidden_states_mask,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=attention_kwargs,
|
||||
)
|
||||
|
||||
# Use only the image part (hidden_states) from the dual-stream blocks
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -180,6 +180,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
cross_attention_dim_head: Optional[int] = None,
|
||||
processor=None,
|
||||
is_cross_attention=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -207,6 +208,8 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def fuse_projections(self):
|
||||
@@ -324,7 +327,7 @@ class WanTimeTextImageEmbedding(nn.Module):
|
||||
):
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
if timestep_seq_len is not None:
|
||||
timestep = timestep.unflatten(0, (1, timestep_seq_len))
|
||||
timestep = timestep.unflatten(0, (-1, timestep_seq_len))
|
||||
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
|
||||
@@ -25,7 +25,6 @@ else:
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"ModularPipelineBlocks",
|
||||
"ModularPipeline",
|
||||
"PipelineBlock",
|
||||
"AutoPipelineBlocks",
|
||||
"SequentialPipelineBlocks",
|
||||
"LoopSequentialPipelineBlocks",
|
||||
@@ -59,7 +58,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipeline,
|
||||
ModularPipelineBlocks,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
|
||||
@@ -13,15 +13,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...models import AutoencoderKL
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import FluxModularPipeline
|
||||
|
||||
@@ -103,6 +104,62 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Adapted from the original implementation.
|
||||
def prepare_latents_img2img(
|
||||
vae, scheduler, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator
|
||||
):
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
||||
latent_channels = vae.config.latent_channels
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] != latent_channels:
|
||||
image_latents = _encode_vae_image(image=image, generator=generator)
|
||||
else:
|
||||
image_latents = image
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
image_latents = torch.cat([image_latents], dim=0)
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = scheduler.scale_noise(image_latents, timestep, noise)
|
||||
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
return latents, latent_image_ids
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
@@ -125,7 +182,56 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FluxInputStep(PipelineBlock):
|
||||
# Cannot use "# Copied from" because it introduces weird indentation errors.
|
||||
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator)
|
||||
|
||||
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
|
||||
def _get_initial_timesteps_and_optionals(
|
||||
transformer,
|
||||
scheduler,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
vae_scale_factor,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
sigmas,
|
||||
device,
|
||||
):
|
||||
image_seq_len = (int(height) // vae_scale_factor // 2) * (int(width) // vae_scale_factor // 2)
|
||||
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
scheduler.config.get("base_image_seq_len", 256),
|
||||
scheduler.config.get("max_image_seq_len", 4096),
|
||||
scheduler.config.get("base_shift", 0.5),
|
||||
scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
|
||||
if transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
return timesteps, num_inference_steps, sigmas, guidance
|
||||
|
||||
|
||||
class FluxInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -143,11 +249,6 @@ class FluxInputStep(PipelineBlock):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
@@ -216,7 +317,7 @@ class FluxInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxSetTimestepsStep(PipelineBlock):
|
||||
class FluxSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -235,17 +336,15 @@ class FluxSetTimestepsStep(PipelineBlock):
|
||||
InputParam("sigmas"),
|
||||
InputParam("guidance_scale", default=3.5),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam(
|
||||
"latents",
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
)
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -264,39 +363,127 @@ class FluxSetTimestepsStep(PipelineBlock):
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
transformer = components.transformer
|
||||
|
||||
latents = block_state.latents
|
||||
image_seq_len = latents.shape[1]
|
||||
|
||||
num_inference_steps = block_state.num_inference_steps
|
||||
sigmas = block_state.sigmas
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
|
||||
transformer,
|
||||
scheduler,
|
||||
batch_size,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
components.vae_scale_factor,
|
||||
block_state.num_inference_steps,
|
||||
block_state.guidance_scale,
|
||||
block_state.sigmas,
|
||||
block_state.device,
|
||||
)
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
block_state.sigmas = sigmas
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
scheduler.config.get("base_image_seq_len", 256),
|
||||
scheduler.config.get("max_image_seq_len", 4096),
|
||||
scheduler.config.get("base_shift", 0.5),
|
||||
scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
scheduler, block_state.num_inference_steps, block_state.device, sigmas=block_state.sigmas, mu=mu
|
||||
)
|
||||
if components.transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
block_state.guidance = guidance
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxPrepareLatentsStep(PipelineBlock):
|
||||
class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the scheduler's timesteps for inference"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_inference_steps", default=50),
|
||||
InputParam("timesteps"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("strength", default=0.6),
|
||||
InputParam("guidance_scale", default=3.5),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||
OutputParam(
|
||||
"num_inference_steps",
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time",
|
||||
),
|
||||
OutputParam(
|
||||
"latent_timestep",
|
||||
type_hint=torch.Tensor,
|
||||
description="The timestep that represents the initial noise level for image-to-image generation",
|
||||
),
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps with self.scheduler->scheduler
|
||||
def get_timesteps(scheduler, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||
if hasattr(scheduler, "set_begin_index"):
|
||||
scheduler.set_begin_index(t_start * scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
transformer = components.transformer
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
|
||||
transformer,
|
||||
scheduler,
|
||||
batch_size,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
components.vae_scale_factor,
|
||||
block_state.num_inference_steps,
|
||||
block_state.guidance_scale,
|
||||
block_state.sigmas,
|
||||
block_state.device,
|
||||
)
|
||||
timesteps, num_inference_steps = self.get_timesteps(
|
||||
scheduler, num_inference_steps, block_state.strength, block_state.device
|
||||
)
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
block_state.sigmas = sigmas
|
||||
block_state.guidance = guidance
|
||||
|
||||
block_state.latent_timestep = timesteps[:1].repeat(batch_size)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -305,7 +492,7 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare latents step that prepares the latents for the text-to-video generation process"
|
||||
return "Prepare latents step that prepares the latents for the text-to-image generation process"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
@@ -314,11 +501,6 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -402,10 +584,10 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
block_state.latents, block_state.latent_image_ids = self.prepare_latents(
|
||||
components,
|
||||
block_state.batch_size * block_state.num_images_per_prompt,
|
||||
batch_size,
|
||||
block_state.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
@@ -418,3 +600,90 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the latents for the image-to-image generation process"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
|
||||
),
|
||||
InputParam(
|
||||
"latent_timestep",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||
),
|
||||
OutputParam(
|
||||
"latent_image_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="IDs computed from the image sequence needed for RoPE",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
block_state.device = components._execution_device
|
||||
|
||||
# TODO: implement `check_inputs`
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
if block_state.latents is None:
|
||||
block_state.latents, block_state.latent_image_ids = prepare_latents_img2img(
|
||||
components.vae,
|
||||
components.scheduler,
|
||||
block_state.image_latents,
|
||||
block_state.latent_timestep,
|
||||
batch_size,
|
||||
block_state.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL
|
||||
from ...utils import logging
|
||||
from ...video_processor import VaeImageProcessor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
return latents
|
||||
|
||||
|
||||
class FluxDecodeStep(PipelineBlock):
|
||||
class FluxDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -70,17 +70,12 @@ class FluxDecodeStep(PipelineBlock):
|
||||
InputParam("output_type", default="pil"),
|
||||
InputParam("height", default=1024),
|
||||
InputParam("width", default=1024),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -32,7 +32,7 @@ from .modular_pipeline import FluxModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FluxLoopDenoiser(PipelineBlock):
|
||||
class FluxLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -49,11 +49,8 @@ class FluxLoopDenoiser(PipelineBlock):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [InputParam("joint_attention_kwargs")]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -113,7 +110,7 @@ class FluxLoopDenoiser(PipelineBlock):
|
||||
return components, block_state
|
||||
|
||||
|
||||
class FluxLoopAfterDenoiser(PipelineBlock):
|
||||
class FluxLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -175,7 +172,7 @@ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
@@ -226,5 +223,5 @@ class FluxDenoiseStep(FluxDenoiseLoopWrapper):
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `FluxLoopDenoiser`\n"
|
||||
" - `FluxLoopAfterDenoiser`\n"
|
||||
"This block supports text2image tasks."
|
||||
"This block supports both text2image and img2img tasks."
|
||||
)
|
||||
|
||||
@@ -19,9 +19,12 @@ import regex as re
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL
|
||||
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import FluxModularPipeline
|
||||
|
||||
@@ -50,7 +53,110 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
class FluxTextEncoderStep(PipelineBlock):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class FluxVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae Encoder step that encode the input image into a latent representation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam(
|
||||
"preprocess_kwargs",
|
||||
type_hint=Optional[dict],
|
||||
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation",
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae
|
||||
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator)
|
||||
|
||||
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
|
||||
)
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
|
||||
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
||||
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components.vae, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -297,7 +403,7 @@ class FluxTextEncoderStep(PipelineBlock):
|
||||
prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
device=block_state.device,
|
||||
num_images_per_prompt=1, # hardcoded for now.
|
||||
num_images_per_prompt=1, # TODO: hardcoded for now.
|
||||
lora_scale=block_state.text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
|
||||
@@ -15,16 +15,38 @@
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import FluxInputStep, FluxPrepareLatentsStep, FluxSetTimestepsStep
|
||||
from .before_denoise import (
|
||||
FluxImg2ImgPrepareLatentsStep,
|
||||
FluxImg2ImgSetTimestepsStep,
|
||||
FluxInputStep,
|
||||
FluxPrepareLatentsStep,
|
||||
FluxSetTimestepsStep,
|
||||
)
|
||||
from .decoders import FluxDecodeStep
|
||||
from .denoise import FluxDenoiseStep
|
||||
from .encoders import FluxTextEncoderStep
|
||||
from .encoders import FluxTextEncoderStep, FluxVaeEncoderStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# before_denoise: text2vid
|
||||
# vae encoder (run before before_denoise)
|
||||
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxVaeEncoderStep]
|
||||
block_names = ["img2img"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block that works for img2img tasks.\n"
|
||||
+ " - `FluxVaeEncoderStep` (img2img) is used when only `image` is provided."
|
||||
+ " - if `image` is provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: text2img, img2img
|
||||
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
FluxInputStep,
|
||||
@@ -44,11 +66,27 @@ class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: all task (text2vid,)
|
||||
# before_denoise: img2img
|
||||
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [FluxInputStep, FluxImg2ImgSetTimestepsStep, FluxImg2ImgPrepareLatentsStep]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `FluxImg2ImgSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `FluxImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: all task (text2img, img2img)
|
||||
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxBeforeDenoiseStep]
|
||||
block_names = ["text2image"]
|
||||
block_trigger_inputs = [None]
|
||||
block_classes = [FluxBeforeDenoiseStep, FluxImg2ImgBeforeDenoiseStep]
|
||||
block_names = ["text2image", "img2img"]
|
||||
block_trigger_inputs = [None, "image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -56,6 +94,7 @@ class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for text2image.\n"
|
||||
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
|
||||
+ " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
|
||||
)
|
||||
|
||||
|
||||
@@ -69,8 +108,8 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2image tasks."
|
||||
" - `FluxDenoiseStep` (denoise) for text2image tasks."
|
||||
"This is a auto pipeline block that works for text2image and img2img tasks."
|
||||
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
|
||||
)
|
||||
|
||||
|
||||
@@ -82,19 +121,26 @@ class FluxAutoDecodeStep(AutoPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decode the denoised latents into videos outputs.\n - `FluxDecodeStep`"
|
||||
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
|
||||
|
||||
|
||||
# text2image
|
||||
class FluxAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [FluxTextEncoderStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep, FluxAutoDecodeStep]
|
||||
block_names = ["text_encoder", "before_denoise", "denoise", "decoder"]
|
||||
block_classes = [
|
||||
FluxTextEncoderStep,
|
||||
FluxAutoVaeEncoderStep,
|
||||
FluxAutoBeforeDenoiseStep,
|
||||
FluxAutoDenoiseStep,
|
||||
FluxAutoDecodeStep,
|
||||
]
|
||||
block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image using Flux.\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
"Auto Modular pipeline for text-to-image and image-to-image using Flux.\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`\n"
|
||||
+ "- for image-to-image generation, you need to provide either `image` or `image_latents`"
|
||||
)
|
||||
|
||||
|
||||
@@ -102,19 +148,29 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("input", FluxInputStep),
|
||||
("prepare_latents", FluxPrepareLatentsStep),
|
||||
# Setting it after preparation of latents because we rely on `latents`
|
||||
# to calculate `img_seq_len` for `shift`.
|
||||
("set_timesteps", FluxSetTimestepsStep),
|
||||
("prepare_latents", FluxPrepareLatentsStep),
|
||||
("denoise", FluxDenoiseStep),
|
||||
("decode", FluxDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("image_encoder", FluxVaeEncoderStep),
|
||||
("input", FluxInputStep),
|
||||
("set_timesteps", FluxImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", FluxImg2ImgPrepareLatentsStep),
|
||||
("denoise", FluxDenoiseStep),
|
||||
("decode", FluxDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("image_encoder", FluxAutoVaeEncoderStep),
|
||||
("before_denoise", FluxAutoBeforeDenoiseStep),
|
||||
("denoise", FluxAutoDenoiseStep),
|
||||
("decode", FluxAutoDecodeStep),
|
||||
@@ -122,4 +178,4 @@ AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
|
||||
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...loaders import FluxLoraLoaderMixin
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
|
||||
@@ -21,7 +21,7 @@ from ..modular_pipeline import ModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin):
|
||||
class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversionLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for Flux.
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -618,7 +618,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
|
||||
|
||||
def make_doc_string(
|
||||
inputs,
|
||||
intermediate_inputs,
|
||||
outputs,
|
||||
description="",
|
||||
class_name=None,
|
||||
@@ -664,7 +663,7 @@ def make_doc_string(
|
||||
output += configs_str + "\n\n"
|
||||
|
||||
# Add inputs section
|
||||
output += format_input_params(inputs + intermediate_inputs, indent_level=2)
|
||||
output += format_input_params(inputs, indent_level=2)
|
||||
|
||||
# Add outputs section
|
||||
output += "\n\n"
|
||||
|
||||
@@ -27,7 +27,7 @@ from ...schedulers import EulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
@@ -195,7 +195,7 @@ def prepare_latents_img2img(
|
||||
return latents
|
||||
|
||||
|
||||
class StableDiffusionXLInputStep(PipelineBlock):
|
||||
class StableDiffusionXLInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -213,11 +213,6 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
@@ -394,7 +389,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -421,11 +416,6 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
InputParam("denoising_start"),
|
||||
# YiYi TODO: do we need num_images_per_prompt here?
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
@@ -543,7 +533,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -611,7 +601,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -640,11 +630,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
"`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of "
|
||||
"`denoising_start` being declared as an integer, the value of `strength` will be ignored.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -744,8 +729,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
timestep=None,
|
||||
is_strength_max=True,
|
||||
add_noise=True,
|
||||
return_noise=False,
|
||||
return_image_latents=False,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
@@ -768,7 +751,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
if image.shape[1] == 4:
|
||||
image_latents = image.to(device=device, dtype=dtype)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
elif return_image_latents or (latents is None and not is_strength_max):
|
||||
elif latents is None and not is_strength_max:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(components, image=image, generator=generator)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
@@ -786,13 +769,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = image_latents.to(device)
|
||||
|
||||
outputs = (latents,)
|
||||
|
||||
if return_noise:
|
||||
outputs += (noise,)
|
||||
|
||||
if return_image_latents:
|
||||
outputs += (image_latents,)
|
||||
outputs = (latents, noise, image_latents)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -864,7 +841,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
|
||||
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
|
||||
|
||||
block_state.latents, block_state.noise = self.prepare_latents_inpaint(
|
||||
block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint(
|
||||
components,
|
||||
block_state.batch_size * block_state.num_images_per_prompt,
|
||||
components.num_channels_latents,
|
||||
@@ -878,8 +855,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
timestep=block_state.latent_timestep,
|
||||
is_strength_max=block_state.is_strength_max,
|
||||
add_noise=block_state.add_noise,
|
||||
return_noise=True,
|
||||
return_image_latents=False,
|
||||
)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
@@ -900,7 +875,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -920,11 +895,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("latents"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("denoising_start"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"latent_timestep",
|
||||
@@ -981,7 +951,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1002,11 +972,6 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("width"),
|
||||
InputParam("latents"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -1092,7 +1057,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1129,11 +1094,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("aesthetic_score", default=6.0),
|
||||
InputParam("negative_aesthetic_score", default=2.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1316,7 +1276,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1345,11 +1305,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("negative_crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1499,7 +1454,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1527,11 +1482,6 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1718,7 +1668,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1747,11 +1697,6 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...models import AutoencoderKL
|
||||
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -33,7 +33,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -56,17 +56,12 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -157,7 +152,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -184,11 +179,6 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
InputParam("image"),
|
||||
InputParam("mask_image"),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# YiYi experimenting composible denoise loop
|
||||
# loop step (1): prepare latent input for denoiser
|
||||
class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -73,7 +73,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (1): prepare latent input for denoiser (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -91,7 +91,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -144,7 +144,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance
|
||||
class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -171,11 +171,6 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("cross_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
@@ -249,7 +244,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance (with controlnet)
|
||||
class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -277,11 +272,6 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("cross_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"controlnet_cond",
|
||||
required=True,
|
||||
@@ -449,7 +439,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents
|
||||
class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -470,11 +460,6 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@@ -520,7 +505,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -542,11 +527,6 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
@@ -660,7 +640,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
|
||||
@@ -35,7 +35,7 @@ from ...utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||
|
||||
@@ -57,7 +57,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -215,7 +215,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -576,7 +576,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -601,11 +601,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam(
|
||||
@@ -668,12 +663,11 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
image = components.image_processor.preprocess(
|
||||
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
|
||||
)
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
image = image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
block_state.batch_size = image.shape[0]
|
||||
|
||||
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
|
||||
@@ -682,16 +676,14 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -726,11 +718,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
InputParam("image", required=True),
|
||||
InputParam("mask_image", required=True),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
@@ -860,34 +847,32 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
block_state.crops_coords = None
|
||||
block_state.resize_mode = "default"
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
image = components.image_processor.preprocess(
|
||||
block_state.image,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
crops_coords=block_state.crops_coords,
|
||||
resize_mode=block_state.resize_mode,
|
||||
)
|
||||
block_state.image = block_state.image.to(dtype=torch.float32)
|
||||
image = image.to(dtype=torch.float32)
|
||||
|
||||
block_state.mask = components.mask_processor.preprocess(
|
||||
mask = components.mask_processor.preprocess(
|
||||
block_state.mask_image,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
resize_mode=block_state.resize_mode,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
|
||||
block_state.masked_image = image * (mask < 0.5)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
block_state.batch_size = image.shape[0]
|
||||
image = image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
|
||||
components,
|
||||
block_state.mask,
|
||||
mask,
|
||||
block_state.masked_image,
|
||||
block_state.batch_size,
|
||||
block_state.height,
|
||||
|
||||
@@ -247,10 +247,6 @@ SDXL_INPUTS_SCHEMA = {
|
||||
"control_mode": InputParam(
|
||||
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"prompt_embeds": InputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
@@ -271,13 +267,6 @@ SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"preprocess_kwargs": InputParam(
|
||||
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
|
||||
),
|
||||
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
|
||||
),
|
||||
"latent_timestep": InputParam(
|
||||
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
|
||||
),
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from ...schedulers import UniPCMultistepScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
|
||||
@@ -94,7 +94,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class WanInputStep(PipelineBlock):
|
||||
class WanInputStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -194,7 +194,7 @@ class WanInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanSetTimestepsStep(PipelineBlock):
|
||||
class WanSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -243,7 +243,7 @@ class WanSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareLatentsStep(PipelineBlock):
|
||||
class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,14 +22,14 @@ from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKLWan
|
||||
from ...utils import logging
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanDecodeStep(PipelineBlock):
|
||||
class WanDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -34,7 +34,7 @@ from .modular_pipeline import WanModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanLoopDenoiser(PipelineBlock):
|
||||
class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -132,7 +132,7 @@ class WanLoopDenoiser(PipelineBlock):
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanLoopAfterDenoiser(PipelineBlock):
|
||||
class WanLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...utils import is_ftfy_available, logging
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
|
||||
@@ -51,7 +51,7 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
class WanTextEncoderStep(PipelineBlock):
|
||||
class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -387,6 +387,7 @@ else:
|
||||
"SkyReelsV2ImageToVideoPipeline",
|
||||
"SkyReelsV2Pipeline",
|
||||
]
|
||||
_import_structure["qwenimage"] = ["QwenImagePipeline"]
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -703,6 +704,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .pia import PIAPipeline
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .qwenimage import QwenImagePipeline
|
||||
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
|
||||
@@ -312,15 +312,14 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
The sequence of generated hidden-states.
|
||||
"""
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.0.dev0"):
|
||||
if is_transformers_version("<", "4.52.1"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
|
||||
|
||||
@@ -310,7 +310,7 @@ class FluxPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -324,7 +324,7 @@ class FluxControlPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -335,7 +335,7 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -374,7 +374,7 @@ class FluxControlInpaintPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -341,7 +341,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -335,7 +335,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -346,7 +346,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -419,7 +419,7 @@ class FluxFillPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -333,7 +333,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -337,7 +337,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -358,7 +358,7 @@ class FluxKontextPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -391,7 +391,7 @@ class FluxKontextInpaintPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -292,7 +292,7 @@ class FluxPriorReduxPipeline(DiffusionPipeline):
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["QwenImagePipelineOutput", "QwenImagePriorReduxPipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["modeling_qwenimage"] = ["ReduxImageEncoder"]
|
||||
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_qwenimage import QwenImagePipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class QwenImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Stable Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -0,0 +1,735 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import QwenImageLoraLoaderMixin
|
||||
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import QwenImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import QwenImagePipeline
|
||||
|
||||
>>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
||||
>>> # Refer to the pipeline documentation for more details.
|
||||
>>> image = pipe(prompt, num_inference_steps=50).images[0]
|
||||
>>> image.save("qwenimage.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
r"""
|
||||
The QwenImage pipeline for text-to-image generation.
|
||||
|
||||
Args:
|
||||
transformer ([`QwenImageTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
||||
tokenizer (`QwenTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLQwenImage,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
transformer: QwenImageTransformer2DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.tokenizer_max_length = 1024
|
||||
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.prompt_template_encode_start_idx = 34
|
||||
self.default_sample_size = 128
|
||||
|
||||
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||
bool_mask = mask.bool()
|
||||
valid_lengths = bool_mask.sum(dim=1)
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
|
||||
return split_result
|
||||
|
||||
def _get_qwen_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
template = self.prompt_template_encode
|
||||
drop_idx = self.prompt_template_encode_start_idx
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = self.tokenizer(
|
||||
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(device)
|
||||
encoder_hidden_states = self.text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
hidden_states = encoder_hidden_states.hidden_states[-1]
|
||||
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
||||
)
|
||||
encoder_attention_mask = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 1024,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
negative_prompt_embeds_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 1024:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
||||
|
||||
@staticmethod
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||
|
||||
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
|
||||
|
||||
return latents
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, 1, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
true_cfg_scale: float = 4.0,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 1.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
||||
not greater than `1`).
|
||||
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
||||
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
|
||||
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
has_neg_prompt = negative_prompt is not None or (
|
||||
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
||||
)
|
||||
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
||||
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_true_cfg:
|
||||
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# handle guidance
|
||||
if self.transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
||||
noise_pred = comb_pred * (cond_norm / noise_norm)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return QwenImagePipelineOutput(images=image)
|
||||
@@ -1034,7 +1034,8 @@ class StableDiffusionPipeline(
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
if hasattr(self.scheduler, "scale_model_input"):
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
|
||||
@@ -125,15 +125,15 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer_2"]
|
||||
_optional_components = ["transformer", "transformer_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer: Optional[WanTransformer3DModel] = None,
|
||||
transformer_2: Optional[WanTransformer3DModel] = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
expand_timesteps: bool = False, # Wan2.2 ti2v
|
||||
@@ -526,7 +526,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
device=device,
|
||||
)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
@@ -536,7 +536,11 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
num_channels_latents = (
|
||||
self.transformer.config.in_channels
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.in_channels
|
||||
)
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -162,17 +162,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer_2", "image_encoder", "image_processor"]
|
||||
_optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
image_processor: CLIPImageProcessor = None,
|
||||
image_encoder: CLIPVisionModel = None,
|
||||
transformer: WanTransformer3DModel = None,
|
||||
transformer_2: WanTransformer3DModel = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
expand_timesteps: bool = False,
|
||||
@@ -669,12 +669,13 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
)
|
||||
|
||||
# Encode image embedding
|
||||
transformer_dtype = self.transformer.dtype
|
||||
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
if self.config.boundary_ratio is None and not self.config.expand_timesteps:
|
||||
# only wan 2.1 i2v transformer accepts image_embeds
|
||||
if self.transformer is not None and self.transformer.config.image_dim is not None:
|
||||
if image_embeds is None:
|
||||
if last_image is None:
|
||||
image_embeds = self.encode_image(image, device)
|
||||
@@ -709,6 +710,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
last_image,
|
||||
)
|
||||
if self.config.expand_timesteps:
|
||||
# wan 2.2 5b i2v use firt_frame_mask to mask timesteps
|
||||
latents, condition, first_frame_mask = latents_outputs
|
||||
else:
|
||||
latents, condition = latents_outputs
|
||||
|
||||
@@ -12,15 +12,15 @@
|
||||
# # See the License for the specific language governing permissions and
|
||||
# # limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...utils import is_accelerate_available
|
||||
from ...utils import is_accelerate_available, is_kernels_available
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -29,6 +29,82 @@ if is_accelerate_available():
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
|
||||
|
||||
can_use_cuda_kernels = (
|
||||
os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 7
|
||||
)
|
||||
if can_use_cuda_kernels and is_kernels_available():
|
||||
from kernels import get_kernel
|
||||
|
||||
ops = get_kernel("Isotr0py/ggml")
|
||||
else:
|
||||
ops = None
|
||||
|
||||
UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
|
||||
STANDARD_QUANT_TYPES = {
|
||||
gguf.GGMLQuantizationType.Q4_0,
|
||||
gguf.GGMLQuantizationType.Q4_1,
|
||||
gguf.GGMLQuantizationType.Q5_0,
|
||||
gguf.GGMLQuantizationType.Q5_1,
|
||||
gguf.GGMLQuantizationType.Q8_0,
|
||||
gguf.GGMLQuantizationType.Q8_1,
|
||||
}
|
||||
KQUANT_TYPES = {
|
||||
gguf.GGMLQuantizationType.Q2_K,
|
||||
gguf.GGMLQuantizationType.Q3_K,
|
||||
gguf.GGMLQuantizationType.Q4_K,
|
||||
gguf.GGMLQuantizationType.Q5_K,
|
||||
gguf.GGMLQuantizationType.Q6_K,
|
||||
}
|
||||
IMATRIX_QUANT_TYPES = {
|
||||
gguf.GGMLQuantizationType.IQ1_M,
|
||||
gguf.GGMLQuantizationType.IQ1_S,
|
||||
gguf.GGMLQuantizationType.IQ2_XXS,
|
||||
gguf.GGMLQuantizationType.IQ2_XS,
|
||||
gguf.GGMLQuantizationType.IQ2_S,
|
||||
gguf.GGMLQuantizationType.IQ3_XXS,
|
||||
gguf.GGMLQuantizationType.IQ3_S,
|
||||
gguf.GGMLQuantizationType.IQ4_XS,
|
||||
gguf.GGMLQuantizationType.IQ4_NL,
|
||||
}
|
||||
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
|
||||
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
|
||||
# MMQ kernel for I-Matrix quantization.
|
||||
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
||||
|
||||
|
||||
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
|
||||
# there is no need to call any kernel for fp16/bf16
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return x @ qweight.T
|
||||
|
||||
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
|
||||
# contiguous batching and inefficient with diffusers' batching,
|
||||
# so we disabled it now.
|
||||
|
||||
# elif qweight_type in MMVQ_QUANT_TYPES:
|
||||
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# elif qweight_type in MMQ_QUANT_TYPES:
|
||||
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
|
||||
# If there is no available MMQ kernel, fallback to dequantize
|
||||
if qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
||||
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
|
||||
y = x @ weight.to(x.dtype).T
|
||||
else:
|
||||
# Raise an error if the quantization type is not supported.
|
||||
# Might be useful if llama.cpp adds a new quantization type.
|
||||
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
|
||||
qweight_type = gguf.GGMLQuantizationType(qweight_type)
|
||||
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
|
||||
return y.as_tensor()
|
||||
|
||||
|
||||
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
|
||||
def _create_accelerate_new_hook(old_hook):
|
||||
r"""
|
||||
@@ -451,11 +527,24 @@ class GGUFLinear(nn.Linear):
|
||||
) -> None:
|
||||
super().__init__(in_features, out_features, bias, device)
|
||||
self.compute_dtype = compute_dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
|
||||
return self.forward_cuda(inputs)
|
||||
return self.forward_native(inputs)
|
||||
|
||||
def forward_native(self, inputs: torch.Tensor):
|
||||
weight = dequantize_gguf_tensor(self.weight)
|
||||
weight = weight.to(self.compute_dtype)
|
||||
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
|
||||
|
||||
output = torch.nn.functional.linear(inputs, weight, bias)
|
||||
return output
|
||||
|
||||
def forward_cuda(self, inputs: torch.Tensor):
|
||||
quant_type = self.weight.quant_type
|
||||
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
|
||||
if self.bias is not None:
|
||||
output += self.bias.to(self.compute_dtype)
|
||||
return output
|
||||
|
||||
@@ -81,6 +81,8 @@ from .import_utils import (
|
||||
is_invisible_watermark_available,
|
||||
is_k_diffusion_available,
|
||||
is_k_diffusion_version,
|
||||
is_kernels_available,
|
||||
is_kornia_available,
|
||||
is_librosa_available,
|
||||
is_matplotlib_available,
|
||||
is_nltk_available,
|
||||
|
||||
@@ -62,6 +62,21 @@ class ClassifierFreeZeroStarGuidance(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FrequencyDecoupledGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PerturbedAttentionGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -423,6 +438,21 @@ class AutoencoderKLMochi(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLQwenImage(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1038,6 +1068,21 @@ class PriorTransformer(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class QwenImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SanaControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1742,6 +1742,21 @@ class PixArtSigmaPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ReduxImageEncoder(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -192,6 +192,7 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
|
||||
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
|
||||
_transformers_available, _transformers_version = _is_package_available("transformers")
|
||||
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
|
||||
_kernels_available, _kernels_version = _is_package_available("kernels")
|
||||
_inflect_available, _inflect_version = _is_package_available("inflect")
|
||||
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
|
||||
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
|
||||
@@ -223,6 +224,7 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("
|
||||
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
|
||||
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
|
||||
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
|
||||
_kornia_available, _kornia_version = _is_package_available("kornia")
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -277,6 +279,10 @@ def is_accelerate_available():
|
||||
return _accelerate_available
|
||||
|
||||
|
||||
def is_kernels_available():
|
||||
return _kernels_available
|
||||
|
||||
|
||||
def is_k_diffusion_available():
|
||||
return _k_diffusion_available
|
||||
|
||||
@@ -393,6 +399,10 @@ def is_flash_attn_3_available():
|
||||
return _flash_attn_3_available
|
||||
|
||||
|
||||
def is_kornia_available():
|
||||
return _kornia_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
|
||||
@@ -197,20 +197,6 @@ def get_peft_kwargs(
|
||||
"lora_bias": lora_bias,
|
||||
}
|
||||
|
||||
# Example: try load FusionX LoRA into Wan VACE
|
||||
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
|
||||
if exclude_modules:
|
||||
if not is_peft_version(">=", "0.14.0"):
|
||||
msg = """
|
||||
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
|
||||
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
|
||||
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
|
||||
https://github.com/huggingface/diffusers/issues/new
|
||||
"""
|
||||
logger.debug(msg)
|
||||
else:
|
||||
lora_config_kwargs.update({"exclude_modules": exclude_modules})
|
||||
|
||||
return lora_config_kwargs
|
||||
|
||||
|
||||
@@ -388,27 +374,3 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
|
||||
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
|
||||
"""
|
||||
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
|
||||
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
|
||||
doesn't exist in `peft_state_dict`.
|
||||
"""
|
||||
if model_state_dict is None:
|
||||
return
|
||||
all_modules = set()
|
||||
string_to_replace = f"{adapter_name}." if adapter_name else ""
|
||||
|
||||
for name in model_state_dict.keys():
|
||||
if string_to_replace:
|
||||
name = name.replace(string_to_replace, "")
|
||||
if "." in name:
|
||||
module_name = name.rsplit(".", 1)[0]
|
||||
all_modules.add(module_name)
|
||||
|
||||
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
|
||||
exclude_modules = list(all_modules - target_modules_set)
|
||||
|
||||
return exclude_modules
|
||||
|
||||
@@ -36,6 +36,7 @@ from .import_utils import (
|
||||
is_compel_available,
|
||||
is_flax_available,
|
||||
is_gguf_available,
|
||||
is_kernels_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
@@ -634,6 +635,18 @@ def require_torchao_version_greater_or_equal(torchao_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_kernels_version_greater_or_equal(kernels_version):
|
||||
def decorator(test_case):
|
||||
correct_kernels_version = is_kernels_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("kernels")).base_version
|
||||
) >= version.parse(kernels_version)
|
||||
return unittest.skipUnless(
|
||||
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
|
||||
@@ -17,7 +17,9 @@ import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers.hooks import HookRegistry, ModelHook
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.utils import get_logger
|
||||
@@ -99,6 +101,29 @@ class DummyModelWithMultipleBlocks(ModelMixin):
|
||||
return x
|
||||
|
||||
|
||||
# Test for https://github.com/huggingface/diffusers/pull/12077
|
||||
class DummyModelWithLayerNorm(ModelMixin):
|
||||
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
|
||||
self.activation = torch.nn.ReLU()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
|
||||
)
|
||||
self.layer_norm = torch.nn.LayerNorm(hidden_features, elementwise_affine=True)
|
||||
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear_1(x)
|
||||
x = self.activation(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.layer_norm(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class DummyPipeline(DiffusionPipeline):
|
||||
model_cpu_offload_seq = "model"
|
||||
|
||||
@@ -113,6 +138,16 @@ class DummyPipeline(DiffusionPipeline):
|
||||
return x
|
||||
|
||||
|
||||
class LayerOutputTrackerHook(ModelHook):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.outputs = []
|
||||
|
||||
def post_forward(self, module, output):
|
||||
self.outputs.append(output)
|
||||
return output
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class GroupOffloadTests(unittest.TestCase):
|
||||
in_features = 64
|
||||
@@ -258,6 +293,7 @@ class GroupOffloadTests(unittest.TestCase):
|
||||
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
|
||||
model = DummyModelWithMultipleBlocks(
|
||||
in_features=self.in_features,
|
||||
hidden_features=self.hidden_features,
|
||||
@@ -274,3 +310,54 @@ class GroupOffloadTests(unittest.TestCase):
|
||||
|
||||
with context:
|
||||
model(self.input)
|
||||
|
||||
@parameterized.expand([("block_level",), ("leaf_level",)])
|
||||
def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
|
||||
def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
|
||||
for name, module in model.named_modules():
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = LayerOutputTrackerHook()
|
||||
registry.register_hook(hook, "layer_output_tracker")
|
||||
|
||||
model_ref = DummyModelWithLayerNorm(128, 256, 128, 2)
|
||||
model = DummyModelWithLayerNorm(128, 256, 128, 2)
|
||||
|
||||
model.load_state_dict(model_ref.state_dict(), strict=True)
|
||||
|
||||
model_ref.to(torch_device)
|
||||
model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
|
||||
|
||||
apply_layer_output_tracker_hook(model_ref)
|
||||
apply_layer_output_tracker_hook(model)
|
||||
|
||||
x = torch.randn(2, 128).to(torch_device)
|
||||
|
||||
out_ref = model_ref(x)
|
||||
out = model(x)
|
||||
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
|
||||
|
||||
num_repeats = 4
|
||||
for i in range(num_repeats):
|
||||
out_ref = model_ref(x)
|
||||
out = model(x)
|
||||
|
||||
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.")
|
||||
|
||||
for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
|
||||
assert ref_name == name
|
||||
ref_outputs = (
|
||||
HookRegistry.check_if_exists_or_initialize(ref_module).get_hook("layer_output_tracker").outputs
|
||||
)
|
||||
outputs = HookRegistry.check_if_exists_or_initialize(module).get_hook("layer_output_tracker").outputs
|
||||
cumulated_absmax = 0.0
|
||||
for i in range(len(outputs)):
|
||||
diff = ref_outputs[0] - outputs[i]
|
||||
absdiff = diff.abs()
|
||||
absmax = absdiff.max().item()
|
||||
cumulated_absmax += absmax
|
||||
self.assertLess(
|
||||
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLQwenImage,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
QwenImagePipeline,
|
||||
QwenImageTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = QwenImagePipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
transformer_cls = QwenImageTransformer2DModel
|
||||
z_dim = 4
|
||||
vae_kwargs = {
|
||||
"base_dim": z_dim * 6,
|
||||
"z_dim": z_dim,
|
||||
"dim_mult": [1, 2, 4],
|
||||
"num_res_blocks": 1,
|
||||
"temperal_downsample": [False, True],
|
||||
"latents_mean": [0.0] * 4,
|
||||
"latents_std": [1.0] * 4,
|
||||
}
|
||||
vae_cls = AutoencoderKLQwenImage
|
||||
tokenizer_cls, tokenizer_id = Qwen2Tokenizer, "hf-internal-testing/tiny-random-Qwen25VLForCondGen"
|
||||
text_encoder_cls, text_encoder_id = (
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
"hf-internal-testing/tiny-random-Qwen25VLForCondGen",
|
||||
)
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 10
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes)
|
||||
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"num_inference_steps": 4,
|
||||
"guidance_scale": 0.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"output_type": "np",
|
||||
}
|
||||
if with_generator:
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
+4
-68
@@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
@@ -292,20 +291,6 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
return modules_to_save
|
||||
|
||||
def _get_exclude_modules(self, pipe):
|
||||
from diffusers.utils.peft_utils import _derive_exclude_modules
|
||||
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
|
||||
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
|
||||
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
|
||||
pipe.unload_lora_weights()
|
||||
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
|
||||
exclude_modules = _derive_exclude_modules(
|
||||
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
|
||||
)
|
||||
return exclude_modules
|
||||
|
||||
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
|
||||
if text_lora_config is not None:
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -2342,58 +2327,6 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_exclude_modules(self):
|
||||
"""
|
||||
Test to check if `exclude_modules` works or not. It works in the following way:
|
||||
we first create a pipeline and insert LoRA config into it. We then derive a `set`
|
||||
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
|
||||
state dict.
|
||||
|
||||
We then create a new LoRA config to include the `exclude_modules` and perform tests.
|
||||
"""
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
# only supported for `denoiser` now
|
||||
pipe_cp = copy.deepcopy(pipe)
|
||||
pipe_cp, _ = self.add_adapters_to_pipeline(
|
||||
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
|
||||
pipe_cp.to("cpu")
|
||||
del pipe_cp
|
||||
|
||||
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should change outputs.",
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Lora outputs should match.",
|
||||
)
|
||||
|
||||
def test_inference_load_delete_load_adapters(self):
|
||||
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
@@ -2467,7 +2400,6 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
|
||||
@@ -2483,6 +2415,10 @@ class PeftLoraLoaderMixinTests:
|
||||
num_blocks_per_group=1,
|
||||
use_stream=use_stream,
|
||||
)
|
||||
# Place other model-level components on `torch_device`.
|
||||
for _, component in pipe.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
component.to(torch_device)
|
||||
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
|
||||
self.assertTrue(group_offload_hook_1 is not None)
|
||||
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, is_peft_available, torch_device
|
||||
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
@@ -172,6 +172,35 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
expected_set = {"FluxTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
# The test exists for cases like
|
||||
# https://github.com/huggingface/diffusers/issues/11874
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_exclude_modules(self):
|
||||
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
lora_rank = 4
|
||||
target_module = "single_transformer_blocks.0.proj_out"
|
||||
adapter_name = "foo"
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
target_mod_shape = state_dict[f"{target_module}.weight"].shape
|
||||
lora_state_dict = {
|
||||
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
|
||||
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
|
||||
}
|
||||
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
|
||||
config = LoraConfig(
|
||||
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
|
||||
)
|
||||
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
|
||||
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
|
||||
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
|
||||
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
|
||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
|
||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
|
||||
|
||||
|
||||
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
|
||||
+462
@@ -0,0 +1,462 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import (
|
||||
ClassifierFreeGuidance,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
)
|
||||
from diffusers.loaders import ModularIPAdapterMixin
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...models.unets.test_models_unet_2d_condition import (
|
||||
create_ip_adapter_state_dict,
|
||||
)
|
||||
from ..test_modular_pipelines_common import (
|
||||
ModularPipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SDXLModularTests:
|
||||
"""
|
||||
This mixin defines method to create pipeline, base input and base test across all SDXL modular tests.
|
||||
"""
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"negative_prompt",
|
||||
"cross_attention_kwargs",
|
||||
"image",
|
||||
"mask_image",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
|
||||
pipeline.load_default_components(torch_dtype=torch_dtype)
|
||||
return pipeline
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
sd_pipe = self.get_pipeline()
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs, output="images")
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == expected_image_shape
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, (
|
||||
"Image Slice does not match expected slice"
|
||||
)
|
||||
|
||||
|
||||
class SDXLModularIPAdapterTests:
|
||||
"""
|
||||
This mixin is designed to test IP Adapter.
|
||||
"""
|
||||
|
||||
def test_pipeline_inputs_and_blocks(self):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
parameters = blocks.input_names
|
||||
|
||||
assert issubclass(self.pipeline_class, ModularIPAdapterMixin)
|
||||
assert "ip_adapter_image" in parameters, (
|
||||
"`ip_adapter_image` argument must be supported by the `__call__` method"
|
||||
)
|
||||
assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block"
|
||||
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
parameters = blocks.input_names
|
||||
assert "ip_adapter_image" not in parameters, (
|
||||
"`ip_adapter_image` argument must be removed from the `__call__` method"
|
||||
)
|
||||
|
||||
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_masks(self, input_size: int = 64):
|
||||
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
|
||||
_masks[0, :, :, : int(input_size / 2)] = 1
|
||||
return _masks
|
||||
|
||||
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
parameters = blocks.input_names
|
||||
if "image" in parameters and "strength" in parameters:
|
||||
inputs["num_inference_steps"] = 4
|
||||
|
||||
inputs["output_type"] = "np"
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
r"""Tests for IP-Adapter.
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
pipe = blocks.init_pipeline(self.repo)
|
||||
pipe.load_default_components(torch_dtype=torch.float32)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
if expected_pipe_slice is None:
|
||||
output_without_adapter = pipe(**inputs, output="images")
|
||||
else:
|
||||
output_without_adapter = expected_pipe_slice
|
||||
|
||||
# 1. Single IP-Adapter test cases
|
||||
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
|
||||
|
||||
# forward pass with single ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(0.0)
|
||||
output_without_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(42.0)
|
||||
output_with_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
|
||||
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
|
||||
|
||||
assert max_diff_without_adapter_scale < expected_max_diff, (
|
||||
"Output without ip-adapter must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference"
|
||||
|
||||
# 2. Multi IP-Adapter test cases
|
||||
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
|
||||
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
|
||||
|
||||
# forward pass with multi ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([0.0, 0.0])
|
||||
output_without_multi_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with multi ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([42.0, 42.0])
|
||||
output_with_multi_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_multi_adapter_scale = np.abs(
|
||||
output_without_multi_adapter_scale - output_without_adapter
|
||||
).max()
|
||||
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
|
||||
assert max_diff_without_multi_adapter_scale < expected_max_diff, (
|
||||
"Output without multi-ip-adapter must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_multi_adapter_scale > 1e-2, (
|
||||
"Output with multi-ip-adapter scale must be different from normal inference"
|
||||
)
|
||||
|
||||
|
||||
class SDXLModularControlNetTests:
|
||||
"""
|
||||
This mixin is designed to test ControlNet.
|
||||
"""
|
||||
|
||||
def test_pipeline_inputs(self):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
parameters = blocks.input_names
|
||||
|
||||
assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method"
|
||||
assert "controlnet_conditioning_scale" in parameters, (
|
||||
"`controlnet_conditioning_scale` argument must be supported by the `__call__` method"
|
||||
)
|
||||
|
||||
def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]):
|
||||
controlnet_embedder_scale_factor = 2
|
||||
image = torch.randn(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
device=torch_device,
|
||||
)
|
||||
inputs["control_image"] = image
|
||||
return inputs
|
||||
|
||||
def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
r"""Tests for ControlNet.
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single ControlNet with scale=0 should produce same output as no ControlNet.
|
||||
- Single ControlNet with scale!=0 should produce different output compared to no ControlNet.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass without controlnet
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_without_controlnet = pipe(**inputs, output="images")
|
||||
output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single controlnet, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["controlnet_conditioning_scale"] = 0.0
|
||||
output_without_controlnet_scale = pipe(**inputs, output="images")
|
||||
output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single controlnet, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["controlnet_conditioning_scale"] = 42.0
|
||||
output_with_controlnet_scale = pipe(**inputs, output="images")
|
||||
output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max()
|
||||
max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max()
|
||||
|
||||
assert max_diff_without_controlnet_scale < expected_max_diff, (
|
||||
"Output without controlnet must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
|
||||
|
||||
def test_controlnet_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = np.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class SDXLModularGuiderTests:
|
||||
def test_guider_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = np.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class SDXLModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.5966781,
|
||||
0.62939394,
|
||||
0.48465094,
|
||||
0.51573336,
|
||||
0.57593524,
|
||||
0.47035995,
|
||||
0.53410417,
|
||||
0.51436996,
|
||||
0.47313565,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
|
||||
class SDXLImg2ImgModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
inputs["image"] = image
|
||||
inputs["strength"] = 0.8
|
||||
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.56943184,
|
||||
0.4702148,
|
||||
0.48048905,
|
||||
0.6235963,
|
||||
0.551138,
|
||||
0.49629188,
|
||||
0.60031277,
|
||||
0.5688907,
|
||||
0.43996853,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
|
||||
class SDXLInpaintingModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
# create mask
|
||||
image[8:, 8:, :] = 255
|
||||
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
|
||||
|
||||
inputs["image"] = init_image
|
||||
inputs["mask_image"] = mask_image
|
||||
inputs["strength"] = 1.0
|
||||
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.40872607,
|
||||
0.38842705,
|
||||
0.34893104,
|
||||
0.47837183,
|
||||
0.43792963,
|
||||
0.5332134,
|
||||
0.3716843,
|
||||
0.47274873,
|
||||
0.45000193,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
@@ -0,0 +1,358 @@
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
require_torch,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.detach().cpu().numpy()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with unittest.TestCase classes.
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
including:
|
||||
- test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
|
||||
- test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
|
||||
- test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
|
||||
- test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
|
||||
- test_to_device: check if the pipeline's __call__ method can handle different devices
|
||||
"""
|
||||
|
||||
# Canonical parameters that are passed to `__call__` regardless
|
||||
# of the type of pipeline. They are always optional and have common
|
||||
# sense default values.
|
||||
optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"num_images_per_prompt",
|
||||
"latents",
|
||||
"output_type",
|
||||
]
|
||||
)
|
||||
# this is modular specific: generator needs to be a intermediate input because it's mutable
|
||||
intermediate_params = frozenset(
|
||||
[
|
||||
"generator",
|
||||
]
|
||||
)
|
||||
|
||||
def get_generator(self, seed):
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device).manual_seed(seed)
|
||||
return generator
|
||||
|
||||
@property
|
||||
def pipeline_class(self) -> Union[Callable, ModularPipeline]:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def repo(self) -> str:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_pipeline(self):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_pipeline(self)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `params` in the child test class. "
|
||||
"`params` are checked for if all values are present in `__call__`'s signature."
|
||||
" You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
|
||||
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
|
||||
"image pipelines, including prompts and prompt embedding overrides."
|
||||
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
|
||||
"do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
|
||||
"with non-configurable height and width arguments should set the attribute as "
|
||||
"`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `batch_params` in the child test class. "
|
||||
"`batch_params` are the parameters required to be batched when passed to the pipeline's "
|
||||
"`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
|
||||
"`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
|
||||
"set of batch arguments has minor changes from one of the common sets of batch arguments, "
|
||||
"do not make modifications to the existing common sets of batch arguments. I.e. a text to "
|
||||
"image pipeline `negative_prompt` is not batched should set the attribute as "
|
||||
"`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test in case of CUDA runtime errors
|
||||
super().tearDown()
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_pipeline_call_signature(self):
|
||||
pipe = self.get_pipeline()
|
||||
input_parameters = pipe.blocks.input_names
|
||||
optional_parameters = pipe.default_call_parameters
|
||||
|
||||
def _check_for_parameters(parameters, expected_parameters, param_type):
|
||||
remaining_parameters = {param for param in parameters if param not in expected_parameters}
|
||||
assert len(remaining_parameters) == 0, (
|
||||
f"Required {param_type} parameters not present: {remaining_parameters}"
|
||||
)
|
||||
|
||||
_check_for_parameters(self.params, input_parameters, "input")
|
||||
_check_for_parameters(self.optional_params, optional_parameters, "optional")
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# prepare batched inputs
|
||||
batched_inputs = []
|
||||
for batch_size in batch_sizes:
|
||||
batched_input = {}
|
||||
batched_input.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
batched_input[name] = batch_size * [value]
|
||||
|
||||
if batch_generator and "generator" in inputs:
|
||||
batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_input["batch_size"] = batch_size
|
||||
|
||||
batched_inputs.append(batched_input)
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
|
||||
output = pipe(**batched_input, output="images")
|
||||
assert len(output) == batch_size, "Output is different from expected batch size"
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
self,
|
||||
batch_size=2,
|
||||
expected_max_diff=1e-4,
|
||||
):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batched_inputs.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
if "generator" in inputs:
|
||||
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_inputs["batch_size"] = batch_size
|
||||
|
||||
output = pipe(**inputs, output="images")
|
||||
output_batch = pipe(**batched_inputs, output="images")
|
||||
|
||||
assert output_batch.shape[0] == batch_size
|
||||
|
||||
max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
|
||||
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
|
||||
|
||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||
@require_accelerator
|
||||
def test_float16_inference(self, expected_max_diff=5e-2):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device, torch.float32)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe_fp16 = self.get_pipeline()
|
||||
pipe_fp16.to(torch_device, torch.float16)
|
||||
pipe_fp16.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in inputs:
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
output = pipe(**inputs, output="images")
|
||||
|
||||
fp16_inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in fp16_inputs:
|
||||
fp16_inputs["generator"] = self.get_generator(0)
|
||||
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
|
||||
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = output.cpu()
|
||||
output_fp16 = output_fp16.cpu()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
|
||||
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
|
||||
|
||||
@require_accelerator
|
||||
def test_to_device(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU"
|
||||
|
||||
pipe.to(torch_device)
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
assert all(device == torch_device for device in model_devices), (
|
||||
"All pipeline components are not on accelerator device"
|
||||
)
|
||||
|
||||
def test_inference_is_not_nan_cpu(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to("cpu")
|
||||
|
||||
output = pipe(**self.get_dummy_inputs("cpu"), output="images")
|
||||
assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN"
|
||||
|
||||
@require_accelerator
|
||||
def test_inference_is_not_nan(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(torch_device), output="images")
|
||||
assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN"
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
pipe = self.get_pipeline()
|
||||
|
||||
if "num_images_per_prompt" not in pipe.blocks.input_names:
|
||||
return
|
||||
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
@require_accelerator
|
||||
def test_components_auto_cpu_offload_inference_consistent(self):
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
cm = ComponentsManager()
|
||||
cm.enable_auto_cpu_offload(device=torch_device)
|
||||
offload_pipe = self.get_pipeline(components_manager=cm)
|
||||
|
||||
image_slices = []
|
||||
for pipe in [base_pipe, offload_pipe]:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_save_from_pretrained(self):
|
||||
pipes = []
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(base_pipe)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
base_pipe.save_pretrained(tmpdirname)
|
||||
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
pipe.load_default_components(torch_dtype=torch.float32)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipes.append(pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
@@ -45,6 +45,7 @@ from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import is_transformers_version
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
@@ -220,6 +221,11 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
}
|
||||
return inputs
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">=", "4.54.1"),
|
||||
reason="Test currently fails on Transformers version 4.54.1.",
|
||||
strict=False,
|
||||
)
|
||||
def test_audioldm2_ddim(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
@@ -312,7 +318,6 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
components = self.get_dummy_components()
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
@@ -371,6 +376,11 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
assert np.abs(audio_1 - audio_2).max() < 1e-2
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">=", "4.54.1"),
|
||||
reason="Test currently fails on Transformers version 4.54.1.",
|
||||
strict=False,
|
||||
)
|
||||
def test_audioldm2_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -20,12 +20,6 @@ TEXT_TO_IMAGE_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
|
||||
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
|
||||
|
||||
IMAGE_VARIATION_PARAMS = frozenset(
|
||||
[
|
||||
"image",
|
||||
@@ -35,8 +29,6 @@ IMAGE_VARIATION_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -50,8 +42,6 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
# Text guided image variation with an image mask
|
||||
@@ -67,8 +57,6 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
|
||||
|
||||
IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
# image variation with an image mask
|
||||
@@ -80,8 +68,6 @@ IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
"example_image",
|
||||
@@ -93,20 +79,12 @@ IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
|
||||
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])
|
||||
|
||||
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
TEXT_TO_AUDIO_PARAMS = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -119,11 +97,38 @@ TEXT_TO_AUDIO_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
# image params
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
|
||||
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
|
||||
|
||||
|
||||
# batch params
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
|
||||
|
||||
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
|
||||
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
|
||||
|
||||
VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])
|
||||
|
||||
# callback params
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user