Compare commits

..

43 Commits

Author SHA1 Message Date
DN6 e3f6ba3fc1 Merge branch 'main' into custom-code-updates 2025-08-11 11:55:57 +05:30
Sayak Paul f442955c6e [lora] support loading loras from lightx2v/Qwen-Image-Lightning (#12119)
* feat: support qwen lightning lora.

* add docs.

* fix
2025-08-11 09:27:10 +05:30
Sayak Paul ff9a387618 [core] add modular support for Flux I2I (#12086)
* start

* encoder.

* up

* up

* up

* up

* up

* up
2025-08-11 07:23:23 +05:30
Dhruv Nair fb8722e9ab update 2025-08-09 16:00:24 +02:00
Dhruv Nair 512044c5ea update 2025-08-09 15:06:18 +02:00
Sayak Paul 03c3f69aa5 [docs] diffusers gguf checkpoints (#12092)
* feat: support loading diffusers format gguf checkpoints.

* update

* update

* qwen

* up

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* up

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-08-09 08:49:49 +05:30
Sayak Paul f20aba3e87 [GGUF] feat: support loading diffusers format gguf checkpoints. (#11684)
* feat: support loading diffusers format gguf checkpoints.

* update

* update

* qwen

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>
2025-08-08 22:27:15 +05:30
Dhruv Nair 6c85fcd899 update 2025-08-08 18:52:55 +02:00
Dhruv Nair 085e9cba36 update 2025-08-08 18:22:44 +02:00
DN6 919ee1aee3 Merge branch 'main' into custom-code-updates 2025-08-08 19:53:30 +05:30
DN6 9cda45701c Merge branch 'main' into custom-code-updates 2025-08-08 19:48:18 +05:30
DN6 c678e8a445 update 2025-08-08 19:47:50 +05:30
YiYi Xu ccf2c31188 [Modular] Fast Tests (#11937)
* rearrage the params to groups: default params /image params /batch params / callback params

* make style

* add names property to pipeline blocks

* style

* remove more unused func

* prepare_latents_inpaint always return noise and image_latents

* up

* up

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>
2025-08-08 19:42:13 +05:30
Dhruv Nair d1342d7464 update 2025-08-08 12:10:06 +02:00
Sayak Paul 7b10e4ae65 [tests] device placement for non-denoiser components in group offloading LoRA tests (#12103)
up
2025-08-08 13:34:29 +05:30
Beinsezii 3c0531bc50 lora_conversion_utils: replace lora up/down with a/b even if transformer. in key (#12101)
lora_conversion_utils: replace lora up/down with a/b even if transformer. in key

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-08-08 11:21:47 +05:30
Sayak Paul a8e47978c6 [lora] adapt new LoRA config injection method (#11999)
* use state dict when setting up LoRA.

* up

* up

* up

* comment

* up

* up
2025-08-08 09:22:48 +05:30
YiYi Xu 50e18ee698 [qwen] device typo (#12099)
up
2025-08-07 12:27:39 -10:00
DefTruth 4b17fa2a2e fix flux type hint (#12089)
fix-flux-type-hint
2025-08-07 13:00:15 +05:30
dg845 d45199a2f1 Implement Frequency-Decoupled Guidance (FDG) as a Guider (#11976)
* Initial commit implementing frequency-decoupled guidance (FDG) as a guider

* Update FrequencyDecoupledGuidance docstring to describe FDG

* Update project so that it accepts any number of non-batch dims

* Change guidance_scale and other params to accept a list of params for each freq level

* Add comment with Laplacian pyramid shapes

* Add function to import_utils to check if the kornia package is available

* Only import from kornia if package is available

* Fix bug: use pred_cond/uncond in freq space rather than data space

* Allow guidance rescaling to be done in data space or frequency space (speculative)

* Add kornia install instructions to kornia import error message

* Add config to control whether operations are upcast to fp64

* Add parallel_weights recommended values to docstring

* Apply style fixes

* make fix-copies

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Aryan <aryan@huggingface.co>
2025-08-07 11:21:02 +05:30
Sayak Paul 061163142d [tests] tighten compilation tests for quantization (#12002)
* tighten compilation tests for quantization

* up

* up
2025-08-07 10:13:14 +05:30
Dhruv Nair 5780776c8a Make prompt_2 optional in Flux Pipelines (#12073)
* update

* update
2025-08-06 15:40:12 -10:00
Aryan f19421e27c Helper functions to return skip-layer compatible layers (#12048)
update

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2025-08-06 07:55:16 -10:00
Dhruv Nair 9a0cc463ee update 2025-08-06 19:32:23 +02:00
Dhruv Nair ef4e373a65 Merge branch 'main' into custom-code-updates 2025-08-06 19:31:05 +02:00
Dhruv Nair 1b4af6b7ef update 2025-08-06 17:43:21 +02:00
Aryan 69cdc25746 Fix group offloading synchronization bug for parameter-only GroupModule's (#12077)
* update

* update

* refactor

* fuck yeah

* make style

* Update src/diffusers/hooks/group_offloading.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/hooks/group_offloading.py

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-08-06 21:11:00 +05:30
Aryan cfd6ec7465 [refactor] condense group offloading (#11990)
* update

* update

* refactor

* add test

* address review comment

* nit
2025-08-06 20:01:02 +05:30
DN6 ea77fdc4b4 update 2025-08-06 17:17:51 +05:30
jiqing-feng 1082c46afa fix input shape for WanGGUFTexttoVideoSingleFileTests (#12081)
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
2025-08-06 14:12:40 +05:30
Isotr0py ba2ba9019f Add cuda kernel support for GGUF inference (#11869)
* add gguf kernel support

Signed-off-by: Isotr0py <2037008807@qq.com>

* fix

Signed-off-by: Isotr0py <2037008807@qq.com>

* optimize

Signed-off-by: Isotr0py <2037008807@qq.com>

* update

* update

* update

* update

* update

---------

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: DN6 <dhruv.nair@gmail.com>
2025-08-05 21:36:48 +05:30
C fa4c0e5e2e optimize QwenImagePipeline to reduce unnecessary CUDA synchronization (#12072) 2025-08-05 04:12:47 -10:00
Sayak Paul b793debd9d [tests] deal with the failing AudioLDM2 tests (#12069)
up
2025-08-05 15:54:25 +05:30
Aryan 377057126c [tests] Fix Qwen test_inference slices (#12070)
update
2025-08-05 14:10:22 +05:30
Sayak Paul 5937e11d85 [docs] small corrections to the example in the Qwen docs (#12068)
* up

* up
2025-08-05 09:47:21 +05:30
Sayak Paul 9c1d4e3be1 [wip] feat: support lora in qwen image and training script (#12056)
* feat: support lora in qwen image and training script

* up

* up

* up

* up

* up

* up

* add lora tests

* fix

* add tests

* fix

* reviewer feedback

* up[

* Apply suggestions from code review

Co-authored-by: Aryan <aryan@huggingface.co>

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2025-08-05 07:06:02 +05:30
Dhruv Nair 255c5742aa update 2025-07-30 08:33:51 +02:00
Dhruv Nair 4524d43279 update 2025-07-30 08:24:25 +02:00
Dhruv Nair b6dc0b75f4 Merge branch 'custom-code-updates' of https://github.com/huggingface/diffusers into custom-code-updates 2025-07-30 08:19:38 +02:00
Dhruv Nair 966a2ff8df update 2025-07-29 21:06:40 +02:00
YiYi Xu 201da97dd0 Merge branch 'main' into custom-code-updates 2025-07-23 10:23:35 -10:00
DN6 4423097b23 update 2025-07-22 19:31:22 +05:30
Dhruv Nair 60d1b81023 update 2025-07-21 18:44:44 +02:00
78 changed files with 6350 additions and 1867 deletions
+1 -1
View File
@@ -333,7 +333,7 @@ jobs:
additional_deps: ["peft"]
- backend: "gguf"
test_location: "gguf"
additional_deps: ["peft"]
additional_deps: ["peft", "kernels"]
- backend: "torchao"
test_location: "torchao"
additional_deps: []
+141
View File
@@ -0,0 +1,141 @@
name: Fast PR tests for Modular
on:
pull_request:
branches: [main]
paths:
- "src/diffusers/modular_pipelines/**.py"
- "src/diffusers/models/modeling_utils.py"
- "src/diffusers/models/model_loading_utils.py"
- "src/diffusers/pipelines/pipeline_utils.py"
- "src/diffusers/pipeline_loading_utils.py"
- "src/diffusers/loaders/lora_base.py"
- "src/diffusers/loaders/lora_pipeline.py"
- "src/diffusers/loaders/peft.py"
- "tests/modular_pipelines/**.py"
- ".github/**.yml"
- "utils/**.py"
- "setup.py"
push:
branches:
- ci-*
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
env:
DIFFUSERS_IS_CI: yes
HF_HUB_ENABLE_HF_TRANSFER: 1
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check quality
run: make quality
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
check_repository_consistency:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check repo consistency
run: |
python utils/check_copies.py
python utils/check_dummies.py
python utils/check_support_list.py
make deps_table_check_updated
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
run_fast_tests:
needs: [check_code_quality, check_repository_consistency]
strategy:
fail-fast: false
matrix:
config:
- name: Fast PyTorch Modular Pipeline CPU tests
framework: pytorch_pipelines
runner: aws-highmemory-32-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_modular_pipelines
name: ${{ matrix.config.name }}
runs-on:
group: ${{ matrix.config.runner }}
container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run fast PyTorch Pipeline CPU tests
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/modular_pipelines
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
path: reports
+5
View File
@@ -30,6 +30,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen)
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
<Tip>
@@ -105,6 +106,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
## QwenImageLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin
## LoraBaseMixin
[[autodoc]] loaders.lora_base.LoraBaseMixin
+61 -2
View File
@@ -14,7 +14,9 @@
# QwenImage
<!-- TODO: update this section when model is out -->
Qwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese.
Check out the model card [here](https://huggingface.co/Qwen/Qwen-Image) to learn more.
<Tip>
@@ -22,12 +24,69 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
</Tip>
## LoRA for faster inference
Use a LoRA from `lightx2v/Qwen-Image-Lightning` to speed up inference by reducing the
number of steps. Refer to the code snippet below:
<details>
<summary>Code</summary>
```py
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
import torch
import math
ckpt_id = "Qwen/Qwen-Image"
# From
# https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": math.log(3), # We use shift=3 in distillation
"invert_sigmas": False,
"max_image_seq_len": 8192,
"max_shift": math.log(3), # We use shift=3 in distillation
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None, # set shift_terminal to None
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
pipe = DiffusionPipeline.from_pretrained(
ckpt_id, scheduler=scheduler, torch_dtype=torch.bfloat16
).to("cuda")
pipe.load_lora_weights(
"lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.0.safetensors"
)
prompt = "a tiny astronaut hatching from an egg on the moon, Ultra HD, 4K, cinematic composition."
negative_prompt = " "
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=1024,
height=1024,
num_inference_steps=8,
true_cfg_scale=1.0,
generator=torch.manual_seed(0),
).images[0]
image.save("qwen_fewsteps.png")
```
</details>
## QwenImagePipeline
[[autodoc]] QwenImagePipeline
- all
- __call__
## QwenImagePipeline
## QwenImagePipelineOutput
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
+51
View File
@@ -53,6 +53,16 @@ image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
```
## Using Optimized CUDA Kernels with GGUF
Optimized CUDA kernels can accelerate GGUF quantized model inference by approximately 10%. This functionality requires a compatible GPU with `torch.cuda.get_device_capability` greater than 7 and the kernels library:
```shell
pip install -U kernels
```
Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels when available. Note that CUDA kernels may introduce minor numerical differences compared to the original GGUF implementation, potentially causing subtle visual variations in generated images. To disable CUDA kernel usage, set the environment variable `DIFFUSERS_GGUF_CUDA_KERNELS=false`.
## Supported Quantization Types
- BF16
@@ -67,3 +77,44 @@ image.save("flux-gguf.png")
- Q5_K
- Q6_K
## Convert to GGUF
Use the Space below to convert a Diffusers checkpoint into the GGUF format for inference.
run conversion:
<iframe
src="https://diffusers-internal-dev-diffusers-to-gguf.hf.space"
frameborder="0"
width="850"
height="450"
></iframe>
```py
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
ckpt_path = (
"https://huggingface.co/sayakpaul/different-lora-from-civitai/blob/main/flux_dev_diffusers-q4_0.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
config="black-forest-labs/FLUX.1-dev",
subfolder="transformer",
torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
```
When using Diffusers format GGUF checkpoints, it's a must to provide the model `config` path. If the
model config resides in a `subfolder`, that needs to be specified, too.
+136
View File
@@ -0,0 +1,136 @@
# DreamBooth training example for Qwen Image
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
The `train_dreambooth_lora_qwen_image.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [Qwen Image](https://huggingface.co/Qwen/Qwen-Image).
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_sana.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
Now, we can launch training using:
```bash
export MODEL_NAME="Qwen/Qwen-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-sana-lora"
accelerate launch train_dreambooth_lora_sana.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--use_8bit_adam \
--learning_rate=2e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
For using `push_to_hub`, make you're logged into your Hugging Face account:
```bash
hf auth login
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
## Notes
Additionally, we welcome you to explore the following CLI arguments:
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
* `--max_sequence_length`: Maximum sequence length to use for text embeddings.
We provide several options for optimizing memory optimization:
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwenimage) of the `QwenImagePipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
## Using quantization
You can quantize the base model with [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/index) to reduce memory usage. To do so, pass a JSON file path to `--bnb_quantization_config_path`. This file should hold the configuration to initialize `BitsAndBytesConfig`. Below is an example JSON file:
```json
{
"load_in_4bit": true,
"bnb_4bit_quant_type": "nf4"
}
```
@@ -0,0 +1,248 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import sys
import tempfile
import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAQwenImage(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_qwen_image.py"
transformer_layer_type = "transformer_blocks.0.attn.to_k"
def test_dreambooth_lora_qwen(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# transformer.transformer_blocks.0.attn.to_k should be in the state dict
starts_with_transformer = all(
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_qwen_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_qwen_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -116,7 +116,7 @@ _deps = [
"librosa",
"numpy",
"parameterized",
"peft>=0.15.0",
"peft>=0.17.0",
"protobuf>=3.20.3,<4",
"pytest",
"pytest-timeout",
+2
View File
@@ -139,6 +139,7 @@ else:
"AutoGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"FrequencyDecoupledGuidance",
"PerturbedAttentionGuidance",
"SkipLayerGuidance",
"SmoothedEnergyGuidance",
@@ -804,6 +805,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
+1 -1
View File
@@ -23,7 +23,7 @@ deps = {
"librosa": "librosa",
"numpy": "numpy",
"parameterized": "parameterized",
"peft": "peft>=0.15.0",
"peft": "peft>=0.17.0",
"protobuf": "protobuf>=3.20.3,<4",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
+2
View File
@@ -22,6 +22,7 @@ if is_torch_available():
from .auto_guidance import AutoGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance
@@ -32,6 +33,7 @@ if is_torch_available():
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
@@ -0,0 +1,327 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from ..utils import is_kornia_available
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
_CAN_USE_KORNIA = is_kornia_available()
if _CAN_USE_KORNIA:
from kornia.geometry import pyrup as upsample_and_blur_func
from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func
else:
upsample_and_blur_func = None
build_laplacian_pyramid_func = None
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
(Algorithm 2).
"""
# v0 shape: [B, ...]
# v1 shape: [B, ...]
# Assume first dim is a batch dim and all other dims are channel or "spatial" dims
all_dims_but_first = list(range(1, len(v0.shape)))
if upcast_to_double:
dtype = v0.dtype
v0, v1 = v0.double(), v1.double()
v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
if upcast_to_double:
v0_parallel = v0_parallel.to(dtype)
v0_orthogonal = v0_orthogonal.to(dtype)
return v0_parallel, v0_orthogonal
def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
"""
Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
(Algorihtm 2).
"""
# pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...]
img = pyramid[-1]
for i in range(len(pyramid) - 2, -1, -1):
img = upsample_and_blur_func(img) + pyramid[i]
return img
class FrequencyDecoupledGuidance(BaseGuidance):
"""
Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713
FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation
quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both
conditional and unconditional data, and use a combination of the two during inference. (If you want more details on
how CFG works, you can check out the CFG guider.)
FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components
using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space
separately for the low- and high-frequency components with different guidance scales. Finally, the inverse
frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images)
to form the final FDG prediction.
For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample
diversity and realistic color composition, while using high guidance scales for high-frequency components enhances
sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for
the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an
example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper).
As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
descending order).
guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
`guidance_scales`.
parallel_weights (`float` or `List[float]`, *optional*):
Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
(that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
recommended. If a list is supplied, it should be the same length as `guidance_scales`.
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float` or `List[float]`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
should be the same length as `guidance_scales`.
stop (`float` or `List[float]`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
should be the same length as `guidance_scales`.
guidance_rescale_space (`str`, defaults to `"data"`):
Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in
`"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
upcast_to_double (`bool`, defaults to `True`):
Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
float64 when performing guidance. This may result in better performance at the cost of increased runtime.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
use_original_formulation: bool = False,
start: Union[float, List[float], Tuple[float]] = 0.0,
stop: Union[float, List[float], Tuple[float]] = 1.0,
guidance_rescale_space: str = "data",
upcast_to_double: bool = True,
):
if not _CAN_USE_KORNIA:
raise ImportError(
"The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which "
"it depends is not available in the current environment. You can install `kornia` with `pip install "
"kornia`."
)
# Set start to earliest start for any freq component and stop to latest stop for any freq component
min_start = start if isinstance(start, float) else min(start)
max_stop = stop if isinstance(stop, float) else max(stop)
super().__init__(min_start, max_stop)
self.guidance_scales = guidance_scales
self.levels = len(guidance_scales)
if isinstance(guidance_rescale, float):
self.guidance_rescale = [guidance_rescale] * self.levels
elif len(guidance_rescale) == self.levels:
self.guidance_rescale = guidance_rescale
else:
raise ValueError(
f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as "
f"`guidance_scales` ({len(self.guidance_scales)})"
)
# Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after
# transforming from frequency space back to data space)
if guidance_rescale_space not in ["data", "freq"]:
raise ValueError(
f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`."
)
self.guidance_rescale_space = guidance_rescale_space
if parallel_weights is None:
# Use normal CFG shift (equal weights for parallel and orthogonal components)
self.parallel_weights = [1.0] * self.levels
elif isinstance(parallel_weights, float):
self.parallel_weights = [parallel_weights] * self.levels
elif len(parallel_weights) == self.levels:
self.parallel_weights = parallel_weights
else:
raise ValueError(
f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as "
f"`guidance_scales` ({len(self.guidance_scales)})"
)
self.use_original_formulation = use_original_formulation
self.upcast_to_double = upcast_to_double
if isinstance(start, float):
self.guidance_start = [start] * self.levels
elif len(start) == self.levels:
self.guidance_start = start
else:
raise ValueError(
f"`start` has length {len(start)} but should have the same length as `guidance_scales` "
f"({len(self.guidance_scales)})"
)
if isinstance(stop, float):
self.guidance_stop = [stop] * self.levels
elif len(stop) == self.levels:
self.guidance_stop = stop
else:
raise ValueError(
f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` "
f"({len(self.guidance_scales)})"
)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_fdg_enabled():
pred = pred_cond
else:
# Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions.
pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels)
pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels)
# From high frequencies to low frequencies, following the paper implementation
pred_guided_pyramid = []
parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale)
for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters):
if self._is_fdg_enabled_for_level(level):
# Get the cond/uncond preds (in freq space) at the current frequency level
pred_cond_freq = pred_cond_pyramid[level]
pred_uncond_freq = pred_uncond_pyramid[level]
shift = pred_cond_freq - pred_uncond_freq
# Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
if not math.isclose(parallel_weight, 1.0):
shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
shift = parallel_weight * shift_parallel + shift_orthogonal
# Apply CFG update for the current frequency level
pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq
pred = pred + guidance_scale * shift
if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale)
# Add the current FDG guided level to the FDG prediction pyramid
pred_guided_pyramid.append(pred)
else:
# Add the current pred_cond_pyramid level as the "non-FDG" prediction
pred_guided_pyramid.append(pred_cond_freq)
# Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
pred = build_image_from_pyramid(pred_guided_pyramid)
# If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value
# across all freq levels
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_fdg_enabled():
num_conditions += 1
return num_conditions
def _is_fdg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales)
else:
is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales)
return is_within_range and not is_close
def _is_fdg_enabled_for_level(self, level: int) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.guidance_start[level] * self._num_inference_steps)
skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scales[level], 0.0)
else:
is_close = math.isclose(self.guidance_scales[level], 1.0)
return is_within_range and not is_close
+1
View File
@@ -133,6 +133,7 @@ def _register_attention_processors_metadata():
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
),
)
# FluxAttnProcessor
AttentionProcessorRegistry.register(
model_class=FluxAttnProcessor,
+95 -110
View File
@@ -95,7 +95,7 @@ class ModuleGroup:
self.offload_to_disk_path = offload_to_disk_path
self._is_offloaded_to_disk = False
if self.offload_to_disk_path:
if self.offload_to_disk_path is not None:
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
self.group_id = group_id if group_id is not None else str(id(self))
short_hash = _compute_group_hash(self.group_id)
@@ -115,6 +115,12 @@ class ModuleGroup:
else:
self.cpu_param_dict = self._init_cpu_param_dict()
self._torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
@@ -138,112 +144,76 @@ class ModuleGroup:
@contextmanager
def _pinned_memory_tensors(self):
pinned_dict = {}
try:
for param, tensor in self.cpu_param_dict.items():
if not tensor.is_pinned():
pinned_dict[param] = tensor.pin_memory()
else:
pinned_dict[param] = tensor
pinned_dict = {
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
for param, tensor in self.cpu_param_dict.items()
}
yield pinned_dict
finally:
pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
def _transfer_tensor_to_device(self, tensor, source_tensor):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream and current_stream is not None:
tensor.data.record_stream(current_stream)
if self.record_stream:
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
def _process_tensors_from_modules(self, pinned_memory=None):
for group_module in self.modules:
for param in group_module.parameters():
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source, current_stream)
self._transfer_tensor_to_device(param, source)
for buffer in group_module.buffers():
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, current_stream)
self._transfer_tensor_to_device(buffer, source)
for param in self.parameters:
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source, current_stream)
self._transfer_tensor_to_device(param, source)
for buffer in self.buffers:
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, current_stream)
def _onload_from_disk(self, current_stream):
if self.stream is not None:
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
with self._pinned_memory_tensors() as pinned_memory:
for key, tensor_obj in self.key_to_tensor.items():
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
self.cpu_param_dict.clear()
else:
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
def _onload_from_memory(self, current_stream):
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory, current_stream)
else:
self._process_tensors_from_modules(None, current_stream)
@torch.compiler.disable()
def onload_(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
if self.offload_to_disk_path:
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
with context:
if self.stream is not None:
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
# Load directly to the target device (synchronous)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return
self._transfer_tensor_to_device(buffer, source)
def _onload_from_disk(self):
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
with context:
if self.offload_to_disk_path:
self._onload_from_disk(current_stream)
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
device = str(self.onload_device) if self.stream is None else "cpu"
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
if self.stream is not None:
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
self._onload_from_memory(current_stream)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
def _onload_from_memory(self):
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
with context:
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory)
else:
self._process_tensors_from_modules(None)
def _offload_to_disk(self):
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
@@ -264,14 +234,10 @@ class ModuleGroup:
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
def _offload_to_memory(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
if self.stream is not None:
if not self.record_stream:
torch_accelerator_module.current_stream().synchronize()
self._torch_accelerator_module.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
@@ -279,18 +245,25 @@ class ModuleGroup:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
@torch.compiler.disable()
def onload_(self):
r"""Onloads the group of parameters to the onload_device."""
if self.offload_to_disk_path is not None:
self._onload_from_disk()
else:
self._onload_from_memory()
@torch.compiler.disable()
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
r"""Offloads the group of parameters to the offload_device."""
if self.offload_to_disk_path:
self._offload_to_disk()
else:
@@ -307,11 +280,9 @@ class GroupOffloadingHook(ModelHook):
_is_stateful = False
def __init__(
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
) -> None:
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
self.group = group
self.next_group = next_group
self.next_group: Optional[ModuleGroup] = None
self.config = config
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
@@ -331,9 +302,23 @@ class GroupOffloadingHook(ModelHook):
if self.group.onload_leader == module:
if self.group.onload_self:
self.group.onload_()
if self.next_group is not None and not self.next_group.onload_self:
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
if should_onload_next_group:
self.next_group.onload_()
should_synchronize = (
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
)
if should_synchronize:
# If this group didn't onload itself, it means it was asynchronously onloaded by the
# previous group. We need to synchronize the side stream to ensure parameters
# are completely loaded to proceed with forward pass. Without this, uninitialized
# weights will be used in the computation, leading to incorrect results
# Also, we should only do this synchronization if we don't already do it from the sync call in
# self.next_group.onload_, hence the `not should_onload_next_group` check.
self.group.stream.synchronize()
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
return args, kwargs
@@ -459,8 +444,8 @@ class LayerExecutionTrackerHook(ModelHook):
def apply_group_offloading(
module: torch.nn.Module,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
onload_device: Union[str, torch.device],
offload_device: Union[str, torch.device] = torch.device("cpu"),
offload_type: Union[str, GroupOffloadingType] = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
@@ -546,6 +531,8 @@ def apply_group_offloading(
```
"""
onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
offload_type = GroupOffloadingType(offload_type)
stream = None
@@ -633,7 +620,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, None, config=config)
_apply_group_offloading_hook(group_module, group, config=config)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
@@ -662,9 +649,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
@@ -693,7 +680,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(submodule, group, None, config=config)
_apply_group_offloading_hook(submodule, group, config=config)
modules_with_group_offloading.add(name)
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -740,7 +727,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(parent_module, group, None, config=config)
_apply_group_offloading_hook(parent_module, group, config=config)
if config.stream is not None:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -762,13 +749,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=_GROUP_ID_LAZY_LEAF,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
def _apply_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
@@ -777,14 +763,13 @@ def _apply_group_offloading_hook(
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group, config=config)
hook = GroupOffloadingHook(group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
def _apply_lazy_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
@@ -793,7 +778,7 @@ def _apply_lazy_group_offloading_hook(
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group, config=config)
hook = GroupOffloadingHook(group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
+43
View File
@@ -0,0 +1,43 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module):
module_list_with_transformer_blocks = []
for name, submodule in module.named_modules():
name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS)
is_modulelist = isinstance(submodule, torch.nn.ModuleList)
if name_endswith_identifier and is_modulelist:
module_list_with_transformer_blocks.append((name, submodule))
return module_list_with_transformer_blocks
def _get_identifiable_attention_layers_in_module(module: torch.nn.Module):
attention_layers = []
for name, submodule in module.named_modules():
if isinstance(submodule, _ATTENTION_CLASSES):
attention_layers.append((name, submodule))
return attention_layers
def _get_identifiable_feedforward_layers_in_module(module: torch.nn.Module):
feedforward_layers = []
for name, submodule in module.named_modules():
if isinstance(submodule, _FEEDFORWARD_CLASSES):
feedforward_layers.append((name, submodule))
return feedforward_layers
+2
View File
@@ -79,6 +79,7 @@ if is_torch_available():
"WanLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
"QwenImageLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
@@ -118,6 +119,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
Mochi1LoraLoaderMixin,
QwenImageLoraLoaderMixin,
SanaLoraLoaderMixin,
SD3LoraLoaderMixin,
SkyReelsV2LoraLoaderMixin,
+41 -1
View File
@@ -817,7 +817,11 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
# has both `peft` and non-peft state dict.
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
if has_peft_state_dict:
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
state_dict = {
k.replace("lora_down.weight", "lora_A.weight").replace("lora_up.weight", "lora_B.weight"): v
for k, v in state_dict.items()
if k.startswith("transformer.")
}
return state_dict
# Another weird one.
@@ -2073,3 +2077,39 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
up_key = ".lora_up.weight"
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
for k in all_keys:
if k.endswith(down_key):
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
alpha_key = k.replace(down_key, ".alpha")
down_weight = state_dict.pop(k)
up_weight = state_dict.pop(k.replace(down_key, up_key))
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[diffusers_down_key] = down_weight * scale_down
converted_state_dict[diffusers_up_key] = up_weight * scale_up
if len(state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
+346
View File
@@ -49,6 +49,7 @@ from .lora_conversion_utils import (
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_ltxv_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_qwen_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
@@ -6538,6 +6539,351 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class QwenImageLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`QwenImageTransformer2DModel`]. Specific to [`QwenImagePipeline`].
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
if has_alphas_in_sd:
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->QwenImageTransformer2DModel
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`QwenImageTransformer2DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
+4 -1
View File
@@ -61,6 +61,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
}
@@ -319,7 +320,9 @@ class PeftAdapterMixin:
# it to None
incompatible_keys = None
else:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
inject_adapter_in_model(
lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
if self._prepare_lora_hotswap_kwargs is not None:
+21 -9
View File
@@ -153,9 +153,17 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"QwenImageTransformer2DModel": {
"checkpoint_mapping_fn": lambda x: x,
"default_subfolder": "transformer",
},
}
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
def _get_single_file_loadable_mapping_class(cls):
diffusers_module = importlib.import_module(__name__.split(".")[0])
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
@@ -381,19 +389,23 @@ class FromOriginalModelMixin:
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
diffusers_model_config.update(model_kwargs)
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
else:
diffusers_format_checkpoint = checkpoint
if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
@@ -60,6 +60,7 @@ if is_accelerate_available():
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
CHECKPOINT_KEY_NAMES = {
"v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
@@ -384,7 +384,7 @@ class FluxSingleTransformerBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -25,7 +25,6 @@ else:
_import_structure["modular_pipeline"] = [
"ModularPipelineBlocks",
"ModularPipeline",
"PipelineBlock",
"AutoPipelineBlocks",
"SequentialPipelineBlocks",
"LoopSequentialPipelineBlocks",
@@ -59,7 +58,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LoopSequentialPipelineBlocks,
ModularPipeline,
ModularPipelineBlocks,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
)
@@ -13,15 +13,16 @@
# limitations under the License.
import inspect
from typing import List, Optional, Union
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from ...models import AutoencoderKL
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline
@@ -103,6 +104,62 @@ def calculate_shift(
return mu
# Adapted from the original implementation.
def prepare_latents_img2img(
vae, scheduler, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
latent_channels = vae.config.latent_channels
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
image = image.to(device=device, dtype=dtype)
if image.shape[1] != latent_channels:
image_latents = _encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = scheduler.scale_noise(image_latents, timestep, noise)
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, latent_image_ids
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
@@ -125,7 +182,56 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
return latent_image_ids.to(device=device, dtype=dtype)
class FluxInputStep(PipelineBlock):
# Cannot use "# Copied from" because it introduces weird indentation errors.
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(vae.encode(image), generator=generator)
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
return image_latents
def _get_initial_timesteps_and_optionals(
transformer,
scheduler,
batch_size,
height,
width,
vae_scale_factor,
num_inference_steps,
guidance_scale,
sigmas,
device,
):
image_seq_len = (int(height) // vae_scale_factor // 2) * (int(width) // vae_scale_factor // 2)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
sigmas = None
mu = calculate_shift(
image_seq_len,
scheduler.config.get("base_image_seq_len", 256),
scheduler.config.get("max_image_seq_len", 4096),
scheduler.config.get("base_shift", 0.5),
scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
if transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(batch_size)
else:
guidance = None
return timesteps, num_inference_steps, sigmas, guidance
class FluxInputStep(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -143,11 +249,6 @@ class FluxInputStep(PipelineBlock):
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"prompt_embeds",
required=True,
@@ -216,7 +317,7 @@ class FluxInputStep(PipelineBlock):
return components, state
class FluxSetTimestepsStep(PipelineBlock):
class FluxSetTimestepsStep(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -235,17 +336,15 @@ class FluxSetTimestepsStep(PipelineBlock):
InputParam("sigmas"),
InputParam("guidance_scale", default=3.5),
InputParam("latents", type_hint=torch.Tensor),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam(
"latents",
"batch_size",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
)
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
),
]
@property
@@ -264,39 +363,127 @@ class FluxSetTimestepsStep(PipelineBlock):
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
scheduler = components.scheduler
transformer = components.transformer
latents = block_state.latents
image_seq_len = latents.shape[1]
num_inference_steps = block_state.num_inference_steps
sigmas = block_state.sigmas
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
sigmas = None
batch_size = block_state.batch_size * block_state.num_images_per_prompt
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
transformer,
scheduler,
batch_size,
block_state.height,
block_state.width,
components.vae_scale_factor,
block_state.num_inference_steps,
block_state.guidance_scale,
block_state.sigmas,
block_state.device,
)
block_state.timesteps = timesteps
block_state.num_inference_steps = num_inference_steps
block_state.sigmas = sigmas
mu = calculate_shift(
image_seq_len,
scheduler.config.get("base_image_seq_len", 256),
scheduler.config.get("max_image_seq_len", 4096),
scheduler.config.get("base_shift", 0.5),
scheduler.config.get("max_shift", 1.15),
)
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
scheduler, block_state.num_inference_steps, block_state.device, sigmas=block_state.sigmas, mu=mu
)
if components.transformer.config.guidance_embeds:
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
block_state.guidance = guidance
self.set_block_state(state, block_state)
return components, state
class FluxPrepareLatentsStep(PipelineBlock):
class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
model_name = "flux"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for inference"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
InputParam("strength", default=0.6),
InputParam("guidance_scale", default=3.5),
InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam(
"num_inference_steps",
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
OutputParam(
"latent_timestep",
type_hint=torch.Tensor,
description="The timestep that represents the initial noise level for image-to-image generation",
),
OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
]
@staticmethod
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps with self.scheduler->scheduler
def get_timesteps(scheduler, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = scheduler.timesteps[t_start * scheduler.order :]
if hasattr(scheduler, "set_begin_index"):
scheduler.set_begin_index(t_start * scheduler.order)
return timesteps, num_inference_steps - t_start
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
scheduler = components.scheduler
transformer = components.transformer
batch_size = block_state.batch_size * block_state.num_images_per_prompt
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
transformer,
scheduler,
batch_size,
block_state.height,
block_state.width,
components.vae_scale_factor,
block_state.num_inference_steps,
block_state.guidance_scale,
block_state.sigmas,
block_state.device,
)
timesteps, num_inference_steps = self.get_timesteps(
scheduler, num_inference_steps, block_state.strength, block_state.device
)
block_state.timesteps = timesteps
block_state.num_inference_steps = num_inference_steps
block_state.sigmas = sigmas
block_state.guidance = guidance
block_state.latent_timestep = timesteps[:1].repeat(batch_size)
self.set_block_state(state, block_state)
return components, state
class FluxPrepareLatentsStep(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -305,7 +492,7 @@ class FluxPrepareLatentsStep(PipelineBlock):
@property
def description(self) -> str:
return "Prepare latents step that prepares the latents for the text-to-video generation process"
return "Prepare latents step that prepares the latents for the text-to-image generation process"
@property
def inputs(self) -> List[InputParam]:
@@ -314,11 +501,6 @@ class FluxPrepareLatentsStep(PipelineBlock):
InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam(
"batch_size",
@@ -402,10 +584,10 @@ class FluxPrepareLatentsStep(PipelineBlock):
block_state.num_channels_latents = components.num_channels_latents
self.check_inputs(components, block_state)
batch_size = block_state.batch_size * block_state.num_images_per_prompt
block_state.latents, block_state.latent_image_ids = self.prepare_latents(
components,
block_state.batch_size * block_state.num_images_per_prompt,
batch_size,
block_state.num_channels_latents,
block_state.height,
block_state.width,
@@ -418,3 +600,90 @@ class FluxPrepareLatentsStep(PipelineBlock):
self.set_block_state(state, block_state)
return components, state
class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
model_name = "flux"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
@property
def description(self) -> str:
return "Step that prepares the latents for the image-to-image generation process"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1),
InputParam("generator"),
InputParam(
"image_latents",
required=True,
type_hint=torch.Tensor,
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
),
InputParam(
"latent_timestep",
required=True,
type_hint=torch.Tensor,
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
),
InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
),
OutputParam(
"latent_image_ids",
type_hint=torch.Tensor,
description="IDs computed from the image sequence needed for RoPE",
),
]
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.device = components._execution_device
block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
block_state.num_channels_latents = components.num_channels_latents
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.device = components._execution_device
# TODO: implement `check_inputs`
batch_size = block_state.batch_size * block_state.num_images_per_prompt
if block_state.latents is None:
block_state.latents, block_state.latent_image_ids = prepare_latents_img2img(
components.vae,
components.scheduler,
block_state.image_latents,
block_state.latent_timestep,
batch_size,
block_state.num_channels_latents,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
)
self.set_block_state(state, block_state)
return components, state
@@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL
from ...utils import logging
from ...video_processor import VaeImageProcessor
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -45,7 +45,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
return latents
class FluxDecodeStep(PipelineBlock):
class FluxDecodeStep(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -70,17 +70,12 @@ class FluxDecodeStep(PipelineBlock):
InputParam("output_type", default="pil"),
InputParam("height", default=1024),
InputParam("width", default=1024),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
)
),
]
@property
@@ -22,7 +22,7 @@ from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
PipelineBlock,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -32,7 +32,7 @@ from .modular_pipeline import FluxModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FluxLoopDenoiser(PipelineBlock):
class FluxLoopDenoiser(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -49,11 +49,8 @@ class FluxLoopDenoiser(PipelineBlock):
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [InputParam("joint_attention_kwargs")]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
@@ -113,7 +110,7 @@ class FluxLoopDenoiser(PipelineBlock):
return components, block_state
class FluxLoopAfterDenoiser(PipelineBlock):
class FluxLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -175,7 +172,7 @@ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
]
@property
def loop_intermediate_inputs(self) -> List[InputParam]:
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
@@ -226,5 +223,5 @@ class FluxDenoiseStep(FluxDenoiseLoopWrapper):
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `FluxLoopDenoiser`\n"
" - `FluxLoopAfterDenoiser`\n"
"This block supports text2image tasks."
"This block supports both text2image and img2img tasks."
)
@@ -19,9 +19,12 @@ import regex as re
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline
@@ -50,7 +53,110 @@ def prompt_clean(text):
return text
class FluxTextEncoderStep(PipelineBlock):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class FluxVaeEncoderStep(ModularPipelineBlocks):
model_name = "flux"
@property
def description(self) -> str:
return "Vae Encoder step that encode the input image into a latent representation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam(
"preprocess_kwargs",
type_hint=Optional[dict],
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=torch.Tensor,
description="The latents representing the reference image for image-to-image/inpainting generation",
)
]
@staticmethod
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(vae.encode(image), generator=generator)
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
return image_latents
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.image = components.image_processor.preprocess(
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
)
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.batch_size = block_state.image.shape[0]
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
)
block_state.image_latents = self._encode_vae_image(
components.vae, image=block_state.image, generator=block_state.generator
)
self.set_block_state(state, block_state)
return components, state
class FluxTextEncoderStep(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -297,7 +403,7 @@ class FluxTextEncoderStep(PipelineBlock):
prompt_embeds=None,
pooled_prompt_embeds=None,
device=block_state.device,
num_images_per_prompt=1, # hardcoded for now.
num_images_per_prompt=1, # TODO: hardcoded for now.
lora_scale=block_state.text_encoder_lora_scale,
)
@@ -15,16 +15,38 @@
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import FluxInputStep, FluxPrepareLatentsStep, FluxSetTimestepsStep
from .before_denoise import (
FluxImg2ImgPrepareLatentsStep,
FluxImg2ImgSetTimestepsStep,
FluxInputStep,
FluxPrepareLatentsStep,
FluxSetTimestepsStep,
)
from .decoders import FluxDecodeStep
from .denoise import FluxDenoiseStep
from .encoders import FluxTextEncoderStep
from .encoders import FluxTextEncoderStep, FluxVaeEncoderStep
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# before_denoise: text2vid
# vae encoder (run before before_denoise)
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [FluxVaeEncoderStep]
block_names = ["img2img"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block that works for img2img tasks.\n"
+ " - `FluxVaeEncoderStep` (img2img) is used when only `image` is provided."
+ " - if `image` is provided, step will be skipped."
)
# before_denoise: text2img, img2img
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
FluxInputStep,
@@ -44,11 +66,27 @@ class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
)
# before_denoise: all task (text2vid,)
# before_denoise: img2img
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [FluxInputStep, FluxImg2ImgSetTimestepsStep, FluxImg2ImgPrepareLatentsStep]
block_names = ["input", "set_timesteps", "prepare_latents"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `FluxImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ " - `FluxImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
)
# before_denoise: all task (text2img, img2img)
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxBeforeDenoiseStep]
block_names = ["text2image"]
block_trigger_inputs = [None]
block_classes = [FluxBeforeDenoiseStep, FluxImg2ImgBeforeDenoiseStep]
block_names = ["text2image", "img2img"]
block_trigger_inputs = [None, "image_latents"]
@property
def description(self):
@@ -56,6 +94,7 @@ class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2image.\n"
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
+ " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
)
@@ -69,8 +108,8 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks):
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2image tasks."
" - `FluxDenoiseStep` (denoise) for text2image tasks."
"This is a auto pipeline block that works for text2image and img2img tasks."
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
)
@@ -82,19 +121,26 @@ class FluxAutoDecodeStep(AutoPipelineBlocks):
@property
def description(self):
return "Decode step that decode the denoised latents into videos outputs.\n - `FluxDecodeStep`"
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
# text2image
class FluxAutoBlocks(SequentialPipelineBlocks):
block_classes = [FluxTextEncoderStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep, FluxAutoDecodeStep]
block_names = ["text_encoder", "before_denoise", "denoise", "decoder"]
block_classes = [
FluxTextEncoderStep,
FluxAutoVaeEncoderStep,
FluxAutoBeforeDenoiseStep,
FluxAutoDenoiseStep,
FluxAutoDecodeStep,
]
block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decoder"]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image using Flux.\n"
+ "- for text-to-image generation, all you need to provide is `prompt`"
"Auto Modular pipeline for text-to-image and image-to-image using Flux.\n"
+ "- for text-to-image generation, all you need to provide is `prompt`\n"
+ "- for image-to-image generation, you need to provide either `image` or `image_latents`"
)
@@ -102,19 +148,29 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep),
("input", FluxInputStep),
("prepare_latents", FluxPrepareLatentsStep),
# Setting it after preparation of latents because we rely on `latents`
# to calculate `img_seq_len` for `shift`.
("set_timesteps", FluxSetTimestepsStep),
("prepare_latents", FluxPrepareLatentsStep),
("denoise", FluxDenoiseStep),
("decode", FluxDecodeStep),
]
)
IMAGE2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep),
("image_encoder", FluxVaeEncoderStep),
("input", FluxInputStep),
("set_timesteps", FluxImg2ImgSetTimestepsStep),
("prepare_latents", FluxImg2ImgPrepareLatentsStep),
("denoise", FluxDenoiseStep),
("decode", FluxDecodeStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep),
("image_encoder", FluxAutoVaeEncoderStep),
("before_denoise", FluxAutoBeforeDenoiseStep),
("denoise", FluxAutoDenoiseStep),
("decode", FluxAutoDecodeStep),
@@ -122,4 +178,4 @@ AUTO_BLOCKS = InsertableDict(
)
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
@@ -13,7 +13,7 @@
# limitations under the License.
from ...loaders import FluxLoraLoaderMixin
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
@@ -21,7 +21,7 @@ from ..modular_pipeline import ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin):
class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversionLoaderMixin):
"""
A ModularPipeline for Flux.
File diff suppressed because it is too large Load Diff
@@ -91,10 +91,7 @@ class ComponentSpec:
type_hint: Optional[Type] = None
description: Optional[str] = None
config: Optional[FrozenDict] = None
# YiYi TODO: currently required is only used to mark optional components that the block can run without, in the future:
# 1. the spec for an optional component should has lower priority when combined in sequential/auto blocks
# 2. should not need to define default_creation_method for optional components
required: bool = True
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
subfolder: Optional[str] = field(default="", metadata={"loading": True})
variant: Optional[str] = field(default=None, metadata={"loading": True})
@@ -621,7 +618,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
def make_doc_string(
inputs,
intermediate_inputs,
outputs,
description="",
class_name=None,
@@ -667,7 +663,7 @@ def make_doc_string(
output += configs_str + "\n\n"
# Add inputs section
output += format_input_params(inputs + intermediate_inputs, indent_level=2)
output += format_input_params(inputs, indent_level=2)
# Add outputs section
output += "\n\n"
File diff suppressed because it is too large Load Diff
@@ -24,7 +24,7 @@ from ...models import AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...utils import logging
from ..modular_pipeline import (
PipelineBlock,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -33,7 +33,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionXLDecodeStep(PipelineBlock):
class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -56,17 +56,12 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
)
),
]
@property
@@ -105,9 +100,9 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
if not block_state.output_type == "latent":
latents = block_state.latents
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
if needs_upcasting:
if block_state.needs_upcasting:
self.upcast_vae(components)
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != components.vae.dtype:
@@ -117,27 +112,29 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
has_latents_mean = (
block_state.has_latents_mean = (
hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None
)
has_latents_std = (
block_state.has_latents_std = (
hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None
)
if has_latents_mean and has_latents_std:
latents_mean = (
if block_state.has_latents_mean and block_state.has_latents_std:
block_state.latents_mean = (
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents_std = (
block_state.latents_std = (
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = latents * latents_std / components.vae.config.scaling_factor + latents_mean
latents = (
latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
)
else:
latents = latents / components.vae.config.scaling_factor
block_state.images = components.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed
if needs_upcasting:
if block_state.needs_upcasting:
components.vae.to(dtype=torch.float16)
else:
block_state.images = block_state.latents
@@ -155,7 +152,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
return components, state
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -182,11 +179,6 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
InputParam("image"),
InputParam("mask_image"),
InputParam("padding_mask_crop"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
@@ -25,7 +25,7 @@ from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
PipelineBlock,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi experimenting composible denoise loop
# loop step (1): prepare latent input for denoiser
class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
)
@property
def intermediate_inputs(self) -> List[str]:
def inputs(self) -> List[str]:
return [
InputParam(
"latents",
@@ -67,13 +67,13 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
block_state.latent_model_input = components.scheduler.scale_model_input(block_state.latents, t)
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
return components, block_state
# loop step (1): prepare latent input for denoiser (with inpainting)
class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -91,7 +91,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
)
@property
def intermediate_inputs(self) -> List[str]:
def inputs(self) -> List[str]:
return [
InputParam(
"latents",
@@ -134,17 +134,17 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
self.check_inputs(components, block_state)
block_state.latent_model_input = components.scheduler.scale_model_input(block_state.latents, t)
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
if components.num_channels_unet == 9:
block_state.latent_model_input = torch.cat(
[block_state.latent_model_input, block_state.mask, block_state.masked_image_latents], dim=1
block_state.scaled_latents = torch.cat(
[block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1
)
return components, block_state
# loop step (2): denoise the latents with guidance
class StableDiffusionXLLoopDenoiser(PipelineBlock):
class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -171,11 +171,6 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("cross_attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"num_inference_steps",
required=True,
@@ -232,7 +227,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.unet(
block_state.latent_model_input,
block_state.scaled_latents,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=block_state.timestep_cond,
@@ -249,7 +244,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
# loop step (2): denoise the latents with guidance (with controlnet)
class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -277,11 +272,6 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("cross_attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"controlnet_cond",
required=True,
@@ -410,7 +400,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
mid_block_res_sample = block_state.mid_block_res_sample_zeros
else:
down_block_res_samples, mid_block_res_sample = components.controlnet(
block_state.latent_model_input,
block_state.scaled_latents,
t,
encoder_hidden_states=guider_state_batch.prompt_embeds,
controlnet_cond=block_state.controlnet_cond,
@@ -430,7 +420,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
# Predict the noise
# store the noise_pred in guider_state_batch so we can apply guidance across all batches
guider_state_batch.noise_pred = components.unet(
block_state.latent_model_input,
block_state.scaled_latents,
t,
encoder_hidden_states=guider_state_batch.prompt_embeds,
timestep_cond=block_state.timestep_cond,
@@ -449,7 +439,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
# loop step (3): scheduler step to update latents
class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -470,11 +460,6 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("eta", default=0.0),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
]
@@ -520,7 +505,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
# loop step (3): scheduler step to update latents (with inpainting)
class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -542,11 +527,6 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("eta", default=0.0),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
InputParam(
"timesteps",
@@ -660,7 +640,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
]
@property
def loop_intermediate_inputs(self) -> List[InputParam]:
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
@@ -15,7 +15,6 @@
from typing import List, Optional, Tuple
import torch
from PIL import Image
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
@@ -36,7 +35,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import StableDiffusionXLModularPipeline
@@ -58,92 +57,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
def get_clip_prompt_embeds(
prompt,
text_encoder,
tokenizer,
device,
clip_skip=None,
max_length=None,
):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length if max_length is not None else tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only using the pooled output of the text_encoder_2, which has 2 dimensions
# (pooled output for text_encoder has 3 dimensions)
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
return prompt_embeds, pooled_prompt_embeds
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
def encode_vae_image(
image: torch.Tensor, vae: AutoencoderKL, generator: torch.Generator, dtype: torch.dtype, device: torch.device
):
latents_mean = latents_std = None
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
image = image.to(device=device, dtype=dtype)
if vae.config.force_upcast:
image = image.float()
vae.to(dtype=torch.float32)
if isinstance(generator, list) and len(generator) != image.shape[0]:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {image.shape[0]}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
image_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(vae.encode(image), generator=generator)
if vae.config.force_upcast:
vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=device, dtype=dtype)
latents_std = latents_std.to(device=device, dtype=dtype)
image_latents = (image_latents - latents_mean) * vae.config.scaling_factor / latents_std
else:
image_latents = vae.config.scaling_factor * image_latents
return image_latents
class StableDiffusionXLIPAdapterStep(PipelineBlock):
class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -172,7 +86,6 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
required=False,
),
]
@@ -190,16 +103,10 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"ip_adapter_embeds",
type_hint=List[torch.Tensor],
kwargs_type="guider_input_fields",
description="IP adapter image embeddings",
),
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
OutputParam(
"negative_ip_adapter_embeds",
type_hint=List[torch.Tensor],
kwargs_type="guider_input_fields",
type_hint=torch.Tensor,
description="Negative IP adapter image embeddings",
),
]
@@ -230,41 +137,85 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
return image_embeds, uncond_image_embeds
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self,
components,
ip_adapter_image,
ip_adapter_image_embeds,
device,
num_images_per_prompt,
prepare_unconditional_embeds,
):
image_embeds = []
if prepare_unconditional_embeds:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
components, single_ip_adapter_image, device, 1, output_hidden_state
)
image_embeds.append(single_image_embeds[None, :])
if prepare_unconditional_embeds:
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
if prepare_unconditional_embeds:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if prepare_unconditional_embeds:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
block_state.ip_adapter_embeds = []
if components.requires_unconditional_embeds:
block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
components,
ip_adapter_image=block_state.ip_adapter_image,
ip_adapter_image_embeds=None,
device=block_state.device,
num_images_per_prompt=1,
prepare_unconditional_embeds=block_state.prepare_unconditional_embeds,
)
if block_state.prepare_unconditional_embeds:
block_state.negative_ip_adapter_embeds = []
if not isinstance(block_state.ip_adapter_image, list):
block_state.ip_adapter_image = [block_state.ip_adapter_image]
if len(block_state.ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(block_state.ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
block_state.ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
components, single_ip_adapter_image, device, 1, output_hidden_state
)
block_state.ip_adapter_embeds.append(single_image_embeds[None, :])
if components.requires_unconditional_embeds:
block_state.negative_ip_adapter_embeds.append(single_negative_image_embeds[None, :])
for i, image_embeds in enumerate(block_state.ip_adapter_embeds):
negative_image_embeds, image_embeds = image_embeds.chunk(2)
block_state.negative_ip_adapter_embeds.append(negative_image_embeds)
block_state.ip_adapter_embeds[i] = image_embeds
self.set_block_state(state, block_state)
return components, state
class StableDiffusionXLTextEncoderStep(PipelineBlock):
class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -274,16 +225,15 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", CLIPTextModel, required=False),
ComponentSpec("text_encoder", CLIPTextModel),
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
ComponentSpec("tokenizer", CLIPTokenizer, required=False),
ComponentSpec("tokenizer", CLIPTokenizer),
ComponentSpec("tokenizer_2", CLIPTokenizer),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config",
required=False,
),
]
@@ -294,7 +244,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt", required=True),
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("negative_prompt"),
InputParam("negative_prompt_2"),
@@ -332,22 +282,15 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
]
@staticmethod
def check_inputs(prompt, prompt_2, negative_prompt, negative_prompt_2):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and (
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
def check_inputs(block_state):
if block_state.prompt is not None and (
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
if negative_prompt_2 is not None and (
not isinstance(negative_prompt_2, str) and not isinstance(negative_prompt_2, list)
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
elif block_state.prompt_2 is not None and (
not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)
):
raise ValueError(f"`negative_prompt_2` has to be of type `str` or `list` but is {type(negative_prompt_2)}")
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}")
@staticmethod
def encode_prompt(
@@ -355,9 +298,14 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
requires_unconditional_embeds: bool = True,
num_images_per_prompt: int = 1,
prepare_unconditional_embeds: bool = True,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
@@ -383,17 +331,52 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
dtype = components.text_encoder_2.dtype
device = device or components._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
components._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if components.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(components.text_encoder, lora_scale)
else:
scale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale)
else:
scale_lora_layers(components.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
tokenizers = (
@@ -406,56 +389,58 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
if components.text_encoder is not None
else [components.text_encoder_2]
)
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
components._lora_scale = lora_scale
# dynamically adjust the LoRA scale
for text_encoder in text_encoders:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(text_encoder, lora_scale)
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, tokenizer)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
scale_lora_layers(text_encoder, lora_scale)
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
# Define prompts
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
prompts = [prompt, prompt_2]
prompt_embeds_list.append(prompt_embeds)
# generate prompt_embeds & pooled_prompt_embeds
prompt_embeds_list = []
pooled_prompt_embeds_list = []
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, tokenizer)
prompt_embeds, pooled_prompt_embeds = get_clip_prompt_embeds(
prompt=prompt,
text_encoder=text_encoder,
tokenizer=tokenizer,
device=device,
clip_skip=clip_skip,
max_length=tokenizer.model_max_length,
)
prompt_embeds_list.append(prompt_embeds)
if pooled_prompt_embeds.ndim == 2:
pooled_prompt_embeds_list.append(pooled_prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = torch.concat(pooled_prompt_embeds_list, dim=0)
negative_prompt_embeds = None
negative_pooled_prompt_embeds = None
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
# generate negative_prompt_embeds & negative_pooled_prompt_embeds
if requires_unconditional_embeds and zero_out_negative_prompt:
if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif requires_unconditional_embeds:
elif prepare_unconditional_embeds and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
@@ -466,52 +451,87 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
)
uncond_tokens: List[str]
if batch_size != len(negative_prompt):
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
if batch_size != len(negative_prompt_2):
raise ValueError(
f"`negative_prompt_2`: {negative_prompt_2} has batch size {len(negative_prompt_2)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt_2` matches"
" the batch size of `prompt`."
)
uncond_tokens = [negative_prompt, negative_prompt_2]
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
negative_pooled_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1]
negative_prompt_embeds, negative_pooled_prompt_embeds = get_clip_prompt_embeds(
prompt=negative_prompt,
text_encoder=text_encoder,
tokenizer=tokenizer,
device=device,
clip_skip=None,
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
if negative_pooled_prompt_embeds.ndim == 2:
negative_pooled_prompt_embeds_list.append(negative_pooled_prompt_embeds)
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
negative_pooled_prompt_embeds = torch.concat(negative_pooled_prompt_embeds_list, dim=0)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device)
if requires_unconditional_embeds:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(dtype=dtype, device=device)
if components.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
else:
prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device)
for text_encoder in text_encoders:
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
if prepare_unconditional_embeds:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
if components.text_encoder_2 is not None:
negative_prompt_embeds = negative_prompt_embeds.to(
dtype=components.text_encoder_2.dtype, device=device
)
else:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if prepare_unconditional_embeds:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if components.text_encoder is not None:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(text_encoder, lora_scale)
unscale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
@@ -519,15 +539,13 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
self.check_inputs(
block_state.prompt, block_state.prompt_2, block_state.negative_prompt, block_state.negative_prompt_2
)
device = components._execution_device
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
# Encode input prompt
lora_scale = (
block_state.text_encoder_lora_scale = (
block_state.cross_attention_kwargs.get("scale", None)
if block_state.cross_attention_kwargs is not None
else None
@@ -539,13 +557,18 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
block_state.negative_pooled_prompt_embeds,
) = self.encode_prompt(
components,
prompt=block_state.prompt,
prompt_2=block_state.prompt_2,
device=device,
requires_unconditional_embeds=components.requires_unconditional_embeds,
negative_prompt=block_state.negative_prompt,
negative_prompt_2=block_state.negative_prompt_2,
lora_scale=lora_scale,
block_state.prompt,
block_state.prompt_2,
block_state.device,
1,
block_state.prepare_unconditional_embeds,
block_state.negative_prompt,
block_state.negative_prompt_2,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
lora_scale=block_state.text_encoder_lora_scale,
clip_skip=block_state.clip_skip,
)
# Add outputs
@@ -553,7 +576,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
return components, state
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -576,13 +599,15 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
def inputs(self) -> List[InputParam]:
return [
InputParam("image", required=True),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("height"),
InputParam("width"),
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam(
"preprocess_kwargs",
type_hint=Optional[dict],
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
),
]
@property
@@ -595,26 +620,70 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
)
]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
device = components._execution_device
dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
image = components.image_processor.preprocess(block_state.image)
# Encode image into latents
block_state.image_latents = encode_vae_image(
image=image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device
image = components.image_processor.preprocess(
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
)
image = image.to(device=block_state.device, dtype=block_state.dtype)
block_state.batch_size = image.shape[0]
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
)
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
self.set_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -649,11 +718,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
InputParam("image", required=True),
InputParam("mask_image", required=True),
InputParam("padding_mask_crop"),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
InputParam("generator"),
]
@@ -664,6 +728,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
OutputParam(
"image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"
),
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
OutputParam(
"masked_image_latents",
type_hint=torch.Tensor,
@@ -674,83 +739,149 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
type_hint=Optional[Tuple[int, int]],
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
),
OutputParam(
"mask",
type_hint=torch.Tensor,
description="The mask to apply on the latents for the inpainting generation.",
),
]
def check_inputs(self, image, mask_image, padding_mask_crop):
if padding_mask_crop is not None and not isinstance(image, Image.Image):
raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
if padding_mask_crop is not None and not isinstance(mask_image, Image.Image):
raise ValueError(
f"The mask image should be a PIL image when inpainting mask crop, but is of type {type(mask_image)}."
)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
# do not accept do_classifier_free_guidance
def prepare_mask_latents(
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image is not None and masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = None
if masked_image is not None:
if masked_image_latents is None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state.image, block_state.mask_image, block_state.padding_mask_crop)
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.device = components._execution_device
dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
device = components._execution_device
height = block_state.height if block_state.height is not None else components.default_height
width = block_state.width if block_state.width is not None else components.default_width
if block_state.height is None:
block_state.height = components.default_height
if block_state.width is None:
block_state.width = components.default_width
if block_state.padding_mask_crop is not None:
block_state.crops_coords = components.mask_processor.get_crop_region(
mask_image=block_state.mask_image, width=width, height=height, pad=block_state.padding_mask_crop
block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop
)
resize_mode = "fill"
block_state.resize_mode = "fill"
else:
block_state.crops_coords = None
resize_mode = "default"
block_state.resize_mode = "default"
image = components.image_processor.preprocess(
block_state.image,
height=height,
width=width,
height=block_state.height,
width=block_state.width,
crops_coords=block_state.crops_coords,
resize_mode=resize_mode,
resize_mode=block_state.resize_mode,
)
image = image.to(dtype=torch.float32)
mask_image = components.mask_processor.preprocess(
mask = components.mask_processor.preprocess(
block_state.mask_image,
height=height,
width=width,
resize_mode=resize_mode,
height=block_state.height,
width=block_state.width,
resize_mode=block_state.resize_mode,
crops_coords=block_state.crops_coords,
)
block_state.masked_image = image * (mask < 0.5)
masked_image = image * (mask_image < 0.5)
block_state.batch_size = image.shape[0]
image = image.to(device=block_state.device, dtype=block_state.dtype)
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
# Prepare image latent variables
block_state.image_latents = encode_vae_image(
image=image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device
# 7. Prepare mask latent variables
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
components,
mask,
block_state.masked_image,
block_state.batch_size,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
)
# Prepare masked image latent variables
block_state.masked_image_latents = encode_vae_image(
image=masked_image, vae=components.vae, generator=block_state.generator, dtype=dtype, device=device
)
# resize mask to match the image latents
_, _, height_latents, width_latents = block_state.image_latents.shape
block_state.mask = torch.nn.functional.interpolate(
mask_image,
size=(height_latents, width_latents),
)
block_state.mask = block_state.mask.to(dtype=dtype, device=device)
self.set_block_state(state, block_state)
return components, state
@@ -23,7 +23,6 @@ from .before_denoise import (
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLInpaintPrepareLatentsStep,
StableDiffusionXLInputStep,
StableDiffusionXLLCMStep,
StableDiffusionXLPrepareAdditionalConditioningStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLSetTimestepsStep,
@@ -80,16 +79,6 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
class StableDiffusionXLAutoLCMStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLLCMStep]
block_names = ["lcm"]
block_trigger_inputs = ["embedded_guidance_scale"]
@property
def description(self):
return "Run LCM step if `latents` is provided. This step should be placed before the 'input' step.\n"
# before_denoise: text2img
class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
@@ -273,7 +262,6 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
StableDiffusionXLAutoBeforeDenoiseStep,
StableDiffusionXLAutoLCMStep,
StableDiffusionXLAutoControlNetInputStep,
StableDiffusionXLAutoDenoiseStep,
StableDiffusionXLAutoDecodeStep,
@@ -283,7 +271,6 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
"ip_adapter",
"image_encoder",
"before_denoise",
"lcm",
"controlnet_input",
"denoise",
"decoder",
@@ -299,7 +286,6 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
+ "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n"
+ "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n"
+ "- for text-to-image generation, all you need to provide is `prompt`"
+ "- to run the latent consistency models workflow, you need to provide `embedded_guidance_scale`"
)
@@ -371,12 +357,6 @@ IP_ADAPTER_BLOCKS = InsertableDict(
]
)
LCM_BLOCKS = InsertableDict(
[
("lcm", StableDiffusionXLAutoLCMStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
@@ -396,6 +376,5 @@ ALL_BLOCKS = {
"inpaint": INPAINT_BLOCKS,
"controlnet": CONTROLNET_BLOCKS,
"ip_adapter": IP_ADAPTER_BLOCKS,
"lcm": LCM_BLOCKS,
"auto": AUTO_BLOCKS,
}
@@ -90,23 +90,6 @@ class StableDiffusionXLModularPipeline(
num_channels_latents = self.vae.config.latent_channels
return num_channels_latents
@property
def requires_unconditional_embeds(self):
# by default, always prepare unconditional embeddings
requires_unconditional_embeds = True
if hasattr(self, "unet") and self.unet is not None and self.unet.config.time_cond_proj_dim is not None:
# LCM
requires_unconditional_embeds = False
elif hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider.num_conditions > 1
elif not hasattr(self, "guider") or self.guider is None:
requires_unconditional_embeds = False
return requires_unconditional_embeds
# YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks
# auto_docstring
@@ -264,10 +247,6 @@ SDXL_INPUTS_SCHEMA = {
"control_mode": InputParam(
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
),
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam(
"prompt_embeds",
type_hint=torch.Tensor,
@@ -288,13 +267,6 @@ SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"preprocess_kwargs": InputParam(
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
),
"latents": InputParam(
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
),
"latent_timestep": InputParam(
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
),
@@ -20,7 +20,7 @@ import torch
from ...schedulers import UniPCMultistepScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline
@@ -94,7 +94,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class WanInputStep(PipelineBlock):
class WanInputStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -194,7 +194,7 @@ class WanInputStep(PipelineBlock):
return components, state
class WanSetTimestepsStep(PipelineBlock):
class WanSetTimestepsStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -243,7 +243,7 @@ class WanSetTimestepsStep(PipelineBlock):
return components, state
class WanPrepareLatentsStep(PipelineBlock):
class WanPrepareLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -22,14 +22,14 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKLWan
from ...utils import logging
from ...video_processor import VideoProcessor
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanDecodeStep(PipelineBlock):
class WanDecodeStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -24,7 +24,7 @@ from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
PipelineBlock,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -34,7 +34,7 @@ from .modular_pipeline import WanModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanLoopDenoiser(PipelineBlock):
class WanLoopDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -132,7 +132,7 @@ class WanLoopDenoiser(PipelineBlock):
return components, block_state
class WanLoopAfterDenoiser(PipelineBlock):
class WanLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -22,7 +22,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...utils import is_ftfy_available, logging
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline
@@ -51,7 +51,7 @@ def prompt_clean(text):
return text
class WanTextEncoderStep(PipelineBlock):
class WanTextEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -312,15 +312,14 @@ class AudioLDM2Pipeline(DiffusionPipeline):
The sequence of generated hidden-states.
"""
cache_position_kwargs = {}
if is_transformers_version("<", "4.52.0.dev0"):
if is_transformers_version("<", "4.52.1"):
cache_position_kwargs["input_ids"] = inputs_embeds
cache_position_kwargs["model_kwargs"] = model_kwargs
else:
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
cache_position_kwargs["device"] = (
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
)
cache_position_kwargs["model_kwargs"] = model_kwargs
cache_position_kwargs["model_kwargs"] = model_kwargs
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
@@ -310,7 +310,7 @@ class FluxPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -324,7 +324,7 @@ class FluxControlPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -335,7 +335,7 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -374,7 +374,7 @@ class FluxControlInpaintPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -341,7 +341,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -335,7 +335,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -346,7 +346,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -419,7 +419,7 @@ class FluxFillPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -333,7 +333,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -337,7 +337,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -358,7 +358,7 @@ class FluxKontextPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -391,7 +391,7 @@ class FluxKontextInpaintPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -292,7 +292,7 @@ class FluxPriorReduxPipeline(DiffusionPipeline):
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -20,6 +20,7 @@ import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import QwenImageLoraLoaderMixin
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
@@ -44,12 +45,12 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import QwenImagePipeline
>>> pipe = QwenImagePipeline.from_pretrained("Qwen/QwenImage-20B", torch_dtype=torch.bfloat16)
>>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
>>> image = pipe(prompt, num_inference_steps=50).images[0]
>>> image.save("qwenimage.png")
```
"""
@@ -128,7 +129,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class QwenImagePipeline(DiffusionPipeline):
class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
r"""
The QwenImage pipeline for text-to-image generation.
@@ -200,7 +201,7 @@ class QwenImagePipeline(DiffusionPipeline):
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
).to(device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
@@ -635,6 +636,11 @@ class QwenImagePipeline(DiffusionPipeline):
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -653,7 +659,7 @@ class QwenImagePipeline(DiffusionPipeline):
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -667,7 +673,7 @@ class QwenImagePipeline(DiffusionPipeline):
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
+92 -3
View File
@@ -12,15 +12,15 @@
# # See the License for the specific language governing permissions and
# # limitations under the License.
import inspect
import os
from contextlib import nullcontext
import gguf
import torch
import torch.nn as nn
from ...utils import is_accelerate_available
from ...utils import is_accelerate_available, is_kernels_available
if is_accelerate_available():
@@ -29,6 +29,82 @@ if is_accelerate_available():
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
can_use_cuda_kernels = (
os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 7
)
if can_use_cuda_kernels and is_kernels_available():
from kernels import get_kernel
ops = get_kernel("Isotr0py/ggml")
else:
ops = None
UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
STANDARD_QUANT_TYPES = {
gguf.GGMLQuantizationType.Q4_0,
gguf.GGMLQuantizationType.Q4_1,
gguf.GGMLQuantizationType.Q5_0,
gguf.GGMLQuantizationType.Q5_1,
gguf.GGMLQuantizationType.Q8_0,
gguf.GGMLQuantizationType.Q8_1,
}
KQUANT_TYPES = {
gguf.GGMLQuantizationType.Q2_K,
gguf.GGMLQuantizationType.Q3_K,
gguf.GGMLQuantizationType.Q4_K,
gguf.GGMLQuantizationType.Q5_K,
gguf.GGMLQuantizationType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
gguf.GGMLQuantizationType.IQ1_M,
gguf.GGMLQuantizationType.IQ1_S,
gguf.GGMLQuantizationType.IQ2_XXS,
gguf.GGMLQuantizationType.IQ2_XS,
gguf.GGMLQuantizationType.IQ2_S,
gguf.GGMLQuantizationType.IQ3_XXS,
gguf.GGMLQuantizationType.IQ3_S,
gguf.GGMLQuantizationType.IQ4_XS,
gguf.GGMLQuantizationType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
# contiguous batching and inefficient with diffusers' batching,
# so we disabled it now.
# elif qweight_type in MMVQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# elif qweight_type in MMQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
if qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
y = x @ weight.to(x.dtype).T
else:
# Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = gguf.GGMLQuantizationType(qweight_type)
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
return y.as_tensor()
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
def _create_accelerate_new_hook(old_hook):
r"""
@@ -451,11 +527,24 @@ class GGUFLinear(nn.Linear):
) -> None:
super().__init__(in_features, out_features, bias, device)
self.compute_dtype = compute_dtype
self.device = device
def forward(self, inputs):
def forward(self, inputs: torch.Tensor):
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
return self.forward_cuda(inputs)
return self.forward_native(inputs)
def forward_native(self, inputs: torch.Tensor):
weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
output = torch.nn.functional.linear(inputs, weight, bias)
return output
def forward_cuda(self, inputs: torch.Tensor):
quant_type = self.weight.quant_type
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
if self.bias is not None:
output += self.bias.to(self.compute_dtype)
return output
+2
View File
@@ -81,6 +81,8 @@ from .import_utils import (
is_invisible_watermark_available,
is_k_diffusion_available,
is_k_diffusion_version,
is_kernels_available,
is_kornia_available,
is_librosa_available,
is_matplotlib_available,
is_nltk_available,
+15
View File
@@ -62,6 +62,21 @@ class ClassifierFreeZeroStarGuidance(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class FrequencyDecoupledGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PerturbedAttentionGuidance(metaclass=DummyObject):
_backends = ["torch"]
+10
View File
@@ -192,6 +192,7 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_kernels_available, _kernels_version = _is_package_available("kernels")
_inflect_available, _inflect_version = _is_package_available("inflect")
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
@@ -223,6 +224,7 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
_kornia_available, _kornia_version = _is_package_available("kornia")
def is_torch_available():
@@ -277,6 +279,10 @@ def is_accelerate_available():
return _accelerate_available
def is_kernels_available():
return _kernels_available
def is_k_diffusion_available():
return _k_diffusion_available
@@ -393,6 +399,10 @@ def is_flash_attn_3_available():
return _flash_attn_3_available
def is_kornia_available():
return _kornia_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
-38
View File
@@ -197,20 +197,6 @@ def get_peft_kwargs(
"lora_bias": lora_bias,
}
# Example: try load FusionX LoRA into Wan VACE
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
if exclude_modules:
if not is_peft_version(">=", "0.14.0"):
msg = """
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
https://github.com/huggingface/diffusers/issues/new
"""
logger.debug(msg)
else:
lora_config_kwargs.update({"exclude_modules": exclude_modules})
return lora_config_kwargs
@@ -388,27 +374,3 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
if warn_msg:
logger.warning(warn_msg)
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
"""
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
doesn't exist in `peft_state_dict`.
"""
if model_state_dict is None:
return
all_modules = set()
string_to_replace = f"{adapter_name}." if adapter_name else ""
for name in model_state_dict.keys():
if string_to_replace:
name = name.replace(string_to_replace, "")
if "." in name:
module_name = name.rsplit(".", 1)[0]
all_modules.add(module_name)
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
exclude_modules = list(all_modules - target_modules_set)
return exclude_modules
+13
View File
@@ -36,6 +36,7 @@ from .import_utils import (
is_compel_available,
is_flax_available,
is_gguf_available,
is_kernels_available,
is_note_seq_available,
is_onnx_available,
is_opencv_available,
@@ -634,6 +635,18 @@ def require_torchao_version_greater_or_equal(torchao_version):
return decorator
def require_kernels_version_greater_or_equal(kernels_version):
def decorator(test_case):
correct_kernels_version = is_kernels_available() and version.parse(
version.parse(importlib.metadata.version("kernels")).base_version
) >= version.parse(kernels_version)
return unittest.skipUnless(
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
)(test_case)
return decorator
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend
+87
View File
@@ -17,7 +17,9 @@ import gc
import unittest
import torch
from parameterized import parameterized
from diffusers.hooks import HookRegistry, ModelHook
from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import get_logger
@@ -99,6 +101,29 @@ class DummyModelWithMultipleBlocks(ModelMixin):
return x
# Test for https://github.com/huggingface/diffusers/pull/12077
class DummyModelWithLayerNorm(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.blocks = torch.nn.ModuleList(
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
)
self.layer_norm = torch.nn.LayerNorm(hidden_features, elementwise_affine=True)
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = self.activation(x)
for block in self.blocks:
x = block(x)
x = self.layer_norm(x)
x = self.linear_2(x)
return x
class DummyPipeline(DiffusionPipeline):
model_cpu_offload_seq = "model"
@@ -113,6 +138,16 @@ class DummyPipeline(DiffusionPipeline):
return x
class LayerOutputTrackerHook(ModelHook):
def __init__(self):
super().__init__()
self.outputs = []
def post_forward(self, module, output):
self.outputs.append(output)
return output
@require_torch_accelerator
class GroupOffloadTests(unittest.TestCase):
in_features = 64
@@ -258,6 +293,7 @@ class GroupOffloadTests(unittest.TestCase):
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
model = DummyModelWithMultipleBlocks(
in_features=self.in_features,
hidden_features=self.hidden_features,
@@ -274,3 +310,54 @@ class GroupOffloadTests(unittest.TestCase):
with context:
model(self.input)
@parameterized.expand([("block_level",), ("leaf_level",)])
def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
for name, module in model.named_modules():
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = LayerOutputTrackerHook()
registry.register_hook(hook, "layer_output_tracker")
model_ref = DummyModelWithLayerNorm(128, 256, 128, 2)
model = DummyModelWithLayerNorm(128, 256, 128, 2)
model.load_state_dict(model_ref.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
apply_layer_output_tracker_hook(model_ref)
apply_layer_output_tracker_hook(model)
x = torch.randn(2, 128).to(torch_device)
out_ref = model_ref(x)
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
num_repeats = 4
for i in range(num_repeats):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.")
for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
assert ref_name == name
ref_outputs = (
HookRegistry.check_if_exists_or_initialize(ref_module).get_hook("layer_output_tracker").outputs
)
outputs = HookRegistry.check_if_exists_or_initialize(module).get_hook("layer_output_tracker").outputs
cumulated_absmax = 0.0
for i in range(len(outputs)):
diff = ref_outputs[0] - outputs[i]
absdiff = diff.abs()
absmax = absdiff.max().item()
cumulated_absmax += absmax
self.assertLess(
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
)
+129
View File
@@ -0,0 +1,129 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from diffusers import (
AutoencoderKLQwenImage,
FlowMatchEulerDiscreteScheduler,
QwenImagePipeline,
QwenImageTransformer2DModel,
)
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
transformer_cls = QwenImageTransformer2DModel
z_dim = 4
vae_kwargs = {
"base_dim": z_dim * 6,
"z_dim": z_dim,
"dim_mult": [1, 2, 4],
"num_res_blocks": 1,
"temperal_downsample": [False, True],
"latents_mean": [0.0] * 4,
"latents_std": [1.0] * 4,
}
vae_cls = AutoencoderKLQwenImage
tokenizer_cls, tokenizer_id = Qwen2Tokenizer, "hf-internal-testing/tiny-random-Qwen25VLForCondGen"
text_encoder_cls, text_encoder_id = (
Qwen2_5_VLForConditionalGeneration,
"hf-internal-testing/tiny-random-Qwen25VLForCondGen",
)
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
@property
def output_shape(self):
return (1, 8, 8, 3)
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
num_channels = 4
sizes = (32, 32)
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 4,
"guidance_scale": 0.0,
"height": 8,
"width": 8,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in Qwen Image.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Qwen Image.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Qwen Image.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_save_load(self):
pass
+4 -68
View File
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import os
import re
@@ -292,20 +291,6 @@ class PeftLoraLoaderMixinTests:
return modules_to_save
def _get_exclude_modules(self, pipe):
from diffusers.utils.peft_utils import _derive_exclude_modules
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
pipe.unload_lora_weights()
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
exclude_modules = _derive_exclude_modules(
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
)
return exclude_modules
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -2342,58 +2327,6 @@ class PeftLoraLoaderMixinTests:
)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
@require_peft_version_greater("0.13.2")
def test_lora_exclude_modules(self):
"""
Test to check if `exclude_modules` works or not. It works in the following way:
we first create a pipeline and insert LoRA config into it. We then derive a `set`
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
state dict.
We then create a new LoRA config to include the `exclude_modules` and perform tests.
"""
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
# only supported for `denoiser` now
pipe_cp = copy.deepcopy(pipe)
pipe_cp, _ = self.add_adapters_to_pipeline(
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
pipe_cp.to("cpu")
del pipe_cp
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdir)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
"LoRA should change outputs.",
)
self.assertTrue(
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
"Lora outputs should match.",
)
def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes:
@@ -2467,7 +2400,6 @@ class PeftLoraLoaderMixinTests:
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
@@ -2483,6 +2415,10 @@ class PeftLoraLoaderMixinTests:
num_blocks_per_group=1,
use_stream=use_stream,
)
# Place other model-level components on `torch_device`.
for _, component in pipe.components.items():
if isinstance(component, torch.nn.Module):
component.to(torch_device)
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_1 is not None)
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -20,7 +20,7 @@ import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, is_peft_available, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
@@ -172,6 +172,35 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# The test exists for cases like
# https://github.com/huggingface/diffusers/issues/11874
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_exclude_modules(self):
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
lora_rank = 4
target_module = "single_transformer_blocks.0.proj_out"
adapter_name = "foo"
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
state_dict = model.state_dict()
target_mod_shape = state_dict[f"{target_module}.weight"].shape
lora_state_dict = {
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
}
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
config = LoraConfig(
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
)
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
View File
@@ -0,0 +1,462 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
from typing import Any, Dict
import numpy as np
import torch
from PIL import Image
from diffusers import (
ClassifierFreeGuidance,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
)
from diffusers.loaders import ModularIPAdapterMixin
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ...models.unets.test_models_unet_2d_condition import (
create_ip_adapter_state_dict,
)
from ..test_modular_pipelines_common import (
ModularPipelineTesterMixin,
)
enable_full_determinism()
class SDXLModularTests:
"""
This mixin defines method to create pipeline, base input and base test across all SDXL modular tests.
"""
pipeline_class = StableDiffusionXLModularPipeline
pipeline_blocks_class = StableDiffusionXLAutoBlocks
repo = "hf-internal-testing/tiny-sdxl-modular"
params = frozenset(
[
"prompt",
"height",
"width",
"negative_prompt",
"cross_attention_kwargs",
"image",
"mask_image",
]
)
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
pipeline.load_default_components(torch_dtype=torch_dtype)
return pipeline
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"output_type": "np",
}
return inputs
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
sd_pipe = self.get_pipeline()
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs, output="images")
image_slice = image[0, -3:, -3:, -1]
assert image.shape == expected_image_shape
assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, (
"Image Slice does not match expected slice"
)
class SDXLModularIPAdapterTests:
"""
This mixin is designed to test IP Adapter.
"""
def test_pipeline_inputs_and_blocks(self):
blocks = self.pipeline_blocks_class()
parameters = blocks.input_names
assert issubclass(self.pipeline_class, ModularIPAdapterMixin)
assert "ip_adapter_image" in parameters, (
"`ip_adapter_image` argument must be supported by the `__call__` method"
)
assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block"
_ = blocks.sub_blocks.pop("ip_adapter")
parameters = blocks.input_names
assert "ip_adapter_image" not in parameters, (
"`ip_adapter_image` argument must be removed from the `__call__` method"
)
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((1, 1, cross_attention_dim), device=torch_device)
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device)
def _get_dummy_masks(self, input_size: int = 64):
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
_masks[0, :, :, : int(input_size / 2)] = 1
return _masks
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter")
parameters = blocks.input_names
if "image" in parameters and "strength" in parameters:
inputs["num_inference_steps"] = 4
inputs["output_type"] = "np"
return inputs
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for IP-Adapter.
The following scenarios are tested:
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter")
pipe = blocks.init_pipeline(self.repo)
pipe.load_default_components(torch_dtype=torch.float32)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
if expected_pipe_slice is None:
output_without_adapter = pipe(**inputs, output="images")
else:
output_without_adapter = expected_pipe_slice
# 1. Single IP-Adapter test cases
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
assert max_diff_without_adapter_scale < expected_max_diff, (
"Output without ip-adapter must be same as normal inference"
)
assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference"
# 2. Multi IP-Adapter test cases
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
# forward pass with multi ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([0.0, 0.0])
output_without_multi_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with multi ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([42.0, 42.0])
output_with_multi_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_multi_adapter_scale = np.abs(
output_without_multi_adapter_scale - output_without_adapter
).max()
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
assert max_diff_without_multi_adapter_scale < expected_max_diff, (
"Output without multi-ip-adapter must be same as normal inference"
)
assert max_diff_with_multi_adapter_scale > 1e-2, (
"Output with multi-ip-adapter scale must be different from normal inference"
)
class SDXLModularControlNetTests:
"""
This mixin is designed to test ControlNet.
"""
def test_pipeline_inputs(self):
blocks = self.pipeline_blocks_class()
parameters = blocks.input_names
assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method"
assert "controlnet_conditioning_scale" in parameters, (
"`controlnet_conditioning_scale` argument must be supported by the `__call__` method"
)
def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]):
controlnet_embedder_scale_factor = 2
image = torch.randn(
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
device=torch_device,
)
inputs["control_image"] = image
return inputs
def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for ControlNet.
The following scenarios are tested:
- Single ControlNet with scale=0 should produce same output as no ControlNet.
- Single ControlNet with scale!=0 should produce different output compared to no ControlNet.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass without controlnet
inputs = self.get_dummy_inputs(torch_device)
output_without_controlnet = pipe(**inputs, output="images")
output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten()
# forward pass with single controlnet, but scale=0 which should have no effect
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
inputs["controlnet_conditioning_scale"] = 0.0
output_without_controlnet_scale = pipe(**inputs, output="images")
output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten()
# forward pass with single controlnet, but with scale of adapter weights
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
inputs["controlnet_conditioning_scale"] = 42.0
output_with_controlnet_scale = pipe(**inputs, output="images")
output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten()
max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max()
max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max()
assert max_diff_without_controlnet_scale < expected_max_diff, (
"Output without controlnet must be same as normal inference"
)
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
def test_controlnet_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
max_diff = np.abs(out_cfg - out_no_cfg).max()
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
class SDXLModularGuiderTests:
def test_guider_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs(torch_device)
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs(torch_device)
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
max_diff = np.abs(out_cfg - out_no_cfg).max()
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
class SDXLModularPipelineFastTests(
SDXLModularTests,
SDXLModularIPAdapterTests,
SDXLModularControlNetTests,
SDXLModularGuiderTests,
ModularPipelineTesterMixin,
unittest.TestCase,
):
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=(1, 64, 64, 3),
expected_slice=[
0.5966781,
0.62939394,
0.48465094,
0.51573336,
0.57593524,
0.47035995,
0.53410417,
0.51436996,
0.47313565,
],
expected_max_diff=1e-2,
)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
class SDXLImg2ImgModularPipelineFastTests(
SDXLModularTests,
SDXLModularIPAdapterTests,
SDXLModularControlNetTests,
SDXLModularGuiderTests,
ModularPipelineTesterMixin,
unittest.TestCase,
):
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
inputs["image"] = image
inputs["strength"] = 0.8
return inputs
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=(1, 64, 64, 3),
expected_slice=[
0.56943184,
0.4702148,
0.48048905,
0.6235963,
0.551138,
0.49629188,
0.60031277,
0.5688907,
0.43996853,
],
expected_max_diff=1e-2,
)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
class SDXLInpaintingModularPipelineFastTests(
SDXLModularTests,
SDXLModularIPAdapterTests,
SDXLModularControlNetTests,
SDXLModularGuiderTests,
ModularPipelineTesterMixin,
unittest.TestCase,
):
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed)
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
# create mask
image[8:, 8:, :] = 255
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
inputs["image"] = init_image
inputs["mask_image"] = mask_image
inputs["strength"] = 1.0
return inputs
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=(1, 64, 64, 3),
expected_slice=[
0.40872607,
0.38842705,
0.34893104,
0.47837183,
0.43792963,
0.5332134,
0.3716843,
0.47274873,
0.45000193,
],
expected_max_diff=1e-2,
)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@@ -0,0 +1,358 @@
import gc
import tempfile
import unittest
from typing import Callable, Union
import numpy as np
import torch
import diffusers
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
require_torch,
torch_device,
)
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
@require_torch
class ModularPipelineTesterMixin:
"""
This mixin is designed to be used with unittest.TestCase classes.
It provides a set of common tests for each modular pipeline,
including:
- test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
- test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
- test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
- test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
- test_to_device: check if the pipeline's __call__ method can handle different devices
"""
# Canonical parameters that are passed to `__call__` regardless
# of the type of pipeline. They are always optional and have common
# sense default values.
optional_params = frozenset(
[
"num_inference_steps",
"num_images_per_prompt",
"latents",
"output_type",
]
)
# this is modular specific: generator needs to be a intermediate input because it's mutable
intermediate_params = frozenset(
[
"generator",
]
)
def get_generator(self, seed):
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device).manual_seed(seed)
return generator
@property
def pipeline_class(self) -> Union[Callable, ModularPipeline]:
raise NotImplementedError(
"You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
"See existing pipeline tests for reference."
)
@property
def repo(self) -> str:
raise NotImplementedError(
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
)
@property
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
raise NotImplementedError(
"You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
"See existing pipeline tests for reference."
)
def get_pipeline(self):
raise NotImplementedError(
"You need to implement `get_pipeline(self)` in the child test class. "
"See existing pipeline tests for reference."
)
def get_dummy_inputs(self, device, seed=0):
raise NotImplementedError(
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
"See existing pipeline tests for reference."
)
@property
def params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `params` in the child test class. "
"`params` are checked for if all values are present in `__call__`'s signature."
" You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
"image pipelines, including prompts and prompt embedding overrides."
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
"do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
"with non-configurable height and width arguments should set the attribute as "
"`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
"See existing pipeline tests for reference."
)
@property
def batch_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `batch_params` in the child test class. "
"`batch_params` are the parameters required to be batched when passed to the pipeline's "
"`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
"`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
"set of batch arguments has minor changes from one of the common sets of batch arguments, "
"do not make modifications to the existing common sets of batch arguments. I.e. a text to "
"image pipeline `negative_prompt` is not batched should set the attribute as "
"`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
"See existing pipeline tests for reference."
)
def setUp(self):
# clean up the VRAM before each test
super().setUp()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def test_pipeline_call_signature(self):
pipe = self.get_pipeline()
input_parameters = pipe.blocks.input_names
optional_parameters = pipe.default_call_parameters
def _check_for_parameters(parameters, expected_parameters, param_type):
remaining_parameters = {param for param in parameters if param not in expected_parameters}
assert len(remaining_parameters) == 0, (
f"Required {param_type} parameters not present: {remaining_parameters}"
)
_check_for_parameters(self.params, input_parameters, "input")
_check_for_parameters(self.optional_params, optional_parameters, "optional")
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
pipe = self.get_pipeline()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# prepare batched inputs
batched_inputs = []
for batch_size in batch_sizes:
batched_input = {}
batched_input.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
batched_input[name] = batch_size * [value]
if batch_generator and "generator" in inputs:
batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_input["batch_size"] = batch_size
batched_inputs.append(batched_input)
logger.setLevel(level=diffusers.logging.WARNING)
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
output = pipe(**batched_input, output="images")
assert len(output) == batch_size, "Output is different from expected batch size"
def test_inference_batch_single_identical(
self,
batch_size=2,
expected_max_diff=1e-4,
):
pipe = self.get_pipeline()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is has been used in self.get_dummy_inputs
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
batched_inputs = {}
batched_inputs.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
batched_inputs[name] = batch_size * [value]
if "generator" in inputs:
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
output = pipe(**inputs, output="images")
output_batch = pipe(**batched_inputs, output="images")
assert output_batch.shape[0] == batch_size
max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
def test_float16_inference(self, expected_max_diff=5e-2):
pipe = self.get_pipeline()
pipe.to(torch_device, torch.float32)
pipe.set_progress_bar_config(disable=None)
pipe_fp16 = self.get_pipeline()
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs, output="images")
fp16_inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
if isinstance(output, torch.Tensor):
output = output.cpu()
output_fp16 = output_fp16.cpu()
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
@require_accelerator
def test_to_device(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU"
pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
assert all(device == torch_device for device in model_devices), (
"All pipeline components are not on accelerator device"
)
def test_inference_is_not_nan_cpu(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
output = pipe(**self.get_dummy_inputs("cpu"), output="images")
assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN"
@require_accelerator
def test_inference_is_not_nan(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
output = pipe(**self.get_dummy_inputs(torch_device), output="images")
assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN"
def test_num_images_per_prompt(self):
pipe = self.get_pipeline()
if "num_images_per_prompt" not in pipe.blocks.input_names:
return
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
batch_sizes = [1, 2]
num_images_per_prompts = [1, 2]
for batch_size in batch_sizes:
for num_images_per_prompt in num_images_per_prompts:
inputs = self.get_dummy_inputs(torch_device)
for key in inputs.keys():
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
assert images.shape[0] == batch_size * num_images_per_prompt
@require_accelerator
def test_components_auto_cpu_offload_inference_consistent(self):
base_pipe = self.get_pipeline().to(torch_device)
cm = ComponentsManager()
cm.enable_auto_cpu_offload(device=torch_device)
offload_pipe = self.get_pipeline(components_manager=cm)
image_slices = []
for pipe in [base_pipe, offload_pipe]:
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_save_from_pretrained(self):
pipes = []
base_pipe = self.get_pipeline().to(torch_device)
pipes.append(base_pipe)
with tempfile.TemporaryDirectory() as tmpdirname:
base_pipe.save_pretrained(tmpdirname)
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
pipe.load_default_components(torch_dtype=torch.float32)
pipe.to(torch_device)
pipes.append(pipe)
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
+11 -1
View File
@@ -45,6 +45,7 @@ from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import is_transformers_version
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
@@ -220,6 +221,11 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
}
return inputs
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.54.1"),
reason="Test currently fails on Transformers version 4.54.1.",
strict=False,
)
def test_audioldm2_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -312,7 +318,6 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
components = self.get_dummy_components()
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
@@ -371,6 +376,11 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(audio_1 - audio_2).max() < 1e-2
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.54.1"),
reason="Test currently fails on Transformers version 4.54.1.",
strict=False,
)
def test_audioldm2_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
+31 -26
View File
@@ -20,12 +20,6 @@ TEXT_TO_IMAGE_PARAMS = frozenset(
]
)
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
IMAGE_VARIATION_PARAMS = frozenset(
[
"image",
@@ -35,8 +29,6 @@ IMAGE_VARIATION_PARAMS = frozenset(
]
)
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
[
"prompt",
@@ -50,8 +42,6 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
]
)
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
[
# Text guided image variation with an image mask
@@ -67,8 +57,6 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
]
)
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
IMAGE_INPAINTING_PARAMS = frozenset(
[
# image variation with an image mask
@@ -80,8 +68,6 @@ IMAGE_INPAINTING_PARAMS = frozenset(
]
)
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
[
"example_image",
@@ -93,20 +79,12 @@ IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
]
)
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
TEXT_TO_AUDIO_PARAMS = frozenset(
[
"prompt",
@@ -119,11 +97,38 @@ TEXT_TO_AUDIO_PARAMS = frozenset(
]
)
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
# image params
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
# batch params
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])
# callback params
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
+1 -1
View File
@@ -160,7 +160,7 @@ class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self.assertEqual(generated_image.shape, (3, 32, 32))
# fmt: off
expected_slice = torch.tensor([0.563, 0.6358, 0.6028, 0.5656, 0.5806, 0.5512, 0.5712, 0.6331, 0.4147, 0.3558, 0.5625, 0.4831, 0.4957, 0.5258, 0.4075, 0.5018])
expected_slice = torch.tensor([0.56331, 0.63677, 0.6015, 0.56369, 0.58166, 0.55277, 0.57176, 0.63261, 0.41466, 0.35561, 0.56229, 0.48334, 0.49714, 0.52622, 0.40872, 0.50208])
# fmt: on
generated_slice = generated_image.flatten()
+1
View File
@@ -886,6 +886,7 @@ class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
components_to_quantize=["transformer", "text_encoder_2"],
)
@require_bitsandbytes_version_greater("0.46.1")
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super().test_torch_compile()
@@ -847,6 +847,10 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
components_to_quantize=["transformer", "text_encoder_2"],
)
@pytest.mark.xfail(
reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
" Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
)
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(torch_dtype=torch.float16)
+69 -1
View File
@@ -30,8 +30,10 @@ from diffusers.utils.testing_utils import (
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
require_accelerator,
require_big_accelerator,
require_gguf_version_greater_or_equal,
require_kernels_version_greater_or_equal,
require_peft_backend,
require_torch_version_greater,
torch_device,
@@ -41,11 +43,66 @@ from ..test_torch_compile_utils import QuantCompileTests
if is_gguf_available():
import gguf
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
enable_full_determinism()
@nightly
@require_accelerate
@require_accelerator
@require_gguf_version_greater_or_equal("0.10.0")
@require_kernels_version_greater_or_equal("0.9.0")
class GGUFCudaKernelsTests(unittest.TestCase):
def setUp(self):
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)
def test_cuda_kernels_vs_native(self):
if torch_device != "cuda":
self.skipTest("CUDA kernels test requires CUDA device")
from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels
if not can_use_cuda_kernels:
self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)")
test_quant_types = ["Q4_0", "Q4_K"]
test_shape = (1, 64, 512) # batch, seq_len, hidden_dim
compute_dtype = torch.bfloat16
for quant_type in test_quant_types:
qtype = getattr(gguf.GGMLQuantizationType, quant_type)
in_features, out_features = 512, 512
torch.manual_seed(42)
float_weight = torch.randn(out_features, in_features, dtype=torch.float32)
quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype)
weight_data = torch.from_numpy(quantized_data).to(device=torch_device)
weight = GGUFParameter(weight_data, quant_type=qtype)
x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device)
linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype)
linear.weight = weight
linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype))
linear = linear.to(torch_device)
with torch.no_grad():
output_native = linear.forward_native(x)
output_cuda = linear.forward_cuda(x)
assert torch.allclose(output_native, output_cuda, 1e-2), (
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
)
@nightly
@require_big_accelerator
@require_accelerate
@@ -155,6 +212,7 @@ class GGUFSingleFileTesterMixin:
class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
diffusers_ckpt_path = "https://huggingface.co/sayakpaul/flux-diffusers-gguf/blob/main/model-Q4_0.gguf"
torch_dtype = torch.bfloat16
model_cls = FluxTransformer2DModel
expected_memory_use_in_gb = 5
@@ -239,6 +297,16 @@ class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4
def test_loading_gguf_diffusers_format(self):
model = self.model_cls.from_single_file(
self.diffusers_ckpt_path,
subfolder="transformer",
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
config="black-forest-labs/FLUX.1-dev",
)
model.to("cuda")
model(**self.get_dummy_inputs())
class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"
@@ -593,7 +661,7 @@ class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.Test
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
"hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(
@@ -56,12 +56,18 @@ class QuantCompileTests:
pipe.transformer.compile(fullgraph=True)
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
with torch._dynamo.config.patch(error_on_recompile=True):
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
pipe.enable_model_cpu_offload()
pipe.transformer.compile()
# regional compilation is better for offloading.
# see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/
if getattr(pipe.transformer, "_repeated_blocks"):
pipe.transformer.compile_repeated_blocks(fullgraph=True)
else:
pipe.transformer.compile()
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)