Compare commits

...

18 Commits

Author SHA1 Message Date
DN6 e6b52a2c2b update 2025-08-12 09:33:30 +05:30
Steven Liu f8ba5cd77a [docs] Cache link (#12105)
cache
2025-08-11 11:03:59 -07:00
Sayak Paul c9c8217306 [chore] complete the licensing statement. (#12001)
complete the licensing statement.
2025-08-11 22:15:15 +05:30
Aryan 135df5be9d [tests] Add inference test slices for SD3 and remove unnecessary tests (#12106)
* update

* nuke LoC for inference slices
2025-08-11 18:36:09 +05:30
Sayak Paul 4a9dbd56f6 enable compilation in qwen image. (#12061)
* update

* update

* update

* enable compilation in qwen image.

* add tests

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2025-08-11 14:37:37 +05:30
Dhruv Nair 630d27fe5b [Modular] More Updates for Custom Code Loading (#11969)
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-08-11 13:26:58 +05:30
Sayak Paul f442955c6e [lora] support loading loras from lightx2v/Qwen-Image-Lightning (#12119)
* feat: support qwen lightning lora.

* add docs.

* fix
2025-08-11 09:27:10 +05:30
Sayak Paul ff9a387618 [core] add modular support for Flux I2I (#12086)
* start

* encoder.

* up

* up

* up

* up

* up

* up
2025-08-11 07:23:23 +05:30
Sayak Paul 03c3f69aa5 [docs] diffusers gguf checkpoints (#12092)
* feat: support loading diffusers format gguf checkpoints.

* update

* update

* qwen

* up

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* up

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-08-09 08:49:49 +05:30
Sayak Paul f20aba3e87 [GGUF] feat: support loading diffusers format gguf checkpoints. (#11684)
* feat: support loading diffusers format gguf checkpoints.

* update

* update

* qwen

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>
2025-08-08 22:27:15 +05:30
YiYi Xu ccf2c31188 [Modular] Fast Tests (#11937)
* rearrage the params to groups: default params /image params /batch params / callback params

* make style

* add names property to pipeline blocks

* style

* remove more unused func

* prepare_latents_inpaint always return noise and image_latents

* up

* up

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>
2025-08-08 19:42:13 +05:30
Sayak Paul 7b10e4ae65 [tests] device placement for non-denoiser components in group offloading LoRA tests (#12103)
up
2025-08-08 13:34:29 +05:30
Beinsezii 3c0531bc50 lora_conversion_utils: replace lora up/down with a/b even if transformer. in key (#12101)
lora_conversion_utils: replace lora up/down with a/b even if transformer. in key

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-08-08 11:21:47 +05:30
Sayak Paul a8e47978c6 [lora] adapt new LoRA config injection method (#11999)
* use state dict when setting up LoRA.

* up

* up

* up

* comment

* up

* up
2025-08-08 09:22:48 +05:30
YiYi Xu 50e18ee698 [qwen] device typo (#12099)
up
2025-08-07 12:27:39 -10:00
DefTruth 4b17fa2a2e fix flux type hint (#12089)
fix-flux-type-hint
2025-08-07 13:00:15 +05:30
dg845 d45199a2f1 Implement Frequency-Decoupled Guidance (FDG) as a Guider (#11976)
* Initial commit implementing frequency-decoupled guidance (FDG) as a guider

* Update FrequencyDecoupledGuidance docstring to describe FDG

* Update project so that it accepts any number of non-batch dims

* Change guidance_scale and other params to accept a list of params for each freq level

* Add comment with Laplacian pyramid shapes

* Add function to import_utils to check if the kornia package is available

* Only import from kornia if package is available

* Fix bug: use pred_cond/uncond in freq space rather than data space

* Allow guidance rescaling to be done in data space or frequency space (speculative)

* Add kornia install instructions to kornia import error message

* Add config to control whether operations are upcast to fp64

* Add parallel_weights recommended values to docstring

* Apply style fixes

* make fix-copies

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Aryan <aryan@huggingface.co>
2025-08-07 11:21:02 +05:30
Sayak Paul 061163142d [tests] tighten compilation tests for quantization (#12002)
* tighten compilation tests for quantization

* up

* up
2025-08-07 10:13:14 +05:30
110 changed files with 2663 additions and 1332 deletions
+141
View File
@@ -0,0 +1,141 @@
name: Fast PR tests for Modular
on:
pull_request:
branches: [main]
paths:
- "src/diffusers/modular_pipelines/**.py"
- "src/diffusers/models/modeling_utils.py"
- "src/diffusers/models/model_loading_utils.py"
- "src/diffusers/pipelines/pipeline_utils.py"
- "src/diffusers/pipeline_loading_utils.py"
- "src/diffusers/loaders/lora_base.py"
- "src/diffusers/loaders/lora_pipeline.py"
- "src/diffusers/loaders/peft.py"
- "tests/modular_pipelines/**.py"
- ".github/**.yml"
- "utils/**.py"
- "setup.py"
push:
branches:
- ci-*
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
env:
DIFFUSERS_IS_CI: yes
HF_HUB_ENABLE_HF_TRANSFER: 1
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check quality
run: make quality
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
check_repository_consistency:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check repo consistency
run: |
python utils/check_copies.py
python utils/check_dummies.py
python utils/check_support_list.py
make deps_table_check_updated
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
run_fast_tests:
needs: [check_code_quality, check_repository_consistency]
strategy:
fail-fast: false
matrix:
config:
- name: Fast PyTorch Modular Pipeline CPU tests
framework: pytorch_pipelines
runner: aws-highmemory-32-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_modular_pipelines
name: ${{ matrix.config.name }}
runs-on:
group: ${{ matrix.config.runner }}
container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run fast PyTorch Pipeline CPU tests
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/modular_pipelines
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
path: reports
+2
View File
@@ -25,6 +25,8 @@ Original model checkpoints for Flux can be found [here](https://huggingface.co/b
Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
[Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
</Tip>
Flux comes in the following variants:
+1 -1
View File
@@ -18,7 +18,7 @@
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
[Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
</Tip>
+1 -1
View File
@@ -88,7 +88,7 @@ export_to_video(video, "output.mp4", fps=24)
</hfoption>
<hfoption id="inference speed">
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
```py
import torch
+58 -1
View File
@@ -20,10 +20,67 @@ Check out the model card [here](https://huggingface.co/Qwen/Qwen-Image) to learn
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
[Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
</Tip>
## LoRA for faster inference
Use a LoRA from `lightx2v/Qwen-Image-Lightning` to speed up inference by reducing the
number of steps. Refer to the code snippet below:
<details>
<summary>Code</summary>
```py
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
import torch
import math
ckpt_id = "Qwen/Qwen-Image"
# From
# https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": math.log(3), # We use shift=3 in distillation
"invert_sigmas": False,
"max_image_seq_len": 8192,
"max_shift": math.log(3), # We use shift=3 in distillation
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None, # set shift_terminal to None
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
pipe = DiffusionPipeline.from_pretrained(
ckpt_id, scheduler=scheduler, torch_dtype=torch.bfloat16
).to("cuda")
pipe.load_lora_weights(
"lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.0.safetensors"
)
prompt = "a tiny astronaut hatching from an egg on the moon, Ultra HD, 4K, cinematic composition."
negative_prompt = " "
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=1024,
height=1024,
num_inference_steps=8,
true_cfg_scale=1.0,
generator=torch.manual_seed(0),
).images[0]
image.save("qwen_fewsteps.png")
```
</details>
## QwenImagePipeline
[[autodoc]] QwenImagePipeline
+1 -1
View File
@@ -119,7 +119,7 @@ export_to_video(output, "output.mp4", fps=16)
</hfoption>
<hfoption id="T2V inference speed">
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
```py
# pip install ftfy
+41
View File
@@ -77,3 +77,44 @@ Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels
- Q5_K
- Q6_K
## Convert to GGUF
Use the Space below to convert a Diffusers checkpoint into the GGUF format for inference.
run conversion:
<iframe
src="https://diffusers-internal-dev-diffusers-to-gguf.hf.space"
frameborder="0"
width="850"
height="450"
></iframe>
```py
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
ckpt_path = (
"https://huggingface.co/sayakpaul/different-lora-from-civitai/blob/main/flux_dev_diffusers-q4_0.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
config="black-forest-labs/FLUX.1-dev",
subfolder="transformer",
torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
```
When using Diffusers format GGUF checkpoints, it's a must to provide the model `config` path. If the
model config resides in a `subfolder`, that needs to be specified, too.
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
+1
View File
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import itertools
+1
View File
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -13,6 +13,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import io
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import typing
@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import functools
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
+1 -1
View File
@@ -116,7 +116,7 @@ _deps = [
"librosa",
"numpy",
"parameterized",
"peft>=0.15.0",
"peft>=0.17.0",
"protobuf>=3.20.3,<4",
"pytest",
"pytest-timeout",
+2
View File
@@ -139,6 +139,7 @@ else:
"AutoGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"FrequencyDecoupledGuidance",
"PerturbedAttentionGuidance",
"SkipLayerGuidance",
"SmoothedEnergyGuidance",
@@ -804,6 +805,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
+1 -1
View File
@@ -23,7 +23,7 @@ deps = {
"librosa": "librosa",
"numpy": "numpy",
"parameterized": "parameterized",
"peft": "peft>=0.15.0",
"peft": "peft>=0.17.0",
"protobuf": "protobuf>=3.20.3,<4",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
+2
View File
@@ -22,6 +22,7 @@ if is_torch_available():
from .auto_guidance import AutoGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance
@@ -32,6 +33,7 @@ if is_torch_available():
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
@@ -0,0 +1,327 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from ..utils import is_kornia_available
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
_CAN_USE_KORNIA = is_kornia_available()
if _CAN_USE_KORNIA:
from kornia.geometry import pyrup as upsample_and_blur_func
from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func
else:
upsample_and_blur_func = None
build_laplacian_pyramid_func = None
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
(Algorithm 2).
"""
# v0 shape: [B, ...]
# v1 shape: [B, ...]
# Assume first dim is a batch dim and all other dims are channel or "spatial" dims
all_dims_but_first = list(range(1, len(v0.shape)))
if upcast_to_double:
dtype = v0.dtype
v0, v1 = v0.double(), v1.double()
v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
if upcast_to_double:
v0_parallel = v0_parallel.to(dtype)
v0_orthogonal = v0_orthogonal.to(dtype)
return v0_parallel, v0_orthogonal
def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
"""
Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
(Algorihtm 2).
"""
# pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...]
img = pyramid[-1]
for i in range(len(pyramid) - 2, -1, -1):
img = upsample_and_blur_func(img) + pyramid[i]
return img
class FrequencyDecoupledGuidance(BaseGuidance):
"""
Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713
FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation
quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both
conditional and unconditional data, and use a combination of the two during inference. (If you want more details on
how CFG works, you can check out the CFG guider.)
FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components
using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space
separately for the low- and high-frequency components with different guidance scales. Finally, the inverse
frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images)
to form the final FDG prediction.
For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample
diversity and realistic color composition, while using high guidance scales for high-frequency components enhances
sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for
the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an
example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper).
As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
descending order).
guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
`guidance_scales`.
parallel_weights (`float` or `List[float]`, *optional*):
Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
(that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
recommended. If a list is supplied, it should be the same length as `guidance_scales`.
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float` or `List[float]`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
should be the same length as `guidance_scales`.
stop (`float` or `List[float]`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
should be the same length as `guidance_scales`.
guidance_rescale_space (`str`, defaults to `"data"`):
Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in
`"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
upcast_to_double (`bool`, defaults to `True`):
Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
float64 when performing guidance. This may result in better performance at the cost of increased runtime.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
use_original_formulation: bool = False,
start: Union[float, List[float], Tuple[float]] = 0.0,
stop: Union[float, List[float], Tuple[float]] = 1.0,
guidance_rescale_space: str = "data",
upcast_to_double: bool = True,
):
if not _CAN_USE_KORNIA:
raise ImportError(
"The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which "
"it depends is not available in the current environment. You can install `kornia` with `pip install "
"kornia`."
)
# Set start to earliest start for any freq component and stop to latest stop for any freq component
min_start = start if isinstance(start, float) else min(start)
max_stop = stop if isinstance(stop, float) else max(stop)
super().__init__(min_start, max_stop)
self.guidance_scales = guidance_scales
self.levels = len(guidance_scales)
if isinstance(guidance_rescale, float):
self.guidance_rescale = [guidance_rescale] * self.levels
elif len(guidance_rescale) == self.levels:
self.guidance_rescale = guidance_rescale
else:
raise ValueError(
f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as "
f"`guidance_scales` ({len(self.guidance_scales)})"
)
# Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after
# transforming from frequency space back to data space)
if guidance_rescale_space not in ["data", "freq"]:
raise ValueError(
f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`."
)
self.guidance_rescale_space = guidance_rescale_space
if parallel_weights is None:
# Use normal CFG shift (equal weights for parallel and orthogonal components)
self.parallel_weights = [1.0] * self.levels
elif isinstance(parallel_weights, float):
self.parallel_weights = [parallel_weights] * self.levels
elif len(parallel_weights) == self.levels:
self.parallel_weights = parallel_weights
else:
raise ValueError(
f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as "
f"`guidance_scales` ({len(self.guidance_scales)})"
)
self.use_original_formulation = use_original_formulation
self.upcast_to_double = upcast_to_double
if isinstance(start, float):
self.guidance_start = [start] * self.levels
elif len(start) == self.levels:
self.guidance_start = start
else:
raise ValueError(
f"`start` has length {len(start)} but should have the same length as `guidance_scales` "
f"({len(self.guidance_scales)})"
)
if isinstance(stop, float):
self.guidance_stop = [stop] * self.levels
elif len(stop) == self.levels:
self.guidance_stop = stop
else:
raise ValueError(
f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` "
f"({len(self.guidance_scales)})"
)
def prepare_inputs(
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_fdg_enabled():
pred = pred_cond
else:
# Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions.
pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels)
pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels)
# From high frequencies to low frequencies, following the paper implementation
pred_guided_pyramid = []
parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale)
for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters):
if self._is_fdg_enabled_for_level(level):
# Get the cond/uncond preds (in freq space) at the current frequency level
pred_cond_freq = pred_cond_pyramid[level]
pred_uncond_freq = pred_uncond_pyramid[level]
shift = pred_cond_freq - pred_uncond_freq
# Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
if not math.isclose(parallel_weight, 1.0):
shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
shift = parallel_weight * shift_parallel + shift_orthogonal
# Apply CFG update for the current frequency level
pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq
pred = pred + guidance_scale * shift
if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale)
# Add the current FDG guided level to the FDG prediction pyramid
pred_guided_pyramid.append(pred)
else:
# Add the current pred_cond_pyramid level as the "non-FDG" prediction
pred_guided_pyramid.append(pred_cond_freq)
# Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
pred = build_image_from_pyramid(pred_guided_pyramid)
# If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value
# across all freq levels
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_fdg_enabled():
num_conditions += 1
return num_conditions
def _is_fdg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales)
else:
is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales)
return is_within_range and not is_close
def _is_fdg_enabled_for_level(self, level: int) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.guidance_start[level] * self._num_inference_steps)
skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scales[level], 0.0)
else:
is_close = math.isclose(self.guidance_scales[level], 1.0)
return is_within_range and not is_close
+41 -1
View File
@@ -817,7 +817,11 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
# has both `peft` and non-peft state dict.
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
if has_peft_state_dict:
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
state_dict = {
k.replace("lora_down.weight", "lora_A.weight").replace("lora_up.weight", "lora_B.weight"): v
for k, v in state_dict.items()
if k.startswith("transformer.")
}
return state_dict
# Another weird one.
@@ -2073,3 +2077,39 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
up_key = ".lora_up.weight"
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
for k in all_keys:
if k.endswith(down_key):
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
alpha_key = k.replace(down_key, ".alpha")
down_weight = state_dict.pop(k)
up_weight = state_dict.pop(k.replace(down_key, up_key))
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[diffusers_down_key] = down_weight * scale_down
converted_state_dict[diffusers_up_key] = up_weight * scale_up
if len(state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
+5 -1
View File
@@ -49,6 +49,7 @@ from .lora_conversion_utils import (
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_ltxv_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_qwen_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
@@ -6548,7 +6549,6 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -6642,6 +6642,10 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
if has_alphas_in_sd:
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
+3 -1
View File
@@ -320,7 +320,9 @@ class PeftAdapterMixin:
# it to None
incompatible_keys = None
else:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
inject_adapter_in_model(
lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
if self._prepare_lora_hotswap_kwargs is not None:
+21 -9
View File
@@ -153,9 +153,17 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"QwenImageTransformer2DModel": {
"checkpoint_mapping_fn": lambda x: x,
"default_subfolder": "transformer",
},
}
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
def _get_single_file_loadable_mapping_class(cls):
diffusers_module = importlib.import_module(__name__.split(".")[0])
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
@@ -381,19 +389,23 @@ class FromOriginalModelMixin:
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
diffusers_model_config.update(model_kwargs)
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
else:
diffusers_format_checkpoint = checkpoint
if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
@@ -60,6 +60,7 @@ if is_accelerate_available():
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
CHECKPOINT_KEY_NAMES = {
"v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
@@ -384,7 +384,7 @@ class FluxSingleTransformerBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -13,6 +13,7 @@
# limitations under the License.
import functools
import math
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -162,7 +163,7 @@ class QwenEmbedRope(nn.Module):
self.axes_dim = axes_dim
pos_index = torch.arange(1024)
neg_index = torch.arange(1024).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
@@ -170,7 +171,7 @@ class QwenEmbedRope(nn.Module):
],
dim=1,
)
self.neg_freqs = torch.cat(
neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
@@ -179,6 +180,8 @@ class QwenEmbedRope(nn.Module):
dim=1,
)
self.rope_cache = {}
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
self.register_buffer("neg_freqs", neg_freqs, persistent=False)
# 是否使用 scale rope
self.scale_rope = scale_rope
@@ -198,33 +201,17 @@ class QwenEmbedRope(nn.Module):
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
frame, height, width = video_fhw
rope_key = f"{frame}_{height}_{width}"
if rope_key not in self.rope_cache:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
self.rope_cache[rope_key] = freqs.clone().contiguous()
vid_freqs = self.rope_cache[rope_key]
if not torch.compiler.is_compiling():
if rope_key not in self.rope_cache:
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width)
vid_freqs = self.rope_cache[rope_key]
else:
vid_freqs = self._compute_video_freqs(frame, height, width)
if self.scale_rope:
max_vid_index = max(height // 2, width // 2)
@@ -236,6 +223,25 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
class QwenDoubleStreamAttnProcessor2_0:
"""
@@ -482,6 +488,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
_supports_gradient_checkpointing = True
_no_split_modules = ["QwenImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["QwenImageTransformerBlock"]
@register_to_config
def __init__(
+8 -13
View File
@@ -7,9 +7,15 @@ from ..utils import (
get_objects_from_module,
is_torch_available,
is_transformers_available,
logging,
)
logger = logging.get_logger(__name__)
logger.warning(
"Modular Diffusers is currently an experimental feature under active development. The API is subject to breaking changes in future releases."
)
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
_import_structure = {}
@@ -25,7 +31,6 @@ else:
_import_structure["modular_pipeline"] = [
"ModularPipelineBlocks",
"ModularPipeline",
"PipelineBlock",
"AutoPipelineBlocks",
"SequentialPipelineBlocks",
"LoopSequentialPipelineBlocks",
@@ -59,21 +64,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LoopSequentialPipelineBlocks,
ModularPipeline,
ModularPipelineBlocks,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
)
from .modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
InputParam,
InsertableDict,
OutputParam,
)
from .stable_diffusion_xl import (
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
)
from .modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, InsertableDict, OutputParam
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from .wan import WanAutoBlocks, WanModularPipeline
else:
import sys
@@ -13,15 +13,16 @@
# limitations under the License.
import inspect
from typing import List, Optional, Union
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from ...models import AutoencoderKL
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline
@@ -103,6 +104,62 @@ def calculate_shift(
return mu
# Adapted from the original implementation.
def prepare_latents_img2img(
vae, scheduler, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
latent_channels = vae.config.latent_channels
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
image = image.to(device=device, dtype=dtype)
if image.shape[1] != latent_channels:
image_latents = _encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = scheduler.scale_noise(image_latents, timestep, noise)
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, latent_image_ids
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
@@ -125,7 +182,56 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
return latent_image_ids.to(device=device, dtype=dtype)
class FluxInputStep(PipelineBlock):
# Cannot use "# Copied from" because it introduces weird indentation errors.
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(vae.encode(image), generator=generator)
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
return image_latents
def _get_initial_timesteps_and_optionals(
transformer,
scheduler,
batch_size,
height,
width,
vae_scale_factor,
num_inference_steps,
guidance_scale,
sigmas,
device,
):
image_seq_len = (int(height) // vae_scale_factor // 2) * (int(width) // vae_scale_factor // 2)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
sigmas = None
mu = calculate_shift(
image_seq_len,
scheduler.config.get("base_image_seq_len", 256),
scheduler.config.get("max_image_seq_len", 4096),
scheduler.config.get("base_shift", 0.5),
scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
if transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(batch_size)
else:
guidance = None
return timesteps, num_inference_steps, sigmas, guidance
class FluxInputStep(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -143,11 +249,6 @@ class FluxInputStep(PipelineBlock):
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"prompt_embeds",
required=True,
@@ -216,7 +317,7 @@ class FluxInputStep(PipelineBlock):
return components, state
class FluxSetTimestepsStep(PipelineBlock):
class FluxSetTimestepsStep(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -235,17 +336,15 @@ class FluxSetTimestepsStep(PipelineBlock):
InputParam("sigmas"),
InputParam("guidance_scale", default=3.5),
InputParam("latents", type_hint=torch.Tensor),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam(
"latents",
"batch_size",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
)
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
),
]
@property
@@ -264,39 +363,127 @@ class FluxSetTimestepsStep(PipelineBlock):
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
scheduler = components.scheduler
transformer = components.transformer
latents = block_state.latents
image_seq_len = latents.shape[1]
num_inference_steps = block_state.num_inference_steps
sigmas = block_state.sigmas
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
sigmas = None
batch_size = block_state.batch_size * block_state.num_images_per_prompt
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
transformer,
scheduler,
batch_size,
block_state.height,
block_state.width,
components.vae_scale_factor,
block_state.num_inference_steps,
block_state.guidance_scale,
block_state.sigmas,
block_state.device,
)
block_state.timesteps = timesteps
block_state.num_inference_steps = num_inference_steps
block_state.sigmas = sigmas
mu = calculate_shift(
image_seq_len,
scheduler.config.get("base_image_seq_len", 256),
scheduler.config.get("max_image_seq_len", 4096),
scheduler.config.get("base_shift", 0.5),
scheduler.config.get("max_shift", 1.15),
)
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
scheduler, block_state.num_inference_steps, block_state.device, sigmas=block_state.sigmas, mu=mu
)
if components.transformer.config.guidance_embeds:
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
block_state.guidance = guidance
self.set_block_state(state, block_state)
return components, state
class FluxPrepareLatentsStep(PipelineBlock):
class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
model_name = "flux"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for inference"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
InputParam("strength", default=0.6),
InputParam("guidance_scale", default=3.5),
InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam(
"num_inference_steps",
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
OutputParam(
"latent_timestep",
type_hint=torch.Tensor,
description="The timestep that represents the initial noise level for image-to-image generation",
),
OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
]
@staticmethod
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps with self.scheduler->scheduler
def get_timesteps(scheduler, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = scheduler.timesteps[t_start * scheduler.order :]
if hasattr(scheduler, "set_begin_index"):
scheduler.set_begin_index(t_start * scheduler.order)
return timesteps, num_inference_steps - t_start
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
scheduler = components.scheduler
transformer = components.transformer
batch_size = block_state.batch_size * block_state.num_images_per_prompt
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
transformer,
scheduler,
batch_size,
block_state.height,
block_state.width,
components.vae_scale_factor,
block_state.num_inference_steps,
block_state.guidance_scale,
block_state.sigmas,
block_state.device,
)
timesteps, num_inference_steps = self.get_timesteps(
scheduler, num_inference_steps, block_state.strength, block_state.device
)
block_state.timesteps = timesteps
block_state.num_inference_steps = num_inference_steps
block_state.sigmas = sigmas
block_state.guidance = guidance
block_state.latent_timestep = timesteps[:1].repeat(batch_size)
self.set_block_state(state, block_state)
return components, state
class FluxPrepareLatentsStep(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -305,7 +492,7 @@ class FluxPrepareLatentsStep(PipelineBlock):
@property
def description(self) -> str:
return "Prepare latents step that prepares the latents for the text-to-video generation process"
return "Prepare latents step that prepares the latents for the text-to-image generation process"
@property
def inputs(self) -> List[InputParam]:
@@ -314,11 +501,6 @@ class FluxPrepareLatentsStep(PipelineBlock):
InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam(
"batch_size",
@@ -402,10 +584,10 @@ class FluxPrepareLatentsStep(PipelineBlock):
block_state.num_channels_latents = components.num_channels_latents
self.check_inputs(components, block_state)
batch_size = block_state.batch_size * block_state.num_images_per_prompt
block_state.latents, block_state.latent_image_ids = self.prepare_latents(
components,
block_state.batch_size * block_state.num_images_per_prompt,
batch_size,
block_state.num_channels_latents,
block_state.height,
block_state.width,
@@ -418,3 +600,90 @@ class FluxPrepareLatentsStep(PipelineBlock):
self.set_block_state(state, block_state)
return components, state
class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
model_name = "flux"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
@property
def description(self) -> str:
return "Step that prepares the latents for the image-to-image generation process"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1),
InputParam("generator"),
InputParam(
"image_latents",
required=True,
type_hint=torch.Tensor,
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
),
InputParam(
"latent_timestep",
required=True,
type_hint=torch.Tensor,
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
),
InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
),
OutputParam(
"latent_image_ids",
type_hint=torch.Tensor,
description="IDs computed from the image sequence needed for RoPE",
),
]
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.device = components._execution_device
block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
block_state.num_channels_latents = components.num_channels_latents
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.device = components._execution_device
# TODO: implement `check_inputs`
batch_size = block_state.batch_size * block_state.num_images_per_prompt
if block_state.latents is None:
block_state.latents, block_state.latent_image_ids = prepare_latents_img2img(
components.vae,
components.scheduler,
block_state.image_latents,
block_state.latent_timestep,
batch_size,
block_state.num_channels_latents,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
)
self.set_block_state(state, block_state)
return components, state
@@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL
from ...utils import logging
from ...video_processor import VaeImageProcessor
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -45,7 +45,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
return latents
class FluxDecodeStep(PipelineBlock):
class FluxDecodeStep(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -70,17 +70,12 @@ class FluxDecodeStep(PipelineBlock):
InputParam("output_type", default="pil"),
InputParam("height", default=1024),
InputParam("width", default=1024),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
)
),
]
@property
@@ -22,7 +22,7 @@ from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
PipelineBlock,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -32,7 +32,7 @@ from .modular_pipeline import FluxModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FluxLoopDenoiser(PipelineBlock):
class FluxLoopDenoiser(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -49,11 +49,8 @@ class FluxLoopDenoiser(PipelineBlock):
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [InputParam("joint_attention_kwargs")]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
@@ -113,7 +110,7 @@ class FluxLoopDenoiser(PipelineBlock):
return components, block_state
class FluxLoopAfterDenoiser(PipelineBlock):
class FluxLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -175,7 +172,7 @@ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
]
@property
def loop_intermediate_inputs(self) -> List[InputParam]:
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
@@ -226,5 +223,5 @@ class FluxDenoiseStep(FluxDenoiseLoopWrapper):
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `FluxLoopDenoiser`\n"
" - `FluxLoopAfterDenoiser`\n"
"This block supports text2image tasks."
"This block supports both text2image and img2img tasks."
)
@@ -19,9 +19,12 @@ import regex as re
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline
@@ -50,7 +53,110 @@ def prompt_clean(text):
return text
class FluxTextEncoderStep(PipelineBlock):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class FluxVaeEncoderStep(ModularPipelineBlocks):
model_name = "flux"
@property
def description(self) -> str:
return "Vae Encoder step that encode the input image into a latent representation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam(
"preprocess_kwargs",
type_hint=Optional[dict],
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=torch.Tensor,
description="The latents representing the reference image for image-to-image/inpainting generation",
)
]
@staticmethod
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(vae.encode(image), generator=generator)
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
return image_latents
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.image = components.image_processor.preprocess(
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
)
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.batch_size = block_state.image.shape[0]
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
)
block_state.image_latents = self._encode_vae_image(
components.vae, image=block_state.image, generator=block_state.generator
)
self.set_block_state(state, block_state)
return components, state
class FluxTextEncoderStep(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -297,7 +403,7 @@ class FluxTextEncoderStep(PipelineBlock):
prompt_embeds=None,
pooled_prompt_embeds=None,
device=block_state.device,
num_images_per_prompt=1, # hardcoded for now.
num_images_per_prompt=1, # TODO: hardcoded for now.
lora_scale=block_state.text_encoder_lora_scale,
)
@@ -15,16 +15,38 @@
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import FluxInputStep, FluxPrepareLatentsStep, FluxSetTimestepsStep
from .before_denoise import (
FluxImg2ImgPrepareLatentsStep,
FluxImg2ImgSetTimestepsStep,
FluxInputStep,
FluxPrepareLatentsStep,
FluxSetTimestepsStep,
)
from .decoders import FluxDecodeStep
from .denoise import FluxDenoiseStep
from .encoders import FluxTextEncoderStep
from .encoders import FluxTextEncoderStep, FluxVaeEncoderStep
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# before_denoise: text2vid
# vae encoder (run before before_denoise)
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [FluxVaeEncoderStep]
block_names = ["img2img"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block that works for img2img tasks.\n"
+ " - `FluxVaeEncoderStep` (img2img) is used when only `image` is provided."
+ " - if `image` is provided, step will be skipped."
)
# before_denoise: text2img, img2img
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
FluxInputStep,
@@ -44,11 +66,27 @@ class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
)
# before_denoise: all task (text2vid,)
# before_denoise: img2img
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [FluxInputStep, FluxImg2ImgSetTimestepsStep, FluxImg2ImgPrepareLatentsStep]
block_names = ["input", "set_timesteps", "prepare_latents"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `FluxImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ " - `FluxImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
)
# before_denoise: all task (text2img, img2img)
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxBeforeDenoiseStep]
block_names = ["text2image"]
block_trigger_inputs = [None]
block_classes = [FluxBeforeDenoiseStep, FluxImg2ImgBeforeDenoiseStep]
block_names = ["text2image", "img2img"]
block_trigger_inputs = [None, "image_latents"]
@property
def description(self):
@@ -56,6 +94,7 @@ class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2image.\n"
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
+ " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
)
@@ -69,8 +108,8 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks):
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2image tasks."
" - `FluxDenoiseStep` (denoise) for text2image tasks."
"This is a auto pipeline block that works for text2image and img2img tasks."
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
)
@@ -82,19 +121,26 @@ class FluxAutoDecodeStep(AutoPipelineBlocks):
@property
def description(self):
return "Decode step that decode the denoised latents into videos outputs.\n - `FluxDecodeStep`"
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
# text2image
class FluxAutoBlocks(SequentialPipelineBlocks):
block_classes = [FluxTextEncoderStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep, FluxAutoDecodeStep]
block_names = ["text_encoder", "before_denoise", "denoise", "decoder"]
block_classes = [
FluxTextEncoderStep,
FluxAutoVaeEncoderStep,
FluxAutoBeforeDenoiseStep,
FluxAutoDenoiseStep,
FluxAutoDecodeStep,
]
block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decoder"]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image using Flux.\n"
+ "- for text-to-image generation, all you need to provide is `prompt`"
"Auto Modular pipeline for text-to-image and image-to-image using Flux.\n"
+ "- for text-to-image generation, all you need to provide is `prompt`\n"
+ "- for image-to-image generation, you need to provide either `image` or `image_latents`"
)
@@ -102,19 +148,29 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep),
("input", FluxInputStep),
("prepare_latents", FluxPrepareLatentsStep),
# Setting it after preparation of latents because we rely on `latents`
# to calculate `img_seq_len` for `shift`.
("set_timesteps", FluxSetTimestepsStep),
("prepare_latents", FluxPrepareLatentsStep),
("denoise", FluxDenoiseStep),
("decode", FluxDecodeStep),
]
)
IMAGE2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep),
("image_encoder", FluxVaeEncoderStep),
("input", FluxInputStep),
("set_timesteps", FluxImg2ImgSetTimestepsStep),
("prepare_latents", FluxImg2ImgPrepareLatentsStep),
("denoise", FluxDenoiseStep),
("decode", FluxDecodeStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep),
("image_encoder", FluxAutoVaeEncoderStep),
("before_denoise", FluxAutoBeforeDenoiseStep),
("denoise", FluxAutoDenoiseStep),
("decode", FluxAutoDecodeStep),
@@ -122,4 +178,4 @@ AUTO_BLOCKS = InsertableDict(
)
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
@@ -13,7 +13,7 @@
# limitations under the License.
from ...loaders import FluxLoraLoaderMixin
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
@@ -21,7 +21,7 @@ from ..modular_pipeline import ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin):
class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversionLoaderMixin):
"""
A ModularPipeline for Flux.
File diff suppressed because it is too large Load Diff
@@ -618,7 +618,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
def make_doc_string(
inputs,
intermediate_inputs,
outputs,
description="",
class_name=None,
@@ -664,7 +663,7 @@ def make_doc_string(
output += configs_str + "\n\n"
# Add inputs section
output += format_input_params(inputs + intermediate_inputs, indent_level=2)
output += format_input_params(inputs, indent_level=2)
# Add outputs section
output += "\n\n"
@@ -27,7 +27,7 @@ from ...schedulers import EulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor, unwrap_module
from ..modular_pipeline import (
PipelineBlock,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
@@ -195,7 +195,7 @@ def prepare_latents_img2img(
return latents
class StableDiffusionXLInputStep(PipelineBlock):
class StableDiffusionXLInputStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -213,11 +213,6 @@ class StableDiffusionXLInputStep(PipelineBlock):
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"prompt_embeds",
required=True,
@@ -394,7 +389,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
return components, state
class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -421,11 +416,6 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
InputParam("denoising_start"),
# YiYi TODO: do we need num_images_per_prompt here?
InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"batch_size",
required=True,
@@ -543,7 +533,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
return components, state
class StableDiffusionXLSetTimestepsStep(PipelineBlock):
class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -611,7 +601,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
return components, state
class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -640,11 +630,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
"`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of "
"`denoising_start` being declared as an integer, the value of `strength` will be ignored.",
),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
InputParam(
"batch_size",
@@ -744,8 +729,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
timestep=None,
is_strength_max=True,
add_noise=True,
return_noise=False,
return_image_latents=False,
):
shape = (
batch_size,
@@ -768,7 +751,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
if image.shape[1] == 4:
image_latents = image.to(device=device, dtype=dtype)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
elif return_image_latents or (latents is None and not is_strength_max):
elif latents is None and not is_strength_max:
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(components, image=image, generator=generator)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
@@ -786,13 +769,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = image_latents.to(device)
outputs = (latents,)
if return_noise:
outputs += (noise,)
if return_image_latents:
outputs += (image_latents,)
outputs = (latents, noise, image_latents)
return outputs
@@ -864,7 +841,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
block_state.latents, block_state.noise = self.prepare_latents_inpaint(
block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint(
components,
block_state.batch_size * block_state.num_images_per_prompt,
components.num_channels_latents,
@@ -878,8 +855,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
timestep=block_state.latent_timestep,
is_strength_max=block_state.is_strength_max,
add_noise=block_state.add_noise,
return_noise=True,
return_image_latents=False,
)
# 7. Prepare mask latent variables
@@ -900,7 +875,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
return components, state
class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -920,11 +895,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
InputParam("latents"),
InputParam("num_images_per_prompt", default=1),
InputParam("denoising_start"),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam(
"latent_timestep",
@@ -981,7 +951,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
return components, state
class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -1002,11 +972,6 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
InputParam("width"),
InputParam("latents"),
InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam(
"batch_size",
@@ -1092,7 +1057,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
return components, state
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -1129,11 +1094,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
InputParam("num_images_per_prompt", default=1),
InputParam("aesthetic_score", default=6.0),
InputParam("negative_aesthetic_score", default=2.0),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
@@ -1316,7 +1276,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
return components, state
class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -1345,11 +1305,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
InputParam("crops_coords_top_left", default=(0, 0)),
InputParam("negative_crops_coords_top_left", default=(0, 0)),
InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
@@ -1499,7 +1454,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
return components, state
class StableDiffusionXLControlNetInputStep(PipelineBlock):
class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -1527,11 +1482,6 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
InputParam("controlnet_conditioning_scale", default=1.0),
InputParam("guess_mode", default=False),
InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
@@ -1718,7 +1668,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
return components, state
class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -1747,11 +1697,6 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
InputParam("controlnet_conditioning_scale", default=1.0),
InputParam("guess_mode", default=False),
InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
@@ -24,7 +24,7 @@ from ...models import AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...utils import logging
from ..modular_pipeline import (
PipelineBlock,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -33,7 +33,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionXLDecodeStep(PipelineBlock):
class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -56,17 +56,12 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
)
),
]
@property
@@ -157,7 +152,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
return components, state
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -184,11 +179,6 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
InputParam("image"),
InputParam("mask_image"),
InputParam("padding_mask_crop"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
@@ -25,7 +25,7 @@ from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
PipelineBlock,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi experimenting composible denoise loop
# loop step (1): prepare latent input for denoiser
class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
)
@property
def intermediate_inputs(self) -> List[str]:
def inputs(self) -> List[str]:
return [
InputParam(
"latents",
@@ -73,7 +73,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
# loop step (1): prepare latent input for denoiser (with inpainting)
class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -91,7 +91,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
)
@property
def intermediate_inputs(self) -> List[str]:
def inputs(self) -> List[str]:
return [
InputParam(
"latents",
@@ -144,7 +144,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
# loop step (2): denoise the latents with guidance
class StableDiffusionXLLoopDenoiser(PipelineBlock):
class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -171,11 +171,6 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("cross_attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"num_inference_steps",
required=True,
@@ -249,7 +244,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
# loop step (2): denoise the latents with guidance (with controlnet)
class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -277,11 +272,6 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("cross_attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"controlnet_cond",
required=True,
@@ -449,7 +439,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
# loop step (3): scheduler step to update latents
class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -470,11 +460,6 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("eta", default=0.0),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
]
@@ -520,7 +505,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
# loop step (3): scheduler step to update latents (with inpainting)
class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -542,11 +527,6 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("eta", default=0.0),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
InputParam(
"timesteps",
@@ -660,7 +640,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
]
@property
def loop_intermediate_inputs(self) -> List[InputParam]:
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
@@ -35,7 +35,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import StableDiffusionXLModularPipeline
@@ -57,7 +57,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionXLIPAdapterStep(PipelineBlock):
class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -215,7 +215,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
return components, state
class StableDiffusionXLTextEncoderStep(PipelineBlock):
class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -576,7 +576,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
return components, state
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -601,11 +601,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam(
@@ -668,12 +663,11 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.image = components.image_processor.preprocess(
image = components.image_processor.preprocess(
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
)
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.batch_size = block_state.image.shape[0]
image = image.to(device=block_state.device, dtype=block_state.dtype)
block_state.batch_size = image.shape[0]
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
@@ -682,16 +676,14 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
)
block_state.image_latents = self._encode_vae_image(
components, image=block_state.image, generator=block_state.generator
)
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
self.set_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -726,11 +718,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
InputParam("image", required=True),
InputParam("mask_image", required=True),
InputParam("padding_mask_crop"),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
InputParam("generator"),
]
@@ -860,34 +847,32 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
block_state.crops_coords = None
block_state.resize_mode = "default"
block_state.image = components.image_processor.preprocess(
image = components.image_processor.preprocess(
block_state.image,
height=block_state.height,
width=block_state.width,
crops_coords=block_state.crops_coords,
resize_mode=block_state.resize_mode,
)
block_state.image = block_state.image.to(dtype=torch.float32)
image = image.to(dtype=torch.float32)
block_state.mask = components.mask_processor.preprocess(
mask = components.mask_processor.preprocess(
block_state.mask_image,
height=block_state.height,
width=block_state.width,
resize_mode=block_state.resize_mode,
crops_coords=block_state.crops_coords,
)
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
block_state.masked_image = image * (mask < 0.5)
block_state.batch_size = block_state.image.shape[0]
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.image_latents = self._encode_vae_image(
components, image=block_state.image, generator=block_state.generator
)
block_state.batch_size = image.shape[0]
image = image.to(device=block_state.device, dtype=block_state.dtype)
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
# 7. Prepare mask latent variables
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
components,
block_state.mask,
mask,
block_state.masked_image,
block_state.batch_size,
block_state.height,
@@ -247,10 +247,6 @@ SDXL_INPUTS_SCHEMA = {
"control_mode": InputParam(
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
),
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam(
"prompt_embeds",
type_hint=torch.Tensor,
@@ -271,13 +267,6 @@ SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"preprocess_kwargs": InputParam(
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
),
"latents": InputParam(
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
),
"latent_timestep": InputParam(
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
),
@@ -20,7 +20,7 @@ import torch
from ...schedulers import UniPCMultistepScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline
@@ -94,7 +94,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class WanInputStep(PipelineBlock):
class WanInputStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -194,7 +194,7 @@ class WanInputStep(PipelineBlock):
return components, state
class WanSetTimestepsStep(PipelineBlock):
class WanSetTimestepsStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -243,7 +243,7 @@ class WanSetTimestepsStep(PipelineBlock):
return components, state
class WanPrepareLatentsStep(PipelineBlock):
class WanPrepareLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -22,14 +22,14 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKLWan
from ...utils import logging
from ...video_processor import VideoProcessor
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanDecodeStep(PipelineBlock):
class WanDecodeStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -24,7 +24,7 @@ from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
PipelineBlock,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -34,7 +34,7 @@ from .modular_pipeline import WanModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanLoopDenoiser(PipelineBlock):
class WanLoopDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -132,7 +132,7 @@ class WanLoopDenoiser(PipelineBlock):
return components, block_state
class WanLoopAfterDenoiser(PipelineBlock):
class WanLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -22,7 +22,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...utils import is_ftfy_available, logging
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline
@@ -51,7 +51,7 @@ def prompt_clean(text):
return text
class WanTextEncoderStep(PipelineBlock):
class WanTextEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -201,7 +201,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
).to(device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
+1
View File
@@ -82,6 +82,7 @@ from .import_utils import (
is_k_diffusion_available,
is_k_diffusion_version,
is_kernels_available,
is_kornia_available,
is_librosa_available,
is_matplotlib_available,
is_nltk_available,
+15
View File
@@ -62,6 +62,21 @@ class ClassifierFreeZeroStarGuidance(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class FrequencyDecoupledGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PerturbedAttentionGuidance(metaclass=DummyObject):
_backends = ["torch"]
+5
View File
@@ -224,6 +224,7 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
_kornia_available, _kornia_version = _is_package_available("kornia")
def is_torch_available():
@@ -398,6 +399,10 @@ def is_flash_attn_3_available():
return _flash_attn_3_available
def is_kornia_available():
return _kornia_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
-38
View File
@@ -197,20 +197,6 @@ def get_peft_kwargs(
"lora_bias": lora_bias,
}
# Example: try load FusionX LoRA into Wan VACE
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
if exclude_modules:
if not is_peft_version(">=", "0.14.0"):
msg = """
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
https://github.com/huggingface/diffusers/issues/new
"""
logger.debug(msg)
else:
lora_config_kwargs.update({"exclude_modules": exclude_modules})
return lora_config_kwargs
@@ -388,27 +374,3 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
if warn_msg:
logger.warning(warn_msg)
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
"""
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
doesn't exist in `peft_state_dict`.
"""
if model_state_dict is None:
return
all_modules = set()
string_to_replace = f"{adapter_name}." if adapter_name else ""
for name in model_state_dict.keys():
if string_to_replace:
name = name.replace(string_to_replace, "")
if "." in name:
module_name = name.rsplit(".", 1)[0]
all_modules.add(module_name)
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
exclude_modules = list(all_modules - target_modules_set)
return exclude_modules
+1
View File
@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
+1
View File
@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
+4 -68
View File
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import os
import re
@@ -292,20 +291,6 @@ class PeftLoraLoaderMixinTests:
return modules_to_save
def _get_exclude_modules(self, pipe):
from diffusers.utils.peft_utils import _derive_exclude_modules
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
pipe.unload_lora_weights()
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
exclude_modules = _derive_exclude_modules(
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
)
return exclude_modules
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -2342,58 +2327,6 @@ class PeftLoraLoaderMixinTests:
)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
@require_peft_version_greater("0.13.2")
def test_lora_exclude_modules(self):
"""
Test to check if `exclude_modules` works or not. It works in the following way:
we first create a pipeline and insert LoRA config into it. We then derive a `set`
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
state dict.
We then create a new LoRA config to include the `exclude_modules` and perform tests.
"""
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
# only supported for `denoiser` now
pipe_cp = copy.deepcopy(pipe)
pipe_cp, _ = self.add_adapters_to_pipeline(
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
pipe_cp.to("cpu")
del pipe_cp
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdir)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
"LoRA should change outputs.",
)
self.assertTrue(
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
"Lora outputs should match.",
)
def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes:
@@ -2467,7 +2400,6 @@ class PeftLoraLoaderMixinTests:
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
@@ -2483,6 +2415,10 @@ class PeftLoraLoaderMixinTests:
num_blocks_per_group=1,
use_stream=use_stream,
)
# Place other model-level components on `torch_device`.
for _, component in pipe.components.items():
if isinstance(component, torch.nn.Module):
component.to(torch_device)
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_1 is not None)
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+5
View File
@@ -1711,6 +1711,11 @@ class ModelTesterMixin:
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
if self.model_class.__name__ == "QwenImageTransformer2DModel":
pytest.skip(
"QwenImageTransformer2DModel doesn't support group offloading with disk. Needs to be investigated."
)
def _has_generator_arg(model):
sig = inspect.signature(model.forward)
params = sig.parameters
@@ -20,7 +20,7 @@ import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, is_peft_available, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
@@ -172,6 +172,35 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# The test exists for cases like
# https://github.com/huggingface/diffusers/issues/11874
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_exclude_modules(self):
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
lora_rank = 4
target_module = "single_transformer_blocks.0.proj_out"
adapter_name = "foo"
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
state_dict = model.state_dict()
target_mod_shape = state_dict[f"{target_module}.weight"].shape
lora_state_dict = {
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
}
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
config = LoraConfig(
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
)
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
@@ -0,0 +1,101 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import QwenImageTransformer2DModel
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (16, 16)
@property
def output_shape(self):
return (16, 16)
def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 7
vae_scale_factor = 4
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
"txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(),
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
View File

Some files were not shown because too many files have changed in this diff Show More