Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e3f6ba3fc1 | |||
| f442955c6e | |||
| ff9a387618 | |||
| fb8722e9ab | |||
| 512044c5ea | |||
| 03c3f69aa5 | |||
| f20aba3e87 | |||
| 6c85fcd899 | |||
| 085e9cba36 | |||
| 919ee1aee3 | |||
| 9cda45701c | |||
| c678e8a445 | |||
| ccf2c31188 | |||
| d1342d7464 | |||
| 7b10e4ae65 | |||
| 3c0531bc50 | |||
| a8e47978c6 | |||
| 50e18ee698 | |||
| 4b17fa2a2e | |||
| d45199a2f1 | |||
| 061163142d | |||
| 5780776c8a | |||
| f19421e27c | |||
| 9a0cc463ee | |||
| ef4e373a65 | |||
| 1b4af6b7ef | |||
| 69cdc25746 | |||
| ea77fdc4b4 | |||
| 255c5742aa | |||
| 4524d43279 | |||
| b6dc0b75f4 | |||
| 966a2ff8df | |||
| 201da97dd0 | |||
| 4423097b23 | |||
| 60d1b81023 |
@@ -0,0 +1,141 @@
|
||||
name: Fast PR tests for Modular
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths:
|
||||
- "src/diffusers/modular_pipelines/**.py"
|
||||
- "src/diffusers/models/modeling_utils.py"
|
||||
- "src/diffusers/models/model_loading_utils.py"
|
||||
- "src/diffusers/pipelines/pipeline_utils.py"
|
||||
- "src/diffusers/pipeline_loading_utils.py"
|
||||
- "src/diffusers/loaders/lora_base.py"
|
||||
- "src/diffusers/loaders/lora_pipeline.py"
|
||||
- "src/diffusers/loaders/peft.py"
|
||||
- "tests/modular_pipelines/**.py"
|
||||
- ".github/**.yml"
|
||||
- "utils/**.py"
|
||||
- "setup.py"
|
||||
push:
|
||||
branches:
|
||||
- ci-*
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_support_list.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Fast PyTorch Modular Pipeline CPU tests
|
||||
framework: pytorch_pipelines
|
||||
runner: aws-highmemory-32-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_modular_pipelines
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on:
|
||||
group: ${{ matrix.config.runner }}
|
||||
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/modular_pipelines
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
@@ -24,6 +24,63 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
|
||||
</Tip>
|
||||
|
||||
## LoRA for faster inference
|
||||
|
||||
Use a LoRA from `lightx2v/Qwen-Image-Lightning` to speed up inference by reducing the
|
||||
number of steps. Refer to the code snippet below:
|
||||
|
||||
<details>
|
||||
<summary>Code</summary>
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
|
||||
import torch
|
||||
import math
|
||||
|
||||
ckpt_id = "Qwen/Qwen-Image"
|
||||
|
||||
# From
|
||||
# https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
|
||||
scheduler_config = {
|
||||
"base_image_seq_len": 256,
|
||||
"base_shift": math.log(3), # We use shift=3 in distillation
|
||||
"invert_sigmas": False,
|
||||
"max_image_seq_len": 8192,
|
||||
"max_shift": math.log(3), # We use shift=3 in distillation
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 1.0,
|
||||
"shift_terminal": None, # set shift_terminal to None
|
||||
"stochastic_sampling": False,
|
||||
"time_shift_type": "exponential",
|
||||
"use_beta_sigmas": False,
|
||||
"use_dynamic_shifting": True,
|
||||
"use_exponential_sigmas": False,
|
||||
"use_karras_sigmas": False,
|
||||
}
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
ckpt_id, scheduler=scheduler, torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
pipe.load_lora_weights(
|
||||
"lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.0.safetensors"
|
||||
)
|
||||
|
||||
prompt = "a tiny astronaut hatching from an egg on the moon, Ultra HD, 4K, cinematic composition."
|
||||
negative_prompt = " "
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=1024,
|
||||
height=1024,
|
||||
num_inference_steps=8,
|
||||
true_cfg_scale=1.0,
|
||||
generator=torch.manual_seed(0),
|
||||
).images[0]
|
||||
image.save("qwen_fewsteps.png")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## QwenImagePipeline
|
||||
|
||||
[[autodoc]] QwenImagePipeline
|
||||
|
||||
@@ -77,3 +77,44 @@ Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels
|
||||
- Q5_K
|
||||
- Q6_K
|
||||
|
||||
## Convert to GGUF
|
||||
|
||||
Use the Space below to convert a Diffusers checkpoint into the GGUF format for inference.
|
||||
run conversion:
|
||||
|
||||
<iframe
|
||||
src="https://diffusers-internal-dev-diffusers-to-gguf.hf.space"
|
||||
frameborder="0"
|
||||
width="850"
|
||||
height="450"
|
||||
></iframe>
|
||||
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
|
||||
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/sayakpaul/different-lora-from-civitai/blob/main/flux_dev_diffusers-q4_0.gguf"
|
||||
)
|
||||
transformer = FluxTransformer2DModel.from_single_file(
|
||||
ckpt_path,
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
config="black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
|
||||
image.save("flux-gguf.png")
|
||||
```
|
||||
|
||||
When using Diffusers format GGUF checkpoints, it's a must to provide the model `config` path. If the
|
||||
model config resides in a `subfolder`, that needs to be specified, too.
|
||||
@@ -116,7 +116,7 @@ _deps = [
|
||||
"librosa",
|
||||
"numpy",
|
||||
"parameterized",
|
||||
"peft>=0.15.0",
|
||||
"peft>=0.17.0",
|
||||
"protobuf>=3.20.3,<4",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
|
||||
@@ -139,6 +139,7 @@ else:
|
||||
"AutoGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"ClassifierFreeZeroStarGuidance",
|
||||
"FrequencyDecoupledGuidance",
|
||||
"PerturbedAttentionGuidance",
|
||||
"SkipLayerGuidance",
|
||||
"SmoothedEnergyGuidance",
|
||||
@@ -804,6 +805,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
FrequencyDecoupledGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
|
||||
@@ -23,7 +23,7 @@ deps = {
|
||||
"librosa": "librosa",
|
||||
"numpy": "numpy",
|
||||
"parameterized": "parameterized",
|
||||
"peft": "peft>=0.15.0",
|
||||
"peft": "peft>=0.17.0",
|
||||
"protobuf": "protobuf>=3.20.3,<4",
|
||||
"pytest": "pytest",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
|
||||
@@ -22,6 +22,7 @@ if is_torch_available():
|
||||
from .auto_guidance import AutoGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
|
||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||
@@ -32,6 +33,7 @@ if is_torch_available():
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
FrequencyDecoupledGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
|
||||
@@ -0,0 +1,327 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from ..utils import is_kornia_available
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
_CAN_USE_KORNIA = is_kornia_available()
|
||||
|
||||
|
||||
if _CAN_USE_KORNIA:
|
||||
from kornia.geometry import pyrup as upsample_and_blur_func
|
||||
from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func
|
||||
else:
|
||||
upsample_and_blur_func = None
|
||||
build_laplacian_pyramid_func = None
|
||||
|
||||
|
||||
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
|
||||
(Algorithm 2).
|
||||
"""
|
||||
# v0 shape: [B, ...]
|
||||
# v1 shape: [B, ...]
|
||||
# Assume first dim is a batch dim and all other dims are channel or "spatial" dims
|
||||
all_dims_but_first = list(range(1, len(v0.shape)))
|
||||
if upcast_to_double:
|
||||
dtype = v0.dtype
|
||||
v0, v1 = v0.double(), v1.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
|
||||
v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
if upcast_to_double:
|
||||
v0_parallel = v0_parallel.to(dtype)
|
||||
v0_orthogonal = v0_orthogonal.to(dtype)
|
||||
return v0_parallel, v0_orthogonal
|
||||
|
||||
|
||||
def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
|
||||
(Algorihtm 2).
|
||||
"""
|
||||
# pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...]
|
||||
img = pyramid[-1]
|
||||
for i in range(len(pyramid) - 2, -1, -1):
|
||||
img = upsample_and_blur_func(img) + pyramid[i]
|
||||
return img
|
||||
|
||||
|
||||
class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
"""
|
||||
Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713
|
||||
|
||||
FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation
|
||||
quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both
|
||||
conditional and unconditional data, and use a combination of the two during inference. (If you want more details on
|
||||
how CFG works, you can check out the CFG guider.)
|
||||
|
||||
FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components
|
||||
using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space
|
||||
separately for the low- and high-frequency components with different guidance scales. Finally, the inverse
|
||||
frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images)
|
||||
to form the final FDG prediction.
|
||||
|
||||
For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample
|
||||
diversity and realistic color composition, while using high guidance scales for high-frequency components enhances
|
||||
sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for
|
||||
the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an
|
||||
example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper).
|
||||
|
||||
As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen
|
||||
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in
|
||||
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
||||
|
||||
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
|
||||
Args:
|
||||
guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
|
||||
The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
|
||||
frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
|
||||
values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
|
||||
image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
|
||||
lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
|
||||
descending order).
|
||||
guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
|
||||
`guidance_scales`.
|
||||
parallel_weights (`float` or `List[float]`, *optional*):
|
||||
Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
|
||||
set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
|
||||
(that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
|
||||
recommended. If a list is supplied, it should be the same length as `guidance_scales`.
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float` or `List[float]`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
|
||||
should be the same length as `guidance_scales`.
|
||||
stop (`float` or `List[float]`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
|
||||
should be the same length as `guidance_scales`.
|
||||
guidance_rescale_space (`str`, defaults to `"data"`):
|
||||
Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in
|
||||
`"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
|
||||
speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
|
||||
will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
|
||||
upcast_to_double (`bool`, defaults to `True`):
|
||||
Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
|
||||
float64 when performing guidance. This may result in better performance at the cost of increased runtime.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
|
||||
guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
|
||||
parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
|
||||
use_original_formulation: bool = False,
|
||||
start: Union[float, List[float], Tuple[float]] = 0.0,
|
||||
stop: Union[float, List[float], Tuple[float]] = 1.0,
|
||||
guidance_rescale_space: str = "data",
|
||||
upcast_to_double: bool = True,
|
||||
):
|
||||
if not _CAN_USE_KORNIA:
|
||||
raise ImportError(
|
||||
"The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which "
|
||||
"it depends is not available in the current environment. You can install `kornia` with `pip install "
|
||||
"kornia`."
|
||||
)
|
||||
|
||||
# Set start to earliest start for any freq component and stop to latest stop for any freq component
|
||||
min_start = start if isinstance(start, float) else min(start)
|
||||
max_stop = stop if isinstance(stop, float) else max(stop)
|
||||
super().__init__(min_start, max_stop)
|
||||
|
||||
self.guidance_scales = guidance_scales
|
||||
self.levels = len(guidance_scales)
|
||||
|
||||
if isinstance(guidance_rescale, float):
|
||||
self.guidance_rescale = [guidance_rescale] * self.levels
|
||||
elif len(guidance_rescale) == self.levels:
|
||||
self.guidance_rescale = guidance_rescale
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as "
|
||||
f"`guidance_scales` ({len(self.guidance_scales)})"
|
||||
)
|
||||
# Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after
|
||||
# transforming from frequency space back to data space)
|
||||
if guidance_rescale_space not in ["data", "freq"]:
|
||||
raise ValueError(
|
||||
f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`."
|
||||
)
|
||||
self.guidance_rescale_space = guidance_rescale_space
|
||||
|
||||
if parallel_weights is None:
|
||||
# Use normal CFG shift (equal weights for parallel and orthogonal components)
|
||||
self.parallel_weights = [1.0] * self.levels
|
||||
elif isinstance(parallel_weights, float):
|
||||
self.parallel_weights = [parallel_weights] * self.levels
|
||||
elif len(parallel_weights) == self.levels:
|
||||
self.parallel_weights = parallel_weights
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as "
|
||||
f"`guidance_scales` ({len(self.guidance_scales)})"
|
||||
)
|
||||
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.upcast_to_double = upcast_to_double
|
||||
|
||||
if isinstance(start, float):
|
||||
self.guidance_start = [start] * self.levels
|
||||
elif len(start) == self.levels:
|
||||
self.guidance_start = start
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`start` has length {len(start)} but should have the same length as `guidance_scales` "
|
||||
f"({len(self.guidance_scales)})"
|
||||
)
|
||||
if isinstance(stop, float):
|
||||
self.guidance_stop = [stop] * self.levels
|
||||
elif len(stop) == self.levels:
|
||||
self.guidance_stop = stop
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` "
|
||||
f"({len(self.guidance_scales)})"
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_fdg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
# Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions.
|
||||
pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels)
|
||||
pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels)
|
||||
|
||||
# From high frequencies to low frequencies, following the paper implementation
|
||||
pred_guided_pyramid = []
|
||||
parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale)
|
||||
for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters):
|
||||
if self._is_fdg_enabled_for_level(level):
|
||||
# Get the cond/uncond preds (in freq space) at the current frequency level
|
||||
pred_cond_freq = pred_cond_pyramid[level]
|
||||
pred_uncond_freq = pred_uncond_pyramid[level]
|
||||
|
||||
shift = pred_cond_freq - pred_uncond_freq
|
||||
|
||||
# Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
|
||||
if not math.isclose(parallel_weight, 1.0):
|
||||
shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
|
||||
shift = parallel_weight * shift_parallel + shift_orthogonal
|
||||
|
||||
# Apply CFG update for the current frequency level
|
||||
pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq
|
||||
pred = pred + guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale)
|
||||
|
||||
# Add the current FDG guided level to the FDG prediction pyramid
|
||||
pred_guided_pyramid.append(pred)
|
||||
else:
|
||||
# Add the current pred_cond_pyramid level as the "non-FDG" prediction
|
||||
pred_guided_pyramid.append(pred_cond_freq)
|
||||
|
||||
# Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
|
||||
pred = build_image_from_pyramid(pred_guided_pyramid)
|
||||
|
||||
# If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value
|
||||
# across all freq levels
|
||||
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_fdg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_fdg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales)
|
||||
else:
|
||||
is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_fdg_enabled_for_level(self, level: int) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.guidance_start[level] * self._num_inference_steps)
|
||||
skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scales[level], 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scales[level], 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
@@ -133,6 +133,7 @@ def _register_attention_processors_metadata():
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
|
||||
# FluxAttnProcessor
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=FluxAttnProcessor,
|
||||
|
||||
@@ -245,7 +245,6 @@ class ModuleGroup:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for buffer in self.buffers:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=False)
|
||||
@@ -303,9 +302,23 @@ class GroupOffloadingHook(ModelHook):
|
||||
if self.group.onload_leader == module:
|
||||
if self.group.onload_self:
|
||||
self.group.onload_()
|
||||
if self.next_group is not None and not self.next_group.onload_self:
|
||||
|
||||
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
|
||||
if should_onload_next_group:
|
||||
self.next_group.onload_()
|
||||
|
||||
should_synchronize = (
|
||||
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
|
||||
)
|
||||
if should_synchronize:
|
||||
# If this group didn't onload itself, it means it was asynchronously onloaded by the
|
||||
# previous group. We need to synchronize the side stream to ensure parameters
|
||||
# are completely loaded to proceed with forward pass. Without this, uninitialized
|
||||
# weights will be used in the computation, leading to incorrect results
|
||||
# Also, we should only do this synchronization if we don't already do it from the sync call in
|
||||
# self.next_group.onload_, hence the `not should_onload_next_group` check.
|
||||
self.group.stream.synchronize()
|
||||
|
||||
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
|
||||
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
|
||||
return args, kwargs
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
|
||||
|
||||
|
||||
def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module):
|
||||
module_list_with_transformer_blocks = []
|
||||
for name, submodule in module.named_modules():
|
||||
name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS)
|
||||
is_modulelist = isinstance(submodule, torch.nn.ModuleList)
|
||||
if name_endswith_identifier and is_modulelist:
|
||||
module_list_with_transformer_blocks.append((name, submodule))
|
||||
return module_list_with_transformer_blocks
|
||||
|
||||
|
||||
def _get_identifiable_attention_layers_in_module(module: torch.nn.Module):
|
||||
attention_layers = []
|
||||
for name, submodule in module.named_modules():
|
||||
if isinstance(submodule, _ATTENTION_CLASSES):
|
||||
attention_layers.append((name, submodule))
|
||||
return attention_layers
|
||||
|
||||
|
||||
def _get_identifiable_feedforward_layers_in_module(module: torch.nn.Module):
|
||||
feedforward_layers = []
|
||||
for name, submodule in module.named_modules():
|
||||
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
||||
feedforward_layers.append((name, submodule))
|
||||
return feedforward_layers
|
||||
@@ -817,7 +817,11 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
# has both `peft` and non-peft state dict.
|
||||
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
|
||||
if has_peft_state_dict:
|
||||
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
|
||||
state_dict = {
|
||||
k.replace("lora_down.weight", "lora_A.weight").replace("lora_up.weight", "lora_B.weight"): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith("transformer.")
|
||||
}
|
||||
return state_dict
|
||||
|
||||
# Another weird one.
|
||||
@@ -2073,3 +2077,39 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
|
||||
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
|
||||
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
converted_state_dict = {}
|
||||
all_keys = list(state_dict.keys())
|
||||
down_key = ".lora_down.weight"
|
||||
up_key = ".lora_up.weight"
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
return scale_down, scale_up
|
||||
|
||||
for k in all_keys:
|
||||
if k.endswith(down_key):
|
||||
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
|
||||
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
|
||||
alpha_key = k.replace(down_key, ".alpha")
|
||||
|
||||
down_weight = state_dict.pop(k)
|
||||
up_weight = state_dict.pop(k.replace(down_key, up_key))
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[diffusers_down_key] = down_weight * scale_down
|
||||
converted_state_dict[diffusers_up_key] = up_weight * scale_up
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
|
||||
|
||||
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
||||
return converted_state_dict
|
||||
|
||||
@@ -49,6 +49,7 @@ from .lora_conversion_utils import (
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_ltxv_lora_to_diffusers,
|
||||
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
||||
_convert_non_diffusers_qwen_lora_to_diffusers,
|
||||
_convert_non_diffusers_wan_lora_to_diffusers,
|
||||
_convert_xlabs_flux_lora_to_diffusers,
|
||||
_maybe_map_sgm_blocks_to_diffusers,
|
||||
@@ -6548,7 +6549,6 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
@@ -6642,6 +6642,10 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
|
||||
if has_alphas_in_sd:
|
||||
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
|
||||
@@ -320,7 +320,9 @@ class PeftAdapterMixin:
|
||||
# it to None
|
||||
incompatible_keys = None
|
||||
else:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
inject_adapter_in_model(
|
||||
lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
|
||||
)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
if self._prepare_lora_hotswap_kwargs is not None:
|
||||
|
||||
@@ -153,9 +153,17 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"QwenImageTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": lambda x: x,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
|
||||
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
|
||||
|
||||
|
||||
def _get_single_file_loadable_mapping_class(cls):
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
@@ -381,19 +389,23 @@ class FromOriginalModelMixin:
|
||||
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
||||
diffusers_model_config.update(model_kwargs)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model = cls.from_config(diffusers_model_config)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
|
||||
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
else:
|
||||
diffusers_format_checkpoint = checkpoint
|
||||
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||
|
||||
@@ -60,6 +60,7 @@ if is_accelerate_available():
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
CHECKPOINT_KEY_NAMES = {
|
||||
"v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
|
||||
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
|
||||
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
|
||||
|
||||
@@ -30,7 +30,6 @@ from huggingface_hub import DDUFEntry
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
|
||||
from ..quantizers import DiffusersQuantizer
|
||||
from ..quantizers.quantization_config import QuantizationMethod
|
||||
from ..utils import (
|
||||
GGUF_FILE_EXTENSION,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@@ -232,7 +231,6 @@ def load_model_dict_into_meta(
|
||||
"""
|
||||
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_higgs = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HIGGS
|
||||
empty_state_dict = model.state_dict()
|
||||
|
||||
for param_name, param in state_dict.items():
|
||||
@@ -282,8 +280,7 @@ def load_model_dict_into_meta(
|
||||
|
||||
# bnb params are flattened.
|
||||
# gguf quants have a different shape based on the type of quantization applied
|
||||
# higgs quants repack the weights so they will have different shapes
|
||||
if empty_state_dict[param_name].shape != param.shape and not is_higgs:
|
||||
if empty_state_dict[param_name].shape != param.shape:
|
||||
if (
|
||||
is_quantized
|
||||
and hf_quantizer.pre_quantized
|
||||
@@ -307,7 +304,7 @@ def load_model_dict_into_meta(
|
||||
hf_quantizer.create_quantized_param(
|
||||
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
|
||||
)
|
||||
elif hf_quantizer is not None:
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
|
||||
|
||||
return offload_index, state_dict_index
|
||||
|
||||
@@ -384,7 +384,7 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ else:
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"ModularPipelineBlocks",
|
||||
"ModularPipeline",
|
||||
"PipelineBlock",
|
||||
"AutoPipelineBlocks",
|
||||
"SequentialPipelineBlocks",
|
||||
"LoopSequentialPipelineBlocks",
|
||||
@@ -59,7 +58,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipeline,
|
||||
ModularPipelineBlocks,
|
||||
PipelineBlock,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
|
||||
@@ -13,15 +13,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...models import AutoencoderKL
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import FluxModularPipeline
|
||||
|
||||
@@ -103,6 +104,62 @@ def calculate_shift(
|
||||
return mu
|
||||
|
||||
|
||||
# Adapted from the original implementation.
|
||||
def prepare_latents_img2img(
|
||||
vae, scheduler, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator
|
||||
):
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
||||
latent_channels = vae.config.latent_channels
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] != latent_channels:
|
||||
image_latents = _encode_vae_image(image=image, generator=generator)
|
||||
else:
|
||||
image_latents = image
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
image_latents = torch.cat([image_latents], dim=0)
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = scheduler.scale_noise(image_latents, timestep, noise)
|
||||
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
return latents, latent_image_ids
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
@@ -125,7 +182,56 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FluxInputStep(PipelineBlock):
|
||||
# Cannot use "# Copied from" because it introduces weird indentation errors.
|
||||
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator)
|
||||
|
||||
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
|
||||
def _get_initial_timesteps_and_optionals(
|
||||
transformer,
|
||||
scheduler,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
vae_scale_factor,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
sigmas,
|
||||
device,
|
||||
):
|
||||
image_seq_len = (int(height) // vae_scale_factor // 2) * (int(width) // vae_scale_factor // 2)
|
||||
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
scheduler.config.get("base_image_seq_len", 256),
|
||||
scheduler.config.get("max_image_seq_len", 4096),
|
||||
scheduler.config.get("base_shift", 0.5),
|
||||
scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
|
||||
if transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
return timesteps, num_inference_steps, sigmas, guidance
|
||||
|
||||
|
||||
class FluxInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -143,11 +249,6 @@ class FluxInputStep(PipelineBlock):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
@@ -216,7 +317,7 @@ class FluxInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxSetTimestepsStep(PipelineBlock):
|
||||
class FluxSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -235,17 +336,15 @@ class FluxSetTimestepsStep(PipelineBlock):
|
||||
InputParam("sigmas"),
|
||||
InputParam("guidance_scale", default=3.5),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam(
|
||||
"latents",
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
)
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -264,39 +363,127 @@ class FluxSetTimestepsStep(PipelineBlock):
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
transformer = components.transformer
|
||||
|
||||
latents = block_state.latents
|
||||
image_seq_len = latents.shape[1]
|
||||
|
||||
num_inference_steps = block_state.num_inference_steps
|
||||
sigmas = block_state.sigmas
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
|
||||
transformer,
|
||||
scheduler,
|
||||
batch_size,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
components.vae_scale_factor,
|
||||
block_state.num_inference_steps,
|
||||
block_state.guidance_scale,
|
||||
block_state.sigmas,
|
||||
block_state.device,
|
||||
)
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
block_state.sigmas = sigmas
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
scheduler.config.get("base_image_seq_len", 256),
|
||||
scheduler.config.get("max_image_seq_len", 4096),
|
||||
scheduler.config.get("base_shift", 0.5),
|
||||
scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
scheduler, block_state.num_inference_steps, block_state.device, sigmas=block_state.sigmas, mu=mu
|
||||
)
|
||||
if components.transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
block_state.guidance = guidance
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxPrepareLatentsStep(PipelineBlock):
|
||||
class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the scheduler's timesteps for inference"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_inference_steps", default=50),
|
||||
InputParam("timesteps"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("strength", default=0.6),
|
||||
InputParam("guidance_scale", default=3.5),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||
OutputParam(
|
||||
"num_inference_steps",
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time",
|
||||
),
|
||||
OutputParam(
|
||||
"latent_timestep",
|
||||
type_hint=torch.Tensor,
|
||||
description="The timestep that represents the initial noise level for image-to-image generation",
|
||||
),
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps with self.scheduler->scheduler
|
||||
def get_timesteps(scheduler, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||
if hasattr(scheduler, "set_begin_index"):
|
||||
scheduler.set_begin_index(t_start * scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
transformer = components.transformer
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
|
||||
transformer,
|
||||
scheduler,
|
||||
batch_size,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
components.vae_scale_factor,
|
||||
block_state.num_inference_steps,
|
||||
block_state.guidance_scale,
|
||||
block_state.sigmas,
|
||||
block_state.device,
|
||||
)
|
||||
timesteps, num_inference_steps = self.get_timesteps(
|
||||
scheduler, num_inference_steps, block_state.strength, block_state.device
|
||||
)
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
block_state.sigmas = sigmas
|
||||
block_state.guidance = guidance
|
||||
|
||||
block_state.latent_timestep = timesteps[:1].repeat(batch_size)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -305,7 +492,7 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare latents step that prepares the latents for the text-to-video generation process"
|
||||
return "Prepare latents step that prepares the latents for the text-to-image generation process"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
@@ -314,11 +501,6 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -402,10 +584,10 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
block_state.latents, block_state.latent_image_ids = self.prepare_latents(
|
||||
components,
|
||||
block_state.batch_size * block_state.num_images_per_prompt,
|
||||
batch_size,
|
||||
block_state.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
@@ -418,3 +600,90 @@ class FluxPrepareLatentsStep(PipelineBlock):
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the latents for the image-to-image generation process"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
|
||||
),
|
||||
InputParam(
|
||||
"latent_timestep",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
|
||||
),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||
),
|
||||
OutputParam(
|
||||
"latent_image_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="IDs computed from the image sequence needed for RoPE",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
block_state.device = components._execution_device
|
||||
|
||||
# TODO: implement `check_inputs`
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
if block_state.latents is None:
|
||||
block_state.latents, block_state.latent_image_ids = prepare_latents_img2img(
|
||||
components.vae,
|
||||
components.scheduler,
|
||||
block_state.image_latents,
|
||||
block_state.latent_timestep,
|
||||
batch_size,
|
||||
block_state.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL
|
||||
from ...utils import logging
|
||||
from ...video_processor import VaeImageProcessor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
return latents
|
||||
|
||||
|
||||
class FluxDecodeStep(PipelineBlock):
|
||||
class FluxDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -70,17 +70,12 @@ class FluxDecodeStep(PipelineBlock):
|
||||
InputParam("output_type", default="pil"),
|
||||
InputParam("height", default=1024),
|
||||
InputParam("width", default=1024),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -32,7 +32,7 @@ from .modular_pipeline import FluxModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FluxLoopDenoiser(PipelineBlock):
|
||||
class FluxLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -49,11 +49,8 @@ class FluxLoopDenoiser(PipelineBlock):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [InputParam("joint_attention_kwargs")]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -113,7 +110,7 @@ class FluxLoopDenoiser(PipelineBlock):
|
||||
return components, block_state
|
||||
|
||||
|
||||
class FluxLoopAfterDenoiser(PipelineBlock):
|
||||
class FluxLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -175,7 +172,7 @@ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
@@ -226,5 +223,5 @@ class FluxDenoiseStep(FluxDenoiseLoopWrapper):
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `FluxLoopDenoiser`\n"
|
||||
" - `FluxLoopAfterDenoiser`\n"
|
||||
"This block supports text2image tasks."
|
||||
"This block supports both text2image and img2img tasks."
|
||||
)
|
||||
|
||||
@@ -19,9 +19,12 @@ import regex as re
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL
|
||||
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import FluxModularPipeline
|
||||
|
||||
@@ -50,7 +53,110 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
class FluxTextEncoderStep(PipelineBlock):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class FluxVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae Encoder step that encode the input image into a latent representation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam(
|
||||
"preprocess_kwargs",
|
||||
type_hint=Optional[dict],
|
||||
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents representing the reference image for image-to-image/inpainting generation",
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae
|
||||
def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator)
|
||||
|
||||
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
|
||||
)
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
|
||||
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
||||
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components.vae, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
@property
|
||||
@@ -297,7 +403,7 @@ class FluxTextEncoderStep(PipelineBlock):
|
||||
prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
device=block_state.device,
|
||||
num_images_per_prompt=1, # hardcoded for now.
|
||||
num_images_per_prompt=1, # TODO: hardcoded for now.
|
||||
lora_scale=block_state.text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
|
||||
@@ -15,16 +15,38 @@
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import FluxInputStep, FluxPrepareLatentsStep, FluxSetTimestepsStep
|
||||
from .before_denoise import (
|
||||
FluxImg2ImgPrepareLatentsStep,
|
||||
FluxImg2ImgSetTimestepsStep,
|
||||
FluxInputStep,
|
||||
FluxPrepareLatentsStep,
|
||||
FluxSetTimestepsStep,
|
||||
)
|
||||
from .decoders import FluxDecodeStep
|
||||
from .denoise import FluxDenoiseStep
|
||||
from .encoders import FluxTextEncoderStep
|
||||
from .encoders import FluxTextEncoderStep, FluxVaeEncoderStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# before_denoise: text2vid
|
||||
# vae encoder (run before before_denoise)
|
||||
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxVaeEncoderStep]
|
||||
block_names = ["img2img"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block that works for img2img tasks.\n"
|
||||
+ " - `FluxVaeEncoderStep` (img2img) is used when only `image` is provided."
|
||||
+ " - if `image` is provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: text2img, img2img
|
||||
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
FluxInputStep,
|
||||
@@ -44,11 +66,27 @@ class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: all task (text2vid,)
|
||||
# before_denoise: img2img
|
||||
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [FluxInputStep, FluxImg2ImgSetTimestepsStep, FluxImg2ImgPrepareLatentsStep]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `FluxImg2ImgSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `FluxImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: all task (text2img, img2img)
|
||||
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxBeforeDenoiseStep]
|
||||
block_names = ["text2image"]
|
||||
block_trigger_inputs = [None]
|
||||
block_classes = [FluxBeforeDenoiseStep, FluxImg2ImgBeforeDenoiseStep]
|
||||
block_names = ["text2image", "img2img"]
|
||||
block_trigger_inputs = [None, "image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -56,6 +94,7 @@ class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for text2image.\n"
|
||||
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
|
||||
+ " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
|
||||
)
|
||||
|
||||
|
||||
@@ -69,8 +108,8 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2image tasks."
|
||||
" - `FluxDenoiseStep` (denoise) for text2image tasks."
|
||||
"This is a auto pipeline block that works for text2image and img2img tasks."
|
||||
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
|
||||
)
|
||||
|
||||
|
||||
@@ -82,19 +121,26 @@ class FluxAutoDecodeStep(AutoPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decode the denoised latents into videos outputs.\n - `FluxDecodeStep`"
|
||||
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
|
||||
|
||||
|
||||
# text2image
|
||||
class FluxAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [FluxTextEncoderStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep, FluxAutoDecodeStep]
|
||||
block_names = ["text_encoder", "before_denoise", "denoise", "decoder"]
|
||||
block_classes = [
|
||||
FluxTextEncoderStep,
|
||||
FluxAutoVaeEncoderStep,
|
||||
FluxAutoBeforeDenoiseStep,
|
||||
FluxAutoDenoiseStep,
|
||||
FluxAutoDecodeStep,
|
||||
]
|
||||
block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image using Flux.\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
"Auto Modular pipeline for text-to-image and image-to-image using Flux.\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`\n"
|
||||
+ "- for image-to-image generation, you need to provide either `image` or `image_latents`"
|
||||
)
|
||||
|
||||
|
||||
@@ -102,19 +148,29 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("input", FluxInputStep),
|
||||
("prepare_latents", FluxPrepareLatentsStep),
|
||||
# Setting it after preparation of latents because we rely on `latents`
|
||||
# to calculate `img_seq_len` for `shift`.
|
||||
("set_timesteps", FluxSetTimestepsStep),
|
||||
("prepare_latents", FluxPrepareLatentsStep),
|
||||
("denoise", FluxDenoiseStep),
|
||||
("decode", FluxDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("image_encoder", FluxVaeEncoderStep),
|
||||
("input", FluxInputStep),
|
||||
("set_timesteps", FluxImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", FluxImg2ImgPrepareLatentsStep),
|
||||
("denoise", FluxDenoiseStep),
|
||||
("decode", FluxDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep),
|
||||
("image_encoder", FluxAutoVaeEncoderStep),
|
||||
("before_denoise", FluxAutoBeforeDenoiseStep),
|
||||
("denoise", FluxAutoDenoiseStep),
|
||||
("decode", FluxAutoDecodeStep),
|
||||
@@ -122,4 +178,4 @@ AUTO_BLOCKS = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
|
||||
ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...loaders import FluxLoraLoaderMixin
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
|
||||
@@ -21,7 +21,7 @@ from ..modular_pipeline import ModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin):
|
||||
class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversionLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for Flux.
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -618,7 +618,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
|
||||
|
||||
def make_doc_string(
|
||||
inputs,
|
||||
intermediate_inputs,
|
||||
outputs,
|
||||
description="",
|
||||
class_name=None,
|
||||
@@ -664,7 +663,7 @@ def make_doc_string(
|
||||
output += configs_str + "\n\n"
|
||||
|
||||
# Add inputs section
|
||||
output += format_input_params(inputs + intermediate_inputs, indent_level=2)
|
||||
output += format_input_params(inputs, indent_level=2)
|
||||
|
||||
# Add outputs section
|
||||
output += "\n\n"
|
||||
|
||||
@@ -27,7 +27,7 @@ from ...schedulers import EulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
@@ -195,7 +195,7 @@ def prepare_latents_img2img(
|
||||
return latents
|
||||
|
||||
|
||||
class StableDiffusionXLInputStep(PipelineBlock):
|
||||
class StableDiffusionXLInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -213,11 +213,6 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
@@ -394,7 +389,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -421,11 +416,6 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
InputParam("denoising_start"),
|
||||
# YiYi TODO: do we need num_images_per_prompt here?
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
@@ -543,7 +533,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -611,7 +601,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -640,11 +630,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
"`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of "
|
||||
"`denoising_start` being declared as an integer, the value of `strength` will be ignored.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -744,8 +729,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
timestep=None,
|
||||
is_strength_max=True,
|
||||
add_noise=True,
|
||||
return_noise=False,
|
||||
return_image_latents=False,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
@@ -768,7 +751,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
if image.shape[1] == 4:
|
||||
image_latents = image.to(device=device, dtype=dtype)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
elif return_image_latents or (latents is None and not is_strength_max):
|
||||
elif latents is None and not is_strength_max:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(components, image=image, generator=generator)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
@@ -786,13 +769,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = image_latents.to(device)
|
||||
|
||||
outputs = (latents,)
|
||||
|
||||
if return_noise:
|
||||
outputs += (noise,)
|
||||
|
||||
if return_image_latents:
|
||||
outputs += (image_latents,)
|
||||
outputs = (latents, noise, image_latents)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -864,7 +841,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
|
||||
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
|
||||
|
||||
block_state.latents, block_state.noise = self.prepare_latents_inpaint(
|
||||
block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint(
|
||||
components,
|
||||
block_state.batch_size * block_state.num_images_per_prompt,
|
||||
components.num_channels_latents,
|
||||
@@ -878,8 +855,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
timestep=block_state.latent_timestep,
|
||||
is_strength_max=block_state.is_strength_max,
|
||||
add_noise=block_state.add_noise,
|
||||
return_noise=True,
|
||||
return_image_latents=False,
|
||||
)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
@@ -900,7 +875,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -920,11 +895,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("latents"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("denoising_start"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"latent_timestep",
|
||||
@@ -981,7 +951,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1002,11 +972,6 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
InputParam("width"),
|
||||
InputParam("latents"),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
@@ -1092,7 +1057,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1129,11 +1094,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("aesthetic_score", default=6.0),
|
||||
InputParam("negative_aesthetic_score", default=2.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1316,7 +1276,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1345,11 +1305,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
InputParam("crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("negative_crops_coords_top_left", default=(0, 0)),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1499,7 +1454,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1527,11 +1482,6 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
@@ -1718,7 +1668,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -1747,11 +1697,6 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
|
||||
InputParam("controlnet_conditioning_scale", default=1.0),
|
||||
InputParam("guess_mode", default=False),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...models import AutoencoderKL
|
||||
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -33,7 +33,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -56,17 +56,12 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
@@ -157,7 +152,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -184,11 +179,6 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
|
||||
InputParam("image"),
|
||||
InputParam("mask_image"),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# YiYi experimenting composible denoise loop
|
||||
# loop step (1): prepare latent input for denoiser
|
||||
class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -73,7 +73,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (1): prepare latent input for denoiser (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -91,7 +91,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
)
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
def inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
@@ -144,7 +144,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance
|
||||
class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -171,11 +171,6 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("cross_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
@@ -249,7 +244,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (2): denoise the latents with guidance (with controlnet)
|
||||
class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -277,11 +272,6 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("cross_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam(
|
||||
"controlnet_cond",
|
||||
required=True,
|
||||
@@ -449,7 +439,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents
|
||||
class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -470,11 +460,6 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@@ -520,7 +505,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
|
||||
|
||||
|
||||
# loop step (3): scheduler step to update latents (with inpainting)
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -542,11 +527,6 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("eta", default=0.0),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"timesteps",
|
||||
@@ -660,7 +640,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_intermediate_inputs(self) -> List[InputParam]:
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
|
||||
@@ -35,7 +35,7 @@ from ...utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||
|
||||
@@ -57,7 +57,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -215,7 +215,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -576,7 +576,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -601,11 +601,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
InputParam("image", required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
InputParam(
|
||||
@@ -668,12 +663,11 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
block_state.device = components._execution_device
|
||||
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
image = components.image_processor.preprocess(
|
||||
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
|
||||
)
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
image = image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
block_state.batch_size = image.shape[0]
|
||||
|
||||
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
|
||||
@@ -682,16 +676,14 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
|
||||
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "stable-diffusion-xl"
|
||||
|
||||
@property
|
||||
@@ -726,11 +718,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
InputParam("image", required=True),
|
||||
InputParam("mask_image", required=True),
|
||||
InputParam("padding_mask_crop"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
InputParam("generator"),
|
||||
]
|
||||
@@ -860,34 +847,32 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
|
||||
block_state.crops_coords = None
|
||||
block_state.resize_mode = "default"
|
||||
|
||||
block_state.image = components.image_processor.preprocess(
|
||||
image = components.image_processor.preprocess(
|
||||
block_state.image,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
crops_coords=block_state.crops_coords,
|
||||
resize_mode=block_state.resize_mode,
|
||||
)
|
||||
block_state.image = block_state.image.to(dtype=torch.float32)
|
||||
image = image.to(dtype=torch.float32)
|
||||
|
||||
block_state.mask = components.mask_processor.preprocess(
|
||||
mask = components.mask_processor.preprocess(
|
||||
block_state.mask_image,
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
resize_mode=block_state.resize_mode,
|
||||
crops_coords=block_state.crops_coords,
|
||||
)
|
||||
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
|
||||
block_state.masked_image = image * (mask < 0.5)
|
||||
|
||||
block_state.batch_size = block_state.image.shape[0]
|
||||
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
block_state.image_latents = self._encode_vae_image(
|
||||
components, image=block_state.image, generator=block_state.generator
|
||||
)
|
||||
block_state.batch_size = image.shape[0]
|
||||
image = image.to(device=block_state.device, dtype=block_state.dtype)
|
||||
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
|
||||
components,
|
||||
block_state.mask,
|
||||
mask,
|
||||
block_state.masked_image,
|
||||
block_state.batch_size,
|
||||
block_state.height,
|
||||
|
||||
@@ -247,10 +247,6 @@ SDXL_INPUTS_SCHEMA = {
|
||||
"control_mode": InputParam(
|
||||
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"prompt_embeds": InputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
@@ -271,13 +267,6 @@ SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"preprocess_kwargs": InputParam(
|
||||
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
|
||||
),
|
||||
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
|
||||
),
|
||||
"latent_timestep": InputParam(
|
||||
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
|
||||
),
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from ...schedulers import UniPCMultistepScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
|
||||
@@ -94,7 +94,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class WanInputStep(PipelineBlock):
|
||||
class WanInputStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -194,7 +194,7 @@ class WanInputStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanSetTimestepsStep(PipelineBlock):
|
||||
class WanSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -243,7 +243,7 @@ class WanSetTimestepsStep(PipelineBlock):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareLatentsStep(PipelineBlock):
|
||||
class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,14 +22,14 @@ from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKLWan
|
||||
from ...utils import logging
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanDecodeStep(PipelineBlock):
|
||||
class WanDecodeStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
PipelineBlock,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
@@ -34,7 +34,7 @@ from .modular_pipeline import WanModularPipeline
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanLoopDenoiser(PipelineBlock):
|
||||
class WanLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -132,7 +132,7 @@ class WanLoopDenoiser(PipelineBlock):
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanLoopAfterDenoiser(PipelineBlock):
|
||||
class WanLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...utils import is_ftfy_available, logging
|
||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
|
||||
@@ -51,7 +51,7 @@ def prompt_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
class WanTextEncoderStep(PipelineBlock):
|
||||
class WanTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -310,7 +310,7 @@ class FluxPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -324,7 +324,7 @@ class FluxControlPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -335,7 +335,7 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -374,7 +374,7 @@ class FluxControlInpaintPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -341,7 +341,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -335,7 +335,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -346,7 +346,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -419,7 +419,7 @@ class FluxFillPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -333,7 +333,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -337,7 +337,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -358,7 +358,7 @@ class FluxKontextPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -391,7 +391,7 @@ class FluxKontextInpaintPipeline(
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -292,7 +292,7 @@ class FluxPriorReduxPipeline(DiffusionPipeline):
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
||||
@@ -201,7 +201,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = self.tokenizer(
|
||||
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(self.device)
|
||||
).to(device)
|
||||
encoder_hidden_states = self.text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask,
|
||||
|
||||
@@ -21,11 +21,9 @@ from typing import Dict, Optional, Union
|
||||
|
||||
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
|
||||
from .gguf import GGUFQuantizer
|
||||
from .higgs import HiggsQuantizer
|
||||
from .quantization_config import (
|
||||
BitsAndBytesConfig,
|
||||
GGUFQuantizationConfig,
|
||||
HiggsConfig,
|
||||
QuantizationConfigMixin,
|
||||
QuantizationMethod,
|
||||
QuantoConfig,
|
||||
@@ -41,7 +39,6 @@ AUTO_QUANTIZER_MAPPING = {
|
||||
"gguf": GGUFQuantizer,
|
||||
"quanto": QuantoQuantizer,
|
||||
"torchao": TorchAoHfQuantizer,
|
||||
"higgs": HiggsQuantizer,
|
||||
}
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
@@ -50,7 +47,6 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"gguf": GGUFQuantizationConfig,
|
||||
"quanto": QuantoConfig,
|
||||
"torchao": TorchAoConfig,
|
||||
"higgs": HiggsConfig,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .higgs_quantizer import HiggsQuantizer
|
||||
@@ -1,205 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/huggingface/transformers/blob/d3d835d4fc145e5062d2153ac23ccd4b3e2c2cbd/src/transformers/quantizers/quantizer_higgs.py
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...utils import get_module_from_name
|
||||
from ..base import DiffusersQuantizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
from ...utils import is_accelerate_available, is_torch_available, logging
|
||||
from ...utils.logging import tqdm
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class HiggsQuantizer(DiffusersQuantizer):
|
||||
"""
|
||||
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of
|
||||
full-precision models.
|
||||
"""
|
||||
|
||||
requires_calibration = False
|
||||
requires_parameters_quantization = True
|
||||
required_packages = ["flute-kernel", "fast_hadamard_transform"]
|
||||
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def validate_environment(self, device_map, **kwargs):
|
||||
if not torch.cuda.is_available():
|
||||
raise NotImplementedError("HIGGS quantization is only supported on GPU. Please use a different quantizer.")
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Using `higgs` quantization requires Accelerate: `pip install accelerate`")
|
||||
|
||||
# TODO: enable this.
|
||||
# if not is_flute_available():
|
||||
# raise ImportError("Using `higgs` quantization requires FLUTE: `pip install flute-kernel>=0.3.0`")
|
||||
|
||||
# if not is_hadamard_available():
|
||||
# raise ImportError(
|
||||
# "Using `higgs` quantization requires fast_hadamard_transform: `pip install fast_hadamard_transform`"
|
||||
# )
|
||||
|
||||
if device_map is None:
|
||||
raise ValueError(
|
||||
"You are attempting to load a HIGGS model without setting device_map."
|
||||
" Please set device_map comprised of 'cuda' devices."
|
||||
)
|
||||
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
|
||||
raise ValueError(
|
||||
"You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device."
|
||||
" This is not supported. Please remove the CPU or disk device from the device_map."
|
||||
)
|
||||
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.float16` for FLUTE compatibility.")
|
||||
torch_dtype = torch.float16
|
||||
elif torch_dtype != torch.float16 and torch_dtype != torch.bfloat16:
|
||||
raise ValueError(
|
||||
f"Invalid `torch_dtype` {torch_dtype}. HIGGS quantization only supports `torch_dtype=torch.float16` or `torch_dtype=torch.bfloat16`."
|
||||
)
|
||||
|
||||
return torch_dtype
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: dict[str, Any],
|
||||
unexpected_keys: Optional[list[str]] = None,
|
||||
):
|
||||
from .utils import quantize_with_higgs
|
||||
|
||||
"""
|
||||
Quantizes weights into weight and weight_scale
|
||||
"""
|
||||
flute_dict = quantize_with_higgs(
|
||||
param_value.to(target_device),
|
||||
self.quantization_config.bits,
|
||||
self.quantization_config.p,
|
||||
self.quantization_config.group_size,
|
||||
self.quantization_config.hadamard_size,
|
||||
)
|
||||
del param_value
|
||||
|
||||
module, _ = get_module_from_name(model, param_name)
|
||||
module_name = ".".join(param_name.split(".")[:-1])
|
||||
for key, value in flute_dict.items():
|
||||
if key in module._parameters:
|
||||
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
|
||||
elif key in module._buffers:
|
||||
module._buffers[key] = torch.nn.Buffer(value)
|
||||
elif key == "tune_metadata":
|
||||
module.tune_metadata = value
|
||||
self.quantization_config.tune_metadata[module_name] = value.to_dict()
|
||||
else:
|
||||
raise ValueError(f"Unexpected key {key} in module {module}")
|
||||
|
||||
if unexpected_keys is not None and param_name in unexpected_keys:
|
||||
unexpected_keys.remove(param_name)
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
|
||||
from .utils import HiggsLinear
|
||||
|
||||
higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
|
||||
|
||||
def should_update(key: str) -> bool:
|
||||
if key.endswith(".weight") or key.endswith(".bias"):
|
||||
return False
|
||||
full_key = f"{prefix}.{key}"
|
||||
return any(name in key or name in full_key for name in higgs_names)
|
||||
|
||||
return [key for key in missing_keys if not should_update(key)]
|
||||
|
||||
@property
|
||||
def is_trainable(self):
|
||||
return False
|
||||
|
||||
def is_serializable(self):
|
||||
return True
|
||||
|
||||
def check_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
from .utils import HiggsLinear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16:
|
||||
# Only quantize weights of HiggsLinear modules that are not already quantized
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
**kwargs,
|
||||
):
|
||||
from .utils import replace_with_higgs_linear
|
||||
|
||||
replace_with_higgs_linear(model, quantization_config=self.quantization_config)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
|
||||
from flute.tune import TuneMetaData, maybe_tune_and_repack
|
||||
from flute.utils import make_workspace_streamk
|
||||
|
||||
from .utils import HiggsLinear
|
||||
|
||||
flute_workspaces = {}
|
||||
flute_modules = {name: module for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
|
||||
for name, module in tqdm(flute_modules.items(), desc="Repacking HIGGS modules", leave=False):
|
||||
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
|
||||
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
|
||||
if module.weight.device not in flute_workspaces:
|
||||
flute_workspaces[module.weight.device] = make_workspace_streamk(device=module.weight.device)
|
||||
module.workspace = flute_workspaces[module.weight.device]
|
||||
|
||||
# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
|
||||
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
|
||||
module.tune_metadata = TuneMetaData.from_dict(self.quantization_config.tune_metadata[name])
|
||||
module.weight.data, module.tune_metadata = maybe_tune_and_repack(
|
||||
weight=module.weight.data,
|
||||
scales=module.scales.data,
|
||||
metadata=module.tune_metadata,
|
||||
)
|
||||
self.quantization_config.tune_metadata[name] = module.tune_metadata.to_dict()
|
||||
|
||||
def _dequantize(self, model):
|
||||
from .utils import dequantize_higgs
|
||||
|
||||
model = dequantize_higgs(model)
|
||||
return model
|
||||
@@ -1,690 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
HIGGS through FLUTE (Flexible Lookup Table Engine for LUT-quantized LLMs) integration file.
|
||||
|
||||
Taken from:
|
||||
https://github.com/huggingface/transformers/blob/d3d835d4fc145e5062d2153ac23ccd4b3e2c2cbd/src/transformers/integrations/higgs.py
|
||||
"""
|
||||
|
||||
from math import sqrt
|
||||
|
||||
from ...utils import (
|
||||
# TODO enable:
|
||||
# is_flute_available,
|
||||
# is_hadamard_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
# if is_flute_available():
|
||||
# if is_hadamard_available():
|
||||
from fast_hadamard_transform import hadamard_transform
|
||||
from flute.integrations.higgs import prepare_data_transposed
|
||||
from flute.tune import TuneMetaData, qgemm_v2
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def pad_to_block(tensor, dims, had_block_size, value=0):
|
||||
pad_dims = [0 for _ in range(2 * len(tensor.shape))]
|
||||
for dim in dims:
|
||||
size = tensor.shape[dim]
|
||||
next_multiple_of_1024 = ((size - 1) // had_block_size + 1) * had_block_size
|
||||
delta = next_multiple_of_1024 - size
|
||||
pad_dims[-2 * dim - 1] = delta
|
||||
|
||||
return nn.functional.pad(tensor, pad_dims, "constant", value)
|
||||
|
||||
|
||||
def get_higgs_grid(p: int, n: int):
|
||||
if (p, n) == (2, 256):
|
||||
return torch.tensor(
|
||||
[
|
||||
[-2.501467704772949, 0.17954708635807037],
|
||||
[-0.6761789321899414, 1.2728623151779175],
|
||||
[-1.8025816679000854, 0.7613157629966736],
|
||||
[-0.538287878036499, -2.6028504371643066],
|
||||
[0.8415029644966125, -0.8600977659225464],
|
||||
[0.7023013234138489, 3.3138747215270996],
|
||||
[0.5699077844619751, 2.5782253742218018],
|
||||
[3.292393207550049, -0.6016128063201904],
|
||||
[0.5561617016792297, -1.7723814249038696],
|
||||
[-2.1012380123138428, 0.020958125591278076],
|
||||
[0.46085724234580994, 0.8428705334663391],
|
||||
[1.4548040628433228, -0.6156039237976074],
|
||||
[3.210029363632202, 0.3546904921531677],
|
||||
[0.8893890976905823, -0.5967988967895508],
|
||||
[0.8618854284286499, -3.2061192989349365],
|
||||
[1.1360996961593628, -0.23852407932281494],
|
||||
[1.6646337509155273, -0.9265465140342712],
|
||||
[1.4767773151397705, 1.2476022243499756],
|
||||
[-1.0511897802352905, 1.94503915309906],
|
||||
[-1.56318998336792, -0.3264186680316925],
|
||||
[-0.1829211413860321, 0.2922491431236267],
|
||||
[-0.8950616717338562, -1.3887052536010742],
|
||||
[-0.08206957578659058, -1.329533576965332],
|
||||
[-0.487422913312912, 1.4817842245101929],
|
||||
[-1.6769757270812988, -2.8269758224487305],
|
||||
[-1.5057679414749146, 1.8905963897705078],
|
||||
[1.8335362672805786, 1.0515104532241821],
|
||||
[0.3273945450782776, 1.0491033792495728],
|
||||
[-3.295924186706543, -0.7021600008010864],
|
||||
[-1.8428784608840942, -1.2315762042999268],
|
||||
[-0.8575026392936707, -1.7005949020385742],
|
||||
[-1.120667815208435, 0.6467998027801514],
|
||||
[-0.1588846743106842, -1.804071068763733],
|
||||
[-0.8539647459983826, 0.5645008683204651],
|
||||
[-1.4192019701004028, -0.6175029873847961],
|
||||
[1.0799058675765991, 1.7871345281600952],
|
||||
[1.171311855316162, 0.7511613965034485],
|
||||
[2.162078380584717, 0.8044339418411255],
|
||||
[1.3969420194625854, -1.243762493133545],
|
||||
[-0.23818807303905487, 0.053944624960422516],
|
||||
[2.304199457168579, -1.2667627334594727],
|
||||
[1.4225027561187744, 0.568610668182373],
|
||||
[0.376836895942688, -0.7134661674499512],
|
||||
[2.0404467582702637, 0.4087389409542084],
|
||||
[0.7639489769935608, -1.1367933750152588],
|
||||
[0.3622530400753021, -1.4827953577041626],
|
||||
[0.4100743532180786, 0.36108437180519104],
|
||||
[-1.5867475271224976, -1.618212342262268],
|
||||
[-2.2769672870635986, -1.2132309675216675],
|
||||
[0.9184022545814514, -0.34428009390830994],
|
||||
[-0.3902314603328705, 0.21785245835781097],
|
||||
[3.120687484741211, 1.3077973127365112],
|
||||
[1.587440848350525, -1.6506884098052979],
|
||||
[-1.718808889389038, -0.038405973464250565],
|
||||
[-0.6888407468795776, -0.8402308821678162],
|
||||
[-0.7981445789337158, -1.1117373704910278],
|
||||
[-2.4124443531036377, 1.3419722318649292],
|
||||
[-0.6611530184745789, 0.9939885139465332],
|
||||
[-0.33103418350219727, -0.16702833771705627],
|
||||
[-2.4091389179229736, -2.326857566833496],
|
||||
[1.6610108613967896, -2.159703254699707],
|
||||
[0.014884627424180508, 0.3887578248977661],
|
||||
[0.029668325558304787, 1.8786455392837524],
|
||||
[1.180362582206726, 2.699317216873169],
|
||||
[1.821286678314209, -0.5960053205490112],
|
||||
[-0.44835323095321655, 3.327436685562134],
|
||||
[-0.3714401423931122, -2.1466753482818604],
|
||||
[-1.1103475093841553, -2.4536871910095215],
|
||||
[-0.39110705256462097, 0.6670510172843933],
|
||||
[0.474752813577652, -1.1959707736968994],
|
||||
[-0.013110585510730743, -2.52519154548645],
|
||||
[-2.0836575031280518, -1.703289270401001],
|
||||
[-1.1077687740325928, -0.1252644956111908],
|
||||
[-0.4138077199459076, 1.1837692260742188],
|
||||
[-1.977599024772644, 1.688241720199585],
|
||||
[-1.659559965133667, -2.1387736797332764],
|
||||
[0.03242531046271324, 0.6526556015014648],
|
||||
[0.9127950072288513, 0.6099498867988586],
|
||||
[-0.38478314876556396, 0.433487206697464],
|
||||
[0.27454206347465515, -0.27719801664352417],
|
||||
[0.10388526320457458, 2.2812814712524414],
|
||||
[-0.014394169673323631, -3.177137613296509],
|
||||
[-1.2871228456497192, -0.8961855173110962],
|
||||
[0.5720916986465454, -0.921597957611084],
|
||||
[1.1159656047821045, -0.7609877586364746],
|
||||
[2.4383342266082764, -2.2983546257019043],
|
||||
[-0.294057160615921, -0.9770799875259399],
|
||||
[-0.9342701435089111, 1.107579231262207],
|
||||
[-1.549338698387146, 3.090520143508911],
|
||||
[2.6076579093933105, 2.051239013671875],
|
||||
[-0.9259037375450134, 1.407211184501648],
|
||||
[-0.1747353971004486, 0.540488600730896],
|
||||
[-0.8963701725006104, 0.8271111249923706],
|
||||
[0.6480194926261902, 1.0128909349441528],
|
||||
[0.980783998966217, -0.06156221032142639],
|
||||
[-0.16883476078510284, 1.0601658821105957],
|
||||
[0.5839992761611938, 0.004697148688137531],
|
||||
[-0.34228450059890747, -1.2423977851867676],
|
||||
[2.500824451446533, 0.3665279746055603],
|
||||
[-0.17641609907150269, 1.3529551029205322],
|
||||
[0.05378641560673714, 2.817232847213745],
|
||||
[-1.2391047477722168, 2.354328155517578],
|
||||
[0.630434513092041, -0.668536365032196],
|
||||
[1.7576488256454468, 0.6738647818565369],
|
||||
[0.4435231387615204, 0.6000469326972961],
|
||||
[-0.08794835954904556, -0.11511358618736267],
|
||||
[1.6540337800979614, 0.33995017409324646],
|
||||
[-0.04202975332736969, -0.5375117063522339],
|
||||
[-0.4247745871543884, -0.7897617220878601],
|
||||
[0.06695003807544708, 1.2000739574432373],
|
||||
[-3.2508881092071533, 0.28734830021858215],
|
||||
[-1.613816261291504, 0.4944162368774414],
|
||||
[1.3598989248275757, 0.26117825508117676],
|
||||
[2.308382511138916, 1.3462618589401245],
|
||||
[-1.2137469053268433, -1.9254342317581177],
|
||||
[-0.4889402985572815, 1.8136259317398071],
|
||||
[-0.1870335340499878, -0.3480615019798279],
|
||||
[1.0766386985778809, -1.0627082586288452],
|
||||
[0.4651014506816864, 2.131748914718628],
|
||||
[-0.1306295394897461, -0.7811847925186157],
|
||||
[0.06433182954788208, -1.5397958755493164],
|
||||
[-0.2894323468208313, -0.5789554715156555],
|
||||
[-0.6081662178039551, 0.4845278263092041],
|
||||
[2.697964668273926, -0.18515698611736298],
|
||||
[0.1277363896369934, -0.7221432328224182],
|
||||
[0.8700758218765259, 0.35042452812194824],
|
||||
[0.22088994085788727, 0.495242178440094],
|
||||
[-2.5843818187713623, -0.8000828623771667],
|
||||
[0.6732649803161621, -1.4362232685089111],
|
||||
[-1.5286413431167603, 1.0417330265045166],
|
||||
[-1.1222513914108276, -0.6269875764846802],
|
||||
[-0.9752035140991211, -0.8750635385513306],
|
||||
[-2.6369473934173584, 0.6918523907661438],
|
||||
[0.14478731155395508, -0.041986867785453796],
|
||||
[-1.5629483461380005, 1.4369450807571411],
|
||||
[0.38952457904815674, -2.16428804397583],
|
||||
[-0.16885095834732056, 0.7976621985435486],
|
||||
[-3.12416934967041, 1.256506085395813],
|
||||
[0.6843105554580688, -0.4203019142150879],
|
||||
[1.9345275163650513, 1.934950351715088],
|
||||
[0.012184220366179943, -2.1080918312072754],
|
||||
[-0.6350273489952087, 0.7358828186988831],
|
||||
[-0.837304949760437, -0.6214472651481628],
|
||||
[0.08211923390626907, -0.9472538232803345],
|
||||
[2.9332995414733887, -1.4956780672073364],
|
||||
[1.3806978464126587, -0.2916182279586792],
|
||||
[0.06773144006729126, 0.9285762310028076],
|
||||
[-1.1943119764328003, 1.5963770151138306],
|
||||
[1.6395620107650757, -0.32285431027412415],
|
||||
[-1.390851378440857, -0.08273141086101532],
|
||||
[1.816330909729004, -1.2812227010726929],
|
||||
[0.7921574711799622, -2.1135804653167725],
|
||||
[0.5817914605140686, 1.2644577026367188],
|
||||
[1.929347038269043, -0.2386285960674286],
|
||||
[0.8877345323562622, 1.190008521080017],
|
||||
[1.4732073545455933, 0.8935023546218872],
|
||||
[-2.8518524169921875, -1.5478795766830444],
|
||||
[0.2439267635345459, 0.7576767802238464],
|
||||
[0.5246709585189819, -2.606659412384033],
|
||||
[1.150876760482788, 1.4073830842971802],
|
||||
[-0.2643202245235443, 2.0634236335754395],
|
||||
[1.555483341217041, -0.0023102816194295883],
|
||||
[2.0830578804016113, -1.7225427627563477],
|
||||
[-0.5424830317497253, -1.070199728012085],
|
||||
[0.9168899655342102, 0.8955540060997009],
|
||||
[-0.8120972514152527, 2.696739912033081],
|
||||
[-0.29908373951911926, -1.5310651063919067],
|
||||
[1.2320337295532227, -1.556247353553772],
|
||||
[1.8612544536590576, 0.08704725652933121],
|
||||
[0.22133447229862213, -1.8091708421707153],
|
||||
[-0.4403655230998993, -0.38571012020111084],
|
||||
[-1.88539457321167, 1.192205786705017],
|
||||
[2.239687919616699, 0.004709010478109121],
|
||||
[1.139495611190796, 0.45733731985092163],
|
||||
[-1.507995367050171, 0.19716016948223114],
|
||||
[0.46986445784568787, 1.5422041416168213],
|
||||
[-1.2573751211166382, -0.35984551906585693],
|
||||
[-1.7415345907211304, -0.6020717024803162],
|
||||
[1.0751984119415283, 0.19006384909152985],
|
||||
[2.24186635017395, -0.46343153715133667],
|
||||
[0.3610347509384155, -0.07658443599939346],
|
||||
[-1.3111497163772583, 0.432013601064682],
|
||||
[0.6164408326148987, 0.24538464844226837],
|
||||
[-1.9266542196273804, -0.3256155550479889],
|
||||
[-0.5870336890220642, -0.1879584938287735],
|
||||
[-1.0476511716842651, 0.3677721917629242],
|
||||
[-1.229940414428711, 1.2433830499649048],
|
||||
[0.18550436198711395, 0.22753673791885376],
|
||||
[-0.017921989783644676, 0.12625974416732788],
|
||||
[1.1659504175186157, -0.5020995736122131],
|
||||
[-0.5983408093452454, -1.40438973903656],
|
||||
[0.7519024014472961, -0.16282692551612854],
|
||||
[0.9920787811279297, -1.344896912574768],
|
||||
[-0.8103678226470947, 0.3064485788345337],
|
||||
[0.6956969499588013, 1.8208192586898804],
|
||||
[-2.7830491065979004, -0.2299390584230423],
|
||||
[-0.34681546688079834, 2.4890666007995605],
|
||||
[-1.4452646970748901, -1.2216600179672241],
|
||||
[-2.1872897148132324, 0.8926076292991638],
|
||||
[1.706072211265564, -2.8440372943878174],
|
||||
[1.1119003295898438, -2.4923460483551025],
|
||||
[-2.582794666290283, 2.0973289012908936],
|
||||
[0.04987720400094986, -0.2964983284473419],
|
||||
[-2.063807487487793, -0.7847916483879089],
|
||||
[-0.4068813621997833, 0.9135897755622864],
|
||||
[-0.9814359545707703, -0.3874954879283905],
|
||||
[-1.4227229356765747, 0.7337291240692139],
|
||||
[0.3065044581890106, 1.3125417232513428],
|
||||
[1.2160996198654175, -1.9643305540084839],
|
||||
[-1.2163853645324707, 0.14608727395534515],
|
||||
[-2.3030710220336914, -0.37558120489120483],
|
||||
[0.9232977628707886, 2.1843791007995605],
|
||||
[-0.1989777386188507, 1.651851773262024],
|
||||
[-0.714374840259552, -0.39365994930267334],
|
||||
[-0.7805715799331665, -2.099881887435913],
|
||||
[0.9015759229660034, -1.7053706645965576],
|
||||
[0.1033422127366066, 1.5256654024124146],
|
||||
[-1.8773194551467896, 2.324174165725708],
|
||||
[1.9227174520492554, 2.7441604137420654],
|
||||
[-0.5994020104408264, 0.23984014987945557],
|
||||
[1.3496100902557373, -0.9126054644584656],
|
||||
[-0.8765304088592529, -3.1877026557922363],
|
||||
[-1.2040035724639893, -1.5169521570205688],
|
||||
[1.4261796474456787, 2.150200128555298],
|
||||
[1.463774561882019, 1.6656692028045654],
|
||||
[0.20364105701446533, -0.4988172650337219],
|
||||
[0.5195154547691345, -0.24067887663841248],
|
||||
[-1.1116786003112793, -1.1599653959274292],
|
||||
[-0.8490808606147766, -0.1681060940027237],
|
||||
[0.3189965784549713, -0.9641751646995544],
|
||||
[-0.5664751529693604, -0.5951744318008423],
|
||||
[-1.6347930431365967, -0.9137664437294006],
|
||||
[0.44048091769218445, -0.47259435057640076],
|
||||
[-2.147747039794922, 0.47442489862442017],
|
||||
[1.834734320640564, 1.4462147951126099],
|
||||
[1.1777573823928833, 1.0659226179122925],
|
||||
[-0.9568989872932434, 0.09495053440332413],
|
||||
[-1.838529348373413, 0.2950586676597595],
|
||||
[-0.4800611734390259, 0.014894310384988785],
|
||||
[-0.5235516428947449, -1.7687653303146362],
|
||||
[2.0735011100769043, -0.8825281262397766],
|
||||
[2.637502431869507, 0.8455678224563599],
|
||||
[2.606602907180786, -0.7848446369171143],
|
||||
[-1.1886937618255615, 0.9330510497093201],
|
||||
[0.38082656264305115, 0.13328030705451965],
|
||||
[0.6847941875457764, 0.7384101152420044],
|
||||
[1.2638574838638306, -0.007309418171644211],
|
||||
[0.18292222917079926, -1.22371244430542],
|
||||
[0.8143821954727173, 1.4976691007614136],
|
||||
[0.6571850776672363, 0.48368802666664124],
|
||||
[-0.6991601586341858, 2.150190830230713],
|
||||
[0.8101756572723389, 0.10206498205661774],
|
||||
[-0.08768226951360703, -1.084917664527893],
|
||||
[-0.7208092212677002, 0.03657956421375275],
|
||||
[0.3211449086666107, 1.803687334060669],
|
||||
[-0.7835946083068848, 1.6869111061096191],
|
||||
]
|
||||
)
|
||||
if (p, n) == (2, 64):
|
||||
return torch.tensor(
|
||||
[
|
||||
[-2.7216711044311523, 0.14431366324424744],
|
||||
[-0.766914427280426, 1.7193410396575928],
|
||||
[-2.2575762271881104, 1.2476624250411987],
|
||||
[1.233758807182312, -2.3560616970062256],
|
||||
[0.8701965808868408, -0.2649352252483368],
|
||||
[1.4506438970565796, 2.1776366233825684],
|
||||
[-0.06305818259716034, 1.9049758911132812],
|
||||
[2.536226511001587, 0.563927412033081],
|
||||
[0.4599496126174927, -1.8745561838150024],
|
||||
[-1.900517225265503, -0.30703988671302795],
|
||||
[0.09386251866817474, 0.8755807280540466],
|
||||
[1.946500539779663, -0.6743080615997314],
|
||||
[2.1338934898376465, 1.4581491947174072],
|
||||
[0.9429940581321716, -0.8038390278816223],
|
||||
[2.0697755813598633, -1.614896535873413],
|
||||
[0.772676408290863, 0.22017823159694672],
|
||||
[1.0689979791641235, -1.525044322013855],
|
||||
[0.6813604831695557, 1.1345642805099487],
|
||||
[0.4706456661224365, 2.606626272201538],
|
||||
[-1.294018030166626, -0.4372096061706543],
|
||||
[-0.09134224057197571, 0.4610418677330017],
|
||||
[-0.7907772064208984, -0.48412787914276123],
|
||||
[0.060459110885858536, -0.9172890186309814],
|
||||
[-0.5855047702789307, 2.56172513961792],
|
||||
[0.11484206467866898, -2.659848213195801],
|
||||
[-1.5893300771713257, 2.188580274581909],
|
||||
[1.6750942468643188, 0.7089915871620178],
|
||||
[-0.445697546005249, 0.7452405095100403],
|
||||
[-1.8539940118789673, -1.8377939462661743],
|
||||
[-1.5791912078857422, -1.017285943031311],
|
||||
[-1.030419945716858, -1.5746369361877441],
|
||||
[-1.9511750936508179, 0.43696075677871704],
|
||||
[-0.3446580767631531, -1.8953213691711426],
|
||||
[-1.4219647645950317, 0.7676230669021606],
|
||||
[-0.9191089272499084, 0.5021472573280334],
|
||||
[0.20464491844177246, 1.3684605360031128],
|
||||
[0.5402919054031372, 0.6699410676956177],
|
||||
[1.8903915882110596, 0.03638288006186485],
|
||||
[0.4723062515258789, -0.6216739416122437],
|
||||
[-0.41345009207725525, -0.22752176225185394],
|
||||
[2.7119064331054688, -0.5111885070800781],
|
||||
[1.065286636352539, 0.6950305700302124],
|
||||
[0.40629103779792786, -0.14339995384216309],
|
||||
[1.2815024852752686, 0.17108257114887238],
|
||||
[0.01785222627222538, -0.43778058886528015],
|
||||
[0.054590027779340744, -1.4225547313690186],
|
||||
[0.3076786696910858, 0.30697619915008545],
|
||||
[-0.9498570561408997, -0.9576997756958008],
|
||||
[-2.4640724658966064, -0.9660449028015137],
|
||||
[1.3714425563812256, -0.39760473370552063],
|
||||
[-0.4857747256755829, 0.2386789172887802],
|
||||
[1.2797833681106567, 1.3097363710403442],
|
||||
[0.5508887767791748, -1.1777795553207397],
|
||||
[-1.384316325187683, 0.1465839296579361],
|
||||
[-0.46556955575942993, -1.2442727088928223],
|
||||
[-0.3915477693080902, -0.7319604158401489],
|
||||
[-1.4005504846572876, 1.3890998363494873],
|
||||
[-0.8647305965423584, 1.0617644786834717],
|
||||
[-0.8901953101158142, -0.01650036871433258],
|
||||
[-0.9893633723258972, -2.4662880897521973],
|
||||
[1.445534110069275, -1.049334168434143],
|
||||
[-0.041650623083114624, 0.012734669260680676],
|
||||
[-0.3302375078201294, 1.26217782497406],
|
||||
[0.6934980154037476, 1.7714335918426514],
|
||||
]
|
||||
)
|
||||
elif (p, n) == (2, 16):
|
||||
return torch.tensor(
|
||||
[
|
||||
[-0.8996632695198059, -1.6360418796539307],
|
||||
[-0.961183488368988, 1.5999565124511719],
|
||||
[-1.882026195526123, 0.678778350353241],
|
||||
[0.36300793290138245, -1.9667866230010986],
|
||||
[-0.6814072728157043, -0.576818585395813],
|
||||
[0.7270012497901917, 0.6186859607696533],
|
||||
[0.3359416127204895, 1.8371193408966064],
|
||||
[1.859930396080017, 0.036668598651885986],
|
||||
[0.17208248376846313, -0.9401724338531494],
|
||||
[-1.7599700689315796, -0.6244229674339294],
|
||||
[-0.8993809223175049, 0.32267823815345764],
|
||||
[0.839488685131073, -0.3017036020755768],
|
||||
[1.5314953327178955, 1.2942044734954834],
|
||||
[-0.0011779458727687597, 0.00022069070837460458],
|
||||
[1.4274526834487915, -1.207889199256897],
|
||||
[-0.16123905777931213, 0.8787511587142944],
|
||||
]
|
||||
)
|
||||
elif (p, n) == (1, 16):
|
||||
return torch.tensor(
|
||||
[
|
||||
[-2.7325894832611084],
|
||||
[-2.069017171859741],
|
||||
[-1.6180464029312134],
|
||||
[-1.2562311887741089],
|
||||
[-0.9423404335975647],
|
||||
[-0.6567591428756714],
|
||||
[-0.38804829120635986],
|
||||
[-0.12839503586292267],
|
||||
[0.12839503586292267],
|
||||
[0.38804829120635986],
|
||||
[0.6567591428756714],
|
||||
[0.9423404335975647],
|
||||
[1.2562311887741089],
|
||||
[1.6180464029312134],
|
||||
[2.069017171859741],
|
||||
[2.7325894832611084],
|
||||
]
|
||||
)
|
||||
elif (p, n) == (1, 8):
|
||||
return torch.tensor(
|
||||
[
|
||||
[-2.1519455909729004],
|
||||
[-1.3439092636108398],
|
||||
[-0.7560052871704102],
|
||||
[-0.2450941801071167],
|
||||
[0.2450941801071167],
|
||||
[0.7560052871704102],
|
||||
[1.3439092636108398],
|
||||
[2.1519455909729004],
|
||||
]
|
||||
)
|
||||
elif (p, n) == (1, 4):
|
||||
return torch.tensor([[-1.5104175806045532], [-0.4527800381183624], [0.4527800381183624], [1.5104175806045532]])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported p={p}, n={n}")
|
||||
|
||||
|
||||
def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256, hadamard_size: int = 1024):
|
||||
assert len(weight.shape) == 2, "Only 2D weights are supported for now"
|
||||
|
||||
grid = get_higgs_grid(p, 2 ** (p * bits)).to(weight.device)
|
||||
grid_norm_2 = torch.linalg.norm(grid, axis=-1) ** 2
|
||||
|
||||
device = weight.device
|
||||
dtype = weight.dtype
|
||||
weight = weight.to(copy=True, dtype=torch.float32)
|
||||
# Pad to Hadamard transform size
|
||||
weight = pad_to_block(weight, [1], hadamard_size)
|
||||
|
||||
# Scale and Hadamard transform
|
||||
mult = weight.shape[1] // hadamard_size
|
||||
weight = weight.reshape(-1, mult, hadamard_size)
|
||||
scales = torch.linalg.norm(weight, axis=-1)
|
||||
weight = hadamard_transform(weight, 1) / scales[:, :, None]
|
||||
|
||||
# Pad to edenn_d and project
|
||||
weight = pad_to_block(weight, [2], p).reshape(weight.shape[0], mult, -1, p)
|
||||
|
||||
# Quantize
|
||||
codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8)
|
||||
for i in range(0, weight.shape[0], 16):
|
||||
codes[i : i + 16] = torch.argmax(2 * weight[i : i + 16] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
|
||||
del weight
|
||||
|
||||
codes = codes.reshape(codes.shape[0], -1)
|
||||
scales = scales / sqrt(hadamard_size)
|
||||
|
||||
weight, scales, tables, tables2, tune_metadata = prepare_data_transposed(
|
||||
codes,
|
||||
torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1),
|
||||
grid.to(dtype),
|
||||
num_bits=bits,
|
||||
group_size=group_size,
|
||||
vector_size=p,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
check_correctness=False,
|
||||
)
|
||||
|
||||
return {
|
||||
"weight": weight,
|
||||
"scales": scales,
|
||||
"tables": tables,
|
||||
"tables2": tables2.view(dtype=torch.float16),
|
||||
"tune_metadata": tune_metadata,
|
||||
}
|
||||
|
||||
|
||||
class HiggsLinear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
num_bits: int,
|
||||
bias=True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
group_size: int = 256,
|
||||
hadamard_size: int = 1024,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.num_bits = num_bits
|
||||
self.group_size = group_size
|
||||
self.hadamard_size = hadamard_size
|
||||
|
||||
assert in_features % group_size == 0
|
||||
assert num_bits in [2, 3, 4]
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((out_features * num_bits // 16, in_features), dtype=torch.int16, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.scales = nn.Parameter(
|
||||
torch.empty((out_features, in_features // group_size), dtype=dtype, device=device), requires_grad=False
|
||||
)
|
||||
self.tables = nn.Parameter(torch.empty((2**num_bits,), dtype=dtype, device=device), requires_grad=False)
|
||||
self.tables2 = nn.Parameter(
|
||||
torch.empty((2**num_bits, 2**num_bits, 2), dtype=dtype, device=device), requires_grad=False
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(out_features, device=device, dtype=dtype), requires_grad=False)
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.workspace = None # must be set externally to be reused among layers
|
||||
self.tune_metadata: TuneMetaData = None # must be set externally because architecture dependent
|
||||
|
||||
def forward(self, x):
|
||||
x = pad_to_block(x, [-1], self.hadamard_size)
|
||||
|
||||
if self.workspace is None:
|
||||
raise Exception("Workspace must be set before calling forward")
|
||||
|
||||
return qgemm_v2(
|
||||
x,
|
||||
self.weight,
|
||||
self.scales,
|
||||
self.tables,
|
||||
self.tables2.view(dtype=torch.float32),
|
||||
self.workspace,
|
||||
self.tune_metadata,
|
||||
hadamard_size=self.hadamard_size,
|
||||
)
|
||||
|
||||
|
||||
def _replace_with_higgs_linear(
|
||||
model, quantization_config=None, current_key_name=None, modules_to_not_convert=None, has_been_replaced=False
|
||||
):
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(current_key_name_str.endswith(key) for key in modules_to_not_convert):
|
||||
with init_empty_weights():
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
# Original size is [3072, 4096]. But after `HiggsLinear`, this is
|
||||
# [768, 4096]. 🤯
|
||||
if name == "context_embedder":
|
||||
print(f"{in_features=}, {out_features=}")
|
||||
model._modules[name] = HiggsLinear(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=module.bias is not None,
|
||||
num_bits=quantization_config.bits,
|
||||
hadamard_size=quantization_config.hadamard_size,
|
||||
group_size=quantization_config.group_size,
|
||||
)
|
||||
if name == "context_embedder":
|
||||
print(model._modules[name].weight.shape)
|
||||
has_been_replaced = True
|
||||
|
||||
# Store the module class in case we need to transpose the weight later
|
||||
model._modules[name].source_cls = type(module)
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_higgs_linear(
|
||||
module,
|
||||
quantization_config=quantization_config,
|
||||
current_key_name=current_key_name,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def replace_with_higgs_linear(
|
||||
model,
|
||||
quantization_config=None,
|
||||
current_key_name=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
"""
|
||||
Public method that recursively replaces the Linear layers of the given model with HIGGS quantized layers.
|
||||
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
|
||||
conversion has been successful or not.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
The model to convert, can be any `torch.nn.Module` instance.
|
||||
quantization_config (`HiggsConfig`):
|
||||
The quantization config object that contains the quantization parameters.
|
||||
current_key_name (`list`, *optional*):
|
||||
A list that contains the current key name. This is used for recursion and should not be passed by the user.
|
||||
has_been_replaced (`bool`, *optional*):
|
||||
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
|
||||
should not be passed by the user.
|
||||
"""
|
||||
modules_to_not_convert = quantization_config.modules_to_not_convert or []
|
||||
model, _ = _replace_with_higgs_linear(
|
||||
model, quantization_config, current_key_name, modules_to_not_convert, has_been_replaced
|
||||
)
|
||||
|
||||
has_been_replaced = any(isinstance(replaced_module, HiggsLinear) for _, replaced_module in model.named_modules())
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model in Higgs but no linear modules were found in your model."
|
||||
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||
" a bug."
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def dequantize_higgs(model, current_key_name=None):
|
||||
"""
|
||||
Dequantizes the HiggsLinear layers in the given model by replacing them with standard torch.nn.Linear layers.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model containing HiggsLinear layers to be dequantized.
|
||||
current_key_name (list, optional):
|
||||
A list to keep track of the current module names during recursion. Defaults to None.
|
||||
Returns:
|
||||
torch.nn.Module: The model with HiggsLinear layers replaced by torch.nn.Linear layers.
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, HiggsLinear):
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
|
||||
model._modules[name] = torch.nn.Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.scales.device,
|
||||
dtype=module.scales.dtype,
|
||||
)
|
||||
|
||||
model._modules[name].weight.data = module(
|
||||
torch.eye(in_features, device=module.scales.device, dtype=module.scales.dtype)
|
||||
).T.contiguous()
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_ = dequantize_higgs(
|
||||
module,
|
||||
current_key_name=current_key_name,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model
|
||||
@@ -46,7 +46,6 @@ class QuantizationMethod(str, Enum):
|
||||
GGUF = "gguf"
|
||||
TORCHAO = "torchao"
|
||||
QUANTO = "quanto"
|
||||
HIGGS = "higgs"
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
@@ -725,62 +724,3 @@ class QuantoConfig(QuantizationConfigMixin):
|
||||
accepted_weights = ["float8", "int8", "int4", "int2"]
|
||||
if self.weights_dtype not in accepted_weights:
|
||||
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class HiggsConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
HiggsConfig is a configuration class for quantization using the HIGGS method.
|
||||
|
||||
Args:
|
||||
bits (int, *optional*, defaults to 4):
|
||||
Number of bits to use for quantization. Can be 2, 3 or 4. Default is 4.
|
||||
p (int, *optional*, defaults to 2):
|
||||
Quantization grid dimension. 1 and 2 are supported. 2 is always better in practice. Default is 2.
|
||||
modules_to_not_convert (`list`, *optional*, default to ["lm_head"]):
|
||||
List of linear layers that should not be quantized.
|
||||
hadamard_size (int, *optional*, defaults to 512):
|
||||
Hadamard size for the HIGGS method. Default is 512. Input dimension of matrices is padded to this value.
|
||||
Decreasing this below 512 will reduce the quality of the quantization.
|
||||
group_size (int, *optional*, defaults to 256):
|
||||
Group size for the HIGGS method. Can be 64, 128 or 256. Decreasing it barely affects the performance.
|
||||
Default is 256. Must be a divisor of hadamard_size.
|
||||
tune_metadata ('dict', *optional*, defaults to {}):
|
||||
Module-wise metadata (gemm block shapes, GPU metadata, etc.) for saving the kernel tuning results. Default
|
||||
is an empty dictionary. Is set automatically during tuning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bits: int = 4,
|
||||
p: int = 2,
|
||||
modules_to_not_convert: Optional[list[str]] = None,
|
||||
hadamard_size: int = 512,
|
||||
group_size: int = 256,
|
||||
tune_metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if tune_metadata is None:
|
||||
tune_metadata = {}
|
||||
self.quant_method = QuantizationMethod.HIGGS
|
||||
self.bits = bits
|
||||
self.p = p
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
self.hadamard_size = hadamard_size
|
||||
self.group_size = group_size
|
||||
self.tune_metadata = tune_metadata
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
||||
"""
|
||||
if self.bits not in [2, 3, 4]:
|
||||
raise ValueError("bits must be 2, 3, or 4")
|
||||
if self.p not in [1, 2]:
|
||||
raise ValueError("p must be 1 or 2. 2 is always better in practice")
|
||||
if self.group_size not in [64, 128, 256]:
|
||||
raise ValueError("group_size must be 64, 128, or 256")
|
||||
if self.hadamard_size % self.group_size != 0:
|
||||
raise ValueError("hadamard_size must be divisible by group_size")
|
||||
|
||||
@@ -82,6 +82,7 @@ from .import_utils import (
|
||||
is_k_diffusion_available,
|
||||
is_k_diffusion_version,
|
||||
is_kernels_available,
|
||||
is_kornia_available,
|
||||
is_librosa_available,
|
||||
is_matplotlib_available,
|
||||
is_nltk_available,
|
||||
|
||||
@@ -62,6 +62,21 @@ class ClassifierFreeZeroStarGuidance(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FrequencyDecoupledGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PerturbedAttentionGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -224,6 +224,7 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("
|
||||
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
|
||||
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
|
||||
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
|
||||
_kornia_available, _kornia_version = _is_package_available("kornia")
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -398,6 +399,10 @@ def is_flash_attn_3_available():
|
||||
return _flash_attn_3_available
|
||||
|
||||
|
||||
def is_kornia_available():
|
||||
return _kornia_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
|
||||
@@ -197,20 +197,6 @@ def get_peft_kwargs(
|
||||
"lora_bias": lora_bias,
|
||||
}
|
||||
|
||||
# Example: try load FusionX LoRA into Wan VACE
|
||||
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
|
||||
if exclude_modules:
|
||||
if not is_peft_version(">=", "0.14.0"):
|
||||
msg = """
|
||||
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
|
||||
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
|
||||
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
|
||||
https://github.com/huggingface/diffusers/issues/new
|
||||
"""
|
||||
logger.debug(msg)
|
||||
else:
|
||||
lora_config_kwargs.update({"exclude_modules": exclude_modules})
|
||||
|
||||
return lora_config_kwargs
|
||||
|
||||
|
||||
@@ -388,27 +374,3 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
|
||||
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
|
||||
"""
|
||||
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
|
||||
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
|
||||
doesn't exist in `peft_state_dict`.
|
||||
"""
|
||||
if model_state_dict is None:
|
||||
return
|
||||
all_modules = set()
|
||||
string_to_replace = f"{adapter_name}." if adapter_name else ""
|
||||
|
||||
for name in model_state_dict.keys():
|
||||
if string_to_replace:
|
||||
name = name.replace(string_to_replace, "")
|
||||
if "." in name:
|
||||
module_name = name.rsplit(".", 1)[0]
|
||||
all_modules.add(module_name)
|
||||
|
||||
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
|
||||
exclude_modules = list(all_modules - target_modules_set)
|
||||
|
||||
return exclude_modules
|
||||
|
||||
+4
-68
@@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
@@ -292,20 +291,6 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
return modules_to_save
|
||||
|
||||
def _get_exclude_modules(self, pipe):
|
||||
from diffusers.utils.peft_utils import _derive_exclude_modules
|
||||
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
|
||||
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
|
||||
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
|
||||
pipe.unload_lora_weights()
|
||||
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
|
||||
exclude_modules = _derive_exclude_modules(
|
||||
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
|
||||
)
|
||||
return exclude_modules
|
||||
|
||||
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
|
||||
if text_lora_config is not None:
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -2342,58 +2327,6 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_exclude_modules(self):
|
||||
"""
|
||||
Test to check if `exclude_modules` works or not. It works in the following way:
|
||||
we first create a pipeline and insert LoRA config into it. We then derive a `set`
|
||||
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
|
||||
state dict.
|
||||
|
||||
We then create a new LoRA config to include the `exclude_modules` and perform tests.
|
||||
"""
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
# only supported for `denoiser` now
|
||||
pipe_cp = copy.deepcopy(pipe)
|
||||
pipe_cp, _ = self.add_adapters_to_pipeline(
|
||||
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
|
||||
pipe_cp.to("cpu")
|
||||
del pipe_cp
|
||||
|
||||
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should change outputs.",
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Lora outputs should match.",
|
||||
)
|
||||
|
||||
def test_inference_load_delete_load_adapters(self):
|
||||
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
@@ -2467,7 +2400,6 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
|
||||
@@ -2483,6 +2415,10 @@ class PeftLoraLoaderMixinTests:
|
||||
num_blocks_per_group=1,
|
||||
use_stream=use_stream,
|
||||
)
|
||||
# Place other model-level components on `torch_device`.
|
||||
for _, component in pipe.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
component.to(torch_device)
|
||||
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
|
||||
self.assertTrue(group_offload_hook_1 is not None)
|
||||
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, is_peft_available, torch_device
|
||||
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
@@ -172,6 +172,35 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
expected_set = {"FluxTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
# The test exists for cases like
|
||||
# https://github.com/huggingface/diffusers/issues/11874
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_exclude_modules(self):
|
||||
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
lora_rank = 4
|
||||
target_module = "single_transformer_blocks.0.proj_out"
|
||||
adapter_name = "foo"
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
target_mod_shape = state_dict[f"{target_module}.weight"].shape
|
||||
lora_state_dict = {
|
||||
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
|
||||
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
|
||||
}
|
||||
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
|
||||
config = LoraConfig(
|
||||
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
|
||||
)
|
||||
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
|
||||
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
|
||||
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
|
||||
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
|
||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
|
||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
|
||||
|
||||
|
||||
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
|
||||
+462
@@ -0,0 +1,462 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import (
|
||||
ClassifierFreeGuidance,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
)
|
||||
from diffusers.loaders import ModularIPAdapterMixin
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...models.unets.test_models_unet_2d_condition import (
|
||||
create_ip_adapter_state_dict,
|
||||
)
|
||||
from ..test_modular_pipelines_common import (
|
||||
ModularPipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SDXLModularTests:
|
||||
"""
|
||||
This mixin defines method to create pipeline, base input and base test across all SDXL modular tests.
|
||||
"""
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"negative_prompt",
|
||||
"cross_attention_kwargs",
|
||||
"image",
|
||||
"mask_image",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
|
||||
pipeline.load_default_components(torch_dtype=torch_dtype)
|
||||
return pipeline
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
sd_pipe = self.get_pipeline()
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs, output="images")
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == expected_image_shape
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, (
|
||||
"Image Slice does not match expected slice"
|
||||
)
|
||||
|
||||
|
||||
class SDXLModularIPAdapterTests:
|
||||
"""
|
||||
This mixin is designed to test IP Adapter.
|
||||
"""
|
||||
|
||||
def test_pipeline_inputs_and_blocks(self):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
parameters = blocks.input_names
|
||||
|
||||
assert issubclass(self.pipeline_class, ModularIPAdapterMixin)
|
||||
assert "ip_adapter_image" in parameters, (
|
||||
"`ip_adapter_image` argument must be supported by the `__call__` method"
|
||||
)
|
||||
assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block"
|
||||
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
parameters = blocks.input_names
|
||||
assert "ip_adapter_image" not in parameters, (
|
||||
"`ip_adapter_image` argument must be removed from the `__call__` method"
|
||||
)
|
||||
|
||||
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_masks(self, input_size: int = 64):
|
||||
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
|
||||
_masks[0, :, :, : int(input_size / 2)] = 1
|
||||
return _masks
|
||||
|
||||
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
parameters = blocks.input_names
|
||||
if "image" in parameters and "strength" in parameters:
|
||||
inputs["num_inference_steps"] = 4
|
||||
|
||||
inputs["output_type"] = "np"
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
r"""Tests for IP-Adapter.
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
pipe = blocks.init_pipeline(self.repo)
|
||||
pipe.load_default_components(torch_dtype=torch.float32)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
if expected_pipe_slice is None:
|
||||
output_without_adapter = pipe(**inputs, output="images")
|
||||
else:
|
||||
output_without_adapter = expected_pipe_slice
|
||||
|
||||
# 1. Single IP-Adapter test cases
|
||||
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
|
||||
|
||||
# forward pass with single ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(0.0)
|
||||
output_without_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(42.0)
|
||||
output_with_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
|
||||
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
|
||||
|
||||
assert max_diff_without_adapter_scale < expected_max_diff, (
|
||||
"Output without ip-adapter must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference"
|
||||
|
||||
# 2. Multi IP-Adapter test cases
|
||||
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
|
||||
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
|
||||
|
||||
# forward pass with multi ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([0.0, 0.0])
|
||||
output_without_multi_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with multi ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([42.0, 42.0])
|
||||
output_with_multi_adapter_scale = pipe(**inputs, output="images")
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_multi_adapter_scale = np.abs(
|
||||
output_without_multi_adapter_scale - output_without_adapter
|
||||
).max()
|
||||
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
|
||||
assert max_diff_without_multi_adapter_scale < expected_max_diff, (
|
||||
"Output without multi-ip-adapter must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_multi_adapter_scale > 1e-2, (
|
||||
"Output with multi-ip-adapter scale must be different from normal inference"
|
||||
)
|
||||
|
||||
|
||||
class SDXLModularControlNetTests:
|
||||
"""
|
||||
This mixin is designed to test ControlNet.
|
||||
"""
|
||||
|
||||
def test_pipeline_inputs(self):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
parameters = blocks.input_names
|
||||
|
||||
assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method"
|
||||
assert "controlnet_conditioning_scale" in parameters, (
|
||||
"`controlnet_conditioning_scale` argument must be supported by the `__call__` method"
|
||||
)
|
||||
|
||||
def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]):
|
||||
controlnet_embedder_scale_factor = 2
|
||||
image = torch.randn(
|
||||
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
|
||||
device=torch_device,
|
||||
)
|
||||
inputs["control_image"] = image
|
||||
return inputs
|
||||
|
||||
def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
r"""Tests for ControlNet.
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single ControlNet with scale=0 should produce same output as no ControlNet.
|
||||
- Single ControlNet with scale!=0 should produce different output compared to no ControlNet.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass without controlnet
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_without_controlnet = pipe(**inputs, output="images")
|
||||
output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single controlnet, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["controlnet_conditioning_scale"] = 0.0
|
||||
output_without_controlnet_scale = pipe(**inputs, output="images")
|
||||
output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single controlnet, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["controlnet_conditioning_scale"] = 42.0
|
||||
output_with_controlnet_scale = pipe(**inputs, output="images")
|
||||
output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max()
|
||||
max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max()
|
||||
|
||||
assert max_diff_without_controlnet_scale < expected_max_diff, (
|
||||
"Output without controlnet must be same as normal inference"
|
||||
)
|
||||
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
|
||||
|
||||
def test_controlnet_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = np.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class SDXLModularGuiderTests:
|
||||
def test_guider_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = np.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class SDXLModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.5966781,
|
||||
0.62939394,
|
||||
0.48465094,
|
||||
0.51573336,
|
||||
0.57593524,
|
||||
0.47035995,
|
||||
0.53410417,
|
||||
0.51436996,
|
||||
0.47313565,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
|
||||
class SDXLImg2ImgModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
inputs["image"] = image
|
||||
inputs["strength"] = 0.8
|
||||
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.56943184,
|
||||
0.4702148,
|
||||
0.48048905,
|
||||
0.6235963,
|
||||
0.551138,
|
||||
0.49629188,
|
||||
0.60031277,
|
||||
0.5688907,
|
||||
0.43996853,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
|
||||
class SDXLInpaintingModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
# create mask
|
||||
image[8:, 8:, :] = 255
|
||||
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
|
||||
|
||||
inputs["image"] = init_image
|
||||
inputs["mask_image"] = mask_image
|
||||
inputs["strength"] = 1.0
|
||||
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.40872607,
|
||||
0.38842705,
|
||||
0.34893104,
|
||||
0.47837183,
|
||||
0.43792963,
|
||||
0.5332134,
|
||||
0.3716843,
|
||||
0.47274873,
|
||||
0.45000193,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
@@ -0,0 +1,358 @@
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
require_torch,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.detach().cpu().numpy()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with unittest.TestCase classes.
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
including:
|
||||
- test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
|
||||
- test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
|
||||
- test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
|
||||
- test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
|
||||
- test_to_device: check if the pipeline's __call__ method can handle different devices
|
||||
"""
|
||||
|
||||
# Canonical parameters that are passed to `__call__` regardless
|
||||
# of the type of pipeline. They are always optional and have common
|
||||
# sense default values.
|
||||
optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"num_images_per_prompt",
|
||||
"latents",
|
||||
"output_type",
|
||||
]
|
||||
)
|
||||
# this is modular specific: generator needs to be a intermediate input because it's mutable
|
||||
intermediate_params = frozenset(
|
||||
[
|
||||
"generator",
|
||||
]
|
||||
)
|
||||
|
||||
def get_generator(self, seed):
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device).manual_seed(seed)
|
||||
return generator
|
||||
|
||||
@property
|
||||
def pipeline_class(self) -> Union[Callable, ModularPipeline]:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def repo(self) -> str:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_pipeline(self):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_pipeline(self)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `params` in the child test class. "
|
||||
"`params` are checked for if all values are present in `__call__`'s signature."
|
||||
" You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
|
||||
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
|
||||
"image pipelines, including prompts and prompt embedding overrides."
|
||||
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
|
||||
"do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
|
||||
"with non-configurable height and width arguments should set the attribute as "
|
||||
"`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `batch_params` in the child test class. "
|
||||
"`batch_params` are the parameters required to be batched when passed to the pipeline's "
|
||||
"`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
|
||||
"`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
|
||||
"set of batch arguments has minor changes from one of the common sets of batch arguments, "
|
||||
"do not make modifications to the existing common sets of batch arguments. I.e. a text to "
|
||||
"image pipeline `negative_prompt` is not batched should set the attribute as "
|
||||
"`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test in case of CUDA runtime errors
|
||||
super().tearDown()
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_pipeline_call_signature(self):
|
||||
pipe = self.get_pipeline()
|
||||
input_parameters = pipe.blocks.input_names
|
||||
optional_parameters = pipe.default_call_parameters
|
||||
|
||||
def _check_for_parameters(parameters, expected_parameters, param_type):
|
||||
remaining_parameters = {param for param in parameters if param not in expected_parameters}
|
||||
assert len(remaining_parameters) == 0, (
|
||||
f"Required {param_type} parameters not present: {remaining_parameters}"
|
||||
)
|
||||
|
||||
_check_for_parameters(self.params, input_parameters, "input")
|
||||
_check_for_parameters(self.optional_params, optional_parameters, "optional")
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# prepare batched inputs
|
||||
batched_inputs = []
|
||||
for batch_size in batch_sizes:
|
||||
batched_input = {}
|
||||
batched_input.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
batched_input[name] = batch_size * [value]
|
||||
|
||||
if batch_generator and "generator" in inputs:
|
||||
batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_input["batch_size"] = batch_size
|
||||
|
||||
batched_inputs.append(batched_input)
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
|
||||
output = pipe(**batched_input, output="images")
|
||||
assert len(output) == batch_size, "Output is different from expected batch size"
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
self,
|
||||
batch_size=2,
|
||||
expected_max_diff=1e-4,
|
||||
):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batched_inputs.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
if "generator" in inputs:
|
||||
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_inputs["batch_size"] = batch_size
|
||||
|
||||
output = pipe(**inputs, output="images")
|
||||
output_batch = pipe(**batched_inputs, output="images")
|
||||
|
||||
assert output_batch.shape[0] == batch_size
|
||||
|
||||
max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
|
||||
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
|
||||
|
||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||
@require_accelerator
|
||||
def test_float16_inference(self, expected_max_diff=5e-2):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device, torch.float32)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe_fp16 = self.get_pipeline()
|
||||
pipe_fp16.to(torch_device, torch.float16)
|
||||
pipe_fp16.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in inputs:
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
output = pipe(**inputs, output="images")
|
||||
|
||||
fp16_inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in fp16_inputs:
|
||||
fp16_inputs["generator"] = self.get_generator(0)
|
||||
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
|
||||
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = output.cpu()
|
||||
output_fp16 = output_fp16.cpu()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
|
||||
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
|
||||
|
||||
@require_accelerator
|
||||
def test_to_device(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU"
|
||||
|
||||
pipe.to(torch_device)
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
assert all(device == torch_device for device in model_devices), (
|
||||
"All pipeline components are not on accelerator device"
|
||||
)
|
||||
|
||||
def test_inference_is_not_nan_cpu(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to("cpu")
|
||||
|
||||
output = pipe(**self.get_dummy_inputs("cpu"), output="images")
|
||||
assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN"
|
||||
|
||||
@require_accelerator
|
||||
def test_inference_is_not_nan(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(torch_device), output="images")
|
||||
assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN"
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
pipe = self.get_pipeline()
|
||||
|
||||
if "num_images_per_prompt" not in pipe.blocks.input_names:
|
||||
return
|
||||
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
@require_accelerator
|
||||
def test_components_auto_cpu_offload_inference_consistent(self):
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
cm = ComponentsManager()
|
||||
cm.enable_auto_cpu_offload(device=torch_device)
|
||||
offload_pipe = self.get_pipeline(components_manager=cm)
|
||||
|
||||
image_slices = []
|
||||
for pipe in [base_pipe, offload_pipe]:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_save_from_pretrained(self):
|
||||
pipes = []
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(base_pipe)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
base_pipe.save_pretrained(tmpdirname)
|
||||
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
pipe.load_default_components(torch_dtype=torch.float32)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipes.append(pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
@@ -20,12 +20,6 @@ TEXT_TO_IMAGE_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
|
||||
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
|
||||
|
||||
IMAGE_VARIATION_PARAMS = frozenset(
|
||||
[
|
||||
"image",
|
||||
@@ -35,8 +29,6 @@ IMAGE_VARIATION_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -50,8 +42,6 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
# Text guided image variation with an image mask
|
||||
@@ -67,8 +57,6 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
|
||||
|
||||
IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
# image variation with an image mask
|
||||
@@ -80,8 +68,6 @@ IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
[
|
||||
"example_image",
|
||||
@@ -93,20 +79,12 @@ IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
|
||||
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])
|
||||
|
||||
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
TEXT_TO_AUDIO_PARAMS = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -119,11 +97,38 @@ TEXT_TO_AUDIO_PARAMS = frozenset(
|
||||
]
|
||||
)
|
||||
|
||||
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
|
||||
|
||||
# image params
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
|
||||
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
|
||||
|
||||
|
||||
# batch params
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
|
||||
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
|
||||
|
||||
IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
|
||||
|
||||
IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
|
||||
|
||||
UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
|
||||
|
||||
TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
|
||||
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
|
||||
|
||||
VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])
|
||||
|
||||
# callback params
|
||||
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
|
||||
|
||||
@@ -886,6 +886,7 @@ class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
@require_bitsandbytes_version_greater("0.46.1")
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super().test_torch_compile()
|
||||
|
||||
@@ -847,6 +847,10 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
|
||||
" Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
|
||||
)
|
||||
def test_torch_compile(self):
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
super()._test_torch_compile(torch_dtype=torch.float16)
|
||||
|
||||
@@ -212,6 +212,7 @@ class GGUFSingleFileTesterMixin:
|
||||
|
||||
class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
||||
diffusers_ckpt_path = "https://huggingface.co/sayakpaul/flux-diffusers-gguf/blob/main/model-Q4_0.gguf"
|
||||
torch_dtype = torch.bfloat16
|
||||
model_cls = FluxTransformer2DModel
|
||||
expected_memory_use_in_gb = 5
|
||||
@@ -296,6 +297,16 @@ class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_loading_gguf_diffusers_format(self):
|
||||
model = self.model_cls.from_single_file(
|
||||
self.diffusers_ckpt_path,
|
||||
subfolder="transformer",
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
config="black-forest-labs/FLUX.1-dev",
|
||||
)
|
||||
model.to("cuda")
|
||||
model(**self.get_dummy_inputs())
|
||||
|
||||
|
||||
class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"
|
||||
|
||||
@@ -56,12 +56,18 @@ class QuantCompileTests:
|
||||
pipe.transformer.compile(fullgraph=True)
|
||||
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.transformer.compile()
|
||||
# regional compilation is better for offloading.
|
||||
# see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/
|
||||
if getattr(pipe.transformer, "_repeated_blocks"):
|
||||
pipe.transformer.compile_repeated_blocks(fullgraph=True)
|
||||
else:
|
||||
pipe.transformer.compile()
|
||||
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
Reference in New Issue
Block a user