Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e3f6ba3fc1 | |||
| f442955c6e | |||
| ff9a387618 | |||
| fb8722e9ab | |||
| 512044c5ea | |||
| 03c3f69aa5 | |||
| f20aba3e87 | |||
| 6c85fcd899 | |||
| 085e9cba36 | |||
| 919ee1aee3 | |||
| 9cda45701c | |||
| c678e8a445 | |||
| ccf2c31188 | |||
| d1342d7464 | |||
| 7b10e4ae65 | |||
| 3c0531bc50 | |||
| a8e47978c6 | |||
| 50e18ee698 | |||
| 4b17fa2a2e | |||
| d45199a2f1 | |||
| 061163142d | |||
| 9a0cc463ee | |||
| ef4e373a65 | |||
| 1b4af6b7ef | |||
| ea77fdc4b4 | |||
| 255c5742aa | |||
| 4524d43279 | |||
| b6dc0b75f4 | |||
| 966a2ff8df | |||
| 201da97dd0 | |||
| 4423097b23 | |||
| 60d1b81023 |
@@ -0,0 +1,141 @@
|
||||
name: Fast PR tests for Modular
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths:
|
||||
- "src/diffusers/modular_pipelines/**.py"
|
||||
- "src/diffusers/models/modeling_utils.py"
|
||||
- "src/diffusers/models/model_loading_utils.py"
|
||||
- "src/diffusers/pipelines/pipeline_utils.py"
|
||||
- "src/diffusers/pipeline_loading_utils.py"
|
||||
- "src/diffusers/loaders/lora_base.py"
|
||||
- "src/diffusers/loaders/lora_pipeline.py"
|
||||
- "src/diffusers/loaders/peft.py"
|
||||
- "tests/modular_pipelines/**.py"
|
||||
- ".github/**.yml"
|
||||
- "utils/**.py"
|
||||
- "setup.py"
|
||||
push:
|
||||
branches:
|
||||
- ci-*
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_support_list.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Fast PyTorch Modular Pipeline CPU tests
|
||||
framework: pytorch_pipelines
|
||||
runner: aws-highmemory-32-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_modular_pipelines
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on:
|
||||
group: ${{ matrix.config.runner }}
|
||||
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/modular_pipelines
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
@@ -24,6 +24,63 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
|
||||
</Tip>
|
||||
|
||||
## LoRA for faster inference
|
||||
|
||||
Use a LoRA from `lightx2v/Qwen-Image-Lightning` to speed up inference by reducing the
|
||||
number of steps. Refer to the code snippet below:
|
||||
|
||||
<details>
|
||||
<summary>Code</summary>
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
|
||||
import torch
|
||||
import math
|
||||
|
||||
ckpt_id = "Qwen/Qwen-Image"
|
||||
|
||||
# From
|
||||
# https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
|
||||
scheduler_config = {
|
||||
"base_image_seq_len": 256,
|
||||
"base_shift": math.log(3), # We use shift=3 in distillation
|
||||
"invert_sigmas": False,
|
||||
"max_image_seq_len": 8192,
|
||||
"max_shift": math.log(3), # We use shift=3 in distillation
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 1.0,
|
||||
"shift_terminal": None, # set shift_terminal to None
|
||||
"stochastic_sampling": False,
|
||||
"time_shift_type": "exponential",
|
||||
"use_beta_sigmas": False,
|
||||
"use_dynamic_shifting": True,
|
||||
"use_exponential_sigmas": False,
|
||||
"use_karras_sigmas": False,
|
||||
}
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
ckpt_id, scheduler=scheduler, torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
pipe.load_lora_weights(
|
||||
"lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.0.safetensors"
|
||||
)
|
||||
|
||||
prompt = "a tiny astronaut hatching from an egg on the moon, Ultra HD, 4K, cinematic composition."
|
||||
negative_prompt = " "
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=1024,
|
||||
height=1024,
|
||||
num_inference_steps=8,
|
||||
true_cfg_scale=1.0,
|
||||
generator=torch.manual_seed(0),
|
||||
).images[0]
|
||||
image.save("qwen_fewsteps.png")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## QwenImagePipeline
|
||||
|
||||
[[autodoc]] QwenImagePipeline
|
||||
|
||||
@@ -77,3 +77,44 @@ Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels
|
||||
- Q5_K
|
||||
- Q6_K
|
||||
|
||||
## Convert to GGUF
|
||||
|
||||
Use the Space below to convert a Diffusers checkpoint into the GGUF format for inference.
|
||||
run conversion:
|
||||
|
||||
<iframe
|
||||
src="https://diffusers-internal-dev-diffusers-to-gguf.hf.space"
|
||||
frameborder="0"
|
||||
width="850"
|
||||
height="450"
|
||||
></iframe>
|
||||
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
|
||||
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/sayakpaul/different-lora-from-civitai/blob/main/flux_dev_diffusers-q4_0.gguf"
|
||||
)
|
||||
transformer = FluxTransformer2DModel.from_single_file(
|
||||
ckpt_path,
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
config="black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
|
||||
image.save("flux-gguf.png")
|
||||
```
|
||||
|
||||
When using Diffusers format GGUF checkpoints, it's a must to provide the model `config` path. If the
|
||||
model config resides in a `subfolder`, that needs to be specified, too.
|
||||
@@ -116,7 +116,7 @@ _deps = [
|
||||
"librosa",
|
||||
"numpy",
|
||||
"parameterized",
|
||||
"peft>=0.15.0",
|
||||
"peft>=0.17.0",
|
||||
"protobuf>=3.20.3,<4",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
|
||||
@@ -139,6 +139,7 @@ else:
|
||||
"AutoGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"ClassifierFreeZeroStarGuidance",
|
||||
"FrequencyDecoupledGuidance",
|
||||
"PerturbedAttentionGuidance",
|
||||
"SkipLayerGuidance",
|
||||
"SmoothedEnergyGuidance",
|
||||
@@ -804,6 +805,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
FrequencyDecoupledGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
|
||||
@@ -23,7 +23,7 @@ deps = {
|
||||
"librosa": "librosa",
|
||||
"numpy": "numpy",
|
||||
"parameterized": "parameterized",
|
||||
"peft": "peft>=0.15.0",
|
||||
"peft": "peft>=0.17.0",
|
||||
"protobuf": "protobuf>=3.20.3,<4",
|
||||
"pytest": "pytest",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
|
||||
@@ -22,6 +22,7 @@ if is_torch_available():
|
||||
from .auto_guidance import AutoGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
|
||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||
@@ -32,6 +33,7 @@ if is_torch_available():
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
FrequencyDecoupledGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
|
||||
@@ -0,0 +1,327 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from ..utils import is_kornia_available
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
_CAN_USE_KORNIA = is_kornia_available()
|
||||
|
||||
|
||||
if _CAN_USE_KORNIA:
|
||||
from kornia.geometry import pyrup as upsample_and_blur_func
|
||||
from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func
|
||||
else:
|
||||
upsample_and_blur_func = None
|
||||
build_laplacian_pyramid_func = None
|
||||
|
||||
|
||||
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
|
||||
(Algorithm 2).
|
||||
"""
|
||||
# v0 shape: [B, ...]
|
||||
# v1 shape: [B, ...]
|
||||
# Assume first dim is a batch dim and all other dims are channel or "spatial" dims
|
||||
all_dims_but_first = list(range(1, len(v0.shape)))
|
||||
if upcast_to_double:
|
||||
dtype = v0.dtype
|
||||
v0, v1 = v0.double(), v1.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
|
||||
v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
if upcast_to_double:
|
||||
v0_parallel = v0_parallel.to(dtype)
|
||||
v0_orthogonal = v0_orthogonal.to(dtype)
|
||||
return v0_parallel, v0_orthogonal
|
||||
|
||||
|
||||
def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
|
||||
(Algorihtm 2).
|
||||
"""
|
||||
# pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...]
|
||||
img = pyramid[-1]
|
||||
for i in range(len(pyramid) - 2, -1, -1):
|
||||
img = upsample_and_blur_func(img) + pyramid[i]
|
||||
return img
|
||||
|
||||
|
||||
class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
"""
|
||||
Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713
|
||||
|
||||
FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation
|
||||
quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both
|
||||
conditional and unconditional data, and use a combination of the two during inference. (If you want more details on
|
||||
how CFG works, you can check out the CFG guider.)
|
||||
|
||||
FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components
|
||||
using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space
|
||||
separately for the low- and high-frequency components with different guidance scales. Finally, the inverse
|
||||
frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images)
|
||||
to form the final FDG prediction.
|
||||
|
||||
For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample
|
||||
diversity and realistic color composition, while using high guidance scales for high-frequency components enhances
|
||||
sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for
|
||||
the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an
|
||||
example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper).
|
||||
|
||||
As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen
|
||||
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in
|
||||
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
||||
|
||||
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
|
||||
Args:
|
||||
guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
|
||||
The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
|
||||
frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
|
||||
values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
|
||||
image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
|
||||
lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
|
||||
descending order).
|
||||
guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
|
||||
`guidance_scales`.
|
||||
parallel_weights (`float` or `List[float]`, *optional*):
|
||||
Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
|
||||
set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
|
||||
(that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
|
||||
recommended. If a list is supplied, it should be the same length as `guidance_scales`.
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float` or `List[float]`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
|
||||
should be the same length as `guidance_scales`.
|
||||
stop (`float` or `List[float]`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
|
||||
should be the same length as `guidance_scales`.
|
||||
guidance_rescale_space (`str`, defaults to `"data"`):
|
||||
Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in
|
||||
`"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
|
||||
speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
|
||||
will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
|
||||
upcast_to_double (`bool`, defaults to `True`):
|
||||
Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
|
||||
float64 when performing guidance. This may result in better performance at the cost of increased runtime.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
|
||||
guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
|
||||
parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
|
||||
use_original_formulation: bool = False,
|
||||
start: Union[float, List[float], Tuple[float]] = 0.0,
|
||||
stop: Union[float, List[float], Tuple[float]] = 1.0,
|
||||
guidance_rescale_space: str = "data",
|
||||
upcast_to_double: bool = True,
|
||||
):
|
||||
if not _CAN_USE_KORNIA:
|
||||
raise ImportError(
|
||||
"The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which "
|
||||
"it depends is not available in the current environment. You can install `kornia` with `pip install "
|
||||
"kornia`."
|
||||
)
|
||||
|
||||
# Set start to earliest start for any freq component and stop to latest stop for any freq component
|
||||
min_start = start if isinstance(start, float) else min(start)
|
||||
max_stop = stop if isinstance(stop, float) else max(stop)
|
||||
super().__init__(min_start, max_stop)
|
||||
|
||||
self.guidance_scales = guidance_scales
|
||||
self.levels = len(guidance_scales)
|
||||
|
||||
if isinstance(guidance_rescale, float):
|
||||
self.guidance_rescale = [guidance_rescale] * self.levels
|
||||
elif len(guidance_rescale) == self.levels:
|
||||
self.guidance_rescale = guidance_rescale
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as "
|
||||
f"`guidance_scales` ({len(self.guidance_scales)})"
|
||||
)
|
||||
# Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after
|
||||
# transforming from frequency space back to data space)
|
||||
if guidance_rescale_space not in ["data", "freq"]:
|
||||
raise ValueError(
|
||||
f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`."
|
||||
)
|
||||
self.guidance_rescale_space = guidance_rescale_space
|
||||
|
||||
if parallel_weights is None:
|
||||
# Use normal CFG shift (equal weights for parallel and orthogonal components)
|
||||
self.parallel_weights = [1.0] * self.levels
|
||||
elif isinstance(parallel_weights, float):
|
||||
self.parallel_weights = [parallel_weights] * self.levels
|
||||
elif len(parallel_weights) == self.levels:
|
||||
self.parallel_weights = parallel_weights
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as "
|
||||
f"`guidance_scales` ({len(self.guidance_scales)})"
|
||||
)
|
||||
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.upcast_to_double = upcast_to_double
|
||||
|
||||
if isinstance(start, float):
|
||||
self.guidance_start = [start] * self.levels
|
||||
elif len(start) == self.levels:
|
||||
self.guidance_start = start
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`start` has length {len(start)} but should have the same length as `guidance_scales` "
|
||||
f"({len(self.guidance_scales)})"
|
||||
)
|
||||
if isinstance(stop, float):
|
||||
self.guidance_stop = [stop] * self.levels
|
||||
elif len(stop) == self.levels:
|
||||
self.guidance_stop = stop
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` "
|
||||
f"({len(self.guidance_scales)})"
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_fdg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
# Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions.
|
||||
pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels)
|
||||
pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels)
|
||||
|
||||
# From high frequencies to low frequencies, following the paper implementation
|
||||
pred_guided_pyramid = []
|
||||
parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale)
|
||||
for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters):
|
||||
if self._is_fdg_enabled_for_level(level):
|
||||
# Get the cond/uncond preds (in freq space) at the current frequency level
|
||||
pred_cond_freq = pred_cond_pyramid[level]
|
||||
pred_uncond_freq = pred_uncond_pyramid[level]
|
||||
|
||||
shift = pred_cond_freq - pred_uncond_freq
|
||||
|
||||
# Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
|
||||
if not math.isclose(parallel_weight, 1.0):
|
||||
shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
|
||||
shift = parallel_weight * shift_parallel + shift_orthogonal
|
||||
|
||||
# Apply CFG update for the current frequency level
|
||||
pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq
|
||||
pred = pred + guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale)
|
||||
|
||||
# Add the current FDG guided level to the FDG prediction pyramid
|
||||
pred_guided_pyramid.append(pred)
|
||||
else:
|
||||
# Add the current pred_cond_pyramid level as the "non-FDG" prediction
|
||||
pred_guided_pyramid.append(pred_cond_freq)
|
||||
|
||||
# Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
|
||||
pred = build_image_from_pyramid(pred_guided_pyramid)
|
||||
|
||||
# If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value
|
||||
# across all freq levels
|
||||
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_fdg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_fdg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales)
|
||||
else:
|
||||
is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_fdg_enabled_for_level(self, level: int) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.guidance_start[level] * self._num_inference_steps)
|
||||
skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scales[level], 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scales[level], 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
@@ -817,7 +817,11 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
# has both `peft` and non-peft state dict.
|
||||
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
|
||||
if has_peft_state_dict:
|
||||
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
|
||||
state_dict = {
|
||||
k.replace("lora_down.weight", "lora_A.weight").replace("lora_up.weight", "lora_B.weight"): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("transformer.")
|
||||
}
|
||||
return state_dict
|
||||
|
||||
# Another weird one.
|
||||
@@ -2073,3 +2077,39 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
|
||||
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
|
||||
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
converted_state_dict = {}
|
||||
all_keys = list(state_dict.keys())
|
||||
down_key = ".lora_down.weight"
|
||||
up_key = ".lora_up.weight"
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
return scale_down, scale_up
|
||||
|
||||
for k in all_keys:
|
||||
if k.endswith(down_key):
|
||||
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
|
||||
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
|
||||
alpha_key = k.replace(down_key, ".alpha")
|
||||
|
||||
down_weight = state_dict.pop(k)
|
||||
up_weight = state_dict.pop(k.replace(down_key, up_key))
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[diffusers_down_key] = down_weight * scale_down
|
||||
converted_state_dict[diffusers_up_key] = up_weight * scale_up
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
|
||||
|
||||
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
||||
return converted_state_dict
|
||||
|
||||
@@ -49,6 +49,7 @@ from .lora_conversion_utils import (
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_ltxv_lora_to_diffusers,
|
||||
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
||||
_convert_non_diffusers_qwen_lora_to_diffusers,
|
||||
_convert_non_diffusers_wan_lora_to_diffusers,
|
||||
_convert_xlabs_flux_lora_to_diffusers,
|
||||
_maybe_map_sgm_blocks_to_diffusers,
|
||||
@@ -6548,7 +6549,6 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
@@ -6642,6 +6642,10 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
|
||||
if has_alphas_in_sd:
|
||||
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
|
||||
@@ -320,7 +320,9 @@ class PeftAdapterMixin:
|
||||
# it to None
|
||||
incompatible_keys = None
|
||||
else:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
inject_adapter_in_model(
|
||||
lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
|
||||
)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
if self._prepare_lora_hotswap_kwargs is not None:
|
||||
|
||||
@@ -153,9 +153,17 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"QwenImageTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": lambda x: x,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
|
||||
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
|
||||
|
||||
|
||||
def _get_single_file_loadable_mapping_class(cls):
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
@@ -381,19 +389,23 @@ class FromOriginalModelMixin:
|
||||
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
||||
diffusers_model_config.update(model_kwargs)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model = cls.from_config(diffusers_model_config)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
|
||||
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
else:
|
||||
diffusers_format_checkpoint = checkpoint
|
||||
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||
|
||||
@@ -60,6 +60,7 @@ if is_accelerate_available():
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
CHECKPOINT_KEY_NAMES = {
|
||||
"v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
|
||||
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
|
||||
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
|
||||
|
||||
@@ -384,7 +384,7 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ else:
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"ModularPipelineBlocks",
|
||||
"ModularPipeline",
|
||||
"PipelineBlock",
|
||||
"AutoPipelineBlocks",
|
||||
"SequentialPipelineBlocks",
|
||||
"LoopSequentialPipelineBlocks",
|
||||
@@ -59,7 +58,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipeline,
|
||||
ModularPipelineBlocks,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
|
||||
@@ -13,15 +13,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...models import AutoencoderKL
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import FluxModularPipeline
|
||||
|
||||
@@ -103,6 +104,62 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Adapted from the original implementation.
|
||||
def prepare_latents_img2img(
|
||||
vae, scheduler, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator
|
||||
):
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
||||
latent_channels = vae.config.latent_channels
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] != latent_channels:
|
||||
image_latents = _encode_vae_image(image=image, generator=generator)
|
||||
else:
|
||||
image_latents = image
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
image_latents = torch.cat([image_latents], dim=0)
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = scheduler.scale_noise(image_latents, timestep, noise)
|
||||
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
return latents, latent_image_ids
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
@@ -125,7 +182,56 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FluxInputStep(PipelineBlock):
|
||||
# Cannot use "# Copied from" because it introduces weird indentation errors.
|
||||
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator)
|
||||
|
||||
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
|
||||
def _get_initial_timesteps_and_optionals(
|
||||
transformer,
|
||||
scheduler,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
vae_scale_factor,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
sigmas,
|
||||
device,
|
||||
):
|
||||
image_seq_len = (int(height) // vae_scale_factor // 2) * (int(width) // vae_scale_factor // 2)
|
||||
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
scheduler.config.get("base_image_seq_len", 256),
|
||||
scheduler.config.get("max_image_seq_len", 4096),
|
||||
scheduler.config.get("base_shift", 0.5),
|
||||
scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
|
||||
if transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
return timesteps, num_inference_steps, sigmas, guidance
|
||||
|
||||
|
||||
class FluxInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -143,11 +249,6 @@ class FluxInputStep(PipelineBlock):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
@@ -216,7 +317,7 @@ class FluxInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxSetTimestepsStep(PipelineBlock):
|
||||
class FluxSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -235,17 +336,15 @@ class FluxSetTimestepsStep(PipelineBlock):
|
||||
InputParam("sigmas"),
|
||||
InputParam("guidance_scale", default=3.5),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam(
|
||||
"latents",
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
)
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -264,39 +363,127 @@ class FluxSetTimestepsStep(PipelineBlock):
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
transformer = components.transformer
|
||||
|
||||
latents = block_state.latents
|
||||
image_seq_len = latents.shape[1]
|
||||
|
||||
num_inference_steps = block_state.num_inference_steps
|
||||
sigmas = block_state.sigmas
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
|
||||
transformer,
|
||||
scheduler,
|
||||
batch_size,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
components.vae_scale_factor,
|
||||
block_state.num_inference_steps,
|
||||
block_state.guidance_scale,
|
||||
block_state.sigmas,
|
||||
block_state.device,
|
||||
)
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
block_state.sigmas = sigmas
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
scheduler.config.get("base_image_seq_len", 256),
|
||||
scheduler.config.get("max_image_seq_len", 4096),
|
||||
scheduler.config.get("base_shift", 0.5),
|
||||
scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
scheduler, block_state.num_inference_steps, block_state.device, sigmas=block_state.sigmas, mu=mu
|
||||
)
|
||||
if components.transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
block_state.guidance = guidance
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxPrepareLatentsStep(PipelineBlock):
|
||||
class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the scheduler's timesteps for inference"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_inference_steps", default=50),
|
||||
InputParam("timesteps"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("strength", default=0.6),
|
||||
InputParam("guidance_scale", default=3.5),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||
OutputParam(
|
||||
"num_inference_steps",
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time",
|
||||
),
|
||||
OutputParam(
|
||||
"latent_timestep",
|
||||
type_hint=torch.Tensor,
|
||||
description="The timestep that represents the initial noise level for image-to-image generation",
|
||||
),
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps with self.scheduler->scheduler
|
||||
def get_timesteps(scheduler, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||
if hasattr(scheduler, "set_begin_index"):
|
||||
scheduler.set_begin_index(t_start * scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
transformer = components.transformer
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
|
||||
transformer,
|
||||
scheduler,
|
||||
batch_size,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
components.vae_scale_factor,
|
||||
block_state.num_inference_steps,
|
||||
block_state.guidance_scale,
|
||||
block_state.sigmas,
|
||||
block_state.device,
|
||||
)
|
||||
timesteps, num_inference_steps = self.get_timesteps(
|
||||
scheduler, num_inference_steps, block_state.strength, block_state.device
|
||||
)
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
block_state.sigmas = sigmas
|
||||
block_state.guidance = guidance
|
||||
|
||||
block_state.latent_timestep = timesteps[:1].repeat(batch_size)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -305,7 +492,7 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare latents step that prepares the latents for the text-to-video generation process"
|
||||
return "Prepare latents step that prepares the latents for the text-to-image generation process"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
@@ -314,11 +501,6 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -402,10 +584,10 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
block_state.latents, block_state.latent_image_ids = self.prepare_latents(
|
||||
components,
|
||||
block_state.batch_size * block_state.num_images_per_prompt,
|
||||
batch_size,
|
||||
block_state.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
@@ -418,3 +600,90 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the latents for the image-to-image generation process"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
|
||||
),
|
||||
InputParam(
|
||||
"latent_timestep",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||
),
|
||||
OutputParam(
|
||||
"latent_image_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="IDs computed from the image sequence needed for RoPE",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
block_state.device = components._execution_device
|
||||
|
||||
# TODO: implement `check_inputs`
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
if block_state.latents is None:
|
||||
block_state.latents, block_state.latent_image_ids = prepare_latents_img2img(
|
||||
components.vae,
|
||||
components.scheduler,
|
||||
block_state.image_latents,
|
||||
block_state.latent_timestep,
|
||||
batch_size,
|
||||
block_state.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL
|
||||
from ...utils import logging
|
||||
from ...video_processor import VaeImageProcessor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
return latents
|
||||
|
||||
|
||||
class FluxDecodeStep(PipelineBlock):
|
||||
class FluxDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -70,17 +70,12 @@ class FluxDecodeStep(PipelineBlock):
|
||||
InputParam("output_type", default="pil"),
|
||||
InputParam("height", default=1024),
|
||||
InputParam("width", default=1024),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -32,7 +32,7 @@ from .modular_pipeline import FluxModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FluxLoopDenoiser(PipelineBlock):
|
||||
class FluxLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -49,11 +49,8 @@ class FluxLoopDenoiser(PipelineBlock):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [InputParam("joint_attention_kwargs")]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -113,7 +110,7 @@ class FluxLoopDenoiser(PipelineBlock):
|
||||
return components, block_state
|
||||
|
||||
|
||||
class FluxLoopAfterDenoiser(PipelineBlock):
|
||||
class FluxLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -175,7 +172,7 @@ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
@@ -226,5 +223,5 @@ class FluxDenoiseStep(FluxDenoiseLoopWrapper):
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `FluxLoopDenoiser`\n"
|
||||
" - `FluxLoopAfterDenoiser`\n"
|
||||
"This block supports text2image tasks."
|
||||
"This block supports both text2image and img2img tasks."
|
||||
)
|
||||
|
||||
@@ -19,9 +19,12 @@ import regex as re
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL
|
||||
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import FluxModularPipeline
|
||||
|
||||
@@ -50,7 +53,110 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
class FluxTextEncoderStep(PipelineBlock):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class FluxVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae Encoder step that encode the input image into a latent representation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam(
|
||||
"preprocess_kwargs",
|
||||
type_hint=Optional[dict],
|
||||
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation",
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae
|
||||
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator)
|
||||
|
||||
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
|
||||
)
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
|
||||
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
||||
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components.vae, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -297,7 +403,7 @@ class FluxTextEncoderStep(PipelineBlock):
|
||||
prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
device=block_state.device,
|
||||
num_images_per_prompt=1, # hardcoded for now.
|
||||
num_images_per_prompt=1, # TODO: hardcoded for now.
|
||||
lora_scale=block_state.text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
|
||||
@@ -15,16 +15,38 @@
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import FluxInputStep, FluxPrepareLatentsStep, FluxSetTimestepsStep
|
||||
from .before_denoise import (
|
||||
FluxImg2ImgPrepareLatentsStep,
|
||||
FluxImg2ImgSetTimestepsStep,
|
||||
FluxInputStep,
|
||||
FluxPrepareLatentsStep,
|
||||
FluxSetTimestepsStep,
|
||||
)
|
||||
from .decoders import FluxDecodeStep
|
||||
from .denoise import FluxDenoiseStep
|
||||
from .encoders import FluxTextEncoderStep
|
||||
from .encoders import FluxTextEncoderStep, FluxVaeEncoderStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# before_denoise: text2vid
|
||||
# vae encoder (run before before_denoise)
|
||||
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxVaeEncoderStep]
|
||||
block_names = ["img2img"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block that works for img2img tasks.\n"
|
||||
+ " - `FluxVaeEncoderStep` (img2img) is used when only `image` is provided."
|
||||
+ " - if `image` is provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: text2img, img2img
|
||||
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
FluxInputStep,
|
||||
@@ -44,11 +66,27 @@ class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: all task (text2vid,)
|
||||
# before_denoise: img2img
|
||||
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [FluxInputStep, FluxImg2ImgSetTimestepsStep, FluxImg2ImgPrepareLatentsStep]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `FluxImg2ImgSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `FluxImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: all task (text2img, img2img)
|
||||
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxBeforeDenoiseStep]
|
||||
block_names = ["text2image"]
|
||||
block_trigger_inputs = [None]
|
||||
block_classes = [FluxBeforeDenoiseStep, FluxImg2ImgBeforeDenoiseStep]
|
||||
block_names = ["text2image", "img2img"]
|
||||
block_trigger_inputs = [None, "image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -56,6 +94,7 @@ class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for text2image.\n"
|
||||
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
|
||||
+ " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
|
||||
)
|
||||
|
||||
|
||||
@@ -69,8 +108,8 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2image tasks."
|
||||
" - `FluxDenoiseStep` (denoise) for text2image tasks."
|
||||
"This is a auto pipeline block that works for text2image and img2img tasks."
|
||||
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
|
||||
)
|
||||
|
||||
|
||||
@@ -82,19 +121,26 @@ class FluxAutoDecodeStep(AutoPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decode the denoised latents into videos outputs.\n - `FluxDecodeStep`"
|
||||
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
|
||||
|
||||
|
||||
# text2image
|
||||
class FluxAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [FluxTextEncoderStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep, FluxAutoDecodeStep]
|
||||
block_names = ["text_encoder", "before_denoise", "denoise", "decoder"]
|
||||
block_classes = [
|
||||
FluxTextEncoderStep,
|
||||
FluxAutoVaeEncoderStep,
|
||||
FluxAutoBeforeDenoiseStep,
|
||||
FluxAutoDenoiseStep,
|
||||
FluxAutoDecodeStep,
|
||||
]
|
||||
block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image using Flux.\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
"Auto Modular pipeline for text-to-image and image-to-image using Flux.\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`\n"
|
||||
+ "- for image-to-image generation, you need to provide either `image` or `image_latents`"
|
||||
)
|
||||
|
||||
|
||||
@@ -102,19 +148,29 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("input", FluxInputStep),
|
||||
("prepare_latents", FluxPrepareLatentsStep),
|
||||
# Setting it after preparation of latents because we rely on `latents`
|
||||
# to calculate `img_seq_len` for `shift`.
|
||||
("set_timesteps", FluxSetTimestepsStep),
|
||||
("prepare_latents", FluxPrepareLatentsStep),
|
||||
("denoise", FluxDenoiseStep),
|
||||
("decode", FluxDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("image_encoder", FluxVaeEncoderStep),
|
||||
("input", FluxInputStep),
|
||||
("set_timesteps", FluxImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", FluxImg2ImgPrepareLatentsStep),
|
||||
("denoise", FluxDenoiseStep),
|
||||
("decode", FluxDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("image_encoder", FluxAutoVaeEncoderStep),
|
||||
("before_denoise", FluxAutoBeforeDenoiseStep),
|
||||
("denoise", FluxAutoDenoiseStep),
|
||||
("decode", FluxAutoDecodeStep),
|
||||
@@ -122,4 +178,4 @@ AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
|
||||
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...loaders import FluxLoraLoaderMixin
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
|
||||
@@ -21,7 +21,7 @@ from ..modular_pipeline import ModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin):
|
||||
class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversionLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for Flux.
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -618,7 +618,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
|
||||
|
||||
def make_doc_string(
|
||||
inputs,
|
||||
intermediate_inputs,
|
||||
outputs,
|
||||
description="",
|
||||
class_name=None,
|
||||
@@ -664,7 +663,7 @@ def make_doc_string(
|
||||
output += configs_str + "\n\n"
|
||||
|
||||
# Add inputs section
|
||||
output += format_input_params(inputs + intermediate_inputs, indent_level=2)
|
||||
output += format_input_params(inputs, indent_level=2)
|
||||
|
||||
# Add outputs section
|
||||
output += "\n\n"
|
||||
|
||||
@@ -27,7 +27,7 @@ from ...schedulers import EulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
@@ -195,7 +195,7 @@ def prepare_latents_img2img(
|
||||
return latents
|
||||
|
||||
|
||||
class StableDiffusionXLInputStep(PipelineBlock):
|
||||
class StableDiffusionXLInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -213,11 +213,6 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
@@ -394,7 +389,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -421,11 +416,6 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
InputParam("denoising_start"),
|
||||
# YiYi TODO: do we need num_images_per_prompt here?
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
@@ -543,7 +533,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -611,7 +601,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -640,11 +630,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
"`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of "
|
||||
"`denoising_start` being declared as an integer, the value of `strength` will be ignored.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -744,8 +729,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
timestep=None,
|
||||
is_strength_max=True,
|
||||
add_noise=True,
|
||||
return_noise=False,
|
||||
return_image_latents=False,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
@@ -768,7 +751,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
if image.shape[1] == 4:
|
||||
image_latents = image.to(device=device, dtype=dtype)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
elif return_image_latents or (latents is None and not is_strength_max):
|
||||
elif latents is None and not is_strength_max:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(components, image=image, generator=generator)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
@@ -786,13 +769,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = image_latents.to(device)
|
||||
|
||||
outputs = (latents,)
|
||||
|
||||
if return_noise:
|
||||
outputs += (noise,)
|
||||
|
||||
if return_image_latents:
|
||||
outputs += (image_latents,)
|
||||
outputs = (latents, noise, image_latents)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -864,7 +841,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
|
||||
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
|
||||
|
||||
block_state.latents, block_state.noise = self.prepare_latents_inpaint(
|
||||
block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint(
|
||||
components,
|
||||
block_state.batch_size * block_state.num_images_per_prompt,
|
||||
components.num_channels_latents,
|
||||
@@ -878,8 +855,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
timestep=block_state.latent_timestep,
|
||||
is_strength_max=block_state.is_strength_max,
|
||||
add_noise=block_state.add_noise,
|
||||
return_noise=True,
|
||||
return_image_latents=False,
|
||||
)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
@@ -900,7 +875,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -920,11 +895,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("latents"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("denoising_start"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"latent_timestep",
|
||||
@@ -981,7 +951,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1002,11 +972,6 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("width"),
|
||||
InputParam("latents"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -1092,7 +1057,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1129,11 +1094,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("aesthetic_score", default=6.0),
|
||||
InputParam("negative_aesthetic_score", default=2.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1316,7 +1276,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1345,11 +1305,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("negative_crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1499,7 +1454,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1527,11 +1482,6 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1718,7 +1668,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1747,11 +1697,6 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...models import AutoencoderKL
|
||||
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -33,7 +33,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -56,17 +56,12 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -157,7 +152,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -184,11 +179,6 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
InputParam("image"),
|
||||
InputParam("mask_image"),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# YiYi experimenting composible denoise loop
|
||||
# loop step (1): prepare latent input for denoiser
|
||||
class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -73,7 +73,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (1): prepare latent input for denoiser (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -91,7 +91,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -144,7 +144,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance
|
||||
class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -171,11 +171,6 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("cross_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
@@ -249,7 +244,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance (with controlnet)
|
||||
class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -277,11 +272,6 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("cross_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"controlnet_cond",
|
||||
required=True,
|
||||
@@ -449,7 +439,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents
|
||||
class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -470,11 +460,6 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@@ -520,7 +505,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -542,11 +527,6 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
@@ -660,7 +640,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
|
||||
@@ -35,7 +35,7 @@ from ...utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||
|
||||
@@ -57,7 +57,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -215,7 +215,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -576,7 +576,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -601,11 +601,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam(
|
||||
@@ -668,12 +663,11 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
image = components.image_processor.preprocess(
|
||||
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
|
||||
)
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
image = image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
block_state.batch_size = image.shape[0]
|
||||
|
||||
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
|
||||
@@ -682,16 +676,14 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -726,11 +718,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
InputParam("image", required=True),
|
||||
InputParam("mask_image", required=True),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
@@ -860,34 +847,32 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
block_state.crops_coords = None
|
||||
block_state.resize_mode = "default"
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
image = components.image_processor.preprocess(
|
||||
block_state.image,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
crops_coords=block_state.crops_coords,
|
||||
resize_mode=block_state.resize_mode,
|
||||
)
|
||||
block_state.image = block_state.image.to(dtype=torch.float32)
|
||||
image = image.to(dtype=torch.float32)
|
||||
|
||||
block_state.mask = components.mask_processor.preprocess(
|
||||
mask = components.mask_processor.preprocess(
|
||||
block_state.mask_image,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
resize_mode=block_state.resize_mode,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
|
||||
block_state.masked_image = image * (mask < 0.5)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
block_state.batch_size = image.shape[0]
|
||||
image = image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
|
||||
components,
|
||||
block_state.mask,
|
||||
mask,
|
||||
block_state.masked_image,
|
||||
block_state.batch_size,
|
||||
block_state.height,
|
||||
|
||||
@@ -247,10 +247,6 @@ SDXL_INPUTS_SCHEMA = {
|
||||
"control_mode": InputParam(
|
||||
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"prompt_embeds": InputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
@@ -271,13 +267,6 @@ SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"preprocess_kwargs": InputParam(
|
||||
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
|
||||
),
|
||||
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
|
||||
),
|
||||
"latent_timestep": InputParam(
|
||||
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
|
||||
),
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from ...schedulers import UniPCMultistepScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
|
||||
@@ -94,7 +94,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class WanInputStep(PipelineBlock):
|
||||
class WanInputStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -194,7 +194,7 @@ class WanInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanSetTimestepsStep(PipelineBlock):
|
||||
class WanSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -243,7 +243,7 @@ class WanSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareLatentsStep(PipelineBlock):
|
||||
class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,14 +22,14 @@ from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKLWan
|
||||
from ...utils import logging
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanDecodeStep(PipelineBlock):
|
||||
class WanDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -34,7 +34,7 @@ from .modular_pipeline import WanModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanLoopDenoiser(PipelineBlock):
|
||||
class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -132,7 +132,7 @@ class WanLoopDenoiser(PipelineBlock):
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanLoopAfterDenoiser(PipelineBlock):
|
||||
class WanLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...utils import is_ftfy_available, logging
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
|
||||
@@ -51,7 +51,7 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
class WanTextEncoderStep(PipelineBlock):
|
||||
class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -201,7 +201,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = self.tokenizer(
|
||||
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(self.device)
|
||||
).to(device)
|
||||
encoder_hidden_states = self.text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask,
|
||||
|
||||
@@ -82,6 +82,7 @@ from .import_utils import (
|
||||
is_k_diffusion_available,
|
||||
is_k_diffusion_version,
|
||||
is_kernels_available,
|
||||
is_kornia_available,
|
||||
is_librosa_available,
|
||||
is_matplotlib_available,
|
||||
is_nltk_available,
|
||||
|
||||
@@ -62,6 +62,21 @@ class ClassifierFreeZeroStarGuidance(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FrequencyDecoupledGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PerturbedAttentionGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -224,6 +224,7 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("
|
||||
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
|
||||
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
|
||||
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
|
||||
_kornia_available, _kornia_version = _is_package_available("kornia")
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -398,6 +399,10 @@ def is_flash_attn_3_available():
|
||||
return _flash_attn_3_available
|
||||
|
||||
|
||||
def is_kornia_available():
|
||||
return _kornia_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
|
||||
@@ -197,20 +197,6 @@ def get_peft_kwargs(
|
||||
"lora_bias": lora_bias,
|
||||
}
|
||||
|
||||
# Example: try load FusionX LoRA into Wan VACE
|
||||
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
|
||||
if exclude_modules:
|
||||
if not is_peft_version(">=", "0.14.0"):
|
||||
msg = """
|
||||
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
|
||||
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
|
||||
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
|
||||
https://github.com/huggingface/diffusers/issues/new
|
||||
"""
|
||||
logger.debug(msg)
|
||||
else:
|
||||
lora_config_kwargs.update({"exclude_modules": exclude_modules})
|
||||
|
||||
return lora_config_kwargs
|
||||
|
||||
|
||||
@@ -388,27 +374,3 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
|
||||
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
|
||||
"""
|
||||
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
|
||||
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
|
||||
doesn't exist in `peft_state_dict`.
|
||||
"""
|
||||
if model_state_dict is None:
|
||||
return
|
||||
all_modules = set()
|
||||
string_to_replace = f"{adapter_name}." if adapter_name else ""
|
||||
|
||||
for name in model_state_dict.keys():
|
||||
if string_to_replace:
|
||||
name = name.replace(string_to_replace, "")
|
||||
if "." in name:
|
||||
module_name = name.rsplit(".", 1)[0]
|
||||
all_modules.add(module_name)
|
||||
|
||||
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
|
||||
exclude_modules = list(all_modules - target_modules_set)
|
||||
|
||||
return exclude_modules
|
||||
|
||||
+4
-68
@@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
@@ -292,20 +291,6 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
return modules_to_save
|
||||
|
||||
def _get_exclude_modules(self, pipe):
|
||||
from diffusers.utils.peft_utils import _derive_exclude_modules
|
||||
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
|
||||
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
|
||||
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
|
||||
pipe.unload_lora_weights()
|
||||
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
|
||||
exclude_modules = _derive_exclude_modules(
|
||||
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
|
||||
)
|
||||
return exclude_modules
|
||||
|
||||
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
|
||||
if text_lora_config is not None:
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -2342,58 +2327,6 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_exclude_modules(self):
|
||||
"""
|
||||
Test to check if `exclude_modules` works or not. It works in the following way:
|
||||
we first create a pipeline and insert LoRA config into it. We then derive a `set`
|
||||
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
|
||||
state dict.
|
||||
|
||||
We then create a new LoRA config to include the `exclude_modules` and perform tests.
|
||||
"""
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
# only supported for `denoiser` now
|
||||
pipe_cp = copy.deepcopy(pipe)
|
||||
pipe_cp, _ = self.add_adapters_to_pipeline(
|
||||
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
|
||||
pipe_cp.to("cpu")
|
||||
del pipe_cp
|
||||
|
||||
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should change outputs.",
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Lora outputs should match.",
|
||||
)
|
||||
|
||||
def test_inference_load_delete_load_adapters(self):
|
||||
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
@@ -2467,7 +2400,6 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
|
||||
@@ -2483,6 +2415,10 @@ class PeftLoraLoaderMixinTests:
|
||||
num_blocks_per_group=1,
|
||||
use_stream=use_stream,
|
||||
)
|
||||
# Place other model-level components on `torch_device`.
|
||||
for _, component in pipe.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
component.to(torch_device)
|
||||
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
|
||||
self.assertTrue(group_offload_hook_1 is not None)
|
||||
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, is_peft_available, torch_device
|
||||
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
@@ -172,6 +172,35 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
expected_set = {"FluxTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
# The test exists for cases like
|
||||
# https://github.com/huggingface/diffusers/issues/11874
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_exclude_modules(self):
|
||||
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
lora_rank = 4
|
||||
target_module = "single_transformer_blocks.0.proj_out"
|
||||
adapter_name = "foo"
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
target_mod_shape = state_dict[f"{target_module}.weight"].shape
|
||||
lora_state_dict = {
|
||||
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
|
||||
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
|
||||
}
|
||||
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
|
||||
config = LoraConfig(
|
||||
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
|
||||
)
|
||||
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
|
||||
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
|
||||
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
|
||||
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
|
||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
|
||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
|
||||
|
||||
|
||||
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
|
||||
+462
@@ -0,0 +1,462 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import (
|
||||
ClassifierFreeGuidance,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
)
|
||||
from diffusers.loaders import ModularIPAdapterMixin
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...models.unets.test_models_unet_2d_condition import (
|
||||
create_ip_adapter_state_dict,
|
||||
)
|
||||
from ..test_modular_pipelines_common import (
|
||||
ModularPipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SDXLModularTests:
|
||||
"""
|
||||
This mixin defines method to create pipeline, base input and base test across all SDXL modular tests.
|
||||
"""
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"negative_prompt",
|
||||
"cross_attention_kwargs",
|
||||
"image",
|
||||
"mask_image",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
|
||||
pipeline.load_default_components(torch_dtype=torch_dtype)
|
||||
return pipeline
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
sd_pipe = self.get_pipeline()
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs, output="images")
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == expected_image_shape
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, (
|
||||
"Image Slice does not match expected slice"
|
||||
)
|
||||
|
||||
|
||||
class SDXLModularIPAdapterTests:
|
||||
"""
|
||||
This mixin is designed to test IP Adapter.
|
||||
"""
|
||||
|
||||
def test_pipeline_inputs_and_blocks(self):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
parameters = blocks.input_names
|
||||
|
||||
assert issubclass(self.pipeline_class, ModularIPAdapterMixin)
|
||||
assert "ip_adapter_image" in parameters, (
|
||||
"`ip_adapter_image` argument must be supported by the `__call__` method"
|
||||
)
|
||||
assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block"
|
||||
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
parameters = blocks.input_names
|
||||
assert "ip_adapter_image" not in parameters, (
|
||||
"`ip_adapter_image` argument must be removed from the `__call__` method"
|
||||
)
|
||||
|
||||
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_masks(self, input_size: int = 64):
|
||||
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
|
||||
_masks[0, :, :, : int(input_size / 2)] = 1
|
||||
return _masks
|
||||
|
||||
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
parameters = blocks.input_names
|
||||
if "image" in parameters and "strength" in parameters:
|
||||
inputs["num_inference_steps"] = 4
|
||||
|
||||
inputs["output_type"] = "np"
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
r"""Tests for IP-Adapter.
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
pipe = blocks.init_pipeline(self.repo)
|
||||
pipe.load_default_components(torch_dtype=torch.float32)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
if expected_pipe_slice is None:
|
||||
output_without_adapter = pipe(**inputs, output="images")
|
||||
else:
|
||||
output_without_adapter = expected_pipe_slice
|
||||
|
||||
# 1. Single IP-Adapter test cases
|
||||
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
|
||||
|
||||
# forward pass with single ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(0.0)
|
||||
output_without_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(42.0)
|
||||
output_with_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
|
||||
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
|
||||
|
||||
assert max_diff_without_adapter_scale < expected_max_diff, (
|
||||
"Output without ip-adapter must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference"
|
||||
|
||||
# 2. Multi IP-Adapter test cases
|
||||
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
|
||||
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
|
||||
|
||||
# forward pass with multi ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([0.0, 0.0])
|
||||
output_without_multi_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with multi ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([42.0, 42.0])
|
||||
output_with_multi_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_multi_adapter_scale = np.abs(
|
||||
output_without_multi_adapter_scale - output_without_adapter
|
||||
).max()
|
||||
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
|
||||
assert max_diff_without_multi_adapter_scale < expected_max_diff, (
|
||||
"Output without multi-ip-adapter must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_multi_adapter_scale > 1e-2, (
|
||||
"Output with multi-ip-adapter scale must be different from normal inference"
|
||||
)
|
||||
|
||||
|
||||
class SDXLModularControlNetTests:
|
||||
"""
|
||||
This mixin is designed to test ControlNet.
|
||||
"""
|
||||
|
||||
def test_pipeline_inputs(self):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
parameters = blocks.input_names
|
||||
|
||||
assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method"
|
||||
assert "controlnet_conditioning_scale" in parameters, (
|
||||
"`controlnet_conditioning_scale` argument must be supported by the `__call__` method"
|
||||
)
|
||||
|
||||
def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]):
|
||||
controlnet_embedder_scale_factor = 2
|
||||
image = torch.randn(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
device=torch_device,
|
||||
)
|
||||
inputs["control_image"] = image
|
||||
return inputs
|
||||
|
||||
def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
r"""Tests for ControlNet.
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single ControlNet with scale=0 should produce same output as no ControlNet.
|
||||
- Single ControlNet with scale!=0 should produce different output compared to no ControlNet.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass without controlnet
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_without_controlnet = pipe(**inputs, output="images")
|
||||
output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single controlnet, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["controlnet_conditioning_scale"] = 0.0
|
||||
output_without_controlnet_scale = pipe(**inputs, output="images")
|
||||
output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single controlnet, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["controlnet_conditioning_scale"] = 42.0
|
||||
output_with_controlnet_scale = pipe(**inputs, output="images")
|
||||
output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max()
|
||||
max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max()
|
||||
|
||||
assert max_diff_without_controlnet_scale < expected_max_diff, (
|
||||
"Output without controlnet must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
|
||||
|
||||
def test_controlnet_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = np.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class SDXLModularGuiderTests:
|
||||
def test_guider_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = np.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class SDXLModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.5966781,
|
||||
0.62939394,
|
||||
0.48465094,
|
||||
0.51573336,
|
||||
0.57593524,
|
||||
0.47035995,
|
||||
0.53410417,
|
||||
0.51436996,
|
||||
0.47313565,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
|
||||
class SDXLImg2ImgModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
inputs["image"] = image
|
||||
inputs["strength"] = 0.8
|
||||
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.56943184,
|
||||
0.4702148,
|
||||
0.48048905,
|
||||
0.6235963,
|
||||
0.551138,
|
||||
0.49629188,
|
||||
0.60031277,
|
||||
0.5688907,
|
||||
0.43996853,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
|
||||
class SDXLInpaintingModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
# create mask
|
||||
image[8:, 8:, :] = 255
|
||||
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
|
||||
|
||||
inputs["image"] = init_image
|
||||
inputs["mask_image"] = mask_image
|
||||
inputs["strength"] = 1.0
|
||||
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.40872607,
|
||||
0.38842705,
|
||||
0.34893104,
|
||||
0.47837183,
|
||||
0.43792963,
|
||||
0.5332134,
|
||||
0.3716843,
|
||||
0.47274873,
|
||||
0.45000193,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
@@ -0,0 +1,358 @@
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
require_torch,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.detach().cpu().numpy()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with unittest.TestCase classes.
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
including:
|
||||
- test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
|
||||
- test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
|
||||
- test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
|
||||
- test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
|
||||
- test_to_device: check if the pipeline's __call__ method can handle different devices
|
||||
"""
|
||||
|
||||
# Canonical parameters that are passed to `__call__` regardless
|
||||
# of the type of pipeline. They are always optional and have common
|
||||
# sense default values.
|
||||
optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"num_images_per_prompt",
|
||||
"latents",
|
||||
"output_type",
|
||||
]
|
||||
)
|
||||
# this is modular specific: generator needs to be a intermediate input because it's mutable
|
||||
intermediate_params = frozenset(
|
||||
[
|
||||
"generator",
|
||||
]
|
||||
)
|
||||
|
||||
def get_generator(self, seed):
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device).manual_seed(seed)
|
||||
return generator
|
||||
|
||||
@property
|
||||
def pipeline_class(self) -> Union[Callable, ModularPipeline]:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def repo(self) -> str:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_pipeline(self):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_pipeline(self)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `params` in the child test class. "
|
||||
"`params` are checked for if all values are present in `__call__`'s signature."
|
||||
" You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
|
||||
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
|
||||
"image pipelines, including prompts and prompt embedding overrides."
|
||||
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
|
||||
"do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
|
||||
"with non-configurable height and width arguments should set the attribute as "
|
||||
"`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `batch_params` in the child test class. "
|
||||
"`batch_params` are the parameters required to be batched when passed to the pipeline's "
|
||||
"`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
|
||||
"`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
|
||||
"set of batch arguments has minor changes from one of the common sets of batch arguments, "
|
||||
"do not make modifications to the existing common sets of batch arguments. I.e. a text to "
|
||||
"image pipeline `negative_prompt` is not batched should set the attribute as "
|
||||
"`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test in case of CUDA runtime errors
|
||||
super().tearDown()
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_pipeline_call_signature(self):
|
||||
pipe = self.get_pipeline()
|
||||
input_parameters = pipe.blocks.input_names
|
||||
optional_parameters = pipe.default_call_parameters
|
||||
|
||||
def _check_for_parameters(parameters, expected_parameters, param_type):
|
||||
remaining_parameters = {param for param in parameters if param not in expected_parameters}
|
||||
assert len(remaining_parameters) == 0, (
|
||||
f"Required {param_type} parameters not present: {remaining_parameters}"
|
||||
)
|
||||
|
||||
_check_for_parameters(self.params, input_parameters, "input")
|
||||
_check_for_parameters(self.optional_params, optional_parameters, "optional")
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# prepare batched inputs
|
||||
batched_inputs = []
|
||||
for batch_size in batch_sizes:
|
||||
batched_input = {}
|
||||
batched_input.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
batched_input[name] = batch_size * [value]
|
||||
|
||||
if batch_generator and "generator" in inputs:
|
||||
batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_input["batch_size"] = batch_size
|
||||
|
||||
batched_inputs.append(batched_input)
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
|
||||
output = pipe(**batched_input, output="images")
|
||||
assert len(output) == batch_size, "Output is different from expected batch size"
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
self,
|
||||
batch_size=2,
|
||||
expected_max_diff=1e-4,
|
||||
):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batched_inputs.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
if "generator" in inputs:
|
||||
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_inputs["batch_size"] = batch_size
|
||||
|
||||
output = pipe(**inputs, output="images")
|
||||
output_batch = pipe(**batched_inputs, output="images")
|
||||
|
||||
assert output_batch.shape[0] == batch_size
|
||||
|
||||
max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
|
||||
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
|
||||
|
||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||
@require_accelerator
|
||||
def test_float16_inference(self, expected_max_diff=5e-2):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device, torch.float32)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe_fp16 = self.get_pipeline()
|
||||
pipe_fp16.to(torch_device, torch.float16)
|
||||
pipe_fp16.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in inputs:
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
output = pipe(**inputs, output="images")
|
||||
|
||||
fp16_inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in fp16_inputs:
|
||||
fp16_inputs["generator"] = self.get_generator(0)
|
||||
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
|
||||
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = output.cpu()
|
||||
output_fp16 = output_fp16.cpu()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
|
||||
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
|
||||
|
||||
@require_accelerator
|
||||
def test_to_device(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU"
|
||||
|
||||
pipe.to(torch_device)
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
assert all(device == torch_device for device in model_devices), (
|
||||
"All pipeline components are not on accelerator device"
|
||||
)
|
||||
|
||||
def test_inference_is_not_nan_cpu(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to("cpu")
|
||||
|
||||
output = pipe(**self.get_dummy_inputs("cpu"), output="images")
|
||||
assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN"
|
||||
|
||||
@require_accelerator
|
||||
def test_inference_is_not_nan(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(torch_device), output="images")
|
||||
assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN"
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
pipe = self.get_pipeline()
|
||||
|
||||
if "num_images_per_prompt" not in pipe.blocks.input_names:
|
||||
return
|
||||
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
@require_accelerator
|
||||
def test_components_auto_cpu_offload_inference_consistent(self):
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
cm = ComponentsManager()
|
||||
cm.enable_auto_cpu_offload(device=torch_device)
|
||||
offload_pipe = self.get_pipeline(components_manager=cm)
|
||||
|
||||
image_slices = []
|
||||
for pipe in [base_pipe, offload_pipe]:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_save_from_pretrained(self):
|
||||
pipes = []
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(base_pipe)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
base_pipe.save_pretrained(tmpdirname)
|
||||
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
pipe.load_default_components(torch_dtype=torch.float32)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipes.append(pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
@@ -20,12 +20,6 @@ TEXT_TO_IMAGE_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
|
||||
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
|
||||
|
||||
IMAGE_VARIATION_PARAMS = frozenset(
|
||||
[
|
||||
"image",
|
||||
@@ -35,8 +29,6 @@ IMAGE_VARIATION_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -50,8 +42,6 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
# Text guided image variation with an image mask
|
||||
@@ -67,8 +57,6 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
|
||||
|
||||
IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
# image variation with an image mask
|
||||
@@ -80,8 +68,6 @@ IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
"example_image",
|
||||
@@ -93,20 +79,12 @@ IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
|
||||
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])
|
||||
|
||||
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
TEXT_TO_AUDIO_PARAMS = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -119,11 +97,38 @@ TEXT_TO_AUDIO_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
# image params
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
|
||||
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
|
||||
|
||||
|
||||
# batch params
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
|
||||
|
||||
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
|
||||
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
|
||||
|
||||
VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])
|
||||
|
||||
# callback params
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
|
||||
|
||||
@@ -886,6 +886,7 @@ class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
@require_bitsandbytes_version_greater("0.46.1")
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super().test_torch_compile()
|
||||
|
||||
@@ -847,6 +847,10 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
|
||||
" Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
|
||||
)
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super()._test_torch_compile(torch_dtype=torch.float16)
|
||||
|
||||
@@ -212,6 +212,7 @@ class GGUFSingleFileTesterMixin:
|
||||
|
||||
class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
||||
diffusers_ckpt_path = "https://huggingface.co/sayakpaul/flux-diffusers-gguf/blob/main/model-Q4_0.gguf"
|
||||
torch_dtype = torch.bfloat16
|
||||
model_cls = FluxTransformer2DModel
|
||||
expected_memory_use_in_gb = 5
|
||||
@@ -296,6 +297,16 @@ class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_loading_gguf_diffusers_format(self):
|
||||
model = self.model_cls.from_single_file(
|
||||
self.diffusers_ckpt_path,
|
||||
subfolder="transformer",
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
config="black-forest-labs/FLUX.1-dev",
|
||||
)
|
||||
model.to("cuda")
|
||||
model(**self.get_dummy_inputs())
|
||||
|
||||
|
||||
class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"
|
||||
|
||||
@@ -56,12 +56,18 @@ class QuantCompileTests:
|
||||
pipe.transformer.compile(fullgraph=True)
|
||||
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.transformer.compile()
|
||||
# regional compilation is better for offloading.
|
||||
# see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/
|
||||
if getattr(pipe.transformer, "_repeated_blocks"):
|
||||
pipe.transformer.compile_repeated_blocks(fullgraph=True)
|
||||
else:
|
||||
pipe.transformer.compile()
|
||||
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
Reference in New Issue
Block a user