Compare commits

..

31 Commits

Author SHA1 Message Date
sayakpaul 8d8621ec72 Merge branch 'main' into folderize-loaders 2025-04-29 00:28:06 +08:00
sayakpaul 0c37895440 consistency 2025-04-29 00:26:28 +08:00
sayakpaul 9bebdf225d fix repo consistency. 2025-04-29 00:11:52 +08:00
sayakpaul c05114d5ec resolve conflicts. 2025-04-29 00:04:19 +08:00
sayakpaul a57a5ab4c0 resolve conflicts. 2025-04-29 00:02:38 +08:00
sayakpaul 4b1c7dc81a resolve conflicts. 2025-04-17 09:27:54 +05:30
Sayak Paul 1590325a60 Merge branch 'main' into folderize-loaders 2025-04-16 18:41:54 +05:30
sayakpaul e4dd7c5333 updates 2025-04-16 18:26:27 +05:30
sayakpaul d6430c79a3 updates 2025-04-16 18:11:39 +05:30
sayakpaul 1597ae6ac9 updates 2025-04-16 17:37:17 +05:30
sayakpaul 11a23d11fe updates 2025-04-16 17:29:26 +05:30
sayakpaul 6b8b225aca single file utils. 2025-04-16 17:26:30 +05:30
sayakpaul 27d2401e59 partially complete single_file_utils 2025-04-16 16:52:54 +05:30
sayakpaul 1ddfe14220 single_file 2025-04-16 16:30:39 +05:30
sayakpaul 0e8d1d25eb ip_adapter 2025-04-16 15:59:01 +05:30
sayakpaul 546446ae21 ip_adapter. 2025-04-16 15:52:57 +05:30
sayakpaul ea3f0b8d68 update 2025-04-16 15:49:14 +05:30
sayakpaul f0ea9ff2e2 deprecate lora loader from loaders easily. 2025-04-16 15:47:40 +05:30
sayakpaul 1b7c286974 fix 2025-04-16 15:09:23 +05:30
sayakpaul 6138cc1720 updates 2025-04-16 13:01:48 +05:30
sayakpaul ea0ce4bfab fixes 2025-04-16 12:50:09 +05:30
sayakpaul f2aa2f91dc fix 2025-04-16 12:46:04 +05:30
sayakpaul 4faac73219 update 2025-04-16 12:38:58 +05:30
sayakpaul d870e3c9a6 update 2025-04-16 12:35:09 +05:30
sayakpaul 178b884673 updates 2025-04-16 12:29:16 +05:30
sayakpaul 2da3cb4a8c fixes 2025-04-16 12:27:37 +05:30
sayakpaul ea3ba4f431 fies 2025-04-16 12:26:30 +05:30
sayakpaul 21b2566933 fixes 2025-04-16 12:23:16 +05:30
sayakpaul a71334b861 fixes 2025-04-16 12:22:12 +05:30
sayakpaul eb47a67d50 fix 2025-04-16 12:18:30 +05:30
sayakpaul 8267677a24 start folderizing the loaders. 2025-04-16 12:02:06 +05:30
108 changed files with 12889 additions and 25376 deletions
+2 -2
View File
@@ -22,11 +22,11 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
## IPAdapterMixin
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
[[autodoc]] loaders.ip_adapter.ip_adapter.IPAdapterMixin
## SD3IPAdapterMixin
[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin
[[autodoc]] loaders.ip_adapter.ip_adapter.SD3IPAdapterMixin
- all
- is_ip_adapter_active
+26 -18
View File
@@ -39,58 +39,66 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
## StableDiffusionLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin
## StableDiffusionXLLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin
## SD3LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.SD3LoraLoaderMixin
## FluxLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.FluxLoraLoaderMixin
## CogVideoXLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin
## Mochi1LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.Mochi1LoraLoaderMixin
## AuraFlowLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.AuraFlowLoraLoaderMixin
## LTXVideoLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.LTXVideoLoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.LTXVideoLoraLoaderMixin
## SanaLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.SanaLoraLoaderMixin
## HunyuanVideoLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.HunyuanVideoLoraLoaderMixin
## Lumina2LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.Lumina2LoraLoaderMixin
## CogView4LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.CogView4LoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.Lumina2LoraLoaderMixin
## WanLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.WanLoraLoaderMixin
## CogView4LoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.CogView4LoraLoaderMixin
## CogView4LoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.CogView4LoraLoaderMixin
## WanLoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.WanLoraLoaderMixin
## AmusedLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
[[autodoc]] loaders.lora.lora_pipeline.AmusedLoraLoaderMixin
## HiDreamImageLoraLoaderMixin
@@ -98,4 +106,4 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
## LoraBaseMixin
[[autodoc]] loaders.lora_base.LoraBaseMixin
[[autodoc]] loaders.lora.lora_base.LoraBaseMixin
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# SD3Transformer2D
This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and [SD3Transformer2DModel], check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
@@ -24,6 +24,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
## SD3Transformer2DLoadersMixin
[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin
[[autodoc]] loaders.ip_adapter.transformer_sd3.SD3Transformer2DLoadersMixin
- all
- _load_ip_adapter_weights
+1 -48
View File
@@ -86,7 +86,6 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
| Perturbed-Attention Guidance |StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).|[Perturbed-Attention Guidance](#perturbed-attention-guidance)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/perturbed_attention_guidance.ipynb)|[Hyoungwon Cho](https://github.com/HyoungwonCho)|
| CogVideoX DDIM Inversion Pipeline | Implementation of DDIM inversion and guided attention-based editing denoising process on CogVideoX. | [CogVideoX DDIM Inversion Pipeline](#cogvideox-ddim-inversion-pipeline) | - | [LittleNyima](https://github.com/LittleNyima) |
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://arxiv.org/abs/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
@@ -5433,50 +5432,4 @@ cropped_image = gen_image.crop((0, 0, width_init, height_init))
cropped_image.save("data/result.png")
````
### Result
[<img src="https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/faithdiff_restored.PNG" width="512px" height="512px"/>](https://imgsli.com/MzY1NzE2)
# Stable Diffusion 3 InstructPix2Pix Pipeline
This the implementation of the Stable Diffusion 3 InstructPix2Pix Pipeline, based on the HuggingFace Diffusers.
## Example Usage
This pipeline aims to edit image based on user's instruction by using SD3
````py
import torch
from diffusers import SD3Transformer2DModel
from diffusers import DiffusionPipeline
from diffusers.utils import load_image
resolution = 512
image = load_image("https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png").resize(
(resolution, resolution)
)
edit_instruction = "Turn sky into a sunny one"
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", custom_pipeline="pipeline_stable_diffusion_3_instruct_pix2pix", torch_dtype=torch.float16).to('cuda')
pipe.transformer = SD3Transformer2DModel.from_pretrained("CaptainZZZ/sd3-instructpix2pix",torch_dtype=torch.float16).to('cuda')
edited_image = pipe(
prompt=edit_instruction,
image=image,
height=resolution,
width=resolution,
guidance_scale=7.5,
image_guidance_scale=1.5,
num_inference_steps=30,
).images[0]
edited_image.save("edited_image.png")
````
|Original|Edited|
|---|---|
|![Original image](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/StableDiffusion3InstructPix2Pix/mountain.png)|![Edited image](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/StableDiffusion3InstructPix2Pix/edited.png)
### Note
This model is trained on 512x512, so input size is better on 512x512.
For better editing performance, please refer to this powerful model https://huggingface.co/BleachNick/SD3_UltraEdit_freeform and Paper "UltraEdit: Instruction-based Fine-Grained Image
Editing at Scale", many thanks to their contribution!
[<img src="https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/faithdiff_restored.PNG" width="512px" height="512px"/>](https://imgsli.com/MzY1NzE2)
File diff suppressed because it is too large Load Diff
+2 -15
View File
@@ -639,15 +639,6 @@ def parse_args(input_args=None):
action="store_true",
help="Enable model cpu offload and save memory.",
)
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -745,13 +736,9 @@ def get_train_dataset(args, accelerator):
def prepare_train_dataset(dataset, accelerator):
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=interpolation),
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
@@ -760,7 +747,7 @@ def prepare_train_dataset(dataset, accelerator):
conditioning_image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=interpolation),
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
+3 -41
View File
@@ -134,25 +134,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
try:
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
except (AttributeError, KeyError):
supported_interpolation_modes = [
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
]
raise ValueError(
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
)
transform = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=interpolation),
transforms.CenterCrop(args.resolution),
]
)
validation_image = transform(validation_image)
validation_image = validation_image.resize((args.resolution, args.resolution))
images = []
@@ -605,15 +587,6 @@ def parse_args(input_args=None):
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -759,20 +732,9 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom
def prepare_train_dataset(dataset, accelerator):
try:
interpolation_mode = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
except (AttributeError, KeyError):
supported_interpolation_modes = [
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
]
raise ValueError(
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
)
image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=interpolation_mode),
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
@@ -781,7 +743,7 @@ def prepare_train_dataset(dataset, accelerator):
conditioning_image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=interpolation_mode),
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
]
+1 -72
View File
@@ -34,12 +34,10 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
"guiders": [],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"modular_pipelines": [],
"quantizers.quantization_config": [],
"schedulers": [],
"utils": [
@@ -132,26 +130,12 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"AutoGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"SkipLayerGuidance",
"SmoothedEnergyGuidance",
"TangentialClassifierFreeGuidance",
]
)
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"LayerSkipConfig",
"SmoothedEnergyGuidanceConfig",
"apply_faster_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
]
)
@@ -261,15 +245,6 @@ else:
"StableDiffusionMixin",
]
)
_import_structure["modular_pipelines"].extend(
[
"ModularLoader",
"ModularPipeline",
"ModularPipelineBlocks",
"ComponentSpec",
"ComponentsManager",
]
)
_import_structure["quantizers"] = ["DiffusersQuantizer"]
_import_structure["schedulers"].extend(
[
@@ -548,24 +523,6 @@ else:
]
)
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torch_and_transformers_objects # noqa F403
_import_structure["utils.dummy_torch_and_transformers_objects"] = [
name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
]
else:
_import_structure["modular_pipelines"].extend(
[
"StableDiffusionXLAutoPipeline",
"StableDiffusionXLModularLoader",
]
)
try:
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
raise OptionalDependencyNotAvailable()
@@ -771,22 +728,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .guiders import (
AdaptiveProjectedGuidance,
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
TangentialClassifierFreeGuidance,
)
from .hooks import (
FasterCacheConfig,
HookRegistry,
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
apply_layer_skip,
apply_faster_cache,
apply_pyramid_attention_broadcast,
)
@@ -894,13 +839,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ScoreSdeVePipeline,
StableDiffusionMixin,
)
from .modular_pipelines import (
ModularLoader,
ModularPipeline,
ModularPipelineBlocks,
ComponentSpec,
ComponentsManager,
)
from .quantizers import DiffusersQuantizer
from .schedulers import (
AmusedScheduler,
@@ -1156,16 +1094,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
)
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_pipelines import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLModularLoader,
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
-29
View File
@@ -1,29 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from ..utils import is_torch_available
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .auto_guidance import AutoGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance]
@@ -1,184 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class AdaptiveProjectedGuidance(BaseGuidance):
"""
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum: Optional[float] = None,
adaptive_projected_guidance_rescale: float = 15.0,
eta: float = 1.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_apg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_apg_enabled():
num_conditions += 1
return num_conditions
def _is_apg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred
-177
View File
@@ -1,177 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
import torch
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class AutoGuidance(BaseGuidance):
"""
AutoGuidance: https://huggingface.co/papers/2406.02507
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
auto_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided.
auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
dropout (`float`, *optional*):
The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or
`auto_guidance_config`). If not provided, the dropout probability will be set to 1.0.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
dropout: Optional[float] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.auto_guidance_layers = auto_guidance_layers
self.auto_guidance_config = auto_guidance_config
self.dropout = dropout
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if auto_guidance_layers is None and auto_guidance_config is None:
raise ValueError(
"Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance."
)
if auto_guidance_layers is not None and auto_guidance_config is not None:
raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
if (dropout is None and auto_guidance_layers is not None) or (dropout is not None and auto_guidance_layers is None):
raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
if auto_guidance_layers is not None:
if isinstance(auto_guidance_layers, int):
auto_guidance_layers = [auto_guidance_layers]
if not isinstance(auto_guidance_layers, list):
raise ValueError(
f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}."
)
auto_guidance_config = [LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers]
if isinstance(auto_guidance_config, LayerSkipConfig):
auto_guidance_config = [auto_guidance_config]
if not isinstance(auto_guidance_config, list):
raise ValueError(
f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
)
self.auto_guidance_config = auto_guidance_config
self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
def prepare_models(self, denoiser: torch.nn.Module) -> None:
self._count_prepared += 1
if self._is_ag_enabled() and self.is_unconditional:
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_ag_enabled() and self.is_unconditional:
for name in self._auto_guidance_hook_names:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_ag_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_ag_enabled():
num_conditions += 1
return num_conditions
def _is_ag_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
@@ -1,132 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class ClassifierFreeGuidance(BaseGuidance):
"""
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity.
The original paper proposes scaling and shifting the conditional distribution based on the difference between
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, stop: float = 1.0
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
@@ -1,148 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class ClassifierFreeZeroStarGuidance(BaseGuidance):
"""
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
quality of generated images.
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
zero_init_steps (`int`, defaults to `1`):
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
zero_init_steps: int = 1,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond)
elif not self._is_cfg_enabled():
pred = pred_cond
else:
pred_cond_flat = pred_cond.flatten(1)
pred_uncond_flat = pred_uncond.flatten(1)
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
pred_uncond = pred_uncond * alpha
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
cond_dtype = cond.dtype
cond = cond.float()
uncond = uncond.float()
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
scale = dot_product / squared_norm
return scale.to(dtype=cond_dtype)
-215
View File
@@ -1,215 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
import torch
from ..utils import get_logger
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
logger = get_logger(__name__) # pylint: disable=invalid-name
class BaseGuidance:
r"""Base class providing the skeleton for implementing guidance techniques."""
_input_predictions = None
_identifier_key = "__guidance_identifier__"
def __init__(self, start: float = 0.0, stop: float = 1.0):
self._start = start
self._stop = stop
self._step: int = None
self._num_inference_steps: int = None
self._timestep: torch.LongTensor = None
self._count_prepared = 0
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
self._enabled = True
if not (0.0 <= start < 1.0):
raise ValueError(
f"Expected `start` to be between 0.0 and 1.0, but got {start}."
)
if not (start <= stop <= 1.0):
raise ValueError(
f"Expected `stop` to be between {start} and 1.0, but got {stop}."
)
if self._input_predictions is None or not isinstance(self._input_predictions, list):
raise ValueError(
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
def disable(self):
self._enabled = False
def enable(self):
self._enabled = True
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
self._step = step
self._num_inference_steps = num_inference_steps
self._timestep = timestep
self._count_prepared = 0
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
"""
Set the input fields for the guidance technique. The input fields are used to specify the names of the
returned attributes containing the prepared data after `prepare_inputs` is called. The prepared data is
obtained from the values of the provided keyword arguments to this method.
Args:
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once
it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2,
which is used to look up the required data provided for preparation.
If a string is provided, it will be used as the conditional data (or unconditional if used with
a guidance method that requires it). If a tuple of length 2 is provided, the first element must
be the conditional data identifier and the second element must be the unconditional data identifier
or None.
Example:
```
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
BaseGuidance.set_input_fields(
latents="latents",
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
)
```
"""
for key, value in kwargs.items():
is_string = isinstance(value, str)
is_tuple_of_str_with_len_2 = isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
if not (is_string or is_tuple_of_str_with_len_2):
raise ValueError(
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
)
self._input_fields = kwargs
def prepare_models(self, denoiser: torch.nn.Module) -> None:
"""
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
subclasses to implement specific model preparation logic.
"""
self._count_prepared += 1
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
"""
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in
subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
modifications made during `prepare_models`.
"""
pass
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
def __call__(self, data: List["BlockState"]) -> Any:
if not all(hasattr(d, "noise_pred") for d in data):
raise ValueError("Expected all data to have `noise_pred` attribute.")
if len(data) != self.num_conditions:
raise ValueError(
f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
)
forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
return self.forward(**forward_inputs)
def forward(self, *args, **kwargs) -> Any:
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
@property
def is_conditional(self) -> bool:
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
@property
def is_unconditional(self) -> bool:
return not self.is_conditional
@property
def num_conditions(self) -> int:
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
@classmethod
def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState":
"""
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of
the `BaseGuidance` class. It prepares the batch based on the provided tuple index.
Args:
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once
it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2,
which is used to look up the required data provided for preparation.
If a string is provided, it will be used as the conditional data (or unconditional if used with
a guidance method that requires it). If a tuple of length 2 is provided, the first element must
be the conditional data identifier and the second element must be the unconditional data identifier
or None.
data (`BlockState`):
The input data to be prepared.
tuple_index (`int`):
The index to use when accessing input fields that are tuples.
Returns:
`BlockState`: The prepared batch of data.
"""
from ..modular_pipelines.modular_pipeline import BlockState
if input_fields is None:
raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.")
data_batch = {}
for key, value in input_fields.items():
try:
if isinstance(value, str):
data_batch[key] = getattr(data, value)
elif isinstance(value, tuple):
data_batch[key] = getattr(data, value[tuple_index])
else:
# We've already checked that value is a string or a tuple of strings with length 2
pass
except AttributeError:
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
@@ -1,251 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
import torch
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class SkipLayerGuidance(BaseGuidance):
"""
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
batch of data, apart from the conditional and unconditional batches already used in CFG
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
based on the difference between conditional without skipping and conditional with skipping predictions.
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
version of the model for the conditional prediction).
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
generation quality in video diffusion models.
Additional reading:
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
skip_layer_guidance_scale (`float`, defaults to `2.8`):
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
values, but it may also lead to overexposure and saturation.
skip_layer_guidance_start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which skip layer guidance starts.
skip_layer_guidance_stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which skip layer guidance stops.
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
3.5 Medium.
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
def __init__(
self,
guidance_scale: float = 7.5,
skip_layer_guidance_scale: float = 2.8,
skip_layer_guidance_start: float = 0.01,
skip_layer_guidance_stop: float = 0.2,
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
self.skip_layer_guidance_start = skip_layer_guidance_start
self.skip_layer_guidance_stop = skip_layer_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if not (0.0 <= skip_layer_guidance_start < 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
)
if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
)
if skip_layer_guidance_layers is None and skip_layer_config is None:
raise ValueError(
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
)
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
if skip_layer_guidance_layers is not None:
if isinstance(skip_layer_guidance_layers, int):
skip_layer_guidance_layers = [skip_layer_guidance_layers]
if not isinstance(skip_layer_guidance_layers, list):
raise ValueError(
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
)
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
if isinstance(skip_layer_config, LayerSkipConfig):
skip_layer_config = [skip_layer_config]
if not isinstance(skip_layer_config, list):
raise ValueError(
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
)
self.skip_layer_config = skip_layer_config
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
def prepare_models(self, denoiser: torch.nn.Module) -> None:
self._count_prepared += 1
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_slg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_cond_skip
pred = pred + self.skip_layer_guidance_scale * shift
elif not self._is_slg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_skip = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_slg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_slg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
return is_within_range and not is_zero
@@ -1,244 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
import torch
from ..hooks import HookRegistry
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class SmoothedEnergyGuidance(BaseGuidance):
"""
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified
in the future without warning or guarantee of reproducibility. This implementation assumes:
- Generated images are square (height == width)
- The model does not combine different modalities together (e.g., text and image latent streams are
not combined together such as Flux)
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
seg_guidance_scale (`float`, defaults to `3.0`):
The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
values, but it may also lead to overexposure and saturation.
seg_blur_sigma (`float`, defaults to `9999999.0`):
The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
seg_blur_threshold_inf (`float`, defaults to `9999.0`):
The threshold above which the blur is considered infinite.
seg_guidance_start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
seg_guidance_stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
seg_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If not
provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
3.5 Medium.
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or a list of
`SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
def __init__(
self,
guidance_scale: float = 7.5,
seg_guidance_scale: float = 2.8,
seg_blur_sigma: float = 9999999.0,
seg_blur_threshold_inf: float = 9999.0,
seg_guidance_start: float = 0.0,
seg_guidance_stop: float = 1.0,
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.seg_guidance_scale = seg_guidance_scale
self.seg_blur_sigma = seg_blur_sigma
self.seg_blur_threshold_inf = seg_blur_threshold_inf
self.seg_guidance_start = seg_guidance_start
self.seg_guidance_stop = seg_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if not (0.0 <= seg_guidance_start < 1.0):
raise ValueError(
f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}."
)
if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
raise ValueError(
f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}."
)
if seg_guidance_layers is None and seg_guidance_config is None:
raise ValueError(
"Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
)
if seg_guidance_layers is not None and seg_guidance_config is not None:
raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
if seg_guidance_layers is not None:
if isinstance(seg_guidance_layers, int):
seg_guidance_layers = [seg_guidance_layers]
if not isinstance(seg_guidance_layers, list):
raise ValueError(
f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
)
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
seg_guidance_config = [seg_guidance_config]
if not isinstance(seg_guidance_config, list):
raise ValueError(
f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
)
self.seg_guidance_config = seg_guidance_config
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
def prepare_models(self, denoiser: torch.nn.Module) -> None:
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
def cleanup_models(self, denoiser: torch.nn.Module):
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_seg: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_seg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_seg
pred = pred_cond if self.use_original_formulation else pred_cond_seg
pred = pred + self.seg_guidance_scale * shift
elif not self._is_seg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_seg = pred_cond - pred_cond_seg
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_seg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_seg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
return is_within_range and not is_zero
@@ -1,137 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class TangentialClassifierFreeGuidance(BaseGuidance):
"""
Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_tcfg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._num_outputs_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_tcfg_enabled():
num_conditions += 1
return num_conditions
def _is_tcfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor:
cond_dtype = pred_cond.dtype
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
preds = preds.flatten(2)
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
Vh_modified = Vh.clone()
Vh_modified[:, 1] = 0
uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
pred = pred_cond if use_original_formulation else pred_uncond
shift = pred_cond - pred_uncond
pred = pred + guidance_scale * shift
return pred
-2
View File
@@ -5,7 +5,5 @@ if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
-43
View File
@@ -1,43 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from ..models.attention import FeedForward, LuminaFeedForward
from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
{
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
}
)
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
for submodule_name, submodule in module.named_modules():
if submodule_name == fqn:
return submodule
return None
-271
View File
@@ -1,271 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, Type
from ..models.attention import BasicTransformerBlock
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_hunyuan_video import (
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenReplaceSingleTransformerBlock,
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
@dataclass
class AttentionProcessorMetadata:
skip_processor_output_fn: Callable[[Any], Any]
@dataclass
class TransformerBlockMetadata:
skip_block_output_fn: Callable[[Any], Any]
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
class AttentionProcessorRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
class TransformerBlockRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> TransformerBlockMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
def _register_attention_processors_metadata():
# AttnProcessor2_0
AttentionProcessorRegistry.register(
model_class=AttnProcessor2_0,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
),
)
# CogView4AttnProcessor
AttentionProcessorRegistry.register(
model_class=CogView4AttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
),
)
def _register_transformer_blocks_metadata():
# BasicTransformerBlock
TransformerBlockRegistry.register(
model_class=BasicTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# CogVideoX
TransformerBlockRegistry.register(
model_class=CogVideoXBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# CogView4
TransformerBlockRegistry.register(
model_class=CogView4TransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Flux
TransformerBlockRegistry.register(
model_class=FluxTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock,
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
TransformerBlockRegistry.register(
model_class=FluxSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock,
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
# HunyuanVideo
TransformerBlockRegistry.register(
model_class=HunyuanVideoTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# LTXVideo
TransformerBlockRegistry.register(
model_class=LTXVideoTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# Mochi
TransformerBlockRegistry.register(
model_class=MochiTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Wan
TransformerBlockRegistry.register(
model_class=WanTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
return hidden_states
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
return hidden_states
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return encoder_hidden_states, hidden_states
_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
# fmt: on
_register_attention_processors_metadata()
_register_transformer_blocks_metadata()
+11 -27
View File
@@ -13,7 +13,7 @@
# limitations under the License.
from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple
import torch
@@ -55,7 +55,7 @@ class ModuleGroup:
parameters: Optional[List[torch.nn.Parameter]] = None,
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
@@ -115,13 +115,8 @@ class ModuleGroup:
def onload_(self):
r"""Onloads the group of modules to the onload_device."""
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
current_stream = torch.cuda.current_stream() if self.record_stream else None
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
@@ -167,15 +162,9 @@ class ModuleGroup:
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
if self.stream is not None:
if not self.record_stream:
torch_accelerator_module.current_stream().synchronize()
torch.cuda.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
@@ -440,10 +429,8 @@ def apply_group_offloading(
if use_stream:
if torch.cuda.is_available():
stream = torch.cuda.Stream()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
stream = torch.Stream()
else:
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
raise ValueError("Using streams for data transfer requires a CUDA device.")
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
@@ -481,7 +468,7 @@ def _apply_group_offloading_block_level(
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
) -> None:
@@ -499,7 +486,7 @@ def _apply_group_offloading_block_level(
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
@@ -512,10 +499,7 @@ def _apply_group_offloading_block_level(
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
if stream is not None and num_blocks_per_group != 1:
logger.warning(
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
)
num_blocks_per_group = 1
raise ValueError(f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}.")
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
@@ -585,7 +569,7 @@ def _apply_group_offloading_leaf_level(
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
) -> None:
@@ -605,7 +589,7 @@ def _apply_group_offloading_leaf_level(
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
-231
View File
@@ -1,231 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Callable, List, Optional
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_LAYER_SKIP_HOOK = "layer_skip_hook"
# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed
# either remove or make it serializable
@dataclass
class LayerSkipConfig:
r"""
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
For automatic detection, set this to `"auto"`.
"auto" only works on DiT models. For UNet models, you must provide the correct fqn.
skip_attention (`bool`, defaults to `True`):
Whether to skip attention blocks.
skip_ff (`bool`, defaults to `True`):
Whether to skip feed-forward blocks.
skip_attention_scores (`bool`, defaults to `False`):
Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
projections as the output of scaled dot product attention.
dropout (`float`, defaults to `1.0`):
The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
skipped layers are fully retained, which is equivalent to not skipping any layers.
"""
indices: List[int]
fqn: str = "auto"
skip_attention: bool = True
skip_attention_scores: bool = False
skip_ff: bool = True
dropout: float = 1.0
def __post_init__(self):
if not (0 <= self.dropout <= 1):
raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.nn.functional.scaled_dot_product_attention:
value = kwargs.get("value", None)
if value is None:
value = args[2]
return value
return func(*args, **kwargs)
class AttentionProcessorSkipHook(ModelHook):
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
self.skip_processor_output_fn = skip_processor_output_fn
self.skip_attention_scores = skip_attention_scores
self.dropout = dropout
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.skip_attention_scores:
if not math.isclose(self.dropout, 1.0):
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
else:
if math.isclose(self.dropout, 1.0):
output = self.skip_processor_output_fn(module, *args, **kwargs)
else:
output = self.fn_ref.original_forward(*args, **kwargs)
output = torch.nn.functional.dropout(output, p=self.dropout)
return output
class FeedForwardSkipHook(ModelHook):
def __init__(self, dropout: float):
super().__init__()
self.dropout = dropout
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if math.isclose(self.dropout, 1.0):
output = kwargs.get("hidden_states", None)
if output is None:
output = kwargs.get("x", None)
if output is None and len(args) > 0:
output = args[0]
else:
output = self.fn_ref.original_forward(*args, **kwargs)
output = torch.nn.functional.dropout(output, p=self.dropout)
return output
class TransformerBlockSkipHook(ModelHook):
def __init__(self, dropout: float):
super().__init__()
self.dropout = dropout
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if math.isclose(self.dropout, 1.0):
output = self._metadata.skip_block_output_fn(module, *args, **kwargs)
else:
output = self.fn_ref.original_forward(*args, **kwargs)
output = torch.nn.functional.dropout(output, p=self.dropout)
return output
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
r"""
Apply layer skipping to internal layers of a transformer.
Args:
module (`torch.nn.Module`):
The transformer model to which the layer skip hook should be applied.
config (`LayerSkipConfig`):
The configuration for the layer skip hook.
Example:
```python
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
>>> apply_layer_skip_hook(transformer, config)
```
"""
_apply_layer_skip_hook(module, config)
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
name = name or _LAYER_SKIP_HOOK
if config.skip_attention and config.skip_attention_scores:
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
raise ValueError("Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0.")
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
config.fqn = identifier
break
else:
raise ValueError(
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
raise ValueError(
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
)
if len(config.indices) == 0:
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
blocks_found = False
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
if config.skip_attention and config.skip_ff:
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = TransformerBlockSkipHook(config.dropout)
registry.register_hook(hook, name)
elif config.skip_attention or config.skip_attention_scores:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
registry.register_hook(hook, name)
if config.skip_ff:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _FEEDFORWARD_CLASSES):
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = FeedForwardSkipHook(config.dropout)
registry.register_hook(hook, name)
if not blocks_found:
raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
)
@@ -1,158 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from ..utils import get_logger
from ._common import _ATTENTION_CLASSES, _get_submodule_from_fqn
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook"
@dataclass
class SmoothedEnergyGuidanceConfig:
r"""
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
For automatic detection, set this to `"auto"`.
"auto" only works on DiT models. For UNet models, you must provide the correct fqn.
_query_proj_identifiers (`List[str]`, defaults to `None`):
The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`.
If `None`, `to_q` is used by default.
"""
indices: List[int]
fqn: str = "auto"
_query_proj_identifiers: List[str] = None
class SmoothedEnergyGuidanceHook(ModelHook):
def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:
super().__init__()
self.blur_sigma = blur_sigma
self.blur_threshold_inf = blur_threshold_inf
def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor:
# Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102
kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf)
return smoothed_output
def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None:
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
config.fqn = identifier
break
else:
raise ValueError(
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
if config._query_proj_identifiers is None:
config._query_proj_identifiers = ["to_q"]
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
blocks_found = False
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
for submodule_name, submodule in block.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
continue
for identifier in config._query_proj_identifiers:
query_proj = getattr(submodule, identifier, None)
if query_proj is None or not isinstance(query_proj, torch.nn.Linear):
continue
logger.debug(
f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}"
)
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
hook = SmoothedEnergyGuidanceHook(blur_sigma)
registry.register_hook(hook, name)
if not blocks_found:
raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
)
# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
"""
This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian
blur. However, some models use joint text-visual token attention for which this may not be suitable. Additionally,
this implementation also assumes that the visual tokens come from a square image/video. In practice, despite
these assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results
for Smoothed Energy Guidance.
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified
in the future without warning or guarantee of reproducibility.
"""
assert query.ndim == 3
is_inf = sigma > sigma_threshold_inf
batch_size, seq_len, embed_dim = query.shape
seq_len_sqrt = int(math.sqrt(seq_len))
num_square_tokens = seq_len_sqrt * seq_len_sqrt
query_slice = query[:, :num_square_tokens, :]
query_slice = query_slice.permute(0, 2, 1)
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
if is_inf:
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
kernel_size_half = (kernel_size - 1) / 2
x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
kernel1d = kernel1d.to(query)
kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :])
kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
query_slice = F.pad(query_slice, padding, mode="reflect")
query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
else:
query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True)
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
query_slice = query_slice.permute(0, 2, 1)
query[:, :num_square_tokens, :] = query_slice.clone()
return query
+12 -18
View File
@@ -54,14 +54,14 @@ if is_transformers_available():
_import_structure = {}
if is_torch_available():
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["ip_adapter.transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
_import_structure["ip_adapter.transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
_import_structure["single_file.single_file_model"] = ["FromOriginalModelMixin"]
_import_structure["unet.unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
_import_structure["single_file"] = ["FromSingleFileMixin"]
_import_structure["lora_pipeline"] = [
_import_structure["single_file.single_file"] = ["FromSingleFileMixin"]
_import_structure["lora.lora_pipeline"] = [
"AmusedLoraLoaderMixin",
"StableDiffusionLoraLoaderMixin",
"SD3LoraLoaderMixin",
@@ -80,11 +80,10 @@ if is_torch_available():
"HiDreamImageLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
_import_structure["ip_adapter.ip_adapter"] = [
"IPAdapterMixin",
"FluxIPAdapterMixin",
"SD3IPAdapterMixin",
"ModularIPAdapterMixin",
]
_import_structure["peft"] = ["PeftAdapterMixin"]
@@ -92,20 +91,14 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .single_file_model import FromOriginalModelMixin
from .transformer_flux import FluxTransformer2DLoadersMixin
from .transformer_sd3 import SD3Transformer2DLoadersMixin
from .ip_adapter import FluxTransformer2DLoadersMixin, SD3Transformer2DLoadersMixin
from .single_file import FromOriginalModelMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
if is_transformers_available():
from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin,
SD3IPAdapterMixin,
ModularIPAdapterMixin,
)
from .lora_pipeline import (
from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin
from .lora import (
AmusedLoraLoaderMixin,
AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
@@ -113,6 +106,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
LoraBaseMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,9 @@
from ...utils.import_utils import is_torch_available, is_transformers_available
if is_torch_available():
from .transformer_flux import FluxTransformer2DLoadersMixin
from .transformer_sd3 import SD3Transformer2DLoadersMixin
if is_transformers_available():
from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin
@@ -0,0 +1,879 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from safetensors import safe_open
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ...utils import (
USE_PEFT_BACKEND,
_get_detailed_type,
_get_model_file,
_is_valid_type,
is_accelerate_available,
is_torch_version,
is_transformers_available,
logging,
)
from ..unet.unet_loader_utils import _maybe_expand_lora_scales
if is_transformers_available():
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
from ...models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
FluxAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
JointAttnProcessor2_0,
SD3IPAdapterJointAttnProcessor2_0,
)
logger = logging.get_logger(__name__)
class IPAdapterMixin:
"""Mixin for handling IP Adapters."""
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
subfolder: Union[str, List[str]],
weight_name: Union[str, List[str]],
image_encoder_folder: Optional[str] = "image_encoder",
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`subfolder`.
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
`image_encoder_folder="different_subfolder/image_encoder"`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# handle the list inputs for multiple IP Adapters
if not isinstance(weight_name, list):
weight_name = [weight_name]
if not isinstance(pretrained_model_name_or_path_or_dict, list):
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
if len(pretrained_model_name_or_path_or_dict) == 1:
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
if not isinstance(subfolder, list):
subfolder = [subfolder]
if len(subfolder) == 1:
subfolder = subfolder * len(weight_name)
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
if len(weight_name) != len(subfolder):
raise ValueError("`weight_name` and `subfolder` must have the same length.")
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
):
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if "image_proj" not in keys and "ip_adapter" not in keys:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
state_dicts.append(state_dict)
# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if image_encoder_folder is not None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
if image_encoder_folder.count("/") == 0:
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
else:
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=image_encoder_subfolder,
low_cpu_mem_usage=low_cpu_mem_usage,
cache_dir=cache_dir,
local_files_only=local_files_only,
torch_dtype=self.dtype,
).to(self.device)
self.register_modules(image_encoder=image_encoder)
else:
raise ValueError(
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
)
else:
logger.warning(
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
)
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
default_clip_size = 224
clip_image_size = (
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
)
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
self.register_modules(feature_extractor=feature_extractor)
# load ip-adapter into unet
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
extra_loras = unet._load_ip_adapter_loras(state_dicts)
if extra_loras != {}:
if not USE_PEFT_BACKEND:
logger.warning("PEFT backend is required to load these weights.")
else:
# apply the IP Adapter Face ID LoRA weights
peft_config = getattr(unet, "peft_config", {})
for k, lora in extra_loras.items():
if f"faceid_{k}" not in peft_config:
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
def set_ip_adapter_scale(self, scale):
"""
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
Example:
```py
# To use original IP-Adapter
scale = 1.0
pipeline.set_ip_adapter_scale(scale)
# To use style block only
scale = {
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)
# To use style+layout blocks
scale = {
"down": {"block_2": [0.0, 1.0]},
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)
# To use style and layout from 2 reference images
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
pipeline.set_ip_adapter_scale(scales)
```
"""
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
if not isinstance(scale, list):
scale = [scale]
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
for attn_name, attn_processor in unet.attn_processors.items():
if isinstance(
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
for i, scale_config in enumerate(scale_configs):
if isinstance(scale_config, dict):
for k, s in scale_config.items():
if attn_name.startswith(k):
attn_processor.scale[i] = s
else:
attn_processor.scale[i] = scale_config
def unload_ip_adapter(self):
"""
Unloads the IP Adapter weights
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
self.register_to_config(image_encoder=[None, None])
# remove feature extractor only when safety_checker is None as safety_checker uses
# the feature_extractor later
if not hasattr(self, "safety_checker"):
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
self.feature_extractor = None
self.register_to_config(feature_extractor=[None, None])
# remove hidden encoder
self.unet.encoder_hid_proj = None
self.unet.config.encoder_hid_dim_type = None
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
self.unet.text_encoder_hid_proj = None
self.unet.config.encoder_hid_dim_type = "text_proj"
# restore original Unet attention processors layers
attn_procs = {}
for name, value in self.unet.attn_processors.items():
attn_processor_class = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
)
attn_procs[name] = (
attn_processor_class
if isinstance(
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
)
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)
class FluxIPAdapterMixin:
"""Mixin for handling Flux IP Adapters."""
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
weight_name: Union[str, List[str]],
subfolder: Optional[Union[str, List[str]]] = "",
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
image_encoder_subfolder: Optional[str] = "",
image_encoder_dtype: torch.dtype = torch.float16,
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`weight_name`.
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
Can be either:
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# handle the list inputs for multiple IP Adapters
if not isinstance(weight_name, list):
weight_name = [weight_name]
if not isinstance(pretrained_model_name_or_path_or_dict, list):
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
if len(pretrained_model_name_or_path_or_dict) == 1:
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
if not isinstance(subfolder, list):
subfolder = [subfolder]
if len(subfolder) == 1:
subfolder = subfolder * len(weight_name)
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
if len(weight_name) != len(subfolder):
raise ValueError("`weight_name` and `subfolder` must have the same length.")
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
):
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
for key in f.keys():
if any(key.startswith(prefix) for prefix in image_proj_keys):
diffusers_name = ".".join(key.split(".")[1:])
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
diffusers_name = (
".".join(key.split(".")[1:])
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
.replace("processor.", "")
)
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
else:
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
state_dicts.append(state_dict)
# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if image_encoder_pretrained_model_name_or_path is not None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
image_encoder = (
CLIPVisionModelWithProjection.from_pretrained(
image_encoder_pretrained_model_name_or_path,
subfolder=image_encoder_subfolder,
low_cpu_mem_usage=low_cpu_mem_usage,
cache_dir=cache_dir,
local_files_only=local_files_only,
dtype=image_encoder_dtype,
)
.to(self.device)
.eval()
)
self.register_modules(image_encoder=image_encoder)
else:
raise ValueError(
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
)
else:
logger.warning(
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
)
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
default_clip_size = 224
clip_image_size = (
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
)
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
self.register_modules(feature_extractor=feature_extractor)
# load ip-adapter into transformer
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
"""
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
granular control over each IP-Adapter behavior. A config can be a float or a list.
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
number of IP adapters and each must match the number of blocks.
Example:
```py
# To use original IP-Adapter
scale = 1.0
pipeline.set_ip_adapter_scale(scale)
def LinearStrengthModel(start, finish, size):
return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
pipeline.set_ip_adapter_scale(ip_strengths)
```
"""
scale_type = Union[int, float]
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
num_layers = self.transformer.config.num_layers
# Single value for all layers of all IP-Adapters
if isinstance(scale, scale_type):
scale = [scale for _ in range(num_ip_adapters)]
# List of per-layer scales for a single IP-Adapter
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
scale = [scale]
# Invalid scale type
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
if len(scale) != num_ip_adapters:
raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
if any(len(s) != num_layers for s in scale if isinstance(s, list)):
invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
raise ValueError(
f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
)
# Scalars are transformed to lists with length num_layers
scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
# Set scales. zip over scale_configs prevents going into single transformer layers
for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
attn_processor.scale = scale
def unload_ip_adapter(self):
"""
Unloads the IP Adapter weights
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
self.register_to_config(image_encoder=[None, None])
# remove feature extractor only when safety_checker is None as safety_checker uses
# the feature_extractor later
if not hasattr(self, "safety_checker"):
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
self.feature_extractor = None
self.register_to_config(feature_extractor=[None, None])
# remove hidden encoder
self.transformer.encoder_hid_proj = None
self.transformer.config.encoder_hid_dim_type = None
# restore original Transformer attention processors layers
attn_procs = {}
for name, value in self.transformer.attn_processors.items():
attn_processor_class = FluxAttnProcessor2_0()
attn_procs[name] = (
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
)
self.transformer.set_attn_processor(attn_procs)
class SD3IPAdapterMixin:
"""Mixin for handling StableDiffusion 3 IP Adapters."""
@property
def is_ip_adapter_active(self) -> bool:
"""Checks if IP-Adapter is loaded and scale > 0.
IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
the image context is irrelevant.
Returns:
`bool`: True when IP-Adapter is loaded and any layer has scale > 0.
"""
scales = [
attn_proc.scale
for attn_proc in self.transformer.attn_processors.values()
if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
]
return len(scales) > 0 and any(scale > 0 for scale in scales)
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
weight_name: str = "ip-adapter.safetensors",
subfolder: Optional[str] = None,
image_encoder_folder: Optional[str] = "image_encoder",
**kwargs,
) -> None:
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
weight_name (`str`, defaults to "ip-adapter.safetensors"):
The name of the weight file to load. If a list is passed, it should have the same length as
`subfolder`.
subfolder (`str`, *optional*):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
`image_encoder_folder="different_subfolder/image_encoder"`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# Load the main state dict first
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if "image_proj" not in keys and "ip_adapter" not in keys:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
# Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if image_encoder_folder is not None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
if image_encoder_folder.count("/") == 0:
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
else:
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
# Commons args for loading image encoder and image processor
kwargs = {
"low_cpu_mem_usage": low_cpu_mem_usage,
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
self.register_modules(
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
image_encoder=SiglipVisionModel.from_pretrained(
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
).to(self.device),
)
else:
raise ValueError(
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
)
else:
logger.warning(
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
)
# Load IP-Adapter into transformer
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
def set_ip_adapter_scale(self, scale: float) -> None:
"""
Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
the model to produce more diverse images, but they may not be as aligned with the image prompt.
Example:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.set_ip_adapter_scale(0.6)
>>> ...
```
Args:
scale (float):
IP-Adapter scale to be set.
"""
for attn_processor in self.transformer.attn_processors.values():
if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
attn_processor.scale = scale
def unload_ip_adapter(self) -> None:
"""
Unloads the IP Adapter weights.
Example:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# Remove image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
self.register_to_config(image_encoder=None)
# Remove feature extractor
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
self.feature_extractor = None
self.register_to_config(feature_extractor=None)
# Remove image projection
self.transformer.image_proj = None
# Restore original attention processors layers
attn_procs = {
name: (
JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
)
for name, value in self.transformer.attn_processors.items()
}
self.transformer.set_attn_processor(attn_procs)
@@ -0,0 +1,168 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from ...models.embeddings import ImageProjection, MultiIPAdapterImageProjection
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ...utils import is_accelerate_available, is_torch_version, logging
logger = logging.get_logger(__name__)
class FluxTransformer2DLoadersMixin:
"""
Load layers into a [`FluxTransformer2DModel`].
"""
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
updated_state_dict = {}
image_projection = None
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
if "proj.weight" in state_dict:
# IP-Adapter
num_image_text_embeds = 4
if state_dict["proj.weight"].shape[0] == 65536:
num_image_text_embeds = 16
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
with init_context():
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
for key, value in state_dict.items():
diffusers_name = key.replace("proj", "image_embeds")
updated_state_dict[diffusers_name] = value
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ...models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 0
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
for name in self.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
attn_processor_class = self.attn_processors[name].__class__
attn_procs[name] = attn_processor_class()
else:
cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
num_image_text_embed = 4
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
num_image_text_embed = 16
# IP-Adapter
num_image_text_embeds += [num_image_text_embed]
with init_context():
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
dtype=self.dtype,
device=self.device,
)
value_dict = {}
for i, state_dict in enumerate(state_dicts):
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(value_dict)
else:
device_map = {"": self.device}
dtype = self.dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
key_id += 1
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]
self.encoder_hid_proj = None
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
self.set_attn_processor(attn_procs)
image_projection_layers = []
for state_dict in state_dicts:
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
)
image_projection_layers.append(image_projection_layer)
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.config.encoder_hid_dim_type = "ip_image_proj"
@@ -0,0 +1,170 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from typing import Dict
from ...models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ...models.embeddings import IPAdapterTimeImageProjection
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ...utils import is_accelerate_available, is_torch_version, logging
logger = logging.get_logger(__name__)
class SD3Transformer2DLoadersMixin:
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
def _convert_ip_adapter_attn_to_diffusers(
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
) -> Dict:
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
# IP-Adapter cross attention parameters
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1]
# Dict where key is transformer layer index, value is attention processor's state dict
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
for key, weights in state_dict.items():
idx, name = key.split(".", maxsplit=1)
layer_state_dict[int(idx)][name] = weights
# Create IP-Adapter attention processor & load state_dict
attn_procs = {}
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
for idx, name in enumerate(self.attn_processors.keys()):
with init_context():
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size,
ip_hidden_states_dim=ip_hidden_states_dim,
head_dim=self.config.attention_head_dim,
timesteps_emb_dim=timesteps_emb_dim,
)
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
else:
device_map = {"": self.device}
load_model_dict_into_meta(
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
)
return attn_procs
def _convert_ip_adapter_image_proj_to_diffusers(
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
) -> IPAdapterTimeImageProjection:
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
# Convert to diffusers
updated_state_dict = {}
for key, value in state_dict.items():
# InstantX/SD3.5-Large-IP-Adapter
if key.startswith("layers."):
idx = key.split(".")[1]
key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
updated_state_dict[key] = value
# Image projetion parameters
embed_dim = updated_state_dict["proj_in.weight"].shape[1]
output_dim = updated_state_dict["proj_out.weight"].shape[0]
hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
num_queries = updated_state_dict["latents"].shape[1]
timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1]
# Image projection
with init_context():
image_proj = IPAdapterTimeImageProjection(
embed_dim=embed_dim,
output_dim=output_dim,
hidden_dim=hidden_dim,
heads=heads,
num_queries=num_queries,
timestep_in_dim=timestep_in_dim,
)
if not low_cpu_mem_usage:
image_proj.load_state_dict(updated_state_dict, strict=True)
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
return image_proj
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
Args:
state_dict (`Dict`):
State dict with keys "ip_adapter", which contains parameters for attention processors, and
"image_proj", which contains parameters for image projection net.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage)
self.set_attn_processor(attn_procs)
self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage)
+25
View File
@@ -0,0 +1,25 @@
from ...utils import is_peft_available, is_torch_available, is_transformers_available
if is_torch_available():
from .lora_base import LoraBaseMixin
if is_transformers_available():
from .lora_pipeline import (
AmusedLoraLoaderMixin,
AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
FluxLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
Mochi1LoraLoaderMixin,
SanaLoraLoaderMixin,
SD3LoraLoaderMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
WanLoraLoaderMixin,
)
+935
View File
@@ -0,0 +1,935 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import os
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import safetensors
import torch
import torch.nn as nn
from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from ...models.modeling_utils import ModelMixin, load_state_dict
from ...utils import (
USE_PEFT_BACKEND,
_get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
delete_adapter_layers,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available,
is_peft_version,
is_transformers_available,
is_transformers_version,
logging,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
if is_transformers_available():
from transformers import PreTrainedModel
from ...models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
"""
Fuses LoRAs for the text encoder.
Args:
text_encoder (`torch.nn.Module`):
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
attribute.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
"""
merge_kwargs = {"safe_merge": safe_fusing}
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
if lora_scale != 1.0:
module.scale_layer(lora_scale)
# For BC with previous PEFT versions, we need to check the signature
# of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs:
merge_kwargs["adapter_names"] = adapter_names
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
raise ValueError(
"The `adapter_names` argument is not supported with your PEFT version. "
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
)
module.merge(**merge_kwargs)
def unfuse_text_encoder_lora(text_encoder):
"""
Unfuses LoRAs for the text encoder.
Args:
text_encoder (`torch.nn.Module`):
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
attribute.
"""
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
def set_adapters_for_text_encoder(
adapter_names: Union[List[str], str],
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
):
"""
Sets the adapter layers for the text encoder.
Args:
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
attribute.
text_encoder_weights (`List[float]`, *optional*):
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
"""
if text_encoder is None:
raise ValueError(
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
)
def process_weights(adapter_names, weights):
# Expand weights into a list, one entry per adapter
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)
if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
)
# Set None values to default of 1.0
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
weights = [w if w is not None else 1.0 for w in weights]
return weights
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
"""
Disables the LoRA layers for the text encoder.
Args:
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
attribute.
"""
if text_encoder is None:
raise ValueError("Text Encoder not found.")
set_adapter_layers(text_encoder, enabled=False)
def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
"""
Enables the LoRA layers for the text encoder.
Args:
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
attribute.
"""
if text_encoder is None:
raise ValueError("Text Encoder not found.")
set_adapter_layers(text_encoder, enabled=True)
def _remove_text_encoder_monkey_patch(text_encoder):
recurse_remove_peft_layers(text_encoder)
if getattr(text_encoder, "peft_config", None) is not None:
del text_encoder.peft_config
text_encoder._hf_peft_config_loaded = None
def _fetch_state_dict(
pretrained_model_name_or_path_or_dict,
weight_name,
use_safetensors,
local_files_only,
cache_dir,
force_download,
proxies,
token,
revision,
subfolder,
user_agent,
allow_pickle,
):
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
# Here we're relaxing the loading check to enable more Inference API
# friendliness where sometimes, it's not at all possible to automatically
# determine `weight_name`.
if weight_name is None:
weight_name = _best_guess_weight_name(
pretrained_model_name_or_path_or_dict,
file_extension=".safetensors",
local_files_only=local_files_only,
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
model_file = None
pass
if model_file is None:
if weight_name is None:
weight_name = _best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
return state_dict
def _best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
if local_files_only or HF_HUB_OFFLINE:
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
targeted_files = []
if os.path.isfile(pretrained_model_name_or_path_or_dict):
return
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
else:
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
if len(targeted_files) == 0:
return
# "scheduler" does not correspond to a LoRA checkpoint.
# "optimizer" does not correspond to a LoRA checkpoint
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
targeted_files = list(
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
)
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
if len(targeted_files) > 1:
raise ValueError(
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
)
weight_name = targeted_files[0]
return weight_name
def _load_lora_into_text_encoder(
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
text_encoder_name="text_encoder",
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes.
prefix = text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
# Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
state_dict = convert_state_dict_to_diffusers(state_dict)
# convert state dict
state_dict = convert_state_dict_to_peft(state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
if prefix is not None and not state_dict:
logger.warning(
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any "
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new"
)
def _func_optionally_disable_offloading(_pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
class LoraBaseMixin:
"""Utility class for handling LoRAs."""
_lora_loadable_modules = []
num_fused_loras = 0
def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.")
@classmethod
def save_lora_weights(cls, **kwargs):
raise NotImplementedError("`save_lora_weights()` not implemented.")
@classmethod
def lora_state_dict(cls, **kwargs):
raise NotImplementedError("`lora_state_dict()` is not implemented.")
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)
@classmethod
def _fetch_state_dict(cls, *args, **kwargs):
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
return _fetch_state_dict(*args, **kwargs)
@classmethod
def _best_guess_weight_name(cls, *args, **kwargs):
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
return _best_guess_weight_name(*args, **kwargs)
def unload_lora_weights(self):
"""
Unloads the LoRA parameters.
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
>>> pipeline.unload_lora_weights()
>>> ...
```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
if issubclass(model.__class__, ModelMixin):
model.unload_lora()
elif issubclass(model.__class__, PreTrainedModel):
_remove_text_encoder_monkey_patch(model)
def fuse_lora(
self,
components: List[str] = [],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
if "fuse_unet" in kwargs:
depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
deprecate(
"fuse_unet",
"1.0.0",
depr_message,
)
if "fuse_transformer" in kwargs:
depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
deprecate(
"fuse_transformer",
"1.0.0",
depr_message,
)
if "fuse_text_encoder" in kwargs:
depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
deprecate(
"fuse_text_encoder",
"1.0.0",
depr_message,
)
if len(components) == 0:
raise ValueError("`components` cannot be an empty list.")
for fuse_component in components:
if fuse_component not in self._lora_loadable_modules:
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
model = getattr(self, fuse_component, None)
if model is not None:
# check if diffusers model
if issubclass(model.__class__, ModelMixin):
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
# handle transformers models.
if issubclass(model.__class__, PreTrainedModel):
fuse_text_encoder_lora(
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
self.num_fused_loras += 1
def unfuse_lora(self, components: List[str] = [], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
if "unfuse_unet" in kwargs:
depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
deprecate(
"unfuse_unet",
"1.0.0",
depr_message,
)
if "unfuse_transformer" in kwargs:
depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
deprecate(
"unfuse_transformer",
"1.0.0",
depr_message,
)
if "unfuse_text_encoder" in kwargs:
depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
deprecate(
"unfuse_text_encoder",
"1.0.0",
depr_message,
)
if len(components) == 0:
raise ValueError("`components` cannot be an empty list.")
for fuse_component in components:
if fuse_component not in self._lora_loadable_modules:
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
model = getattr(self, fuse_component, None)
if model is not None:
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
self.num_fused_loras -= 1
def set_adapters(
self,
adapter_names: Union[List[str], str],
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):
if isinstance(adapter_weights, dict):
components_passed = set(adapter_weights.keys())
lora_components = set(self._lora_loadable_modules)
invalid_components = sorted(components_passed - lora_components)
if invalid_components:
logger.warning(
f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
"to the invalid components will be removed and ignored."
)
adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
adapter_weights = copy.deepcopy(adapter_weights)
# Expand weights into a list, one entry per adapter
if not isinstance(adapter_weights, list):
adapter_weights = [adapter_weights] * len(adapter_names)
if len(adapter_names) != len(adapter_weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
)
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
# eg ["adapter1", "adapter2"]
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
missing_adapters = set(adapter_names) - all_adapters
if len(missing_adapters) > 0:
raise ValueError(
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
)
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
invert_list_adapters = {
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
for adapter in all_adapters
}
# Decompose weights into weights for denoiser and text encoders.
_component_adapter_weights = {}
for component in self._lora_loadable_modules:
model = getattr(self, component)
for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):
component_adapter_weights = weights.pop(component, None)
if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
logger.warning(
(
f"Lora weight dict for adapter '{adapter_name}' contains {component},"
f"but this will be ignored because {adapter_name} does not contain weights for {component}."
f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
)
)
else:
component_adapter_weights = weights
_component_adapter_weights.setdefault(component, [])
_component_adapter_weights[component].append(component_adapter_weights)
if issubclass(model.__class__, ModelMixin):
model.set_adapters(adapter_names, _component_adapter_weights[component])
elif issubclass(model.__class__, PreTrainedModel):
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
def disable_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
if issubclass(model.__class__, ModelMixin):
model.disable_lora()
elif issubclass(model.__class__, PreTrainedModel):
disable_lora_for_text_encoder(model)
def enable_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
if issubclass(model.__class__, ModelMixin):
model.enable_lora()
elif issubclass(model.__class__, PreTrainedModel):
enable_lora_for_text_encoder(model)
def delete_adapters(self, adapter_names: Union[List[str], str]):
"""
Args:
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
adapter_names (`Union[List[str], str]`):
The names of the adapter to delete. Can be a single string or a list of strings
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
if isinstance(adapter_names, str):
adapter_names = [adapter_names]
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
if issubclass(model.__class__, ModelMixin):
model.delete_adapters(adapter_names)
elif issubclass(model.__class__, PreTrainedModel):
for adapter_name in adapter_names:
delete_adapter_layers(model, adapter_name)
def get_active_adapters(self) -> List[str]:
"""
Gets the list of the current active adapters.
Example:
```python
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
).to("cuda")
pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
pipeline.get_active_adapters()
```
"""
if not USE_PEFT_BACKEND:
raise ValueError(
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
)
active_adapters = []
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None and issubclass(model.__class__, ModelMixin):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
break
return active_adapters
def get_list_adapters(self) -> Dict[str, List[str]]:
"""
Gets the current list of all available adapters in the pipeline.
"""
if not USE_PEFT_BACKEND:
raise ValueError(
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
)
set_adapters = {}
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if (
model is not None
and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
and hasattr(model, "peft_config")
):
set_adapters[component] = list(model.peft_config.keys())
return set_adapters
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
"""
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
you want to load multiple adapters and free some GPU memory.
Args:
adapter_names (`List[str]`):
List of adapters to send device to.
device (`Union[torch.device, str, int]`):
Device to send the adapters to. Can be either a torch device, a str or an integer.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
for module in model.modules():
if isinstance(module, BaseTunerLayer):
for adapter_name in adapter_names:
module.lora_A[adapter_name].to(device)
module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
if adapter_name in module.lora_magnitude_vector:
module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
adapter_name
].to(device)
@staticmethod
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
@staticmethod
def write_lora_layers(
state_dict: Dict[str, torch.Tensor],
save_directory: str,
is_main_process: bool,
weight_name: str,
save_function: Callable,
safe_serialization: bool,
):
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
if weight_name is None:
if safe_serialization:
weight_name = LORA_WEIGHT_NAME_SAFE
else:
weight_name = LORA_WEIGHT_NAME
save_path = Path(save_directory, weight_name).as_posix()
save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}")
@property
def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline.
# if _lora_scale has not been set, return 1
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
def enable_lora_hotswap(self, **kwargs) -> None:
"""Enables the possibility to hotswap LoRA adapters.
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
the loaded adapters differ.
Args:
target_rank (`int`):
The highest rank among all the adapters that will be loaded.
check_compiled (`str`, *optional*, defaults to `"error"`):
How to handle the case when the model is already compiled, which should generally be avoided. The
options are:
- "error" (default): raise an error
- "warn": issue a warning
- "ignore": do nothing
"""
for key, component in self.components.items():
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
component.enable_lora_hotswap(**kwargs)
@@ -17,7 +17,7 @@ from typing import List
import torch
from ..utils import is_peft_version, logging, state_dict_all_zero
from ...utils import is_peft_version, logging, state_dict_all_zero
logger = logging.get_logger(__name__)
File diff suppressed because it is too large Load Diff
+36 -897
View File
@@ -12,927 +12,66 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import os
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import safetensors
import torch
import torch.nn as nn
from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
delete_adapter_layers,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available,
is_peft_version,
is_transformers_available,
is_transformers_version,
logging,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
if is_transformers_available():
from transformers import PreTrainedModel
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
from ..utils import deprecate
from .lora.lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin # noqa: F401
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
"""
Fuses LoRAs for the text encoder.
from .lora.lora_base import fuse_text_encoder_lora
Args:
text_encoder (`torch.nn.Module`):
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
attribute.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
"""
merge_kwargs = {"safe_merge": safe_fusing}
deprecation_message = "Importing `fuse_text_encoder_lora()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import fuse_text_encoder_lora` instead."
deprecate("diffusers.loaders.lora_base.fuse_text_encoder_lora", "0.36", deprecation_message)
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
if lora_scale != 1.0:
module.scale_layer(lora_scale)
# For BC with previous PEFT versions, we need to check the signature
# of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs:
merge_kwargs["adapter_names"] = adapter_names
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
raise ValueError(
"The `adapter_names` argument is not supported with your PEFT version. "
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
)
module.merge(**merge_kwargs)
return fuse_text_encoder_lora(
text_encoder, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
def unfuse_text_encoder_lora(text_encoder):
"""
Unfuses LoRAs for the text encoder.
from .lora.lora_base import unfuse_text_encoder_lora
Args:
text_encoder (`torch.nn.Module`):
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
attribute.
"""
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
deprecation_message = "Importing `unfuse_text_encoder_lora()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import unfuse_text_encoder_lora` instead."
deprecate("diffusers.loaders.lora_base.unfuse_text_encoder_lora", "0.36", deprecation_message)
return unfuse_text_encoder_lora(text_encoder)
def set_adapters_for_text_encoder(
adapter_names: Union[List[str], str],
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
adapter_names,
text_encoder=None,
text_encoder_weights=None,
):
"""
Sets the adapter layers for the text encoder.
from .lora.lora_base import set_adapters_for_text_encoder
Args:
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
attribute.
text_encoder_weights (`List[float]`, *optional*):
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
"""
if text_encoder is None:
raise ValueError(
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
)
deprecation_message = "Importing `set_adapters_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import set_adapters_for_text_encoder` instead."
deprecate("diffusers.loaders.lora_base.set_adapters_for_text_encoder", "0.36", deprecation_message)
def process_weights(adapter_names, weights):
# Expand weights into a list, one entry per adapter
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)
if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
)
# Set None values to default of 1.0
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
weights = [w if w is not None else 1.0 for w in weights]
return weights
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
"""
Disables the LoRA layers for the text encoder.
Args:
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
attribute.
"""
if text_encoder is None:
raise ValueError("Text Encoder not found.")
set_adapter_layers(text_encoder, enabled=False)
def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
"""
Enables the LoRA layers for the text encoder.
Args:
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
attribute.
"""
if text_encoder is None:
raise ValueError("Text Encoder not found.")
set_adapter_layers(text_encoder, enabled=True)
def _remove_text_encoder_monkey_patch(text_encoder):
recurse_remove_peft_layers(text_encoder)
if getattr(text_encoder, "peft_config", None) is not None:
del text_encoder.peft_config
text_encoder._hf_peft_config_loaded = None
def _fetch_state_dict(
pretrained_model_name_or_path_or_dict,
weight_name,
use_safetensors,
local_files_only,
cache_dir,
force_download,
proxies,
token,
revision,
subfolder,
user_agent,
allow_pickle,
):
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
# Here we're relaxing the loading check to enable more Inference API
# friendliness where sometimes, it's not at all possible to automatically
# determine `weight_name`.
if weight_name is None:
weight_name = _best_guess_weight_name(
pretrained_model_name_or_path_or_dict,
file_extension=".safetensors",
local_files_only=local_files_only,
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
model_file = None
pass
if model_file is None:
if weight_name is None:
weight_name = _best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
return state_dict
def _best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
if local_files_only or HF_HUB_OFFLINE:
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
targeted_files = []
if os.path.isfile(pretrained_model_name_or_path_or_dict):
return
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
else:
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
if len(targeted_files) == 0:
return
# "scheduler" does not correspond to a LoRA checkpoint.
# "optimizer" does not correspond to a LoRA checkpoint
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
targeted_files = list(
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
return set_adapters_for_text_encoder(
adapter_names=adapter_names, text_encoder=text_encoder, text_encoder_weights=text_encoder_weights
)
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
if len(targeted_files) > 1:
raise ValueError(
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
)
weight_name = targeted_files[0]
return weight_name
def disable_lora_for_text_encoder(text_encoder=None):
from .lora.lora_base import disable_lora_for_text_encoder
deprecation_message = "Importing `disable_lora_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import disable_lora_for_text_encoder` instead."
deprecate("diffusers.loaders.lora_base.disable_lora_for_text_encoder", "0.36", deprecation_message)
def _load_lora_into_text_encoder(
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
text_encoder_name="text_encoder",
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
return disable_lora_for_text_encoder(text_encoder=text_encoder)
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
from peft import LoraConfig
def enable_lora_for_text_encoder(text_encoder=None):
from .lora.lora_base import enable_lora_for_text_encoder
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes.
prefix = text_encoder_name if prefix is None else prefix
deprecation_message = "Importing `enable_lora_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import enable_lora_for_text_encoder` instead."
deprecate("diffusers.loaders.lora_base.enable_lora_for_text_encoder", "0.36", deprecation_message)
# Safe prefix to check with.
if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
return enable_lora_for_text_encoder(text_encoder=text_encoder)
# Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
state_dict = convert_state_dict_to_diffusers(state_dict)
# convert state dict
state_dict = convert_state_dict_to_peft(state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
if prefix is not None and not state_dict:
logger.warning(
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any "
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new"
)
def _func_optionally_disable_offloading(_pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and hasattr(_pipeline, "hf_device_map") and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
class LoraBaseMixin:
"""Utility class for handling LoRAs."""
_lora_loadable_modules = []
num_fused_loras = 0
def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.")
@classmethod
def save_lora_weights(cls, **kwargs):
raise NotImplementedError("`save_lora_weights()` not implemented.")
@classmethod
def lora_state_dict(cls, **kwargs):
raise NotImplementedError("`lora_state_dict()` is not implemented.")
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)
@classmethod
def _fetch_state_dict(cls, *args, **kwargs):
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
return _fetch_state_dict(*args, **kwargs)
@classmethod
def _best_guess_weight_name(cls, *args, **kwargs):
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
return _best_guess_weight_name(*args, **kwargs)
def unload_lora_weights(self):
"""
Unloads the LoRA parameters.
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
>>> pipeline.unload_lora_weights()
>>> ...
```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
if issubclass(model.__class__, ModelMixin):
model.unload_lora()
elif issubclass(model.__class__, PreTrainedModel):
_remove_text_encoder_monkey_patch(model)
def fuse_lora(
self,
components: List[str] = [],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
if "fuse_unet" in kwargs:
depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
deprecate(
"fuse_unet",
"1.0.0",
depr_message,
)
if "fuse_transformer" in kwargs:
depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
deprecate(
"fuse_transformer",
"1.0.0",
depr_message,
)
if "fuse_text_encoder" in kwargs:
depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
deprecate(
"fuse_text_encoder",
"1.0.0",
depr_message,
)
if len(components) == 0:
raise ValueError("`components` cannot be an empty list.")
for fuse_component in components:
if fuse_component not in self._lora_loadable_modules:
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
model = getattr(self, fuse_component, None)
if model is not None:
# check if diffusers model
if issubclass(model.__class__, ModelMixin):
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
# handle transformers models.
if issubclass(model.__class__, PreTrainedModel):
fuse_text_encoder_lora(
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
self.num_fused_loras += 1
def unfuse_lora(self, components: List[str] = [], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
if "unfuse_unet" in kwargs:
depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
deprecate(
"unfuse_unet",
"1.0.0",
depr_message,
)
if "unfuse_transformer" in kwargs:
depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
deprecate(
"unfuse_transformer",
"1.0.0",
depr_message,
)
if "unfuse_text_encoder" in kwargs:
depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
deprecate(
"unfuse_text_encoder",
"1.0.0",
depr_message,
)
if len(components) == 0:
raise ValueError("`components` cannot be an empty list.")
for fuse_component in components:
if fuse_component not in self._lora_loadable_modules:
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
model = getattr(self, fuse_component, None)
if model is not None:
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
self.num_fused_loras -= 1
def set_adapters(
self,
adapter_names: Union[List[str], str],
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):
if isinstance(adapter_weights, dict):
components_passed = set(adapter_weights.keys())
lora_components = set(self._lora_loadable_modules)
invalid_components = sorted(components_passed - lora_components)
if invalid_components:
logger.warning(
f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
"to the invalid components will be removed and ignored."
)
adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
adapter_weights = copy.deepcopy(adapter_weights)
# Expand weights into a list, one entry per adapter
if not isinstance(adapter_weights, list):
adapter_weights = [adapter_weights] * len(adapter_names)
if len(adapter_names) != len(adapter_weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
)
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
# eg ["adapter1", "adapter2"]
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
missing_adapters = set(adapter_names) - all_adapters
if len(missing_adapters) > 0:
raise ValueError(
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
)
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
invert_list_adapters = {
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
for adapter in all_adapters
}
# Decompose weights into weights for denoiser and text encoders.
_component_adapter_weights = {}
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is None:
logger.warning(f"Model {component} not found in pipeline.")
continue
for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):
component_adapter_weights = weights.pop(component, None)
if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
logger.warning(
(
f"Lora weight dict for adapter '{adapter_name}' contains {component},"
f"but this will be ignored because {adapter_name} does not contain weights for {component}."
f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
)
)
else:
component_adapter_weights = weights
_component_adapter_weights.setdefault(component, [])
_component_adapter_weights[component].append(component_adapter_weights)
if issubclass(model.__class__, ModelMixin):
model.set_adapters(adapter_names, _component_adapter_weights[component])
elif issubclass(model.__class__, PreTrainedModel):
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
def disable_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
if issubclass(model.__class__, ModelMixin):
model.disable_lora()
elif issubclass(model.__class__, PreTrainedModel):
disable_lora_for_text_encoder(model)
def enable_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
if issubclass(model.__class__, ModelMixin):
model.enable_lora()
elif issubclass(model.__class__, PreTrainedModel):
enable_lora_for_text_encoder(model)
def delete_adapters(self, adapter_names: Union[List[str], str]):
"""
Args:
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
adapter_names (`Union[List[str], str]`):
The names of the adapter to delete. Can be a single string or a list of strings
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
if isinstance(adapter_names, str):
adapter_names = [adapter_names]
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
if issubclass(model.__class__, ModelMixin):
model.delete_adapters(adapter_names)
elif issubclass(model.__class__, PreTrainedModel):
for adapter_name in adapter_names:
delete_adapter_layers(model, adapter_name)
def get_active_adapters(self) -> List[str]:
"""
Gets the list of the current active adapters.
Example:
```python
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
).to("cuda")
pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
pipeline.get_active_adapters()
```
"""
if not USE_PEFT_BACKEND:
raise ValueError(
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
)
active_adapters = []
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None and issubclass(model.__class__, ModelMixin):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
break
return active_adapters
def get_list_adapters(self) -> Dict[str, List[str]]:
"""
Gets the current list of all available adapters in the pipeline.
"""
if not USE_PEFT_BACKEND:
raise ValueError(
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
)
set_adapters = {}
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if (
model is not None
and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
and hasattr(model, "peft_config")
):
set_adapters[component] = list(model.peft_config.keys())
return set_adapters
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
"""
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
you want to load multiple adapters and free some GPU memory.
Args:
adapter_names (`List[str]`):
List of adapters to send device to.
device (`Union[torch.device, str, int]`):
Device to send the adapters to. Can be either a torch device, a str or an integer.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
for module in model.modules():
if isinstance(module, BaseTunerLayer):
for adapter_name in adapter_names:
module.lora_A[adapter_name].to(device)
module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
if adapter_name in module.lora_magnitude_vector:
module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
adapter_name
].to(device)
@staticmethod
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
@staticmethod
def write_lora_layers(
state_dict: Dict[str, torch.Tensor],
save_directory: str,
is_main_process: bool,
weight_name: str,
save_function: Callable,
safe_serialization: bool,
):
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
if weight_name is None:
if safe_serialization:
weight_name = LORA_WEIGHT_NAME_SAFE
else:
weight_name = LORA_WEIGHT_NAME
save_path = Path(save_directory, weight_name).as_posix()
save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}")
@property
def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline.
# if _lora_scale has not been set, return 1
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
def enable_lora_hotswap(self, **kwargs) -> None:
"""Enables the possibility to hotswap LoRA adapters.
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
the loaded adapters differ.
Args:
target_rank (`int`):
The highest rank among all the adapters that will be loaded.
check_compiled (`str`, *optional*, defaults to `"error"`):
How to handle the case when the model is already compiled, which should generally be avoided. The
options are:
- "error" (default): raise an error
- "warn": issue a warning
- "ignore": do nothing
"""
for key, component in self.components.items():
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
component.enable_lora_hotswap(**kwargs)
class LoraBaseMixin(LoraBaseMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `LoraBaseMixin` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import LoraBaseMixin` instead."
deprecate("diffusers.loaders.lora_base.LoraBaseMixin", "0.36", deprecation_message)
super().__init__(*args, **kwargs)
File diff suppressed because it is too large Load Diff
+3 -3
View File
@@ -35,8 +35,8 @@ from ..utils import (
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
from .unet_loader_utils import _maybe_expand_lora_scales
from .lora.lora_base import _fetch_state_dict, _func_optionally_disable_offloading
from .unet.unet_loader_utils import _maybe_expand_lora_scales
logger = logging.get_logger(__name__)
@@ -99,7 +99,7 @@ class PeftAdapterMixin:
_prepare_lora_hotswap_kwargs: Optional[dict] = None
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
# Copied from diffusers.loaders.lora.lora_base.LoraBaseMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
+24 -530
View File
@@ -11,42 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import os
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
from packaging import version
from typing_extensions import Self
from ..utils import deprecate, is_transformers_available, logging
from .single_file_utils import (
SingleFileComponentError,
_is_legacy_scheduler_kwargs,
_is_model_weights_in_cached_folder,
_legacy_load_clip_tokenizer,
_legacy_load_safety_checker,
_legacy_load_scheduler,
create_diffusers_clip_model_from_ldm,
create_diffusers_t5_model_from_checkpoint,
fetch_diffusers_config,
fetch_original_config,
is_clip_model_in_single_file,
is_t5_in_single_file,
load_single_file_checkpoint,
)
logger = logging.get_logger(__name__)
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
if is_transformers_available():
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer
from ..utils import deprecate
from .single_file.single_file import FromSingleFileMixin
def load_single_file_sub_model(
@@ -64,502 +30,30 @@ def load_single_file_sub_model(
disable_mmap=False,
**kwargs,
):
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
from .single_file.single_file import load_single_file_sub_model
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
transformers_version = "N/A"
deprecation_message = "Importing `load_single_file_sub_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file import load_single_file_sub_model` instead."
deprecate("diffusers.loaders.single_file.load_single_file_sub_model", "0.36", deprecation_message)
is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)
is_tokenizer = (
is_transformers_available()
and issubclass(class_obj, PreTrainedTokenizer)
and transformers_version >= version.parse("4.20.0")
return load_single_file_sub_model(
library_name,
class_name,
name,
checkpoint,
pipelines,
is_pipeline_module,
cached_model_config_path,
original_config,
local_files_only,
torch_dtype,
is_legacy_loading,
disable_mmap,
**kwargs,
)
diffusers_module = importlib.import_module(__name__.split(".")[0])
is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
if is_diffusers_single_file_model:
load_method = getattr(class_obj, "from_single_file")
# We cannot provide two different config options to the `from_single_file` method
# Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
if original_config:
cached_model_config_path = None
loaded_sub_model = load_method(
pretrained_model_link_or_path_or_dict=checkpoint,
original_config=original_config,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
disable_mmap=disable_mmap,
**kwargs,
)
elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
loaded_sub_model = create_diffusers_clip_model_from_ldm(
class_obj,
checkpoint=checkpoint,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading,
)
elif is_transformers_model and is_t5_in_single_file(checkpoint):
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
class_obj,
checkpoint=checkpoint,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
)
elif is_tokenizer and is_legacy_loading:
loaded_sub_model = _legacy_load_clip_tokenizer(
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
)
elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
loaded_sub_model = _legacy_load_scheduler(
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
)
else:
if not hasattr(class_obj, "from_pretrained"):
raise ValueError(
(
f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
" a supported loading method."
)
)
loading_kwargs = {}
loading_kwargs.update(
{
"pretrained_model_name_or_path": cached_model_config_path,
"subfolder": name,
"local_files_only": local_files_only,
}
)
# Schedulers and Tokenizers don't make use of torch_dtype
# Skip passing it to those objects
if issubclass(class_obj, torch.nn.Module):
loading_kwargs.update({"torch_dtype": torch_dtype})
if is_diffusers_model or is_transformers_model:
if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
raise SingleFileComponentError(
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
)
load_method = getattr(class_obj, "from_pretrained")
loaded_sub_model = load_method(**loading_kwargs)
return loaded_sub_model
def _map_component_types_to_config_dict(component_types):
diffusers_module = importlib.import_module(__name__.split(".")[0])
config_dict = {}
component_types.pop("self", None)
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
transformers_version = "N/A"
for component_name, component_value in component_types.items():
is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
is_transformers_model = (
is_transformers_available()
and issubclass(component_value[0], PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)
is_transformers_tokenizer = (
is_transformers_available()
and issubclass(component_value[0], PreTrainedTokenizer)
and transformers_version >= version.parse("4.20.0")
)
if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
config_dict[component_name] = ["diffusers", component_value[0].__name__]
elif is_scheduler_enum or is_scheduler:
if is_scheduler_enum:
# Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
# if the type hint is a KarrassDiffusionSchedulers enum
config_dict[component_name] = ["diffusers", "DDIMScheduler"]
elif is_scheduler:
config_dict[component_name] = ["diffusers", component_value[0].__name__]
elif (
is_transformers_model or is_transformers_tokenizer
) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
config_dict[component_name] = ["transformers", component_value[0].__name__]
else:
config_dict[component_name] = [None, None]
return config_dict
def _infer_pipeline_config_dict(pipeline_class):
parameters = inspect.signature(pipeline_class.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
component_types = pipeline_class._get_signature_types()
# Ignore parameters that are not required for the pipeline
component_types = {k: v for k, v in component_types.items() if k in required_parameters}
config_dict = _map_component_types_to_config_dict(component_types)
return config_dict
def _download_diffusers_model_config_from_hub(
pretrained_model_name_or_path,
cache_dir,
revision,
proxies,
force_download=None,
local_files_only=None,
token=None,
):
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
cached_model_path = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
local_files_only=local_files_only,
token=token,
allow_patterns=allow_patterns,
)
return cached_model_path
class FromSingleFileMixin:
"""
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
r"""
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
original_config_file (`str`, *optional*):
The path to the original config file that was used to train the model. If not provided, the config file
will be inferred from the checkpoint file.
config (`str`, *optional*):
Can be either:
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
hosted on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
component configs in Diffusers format.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
below for more information.
Examples:
```py
>>> from diffusers import StableDiffusionPipeline
>>> # Download pipeline from huggingface.co and cache.
>>> pipeline = StableDiffusionPipeline.from_single_file(
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
... )
>>> # Download pipeline from local file
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
>>> # Enable float16 and move to GPU
>>> pipeline = StableDiffusionPipeline.from_single_file(
... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
... torch_dtype=torch.float16,
... )
>>> pipeline.to("cuda")
```
"""
original_config_file = kwargs.pop("original_config_file", None)
config = kwargs.pop("config", None)
original_config = kwargs.pop("original_config", None)
if original_config_file is not None:
deprecation_message = (
"`original_config_file` argument is deprecated and will be removed in future versions."
"please use the `original_config` argument instead."
)
deprecate("original_config_file", "1.0.0", deprecation_message)
original_config = original_config_file
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
disable_mmap = kwargs.pop("disable_mmap", False)
is_legacy_loading = False
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)
# We shouldn't allow configuring individual models components through a Pipeline creation method
# These model kwargs should be deprecated
scaling_factor = kwargs.get("scaling_factor", None)
if scaling_factor is not None:
deprecation_message = (
"Passing the `scaling_factor` argument to `from_single_file is deprecated "
"and will be ignored in future versions."
)
deprecate("scaling_factor", "1.0.0", deprecation_message)
if original_config is not None:
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
from ..pipelines.pipeline_utils import _get_pipeline_class
pipeline_class = _get_pipeline_class(cls, config=None)
checkpoint = load_single_file_checkpoint(
pretrained_model_link_or_path,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
)
if config is None:
config = fetch_diffusers_config(checkpoint)
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
else:
default_pretrained_model_config_name = config
if not os.path.isdir(default_pretrained_model_config_name):
# Provided config is a repo_id
if default_pretrained_model_config_name.count("/") > 1:
raise ValueError(
f'The provided config "{config}"'
" is neither a valid local path nor a valid repo id. Please check the parameter."
)
try:
# Attempt to download the config files for the pipeline
cached_model_config_path = _download_diffusers_model_config_from_hub(
default_pretrained_model_config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
local_files_only=local_files_only,
token=token,
)
config_dict = pipeline_class.load_config(cached_model_config_path)
except LocalEntryNotFoundError:
# `local_files_only=True` but a local diffusers format model config is not available in the cache
# If `original_config` is not provided, we need override `local_files_only` to False
# to fetch the config files from the hub so that we have a way
# to configure the pipeline components.
if original_config is None:
logger.warning(
"`local_files_only` is True but no local configs were found for this checkpoint.\n"
"Attempting to download the necessary config files for this pipeline.\n"
)
cached_model_config_path = _download_diffusers_model_config_from_hub(
default_pretrained_model_config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
local_files_only=False,
token=token,
)
config_dict = pipeline_class.load_config(cached_model_config_path)
else:
# For backwards compatibility
# If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
logger.warning(
"Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
"This may lead to errors if the model components are not correctly inferred. \n"
"To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
"e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
"or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
"the necessary config files.\n"
)
is_legacy_loading = True
cached_model_config_path = None
config_dict = _infer_pipeline_config_dict(pipeline_class)
config_dict["_class_name"] = pipeline_class.__name__
else:
# Provided config is a path to a local directory attempt to load directly.
cached_model_config_path = default_pretrained_model_config_name
config_dict = pipeline_class.load_config(cached_model_config_path)
# pop out "_ignore_files" as it is only needed for download
config_dict.pop("_ignore_files", None)
expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
from diffusers import pipelines
# remove `null` components
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
for name, (library_name, class_name) in logging.tqdm(
sorted(init_dict.items()), desc="Loading pipeline components..."
):
loaded_sub_model = None
is_pipeline_module = hasattr(pipelines, library_name)
if name in passed_class_obj:
loaded_sub_model = passed_class_obj[name]
else:
try:
loaded_sub_model = load_single_file_sub_model(
library_name=library_name,
class_name=class_name,
name=name,
checkpoint=checkpoint,
is_pipeline_module=is_pipeline_module,
cached_model_config_path=cached_model_config_path,
pipelines=pipelines,
torch_dtype=torch_dtype,
original_config=original_config,
local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading,
disable_mmap=disable_mmap,
**kwargs,
)
except SingleFileComponentError as e:
raise SingleFileComponentError(
(
f"{e.message}\n"
f"Please load the component before passing it in as an argument to `from_single_file`.\n"
f"\n"
f"{name} = {class_name}.from_pretrained('...')\n"
f"pipe = {pipeline_class.__name__}.from_single_file(<checkpoint path>, {name}={name})\n"
f"\n"
)
)
init_kwargs[name] = loaded_sub_model
missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
for module in missing_modules:
init_kwargs[module] = passed_class_obj.get(module, None)
elif len(missing_modules) > 0:
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
# deprecated kwargs
load_safety_checker = kwargs.pop("load_safety_checker", None)
if load_safety_checker is not None:
deprecation_message = (
"Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
"using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
)
deprecate("load_safety_checker", "1.0.0", deprecation_message)
safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
init_kwargs.update(safety_checker_components)
pipe = pipeline_class(**init_kwargs)
return pipe
class FromSingleFileMixin(FromSingleFileMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `FromSingleFileMixin` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file import FromSingleFileMixin` instead."
deprecate("diffusers.loaders.single_file.FromSingleFileMixin", "0.36", deprecation_message)
super().__init__(*args, **kwargs)
@@ -0,0 +1,8 @@
from ...utils import is_torch_available, is_transformers_available
if is_torch_available():
from .single_file_model import FromOriginalModelMixin
if is_transformers_available():
from .single_file import FromSingleFileMixin
@@ -0,0 +1,565 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import os
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
from packaging import version
from typing_extensions import Self
from ...utils import deprecate, is_transformers_available, logging
from .single_file_utils import (
SingleFileComponentError,
_is_legacy_scheduler_kwargs,
_is_model_weights_in_cached_folder,
_legacy_load_clip_tokenizer,
_legacy_load_safety_checker,
_legacy_load_scheduler,
create_diffusers_clip_model_from_ldm,
create_diffusers_t5_model_from_checkpoint,
fetch_diffusers_config,
fetch_original_config,
is_clip_model_in_single_file,
is_t5_in_single_file,
load_single_file_checkpoint,
)
logger = logging.get_logger(__name__)
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
if is_transformers_available():
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer
def load_single_file_sub_model(
library_name,
class_name,
name,
checkpoint,
pipelines,
is_pipeline_module,
cached_model_config_path,
original_config=None,
local_files_only=False,
torch_dtype=None,
is_legacy_loading=False,
disable_mmap=False,
**kwargs,
):
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
transformers_version = "N/A"
is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)
is_tokenizer = (
is_transformers_available()
and issubclass(class_obj, PreTrainedTokenizer)
and transformers_version >= version.parse("4.20.0")
)
diffusers_module = importlib.import_module(__name__.split(".")[0])
is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
if is_diffusers_single_file_model:
load_method = getattr(class_obj, "from_single_file")
# We cannot provide two different config options to the `from_single_file` method
# Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
if original_config:
cached_model_config_path = None
loaded_sub_model = load_method(
pretrained_model_link_or_path_or_dict=checkpoint,
original_config=original_config,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
disable_mmap=disable_mmap,
**kwargs,
)
elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
loaded_sub_model = create_diffusers_clip_model_from_ldm(
class_obj,
checkpoint=checkpoint,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading,
)
elif is_transformers_model and is_t5_in_single_file(checkpoint):
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
class_obj,
checkpoint=checkpoint,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
)
elif is_tokenizer and is_legacy_loading:
loaded_sub_model = _legacy_load_clip_tokenizer(
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
)
elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
loaded_sub_model = _legacy_load_scheduler(
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
)
else:
if not hasattr(class_obj, "from_pretrained"):
raise ValueError(
(
f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
" a supported loading method."
)
)
loading_kwargs = {}
loading_kwargs.update(
{
"pretrained_model_name_or_path": cached_model_config_path,
"subfolder": name,
"local_files_only": local_files_only,
}
)
# Schedulers and Tokenizers don't make use of torch_dtype
# Skip passing it to those objects
if issubclass(class_obj, torch.nn.Module):
loading_kwargs.update({"torch_dtype": torch_dtype})
if is_diffusers_model or is_transformers_model:
if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
raise SingleFileComponentError(
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
)
load_method = getattr(class_obj, "from_pretrained")
loaded_sub_model = load_method(**loading_kwargs)
return loaded_sub_model
def _map_component_types_to_config_dict(component_types):
diffusers_module = importlib.import_module(__name__.split(".")[0])
config_dict = {}
component_types.pop("self", None)
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
transformers_version = "N/A"
for component_name, component_value in component_types.items():
is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
is_transformers_model = (
is_transformers_available()
and issubclass(component_value[0], PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)
is_transformers_tokenizer = (
is_transformers_available()
and issubclass(component_value[0], PreTrainedTokenizer)
and transformers_version >= version.parse("4.20.0")
)
if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
config_dict[component_name] = ["diffusers", component_value[0].__name__]
elif is_scheduler_enum or is_scheduler:
if is_scheduler_enum:
# Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
# if the type hint is a KarrassDiffusionSchedulers enum
config_dict[component_name] = ["diffusers", "DDIMScheduler"]
elif is_scheduler:
config_dict[component_name] = ["diffusers", component_value[0].__name__]
elif (
is_transformers_model or is_transformers_tokenizer
) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
config_dict[component_name] = ["transformers", component_value[0].__name__]
else:
config_dict[component_name] = [None, None]
return config_dict
def _infer_pipeline_config_dict(pipeline_class):
parameters = inspect.signature(pipeline_class.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
component_types = pipeline_class._get_signature_types()
# Ignore parameters that are not required for the pipeline
component_types = {k: v for k, v in component_types.items() if k in required_parameters}
config_dict = _map_component_types_to_config_dict(component_types)
return config_dict
def _download_diffusers_model_config_from_hub(
pretrained_model_name_or_path,
cache_dir,
revision,
proxies,
force_download=None,
local_files_only=None,
token=None,
):
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
cached_model_path = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
local_files_only=local_files_only,
token=token,
allow_patterns=allow_patterns,
)
return cached_model_path
class FromSingleFileMixin:
"""
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
r"""
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
original_config_file (`str`, *optional*):
The path to the original config file that was used to train the model. If not provided, the config file
will be inferred from the checkpoint file.
config (`str`, *optional*):
Can be either:
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
hosted on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
component configs in Diffusers format.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
below for more information.
Examples:
```py
>>> from diffusers import StableDiffusionPipeline
>>> # Download pipeline from huggingface.co and cache.
>>> pipeline = StableDiffusionPipeline.from_single_file(
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
... )
>>> # Download pipeline from local file
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
>>> # Enable float16 and move to GPU
>>> pipeline = StableDiffusionPipeline.from_single_file(
... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
... torch_dtype=torch.float16,
... )
>>> pipeline.to("cuda")
```
"""
original_config_file = kwargs.pop("original_config_file", None)
config = kwargs.pop("config", None)
original_config = kwargs.pop("original_config", None)
if original_config_file is not None:
deprecation_message = (
"`original_config_file` argument is deprecated and will be removed in future versions."
"please use the `original_config` argument instead."
)
deprecate("original_config_file", "1.0.0", deprecation_message)
original_config = original_config_file
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
disable_mmap = kwargs.pop("disable_mmap", False)
is_legacy_loading = False
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)
# We shouldn't allow configuring individual models components through a Pipeline creation method
# These model kwargs should be deprecated
scaling_factor = kwargs.get("scaling_factor", None)
if scaling_factor is not None:
deprecation_message = (
"Passing the `scaling_factor` argument to `from_single_file is deprecated "
"and will be ignored in future versions."
)
deprecate("scaling_factor", "1.0.0", deprecation_message)
if original_config is not None:
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
from ..pipelines.pipeline_utils import _get_pipeline_class
pipeline_class = _get_pipeline_class(cls, config=None)
checkpoint = load_single_file_checkpoint(
pretrained_model_link_or_path,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
)
if config is None:
config = fetch_diffusers_config(checkpoint)
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
else:
default_pretrained_model_config_name = config
if not os.path.isdir(default_pretrained_model_config_name):
# Provided config is a repo_id
if default_pretrained_model_config_name.count("/") > 1:
raise ValueError(
f'The provided config "{config}"'
" is neither a valid local path nor a valid repo id. Please check the parameter."
)
try:
# Attempt to download the config files for the pipeline
cached_model_config_path = _download_diffusers_model_config_from_hub(
default_pretrained_model_config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
local_files_only=local_files_only,
token=token,
)
config_dict = pipeline_class.load_config(cached_model_config_path)
except LocalEntryNotFoundError:
# `local_files_only=True` but a local diffusers format model config is not available in the cache
# If `original_config` is not provided, we need override `local_files_only` to False
# to fetch the config files from the hub so that we have a way
# to configure the pipeline components.
if original_config is None:
logger.warning(
"`local_files_only` is True but no local configs were found for this checkpoint.\n"
"Attempting to download the necessary config files for this pipeline.\n"
)
cached_model_config_path = _download_diffusers_model_config_from_hub(
default_pretrained_model_config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
local_files_only=False,
token=token,
)
config_dict = pipeline_class.load_config(cached_model_config_path)
else:
# For backwards compatibility
# If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
logger.warning(
"Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
"This may lead to errors if the model components are not correctly inferred. \n"
"To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
"e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
"or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
"the necessary config files.\n"
)
is_legacy_loading = True
cached_model_config_path = None
config_dict = _infer_pipeline_config_dict(pipeline_class)
config_dict["_class_name"] = pipeline_class.__name__
else:
# Provided config is a path to a local directory attempt to load directly.
cached_model_config_path = default_pretrained_model_config_name
config_dict = pipeline_class.load_config(cached_model_config_path)
# pop out "_ignore_files" as it is only needed for download
config_dict.pop("_ignore_files", None)
expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
from diffusers import pipelines
# remove `null` components
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
for name, (library_name, class_name) in logging.tqdm(
sorted(init_dict.items()), desc="Loading pipeline components..."
):
loaded_sub_model = None
is_pipeline_module = hasattr(pipelines, library_name)
if name in passed_class_obj:
loaded_sub_model = passed_class_obj[name]
else:
try:
loaded_sub_model = load_single_file_sub_model(
library_name=library_name,
class_name=class_name,
name=name,
checkpoint=checkpoint,
is_pipeline_module=is_pipeline_module,
cached_model_config_path=cached_model_config_path,
pipelines=pipelines,
torch_dtype=torch_dtype,
original_config=original_config,
local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading,
disable_mmap=disable_mmap,
**kwargs,
)
except SingleFileComponentError as e:
raise SingleFileComponentError(
(
f"{e.message}\n"
f"Please load the component before passing it in as an argument to `from_single_file`.\n"
f"\n"
f"{name} = {class_name}.from_pretrained('...')\n"
f"pipe = {pipeline_class.__name__}.from_single_file(<checkpoint path>, {name}={name})\n"
f"\n"
)
)
init_kwargs[name] = loaded_sub_model
missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
for module in missing_modules:
init_kwargs[module] = passed_class_obj.get(module, None)
elif len(missing_modules) > 0:
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
# deprecated kwargs
load_safety_checker = kwargs.pop("load_safety_checker", None)
if load_safety_checker is not None:
deprecation_message = (
"Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
"using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
)
deprecate("load_safety_checker", "1.0.0", deprecation_message)
safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
init_kwargs.update(safety_checker_components)
pipe = pipeline_class(**init_kwargs)
return pipe
@@ -0,0 +1,440 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import re
from contextlib import nullcontext
from typing import Optional
import torch
from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self
from ... import __version__
from ...quantizers import DiffusersAutoQuantizer
from ...utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ltx_transformer_checkpoint_to_diffusers,
convert_ltx_vae_checkpoint_to_diffusers,
convert_lumina2_to_diffusers,
convert_mochi_transformer_checkpoint_to_diffusers,
convert_sana_transformer_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
convert_wan_transformer_to_diffusers,
convert_wan_vae_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
fetch_diffusers_config,
fetch_original_config,
load_single_file_checkpoint,
)
logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate import dispatch_model, init_empty_weights
from ...models.modeling_utils import load_model_dict_into_meta
SINGLE_FILE_LOADABLE_CLASSES = {
"StableCascadeUNet": {
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
},
"UNet2DConditionModel": {
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
"default_subfolder": "unet",
"legacy_kwargs": {
"num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
},
},
"AutoencoderKL": {
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
"default_subfolder": "vae",
},
"ControlNetModel": {
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
},
"SD3Transformer2DModel": {
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"MotionAdapter": {
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
},
"SparseControlNetModel": {
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
},
"FluxTransformer2DModel": {
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"LTXVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLLTXVideo": {
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
"default_subfolder": "vae",
},
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
"MochiTransformer3DModel": {
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"HunyuanVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AuraFlowTransformer2DModel": {
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"Lumina2Transformer2DModel": {
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
"default_subfolder": "transformer",
},
"SanaTransformer2DModel": {
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"WanTransformer3DModel": {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLWan": {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
},
}
def _get_single_file_loadable_mapping_class(cls):
diffusers_module = importlib.import_module(__name__.split(".")[0])
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
loadable_class = getattr(diffusers_module, loadable_class_str)
if issubclass(cls, loadable_class):
return loadable_class_str
return None
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
parameters = inspect.signature(mapping_fn).parameters
mapping_kwargs = {}
for parameter in parameters:
if parameter in kwargs:
mapping_kwargs[parameter] = kwargs[parameter]
return mapping_kwargs
class FromOriginalModelMixin:
"""
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
r"""
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path_or_dict (`str`, *optional*):
Can be either:
- A link to the `.safetensors` or `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.
- A path to a local *file* containing the weights of the component model.
- A state dict containing the component model weights.
config (`str`, *optional*):
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
configs in Diffusers format.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
original_config (`str`, *optional*):
Dict or path to a yaml file containing the configuration for the model in its original format.
If a dict is provided, it will be used to initialize the model configuration.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.
```py
>>> from diffusers import StableCascadeUNet
>>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
>>> model = StableCascadeUNet.from_single_file(ckpt_path)
```
"""
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
if mapping_class_name is None:
raise ValueError(
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
)
pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
if pretrained_model_link_or_path is not None:
deprecation_message = (
"Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
)
deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
config = kwargs.pop("config", None)
original_config = kwargs.pop("original_config", None)
if config is not None and original_config is not None:
raise ValueError(
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
if quantization_config is not None:
user_agent["quant"] = quantization_config.quant_method.value
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)
if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
else:
checkpoint = load_single_file_checkpoint(
pretrained_model_link_or_path_or_dict,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
user_agent=user_agent,
)
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
hf_quantizer.validate_environment()
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
else:
hf_quantizer = None
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
if original_config is not None:
if "config_mapping_fn" in mapping_functions:
config_mapping_fn = mapping_functions["config_mapping_fn"]
else:
config_mapping_fn = None
if config_mapping_fn is None:
raise ValueError(
(
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
"was found to convert the original config to a Diffusers config in"
"`diffusers.loaders.single_file_utils`"
)
)
if isinstance(original_config, str):
# If original_config is a URL or filepath fetch the original_config dict
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
diffusers_model_config = config_mapping_fn(
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
)
else:
if config is not None:
if isinstance(config, str):
default_pretrained_model_config_name = config
else:
raise ValueError(
(
"Invalid `config` argument. Please provide a string representing a repo id"
"or path to a local Diffusers model repo."
)
)
else:
config = fetch_diffusers_config(checkpoint)
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
if "default_subfolder" in mapping_functions:
subfolder = mapping_functions["default_subfolder"]
subfolder = subfolder or config.pop(
"subfolder", None
) # some configs contain a subfolder key, e.g. StableCascadeUNet
diffusers_model_config = cls.load_config(
pretrained_model_name_or_path=default_pretrained_model_config_name,
subfolder=subfolder,
local_files_only=local_files_only,
token=token,
revision=config_revision,
)
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
# Map legacy kwargs to new kwargs
if "legacy_kwargs" in mapping_functions:
legacy_kwargs = mapping_functions["legacy_kwargs"]
for legacy_key, new_key in legacy_kwargs.items():
if legacy_key in kwargs:
kwargs[new_key] = kwargs.pop(legacy_key)
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
diffusers_model_config.update(model_kwargs)
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = cls._keep_in_fp32_modules
if not isinstance(keep_in_fp32_modules, list):
keep_in_fp32_modules = [keep_in_fp32_modules]
else:
keep_in_fp32_modules = []
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model,
device_map=None,
state_dict=diffusers_format_checkpoint,
keep_in_fp32_modules=keep_in_fp32_modules,
)
device_map = None
if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu")
empty_state_dict = model.state_dict()
unexpected_keys = [
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
]
device_map = {"": param_device}
load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device_map=device_map,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
model.hf_quantizer = hf_quantizer
if torch_dtype is not None and hf_quantizer is None:
model.to(torch_dtype)
model.eval()
if device_map is not None:
device_map_kwargs = {"device_map": device_map}
dispatch_model(model, **device_map_kwargs)
return model
File diff suppressed because it is too large Load Diff
+9 -422
View File
@@ -11,430 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import re
from contextlib import nullcontext
from typing import Optional
import torch
from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ltx_transformer_checkpoint_to_diffusers,
convert_ltx_vae_checkpoint_to_diffusers,
convert_lumina2_to_diffusers,
convert_mochi_transformer_checkpoint_to_diffusers,
convert_sana_transformer_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
convert_wan_transformer_to_diffusers,
convert_wan_vae_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
fetch_diffusers_config,
fetch_original_config,
load_single_file_checkpoint,
from ..utils import deprecate
from .single_file.single_file_model import (
SINGLE_FILE_LOADABLE_CLASSES, # noqa: F401
FromOriginalModelMixin,
)
logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate import dispatch_model, init_empty_weights
from ..models.modeling_utils import load_model_dict_into_meta
SINGLE_FILE_LOADABLE_CLASSES = {
"StableCascadeUNet": {
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
},
"UNet2DConditionModel": {
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
"default_subfolder": "unet",
"legacy_kwargs": {
"num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
},
},
"AutoencoderKL": {
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
"default_subfolder": "vae",
},
"ControlNetModel": {
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
},
"SD3Transformer2DModel": {
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"MotionAdapter": {
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
},
"SparseControlNetModel": {
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
},
"FluxTransformer2DModel": {
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"LTXVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLLTXVideo": {
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
"default_subfolder": "vae",
},
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
"MochiTransformer3DModel": {
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"HunyuanVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AuraFlowTransformer2DModel": {
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"Lumina2Transformer2DModel": {
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
"default_subfolder": "transformer",
},
"SanaTransformer2DModel": {
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"WanTransformer3DModel": {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLWan": {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
},
}
def _get_single_file_loadable_mapping_class(cls):
diffusers_module = importlib.import_module(__name__.split(".")[0])
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
loadable_class = getattr(diffusers_module, loadable_class_str)
if issubclass(cls, loadable_class):
return loadable_class_str
return None
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
parameters = inspect.signature(mapping_fn).parameters
mapping_kwargs = {}
for parameter in parameters:
if parameter in kwargs:
mapping_kwargs[parameter] = kwargs[parameter]
return mapping_kwargs
class FromOriginalModelMixin:
"""
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
r"""
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path_or_dict (`str`, *optional*):
Can be either:
- A link to the `.safetensors` or `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.
- A path to a local *file* containing the weights of the component model.
- A state dict containing the component model weights.
config (`str`, *optional*):
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
configs in Diffusers format.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
original_config (`str`, *optional*):
Dict or path to a yaml file containing the configuration for the model in its original format.
If a dict is provided, it will be used to initialize the model configuration.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.
```py
>>> from diffusers import StableCascadeUNet
>>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
>>> model = StableCascadeUNet.from_single_file(ckpt_path)
```
"""
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
if mapping_class_name is None:
raise ValueError(
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
)
pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
if pretrained_model_link_or_path is not None:
deprecation_message = (
"Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
)
deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
config = kwargs.pop("config", None)
original_config = kwargs.pop("original_config", None)
if config is not None and original_config is not None:
raise ValueError(
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
if quantization_config is not None:
user_agent["quant"] = quantization_config.quant_method.value
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)
if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
else:
checkpoint = load_single_file_checkpoint(
pretrained_model_link_or_path_or_dict,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
user_agent=user_agent,
)
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
hf_quantizer.validate_environment()
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
else:
hf_quantizer = None
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
if original_config is not None:
if "config_mapping_fn" in mapping_functions:
config_mapping_fn = mapping_functions["config_mapping_fn"]
else:
config_mapping_fn = None
if config_mapping_fn is None:
raise ValueError(
(
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
"was found to convert the original config to a Diffusers config in"
"`diffusers.loaders.single_file_utils`"
)
)
if isinstance(original_config, str):
# If original_config is a URL or filepath fetch the original_config dict
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
diffusers_model_config = config_mapping_fn(
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
)
else:
if config is not None:
if isinstance(config, str):
default_pretrained_model_config_name = config
else:
raise ValueError(
(
"Invalid `config` argument. Please provide a string representing a repo id"
"or path to a local Diffusers model repo."
)
)
else:
config = fetch_diffusers_config(checkpoint)
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
if "default_subfolder" in mapping_functions:
subfolder = mapping_functions["default_subfolder"]
subfolder = subfolder or config.pop(
"subfolder", None
) # some configs contain a subfolder key, e.g. StableCascadeUNet
diffusers_model_config = cls.load_config(
pretrained_model_name_or_path=default_pretrained_model_config_name,
subfolder=subfolder,
local_files_only=local_files_only,
token=token,
revision=config_revision,
)
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
# Map legacy kwargs to new kwargs
if "legacy_kwargs" in mapping_functions:
legacy_kwargs = mapping_functions["legacy_kwargs"]
for legacy_key, new_key in legacy_kwargs.items():
if legacy_key in kwargs:
kwargs[new_key] = kwargs.pop(legacy_key)
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
diffusers_model_config.update(model_kwargs)
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = cls._keep_in_fp32_modules
if not isinstance(keep_in_fp32_modules, list):
keep_in_fp32_modules = [keep_in_fp32_modules]
else:
keep_in_fp32_modules = []
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model,
device_map=None,
state_dict=diffusers_format_checkpoint,
keep_in_fp32_modules=keep_in_fp32_modules,
)
device_map = None
if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu")
empty_state_dict = model.state_dict()
unexpected_keys = [
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
]
device_map = {"": param_device}
load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device_map=device_map,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
model.hf_quantizer = hf_quantizer
if torch_dtype is not None and hf_quantizer is None:
model.to(torch_dtype)
model.eval()
if device_map is not None:
device_map_kwargs = {"device_map": device_map}
dispatch_model(model, **device_map_kwargs)
return model
class FromOriginalModelMixin(FromOriginalModelMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `FromOriginalModelMixin` from diffusers.loaders.single_file_model has been deprecated. Please use `from diffusers.loaders.single_file.single_file_model import FromOriginalModelMixin` instead."
deprecate("diffusers.loaders.single_file_model.FromOriginalModelMixin", "0.36", deprecation_message)
super().__init__(*args, **kwargs)
File diff suppressed because it is too large Load Diff
+7 -164
View File
@@ -11,170 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from ..models.embeddings import (
ImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
is_accelerate_available,
is_torch_version,
logging,
)
from ..utils import deprecate
from .ip_adapter.transformer_flux import FluxTransformer2DLoadersMixin
if is_accelerate_available():
pass
logger = logging.get_logger(__name__)
class FluxTransformer2DLoadersMixin:
"""
Load layers into a [`FluxTransformer2DModel`].
"""
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
updated_state_dict = {}
image_projection = None
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
if "proj.weight" in state_dict:
# IP-Adapter
num_image_text_embeds = 4
if state_dict["proj.weight"].shape[0] == 65536:
num_image_text_embeds = 16
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
with init_context():
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
for key, value in state_dict.items():
diffusers_name = key.replace("proj", "image_embeds")
updated_state_dict[diffusers_name] = value
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import (
FluxIPAdapterJointAttnProcessor2_0,
)
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 0
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
for name in self.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
attn_processor_class = self.attn_processors[name].__class__
attn_procs[name] = attn_processor_class()
else:
cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
num_image_text_embed = 4
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
num_image_text_embed = 16
# IP-Adapter
num_image_text_embeds += [num_image_text_embed]
with init_context():
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
dtype=self.dtype,
device=self.device,
)
value_dict = {}
for i, state_dict in enumerate(state_dicts):
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(value_dict)
else:
device_map = {"": self.device}
dtype = self.dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
key_id += 1
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]
self.encoder_hid_proj = None
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
self.set_attn_processor(attn_procs)
image_projection_layers = []
for state_dict in state_dicts:
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
)
image_projection_layers.append(image_projection_layer)
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.config.encoder_hid_dim_type = "ip_image_proj"
class FluxTransformer2DLoadersMixin(FluxTransformer2DLoadersMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `FluxTransformer2DLoadersMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.transformer_flux import FluxTransformer2DLoadersMixin` instead."
deprecate("diffusers.loaders.ip_adapter.FluxTransformer2DLoadersMixin", "0.36", deprecation_message)
super().__init__(*args, **kwargs)
+7 -155
View File
@@ -11,160 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from typing import Dict
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils import deprecate
from .ip_adapter.transformer_sd3 import SD3Transformer2DLoadersMixin
logger = logging.get_logger(__name__)
class SD3Transformer2DLoadersMixin:
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
def _convert_ip_adapter_attn_to_diffusers(
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
) -> Dict:
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
# IP-Adapter cross attention parameters
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1]
# Dict where key is transformer layer index, value is attention processor's state dict
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
for key, weights in state_dict.items():
idx, name = key.split(".", maxsplit=1)
layer_state_dict[int(idx)][name] = weights
# Create IP-Adapter attention processor & load state_dict
attn_procs = {}
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
for idx, name in enumerate(self.attn_processors.keys()):
with init_context():
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size,
ip_hidden_states_dim=ip_hidden_states_dim,
head_dim=self.config.attention_head_dim,
timesteps_emb_dim=timesteps_emb_dim,
)
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
else:
device_map = {"": self.device}
load_model_dict_into_meta(
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
)
return attn_procs
def _convert_ip_adapter_image_proj_to_diffusers(
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
) -> IPAdapterTimeImageProjection:
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
# Convert to diffusers
updated_state_dict = {}
for key, value in state_dict.items():
# InstantX/SD3.5-Large-IP-Adapter
if key.startswith("layers."):
idx = key.split(".")[1]
key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
updated_state_dict[key] = value
# Image projetion parameters
embed_dim = updated_state_dict["proj_in.weight"].shape[1]
output_dim = updated_state_dict["proj_out.weight"].shape[0]
hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
num_queries = updated_state_dict["latents"].shape[1]
timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1]
# Image projection
with init_context():
image_proj = IPAdapterTimeImageProjection(
embed_dim=embed_dim,
output_dim=output_dim,
hidden_dim=hidden_dim,
heads=heads,
num_queries=num_queries,
timestep_in_dim=timestep_in_dim,
)
if not low_cpu_mem_usage:
image_proj.load_state_dict(updated_state_dict, strict=True)
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
return image_proj
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
Args:
state_dict (`Dict`):
State dict with keys "ip_adapter", which contains parameters for attention processors, and
"image_proj", which contains parameters for image projection net.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage)
self.set_attn_processor(attn_procs)
self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage)
class SD3Transformer2DLoadersMixin(SD3Transformer2DLoadersMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `SD3Transformer2DLoadersMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.transformer_sd3 import SD3Transformer2DLoadersMixin` instead."
deprecate("diffusers.loaders.ip_adapter.SD3Transformer2DLoadersMixin", "0.36", deprecation_message)
super().__init__(*args, **kwargs)
+5
View File
@@ -0,0 +1,5 @@
from ...utils import is_torch_available
if is_torch_available():
from .unet import UNet2DConditionLoadersMixin
@@ -22,7 +22,7 @@ import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from ..models.embeddings import (
from ...models.embeddings import (
ImageProjection,
IPAdapterFaceIDImageProjection,
IPAdapterFaceIDPlusImageProjection,
@@ -30,8 +30,8 @@ from ..models.embeddings import (
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ..utils import (
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ...utils import (
USE_PEFT_BACKEND,
_get_model_file,
convert_unet_state_dict_to_peft,
@@ -43,9 +43,9 @@ from ..utils import (
is_torch_version,
logging,
)
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
from ..lora.lora_base import _func_optionally_disable_offloading
from ..lora.lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from ..utils import AttnProcsLayers
logger = logging.get_logger(__name__)
@@ -247,7 +247,7 @@ class UNet2DConditionLoadersMixin:
# Unsafe code />
def _process_custom_diffusion(self, state_dict):
from ..models.attention_processor import CustomDiffusionAttnProcessor
from ...models.attention_processor import CustomDiffusionAttnProcessor
attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict)
@@ -395,7 +395,7 @@ class UNet2DConditionLoadersMixin:
return is_model_cpu_offload, is_sequential_cpu_offload
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
# Copied from diffusers.loaders.lora.lora_base.LoraBaseMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
@@ -408,7 +408,6 @@ class UNet2DConditionLoadersMixin:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)
def save_attn_procs(
@@ -452,7 +451,7 @@ class UNet2DConditionLoadersMixin:
pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
```
"""
from ..models.attention_processor import (
from ...models.attention_processor import (
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
@@ -514,7 +513,7 @@ class UNet2DConditionLoadersMixin:
logger.info(f"Model weights saved in {save_path}")
def _get_custom_diffusion_state_dict(self):
from ..models.attention_processor import (
from ...models.attention_processor import (
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
@@ -760,7 +759,7 @@ class UNet2DConditionLoadersMixin:
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import (
from ...models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
@@ -14,12 +14,12 @@
import copy
from typing import TYPE_CHECKING, Dict, List, Union
from ..utils import logging
from ...utils import logging
if TYPE_CHECKING:
# import here to avoid circular imports
from ..models import UNet2DConditionModel
from ...models import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -17,8 +17,7 @@ import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import deprecate
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import (
@@ -21,7 +21,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...loaders import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
@@ -19,7 +19,7 @@ from torch import nn
from torch.nn import functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...loaders import FromOriginalModelMixin
from ...utils import BaseOutput, logging
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -17,7 +17,7 @@ import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...loaders import FromOriginalModelMixin
from ...utils import logging
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -1068,15 +1068,17 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
latent_sequence_length = hidden_states.shape[1]
condition_sequence_length = encoder_hidden_states.shape[1]
sequence_length = latent_sequence_length + condition_sequence_length
attention_mask = torch.ones(
attention_mask = torch.zeros(
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N]
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
indices = torch.arange(sequence_length, device=hidden_states.device).unsqueeze(0) # [1, N]
mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N]
attention_mask = attention_mask.masked_fill(mask_indices, False)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N]
for i in range(batch_size):
attention_mask[i, : effective_sequence_length[i]] = True
# [B, 1, 1, N], for broadcasting across attention heads
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -20,8 +20,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import LuminaFeedForward
from ..attention_processor import Attention
@@ -19,8 +19,7 @@ import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
@@ -19,8 +19,7 @@ import torch.nn as nn
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..activations import get_activation
from ..attention_processor import (
@@ -1,84 +0,0 @@
from typing import TYPE_CHECKING
from ..utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
_import_structure = {}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_pt_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
else:
_import_structure["modular_pipeline"] = [
"ModularPipelineBlocks",
"ModularPipeline",
"PipelineBlock",
"AutoPipelineBlocks",
"SequentialPipelineBlocks",
"LoopSequentialPipelineBlocks",
"ModularLoader",
"PipelineState",
"BlockState",
]
_import_structure["modular_pipeline_utils"] = [
"ComponentSpec",
"ConfigSpec",
"InputParam",
"OutputParam",
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoPipeline", "StableDiffusionXLModularLoader"]
_import_structure["components_manager"] = ["ComponentsManager"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,
LoopSequentialPipelineBlocks,
ModularLoader,
ModularPipelineBlocks,
ModularPipeline,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
)
from .modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
InputParam,
OutputParam,
)
from .stable_diffusion_xl import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLModularLoader,
)
from .components_manager import ComponentsManager
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -1,934 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from itertools import combinations
from typing import List, Optional, Union, Dict, Any
import copy
import torch
import time
from dataclasses import dataclass
from ..utils import (
is_accelerate_available,
logging,
)
from ..models.modeling_utils import ModelMixin
from .modular_pipeline_utils import ComponentSpec
import uuid
if is_accelerate_available():
from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module
from accelerate.state import PartialState
from accelerate.utils import send_to_device
from accelerate.utils.memory import clear_device_cache
from accelerate.utils.modeling import convert_file_size_to_int
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi Notes: copied from modeling_utils.py (decide later where to put this)
def get_memory_footprint(self, return_buffers=True):
r"""
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to
benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch
discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
Arguments:
return_buffers (`bool`, *optional*, defaults to `True`):
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are
tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm
layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
"""
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
if return_buffers:
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
mem = mem + mem_bufs
return mem
class CustomOffloadHook(ModelHook):
"""
A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are
on the given device. Optionally offloads other models to the CPU before the forward pass is called.
Args:
execution_device(`str`, `int` or `torch.device`, *optional*):
The device on which the model should be executed. Will default to the MPS device if it's available, then
GPU 0 if there is a GPU, and finally to the CPU.
"""
def __init__(
self,
execution_device: Optional[Union[str, int, torch.device]] = None,
other_hooks: Optional[List["UserCustomOffloadHook"]] = None,
offload_strategy: Optional["AutoOffloadStrategy"] = None,
):
self.execution_device = execution_device if execution_device is not None else PartialState().default_device
self.other_hooks = other_hooks
self.offload_strategy = offload_strategy
self.model_id = None
def set_strategy(self, offload_strategy: "AutoOffloadStrategy"):
self.offload_strategy = offload_strategy
def add_other_hook(self, hook: "UserCustomOffloadHook"):
"""
Add a hook to the list of hooks to consider for offloading.
"""
if self.other_hooks is None:
self.other_hooks = []
self.other_hooks.append(hook)
def init_hook(self, module):
return module.to("cpu")
def pre_forward(self, module, *args, **kwargs):
if module.device != self.execution_device:
if self.other_hooks is not None:
hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device]
# offload all other hooks
start_time = time.perf_counter()
if self.offload_strategy is not None:
hooks_to_offload = self.offload_strategy(
hooks=hooks_to_offload,
model_id=self.model_id,
model=module,
execution_device=self.execution_device,
)
end_time = time.perf_counter()
logger.info(
f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds"
)
for hook in hooks_to_offload:
logger.info(
f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu"
)
hook.offload()
if hooks_to_offload:
clear_device_cache()
module.to(self.execution_device)
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
class UserCustomOffloadHook:
"""
A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of
the hook or remove it entirely.
"""
def __init__(self, model_id, model, hook):
self.model_id = model_id
self.model = model
self.hook = hook
def offload(self):
self.hook.init_hook(self.model)
def attach(self):
add_hook_to_module(self.model, self.hook)
self.hook.model_id = self.model_id
def remove(self):
remove_hook_from_module(self.model)
self.hook.model_id = None
def add_other_hook(self, hook: "UserCustomOffloadHook"):
self.hook.add_other_hook(hook)
def custom_offload_with_hook(
model_id: str,
model: torch.nn.Module,
execution_device: Union[str, int, torch.device] = None,
offload_strategy: Optional["AutoOffloadStrategy"] = None,
):
hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy)
user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook)
user_hook.attach()
return user_hook
class AutoOffloadStrategy:
"""
Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on
the available memory on the device.
"""
def __init__(self, memory_reserve_margin="3GB"):
self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin)
def __call__(self, hooks, model_id, model, execution_device):
if len(hooks) == 0:
return []
current_module_size = get_memory_footprint(model)
mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
mem_on_device = mem_on_device - self.memory_reserve_margin
if current_module_size < mem_on_device:
return []
min_memory_offload = current_module_size - mem_on_device
logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory")
# exlucde models that's not currently loaded on the device
module_sizes = dict(
sorted(
{hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(),
key=lambda x: x[1],
reverse=True,
)
)
def search_best_candidate(module_sizes, min_memory_offload):
"""
search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a
minimum memory offload size. the combination of models should add up to the smallest modulesize that is
larger than `min_memory_offload`
"""
model_ids = list(module_sizes.keys())
best_candidate = None
best_size = float("inf")
for r in range(1, len(model_ids) + 1):
for candidate_model_ids in combinations(model_ids, r):
candidate_size = sum(
module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids
)
if candidate_size < min_memory_offload:
continue
else:
if best_candidate is None or candidate_size < best_size:
best_candidate = candidate_model_ids
best_size = candidate_size
return best_candidate
best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload)
if best_offload_model_ids is None:
# if no combination is found, meaning that we cannot meet the memory requirement, offload all models
logger.warning("no combination of models to offload to cpu is found, offloading all models")
hooks_to_offload = hooks
else:
hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids]
return hooks_to_offload
class ComponentsManager:
def __init__(self):
self.components = OrderedDict()
self.added_time = OrderedDict() # Store when components were added
self.collections = OrderedDict() # collection_name -> set of component_names
self.model_hooks = None
self._auto_offload_enabled = False
def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None):
"""
Lookup component_ids by name, collection, or load_id.
"""
if components is None:
components = self.components
if name:
ids_by_name = set()
for component_id, component in components.items():
comp_name = self._id_to_name(component_id)
if comp_name == name:
ids_by_name.add(component_id)
else:
ids_by_name = set(components.keys())
if collection:
ids_by_collection = set()
for component_id, component in components.items():
if component_id in self.collections[collection]:
ids_by_collection.add(component_id)
else:
ids_by_collection = set(components.keys())
if load_id:
ids_by_load_id = set()
for name, component in components.items():
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
ids_by_load_id.add(name)
else:
ids_by_load_id = set(components.keys())
ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id)
return ids
@staticmethod
def _id_to_name(component_id: str):
return "_".join(component_id.split("_")[:-1])
def add(self, name, component, collection: Optional[str] = None):
component_id = f"{name}_{uuid.uuid4()}"
# check for duplicated components
for comp_id, comp in self.components.items():
if comp == component:
comp_name = self._id_to_name(comp_id)
if comp_name == name:
logger.warning(
f"component '{name}' already exists as '{comp_id}'"
)
component_id = comp_id
break
else:
logger.warning(
f"Adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'"
f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
)
# check for duplicated load_id and warn (we do not delete for you)
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id)
components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id]
if components_with_same_load_id:
existing = ", ".join(components_with_same_load_id)
logger.warning(
f"Adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
)
# add component to components manager
self.components[component_id] = component
self.added_time[component_id] = time.time()
if collection:
if collection not in self.collections:
self.collections[collection] = set()
if not component_id in self.collections[collection]:
comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
for comp_id in comp_ids_in_collection:
logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}")
self.remove(comp_id)
self.collections[collection].add(component_id)
logger.info(f"Added component '{name}' in collection '{collection}': {component_id}")
else:
logger.info(f"Added component '{name}' as '{component_id}'")
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
return component_id
def remove(self, component_id: str = None):
if component_id not in self.components:
logger.warning(f"Component '{component_id}' not found in ComponentsManager")
return
component = self.components.pop(component_id)
self.added_time.pop(component_id)
for collection in self.collections:
if component_id in self.collections[collection]:
self.collections[collection].remove(component_id)
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
else:
if isinstance(component, torch.nn.Module):
component.to("cpu")
del component
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None,
as_name_component_tuples: bool = False):
"""
Select components by name with simple pattern matching.
Args:
names: Component name(s) or pattern(s)
Patterns:
- "unet" : match any component with base name "unet" (e.g., unet_123abc)
- "!unet" : everything except components with base name "unet"
- "unet*" : anything with base name starting with "unet"
- "!unet*" : anything with base name NOT starting with "unet"
- "*unet*" : anything with base name containing "unet"
- "!*unet*" : anything with base name NOT containing "unet"
- "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet"
- "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet"
- "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
collection: Optional collection to filter by
load_id: Optional load_id to filter by
as_name_component_tuples: If True, returns a list of (name, component) tuples using base names
instead of a dictionary with component IDs as keys
Returns:
Dictionary mapping component IDs to components,
or list of (base_name, component) tuples if as_name_component_tuples=True
"""
selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
components = {k: self.components[k] for k in selected_ids}
# Helper to extract base name from component_id
def get_base_name(component_id):
parts = component_id.split('_')
# If the last part looks like a UUID, remove it
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
return '_'.join(parts[:-1])
return component_id
if names is None:
if as_name_component_tuples:
return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()]
else:
return components
# Create mapping from component_id to base_name for all components
base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()}
def matches_pattern(component_id, pattern, exact_match=False):
"""
Helper function to check if a component matches a pattern based on its base name.
Args:
component_id: The component ID to check
pattern: The pattern to match against
exact_match: If True, only exact matches to base_name are considered
"""
base_name = base_names[component_id]
# Exact match with base name
if exact_match:
return pattern == base_name
# Prefix match (ends with *)
elif pattern.endswith('*'):
prefix = pattern[:-1]
return base_name.startswith(prefix)
# Contains match (starts with *)
elif pattern.startswith('*'):
search = pattern[1:-1] if pattern.endswith('*') else pattern[1:]
return search in base_name
# Exact match (no wildcards)
else:
return pattern == base_name
if isinstance(names, str):
# Check if this is a "not" pattern
is_not_pattern = names.startswith('!')
if is_not_pattern:
names = names[1:] # Remove the ! prefix
# Handle OR patterns (containing |)
if '|' in names:
terms = names.split('|')
matches = {}
for comp_id, comp in components.items():
# For OR patterns with exact names (no wildcards), we do exact matching on base names
exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms)
# Check if any of the terms match this component
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
# Flip the decision if this is a NOT pattern
if is_not_pattern:
should_include = not should_include
if should_include:
matches[comp_id] = comp
log_msg = "NOT " if is_not_pattern else ""
match_type = "exactly matching" if exact_match else "matching any of patterns"
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
# Try exact match with a base name
elif any(names == base_name for base_name in base_names.values()):
# Find all components with this base name
matches = {
comp_id: comp for comp_id, comp in components.items()
if (base_names[comp_id] == names) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
else:
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
# Prefix match (ends with *)
elif names.endswith('*'):
prefix = names[:-1]
matches = {
comp_id: comp for comp_id, comp in components.items()
if base_names[comp_id].startswith(prefix) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
else:
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
# Contains match (starts with *)
elif names.startswith('*'):
search = names[1:-1] if names.endswith('*') else names[1:]
matches = {
comp_id: comp for comp_id, comp in components.items()
if (search in base_names[comp_id]) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
else:
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
# Substring match (no wildcards, but not an exact component name)
elif any(names in base_name for base_name in base_names.values()):
matches = {
comp_id: comp for comp_id, comp in components.items()
if (names in base_names[comp_id]) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
else:
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
else:
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
if not matches:
raise ValueError(f"No components found matching pattern '{names}'")
if as_name_component_tuples:
return [(base_names[comp_id], comp) for comp_id, comp in matches.items()]
else:
return matches
elif isinstance(names, list):
results = {}
for name in names:
result = self.get(name, collection, load_id, as_name_component_tuples=False)
results.update(result)
if as_name_component_tuples:
return [(base_names[comp_id], comp) for comp_id, comp in results.items()]
else:
return results
else:
raise ValueError(f"Invalid type for names: {type(names)}")
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device]="cuda", memory_reserve_margin="3GB"):
for name, component in self.components.items():
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
remove_hook_from_module(component, recurse=True)
self.disable_auto_cpu_offload()
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
device = torch.device(device)
if device.index is None:
device = torch.device(f"{device.type}:{0}")
all_hooks = []
for name, component in self.components.items():
if isinstance(component, torch.nn.Module):
hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy)
all_hooks.append(hook)
for hook in all_hooks:
other_hooks = [h for h in all_hooks if h is not hook]
for other_hook in other_hooks:
if other_hook.hook.execution_device == hook.hook.execution_device:
hook.add_other_hook(other_hook)
self.model_hooks = all_hooks
self._auto_offload_enabled = True
self._auto_offload_device = device
def disable_auto_cpu_offload(self):
if self.model_hooks is None:
self._auto_offload_enabled = False
return
for hook in self.model_hooks:
hook.offload()
hook.remove()
if self.model_hooks:
clear_device_cache()
self.model_hooks = None
self._auto_offload_enabled = False
# YiYi TODO: add quantization info
def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]:
"""Get comprehensive information about a component.
Args:
component_id: Name of the component to get info for
fields: Optional field(s) to return. Can be a string for single field or list of fields.
If None, returns all fields.
Returns:
Dictionary containing requested component metadata.
If fields is specified, returns only those fields.
If a single field is requested as string, returns just that field's value.
"""
if component_id not in self.components:
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
component = self.components[component_id]
# Build complete info dict first
info = {
"model_id": component_id,
"added_time": self.added_time[component_id],
"collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) or None,
}
# Additional info for torch.nn.Module components
if isinstance(component, torch.nn.Module):
# Check for hook information
has_hook = hasattr(component, "_hf_hook")
execution_device = None
if has_hook and hasattr(component._hf_hook, "execution_device"):
execution_device = component._hf_hook.execution_device
info.update({
"class_name": component.__class__.__name__,
"size_gb": get_memory_footprint(component) / (1024**3),
"adapters": None, # Default to None
"has_hook": has_hook,
"execution_device": execution_device,
})
# Get adapters if applicable
if hasattr(component, "peft_config"):
info["adapters"] = list(component.peft_config.keys())
# Check for IP-Adapter scales
if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"):
processors = copy.deepcopy(component.attn_processors)
# First check if any processor is an IP-Adapter
processor_types = [v.__class__.__name__ for v in processors.values()]
if any("IPAdapter" in ptype for ptype in processor_types):
# Then get scales only from IP-Adapter processors
scales = {
k: v.scale
for k, v in processors.items()
if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
}
if scales:
info["ip_adapter"] = summarize_dict_by_value_and_parts(scales)
# If fields specified, filter info
if fields is not None:
if isinstance(fields, str):
# Single field requested, return just that value
return {fields: info.get(fields)}
else:
# List of fields requested, return dict with just those fields
return {k: v for k, v in info.items() if k in fields}
return info
def __repr__(self):
# Helper to get simple name without UUID
def get_simple_name(name):
# Extract the base name by splitting on underscore and taking first part
# This assumes names are in format "name_uuid"
parts = name.split('_')
# If we have at least 2 parts and the last part looks like a UUID, remove it
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
return '_'.join(parts[:-1])
return name
# Extract load_id if available
def get_load_id(component):
if hasattr(component, "_diffusers_load_id"):
return component._diffusers_load_id
return "N/A"
# Format device info compactly
def format_device(component, info):
if not info["has_hook"]:
return str(getattr(component, 'device', 'N/A'))
else:
device = str(getattr(component, 'device', 'N/A'))
exec_device = str(info['execution_device'] or 'N/A')
return f"{device}({exec_device})"
# Get all simple names to calculate width
simple_names = [get_simple_name(id) for id in self.components.keys()]
# Get max length of load_ids for models
load_ids = [
get_load_id(component)
for component in self.components.values()
if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
]
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
# Get all collections for each component
component_collections = {}
for name in self.components.keys():
component_collections[name] = []
for coll, comps in self.collections.items():
if name in comps:
component_collections[name].append(coll)
if not component_collections[name]:
component_collections[name] = ["N/A"]
# Find the maximum collection name length
all_collections = [coll for colls in component_collections.values() for coll in colls]
max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10
col_widths = {
"name": max(15, max(len(name) for name in simple_names)),
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
"device": 15, # Reduced since using more compact format
"dtype": 15,
"size": 10,
"load_id": max_load_id_len,
"collection": max_collection_len
}
# Create the header lines
sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
output = "Components:\n" + sep_line
# Separate components into models and others
models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)}
# Models section
if models:
output += "Models:\n" + dash_line
# Column headers
output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | "
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | "
output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n"
output += dash_line
# Model entries
for name, component in models.items():
info = self.get_model_info(name)
simple_name = get_simple_name(name)
device_str = format_device(component, info)
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
load_id = get_load_id(component)
# Print first collection on the main line
first_collection = component_collections[name][0] if component_collections[name] else "N/A"
output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | "
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n"
# Print additional collections on separate lines if they exist
for i in range(1, len(component_collections[name])):
collection = component_collections[name][i]
output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | "
output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | "
output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n"
output += dash_line
# Other components section
if others:
if models: # Add extra newline if we had models section
output += "\n"
output += "Other Components:\n" + dash_line
# Column headers for other components
output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n"
output += dash_line
# Other component entries
for name, component in others.items():
info = self.get_model_info(name)
simple_name = get_simple_name(name)
# Print first collection on the main line
first_collection = component_collections[name][0] if component_collections[name] else "N/A"
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n"
# Print additional collections on separate lines if they exist
for i in range(1, len(component_collections[name])):
collection = component_collections[name][i]
output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | {collection}\n"
output += dash_line
# Add additional component info
output += "\nAdditional Component Info:\n" + "=" * 50 + "\n"
for name in self.components:
info = self.get_model_info(name)
if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
simple_name = get_simple_name(name)
output += f"\n{simple_name}:\n"
if info.get("adapters") is not None:
output += f" Adapters: {info['adapters']}\n"
if info.get("ip_adapter"):
output += f" IP-Adapter: Enabled\n"
output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n"
return output
def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
"""
Load components from a pretrained model and add them to the manager.
Args:
pretrained_model_name_or_path (str): The path or identifier of the pretrained model
prefix (str, optional): Prefix to add to all component names loaded from this model.
If provided, components will be named as "{prefix}_{component_name}"
**kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained()
"""
subfolder = kwargs.pop("subfolder", None)
# YiYi TODO: extend AutoModel to support non-diffusers models
if subfolder:
from ..models import AutoModel
component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs)
component_name = f"{prefix}_{subfolder}" if prefix else subfolder
if component_name not in self.components:
self.add(component_name, component)
else:
logger.warning(
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
f"1. remove the existing component with remove('{component_name}')\n"
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
)
else:
from ..pipelines.pipeline_utils import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
for name, component in pipe.components.items():
if component is None:
continue
# Add prefix if specified
component_name = f"{prefix}_{name}" if prefix else name
if component_name not in self.components:
self.add(component_name, component)
else:
logger.warning(
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
f"1. remove the existing component with remove('{component_name}')\n"
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
)
def get_one(self, component_id: Optional[str] = None, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any:
"""
Get a single component by name. Raises an error if multiple components match or none are found.
Args:
name: Component name or pattern
collection: Optional collection to filter by
load_id: Optional load_id to filter by
Returns:
A single component
Raises:
ValueError: If no components match or multiple components match
"""
# if component_id is provided, return the component
if component_id is not None and (name is not None or collection is not None or load_id is not None):
raise ValueError(" if component_id is provided, name, collection, and load_id must be None")
elif component_id is not None:
if component_id not in self.components:
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
return self.components[component_id]
results = self.get(name, collection, load_id)
if not results:
raise ValueError(f"No components found matching '{name}'")
if len(results) > 1:
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
return next(iter(results.values()))
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
"""Summarizes a dictionary by finding common prefixes that share the same value.
For a dictionary with dot-separated keys like:
{
'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
}
Returns a dictionary where keys are the shortest common prefixes and values are their shared values:
{
'down_blocks': [0.6],
'up_blocks': [0.3]
}
"""
# First group by values - convert lists to tuples to make them hashable
value_to_keys = {}
for key, value in d.items():
value_tuple = tuple(value) if isinstance(value, list) else value
if value_tuple not in value_to_keys:
value_to_keys[value_tuple] = []
value_to_keys[value_tuple].append(key)
def find_common_prefix(keys: List[str]) -> str:
"""Find the shortest common prefix among a list of dot-separated keys."""
if not keys:
return ""
if len(keys) == 1:
return keys[0]
# Split all keys into parts
key_parts = [k.split('.') for k in keys]
# Find how many initial parts are common
common_length = 0
for parts in zip(*key_parts):
if len(set(parts)) == 1: # All parts at this position are the same
common_length += 1
else:
break
if common_length == 0:
return ""
# Return the common prefix
return '.'.join(key_parts[0][:common_length])
# Create summary by finding common prefixes for each value group
summary = {}
for value_tuple, keys in value_to_keys.items():
prefix = find_common_prefix(keys)
if prefix: # Only add if we found a common prefix
# Convert tuple back to list if it was originally a list
value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple
summary[prefix] = value
else:
summary[""] = value # Use empty string if no common prefix
return summary
File diff suppressed because it is too large Load Diff
@@ -1,616 +0,0 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import inspect
from dataclasses import dataclass, asdict, field, fields
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal
from ..utils.import_utils import is_torch_available
from ..configuration_utils import FrozenDict, ConfigMixin
from collections import OrderedDict
if is_torch_available():
import torch
class InsertableOrderedDict(OrderedDict):
def insert(self, key, value, index):
items = list(self.items())
# Remove key if it already exists to avoid duplicates
items = [(k, v) for k, v in items if k != key]
# Insert at the specified index
items.insert(index, (key, value))
# Clear and update self
self.clear()
self.update(items)
# Return self for method chaining
return self
# YiYi TODO:
# 1. validate the dataclass fields
# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained()
@dataclass
class ComponentSpec:
"""Specification for a pipeline component.
A component can be created in two ways:
1. From scratch using __init__ with a config dict
2. using `from_pretrained`
Attributes:
name: Name of the component
type_hint: Type of the component (e.g. UNet2DConditionModel)
description: Optional description of the component
config: Optional config dict for __init__ creation
repo: Optional repo path for from_pretrained creation
subfolder: Optional subfolder in repo
variant: Optional variant in repo
revision: Optional revision in repo
default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
"""
name: Optional[str] = None
type_hint: Optional[Type] = None
description: Optional[str] = None
config: Optional[FrozenDict[str, Any]] = None
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
subfolder: Optional[str] = field(default=None, metadata={"loading": True})
variant: Optional[str] = field(default=None, metadata={"loading": True})
revision: Optional[str] = field(default=None, metadata={"loading": True})
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
def __hash__(self):
"""Make ComponentSpec hashable, using load_id as the hash value."""
return hash((self.name, self.load_id, self.default_creation_method))
def __eq__(self, other):
"""Compare ComponentSpec objects based on name and load_id."""
if not isinstance(other, ComponentSpec):
return False
return (self.name == other.name and
self.load_id == other.load_id and
self.default_creation_method == other.default_creation_method)
@classmethod
def from_component(cls, name: str, component: Any) -> Any:
"""Create a ComponentSpec from a Component created by `create` or `load` method."""
if not hasattr(component, "_diffusers_load_id"):
raise ValueError("Component is not created by `create` or `load` method")
# throw a error if component is created with `create` method but not a subclass of ConfigMixin
# YiYi TODO: remove this check if we remove support for non configmixin in `create()` method
if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin):
raise ValueError(
"We currently only support creating ComponentSpec from a component with "
"created with `ComponentSpec.load` method"
"or created with `ComponentSpec.create` and a subclass of ConfigMixin"
)
type_hint = component.__class__
default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained"
if isinstance(component, ConfigMixin):
config = component.config
else:
config = None
load_spec = cls.decode_load_id(component._diffusers_load_id)
return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec)
@classmethod
def loading_fields(cls) -> List[str]:
"""
Return the names of all loadingrelated fields
(i.e. those whose field.metadata["loading"] is True).
"""
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
@property
def load_id(self) -> str:
"""
Unique identifier for this spec's pretrained load,
composed of repo|subfolder|variant|revision (no empty segments).
"""
parts = [getattr(self, k) for k in self.loading_fields()]
parts = ["null" if p is None else p for p in parts]
return "|".join(p for p in parts if p)
@classmethod
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
"""
Decode a load_id string back into a dictionary of loading fields and values.
Args:
load_id: The load_id string to decode, format: "repo|subfolder|variant|revision"
where None values are represented as "null"
Returns:
Dict mapping loading field names to their values. e.g.
{
"repo": "path/to/repo",
"subfolder": "subfolder",
"variant": "variant",
"revision": "revision"
}
If a segment value is "null", it's replaced with None.
Returns None if load_id is "null" (indicating component not created with `load` method).
"""
# Get all loading fields in order
loading_fields = cls.loading_fields()
result = {f: None for f in loading_fields}
if load_id == "null":
return result
# Split the load_id
parts = load_id.split("|")
# Map parts to loading fields by position
for i, part in enumerate(parts):
if i < len(loading_fields):
# Convert "null" string back to None
result[loading_fields[i]] = None if part == "null" else part
return result
# YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin)
# otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component)
# the config info is lost in the process
# remove error check in from_component spec and ModularLoader.update() if we remove support for non configmixin in `create()` method
def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
"""Create component using from_config with config."""
if self.type_hint is None or not isinstance(self.type_hint, type):
raise ValueError(
f"`type_hint` is required when using from_config creation method."
)
config = config or self.config or {}
if issubclass(self.type_hint, ConfigMixin):
component = self.type_hint.from_config(config, **kwargs)
else:
signature_params = inspect.signature(self.type_hint.__init__).parameters
init_kwargs = {}
for k, v in config.items():
if k in signature_params:
init_kwargs[k] = v
for k, v in kwargs.items():
if k in signature_params:
init_kwargs[k] = v
component = self.type_hint(**init_kwargs)
component._diffusers_load_id = "null"
if hasattr(component, "config"):
self.config = component.config
return component
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
def load(self, **kwargs) -> Any:
"""Load component using from_pretrained."""
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
# merge loading field value in the spec with user passed values to create load_kwargs
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
repo = load_kwargs.pop("repo", None)
if repo is None:
raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
if self.type_hint is None:
try:
from diffusers import AutoModel
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
# update type_hint if AutoModel load successfully
self.type_hint = component.__class__
else:
try:
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Unable to load {self.name} using load method: {e}")
self.repo = repo
for k, v in load_kwargs.items():
setattr(self, k, v)
component._diffusers_load_id = self.load_id
return component
@dataclass
class ConfigSpec:
"""Specification for a pipeline configuration parameter."""
name: str
default: Any
description: Optional[str] = None
# YiYi Notes: both inputs and intermediates_inputs are InputParam objects
# however some fields are not relevant for intermediates_inputs
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
# default is not used for intermediates_inputs, we only use default from inputs, so it is ignored if it is set for intermediates_inputs
# -> should we use different class for inputs and intermediates_inputs?
@dataclass
class InputParam:
"""Specification for an input parameter."""
name: str = None
type_hint: Any = None
default: Any = None
required: bool = False
description: str = ""
kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@dataclass
class OutputParam:
"""Specification for an output parameter."""
name: str
type_hint: Any = None
description: str = ""
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
def __repr__(self):
return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
def format_inputs_short(inputs):
"""
Format input parameters into a string representation, with required params first followed by optional ones.
Args:
inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params
Returns:
str: Formatted string of input parameters
Example:
>>> inputs = [
... InputParam(name="prompt", required=True),
... InputParam(name="image", required=True),
... InputParam(name="guidance_scale", required=False, default=7.5),
... InputParam(name="num_inference_steps", required=False, default=50)
... ]
>>> format_inputs_short(inputs)
'prompt, image, guidance_scale=7.5, num_inference_steps=50'
"""
required_inputs = [param for param in inputs if param.required]
optional_inputs = [param for param in inputs if not param.required]
required_str = ", ".join(param.name for param in required_inputs)
optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs)
inputs_str = required_str
if optional_str:
inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str
return inputs_str
def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs):
"""
Formats intermediate inputs and outputs of a block into a string representation.
Args:
intermediates_inputs: List of intermediate input parameters
required_intermediates_inputs: List of required intermediate input names
intermediates_outputs: List of intermediate output parameters
Returns:
str: Formatted string like:
Intermediates:
- inputs: Required(latents), dtype
- modified: latents # variables that appear in both inputs and outputs
- outputs: images # new outputs only
"""
# Handle inputs
input_parts = []
for inp in intermediates_inputs:
if inp.name in required_intermediates_inputs:
input_parts.append(f"Required({inp.name})")
else:
if inp.name is None and inp.kwargs_type is not None:
inp_name = "*_" + inp.kwargs_type
else:
inp_name = inp.name
input_parts.append(inp_name)
# Handle modified variables (appear in both inputs and outputs)
inputs_set = {inp.name for inp in intermediates_inputs}
modified_parts = []
new_output_parts = []
for out in intermediates_outputs:
if out.name in inputs_set:
modified_parts.append(out.name)
else:
new_output_parts.append(out.name)
result = []
if input_parts:
result.append(f" - inputs: {', '.join(input_parts)}")
if modified_parts:
result.append(f" - modified: {', '.join(modified_parts)}")
if new_output_parts:
result.append(f" - outputs: {', '.join(new_output_parts)}")
return "\n".join(result) if result else " (none)"
def format_params(params, header="Args", indent_level=4, max_line_length=115):
"""Format a list of InputParam or OutputParam objects into a readable string representation.
Args:
params: List of InputParam or OutputParam objects to format
header: Header text to use (e.g. "Args" or "Returns")
indent_level: Number of spaces to indent each parameter line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
Returns:
A formatted string representing all parameters
"""
if not params:
return ""
base_indent = " " * indent_level
param_indent = " " * (indent_level + 4)
desc_indent = " " * (indent_level + 8)
formatted_params = []
def get_type_str(type_hint):
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
return f"Union[{', '.join(types)}]"
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
def wrap_text(text, indent, max_length):
"""Wrap text while preserving markdown links and maintaining indentation."""
words = text.split()
lines = []
current_line = []
current_length = 0
for word in words:
word_length = len(word) + (1 if current_line else 0)
if current_line and current_length + word_length > max_length:
lines.append(" ".join(current_line))
current_line = [word]
current_length = len(word)
else:
current_line.append(word)
current_length += word_length
if current_line:
lines.append(" ".join(current_line))
return f"\n{indent}".join(lines)
# Add the header
formatted_params.append(f"{base_indent}{header}:")
for param in params:
# Format parameter name and type
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
# YiYi Notes: remove this line if we remove kwargs_type
name = f'**{param.kwargs_type}' if param.name is None and param.kwargs_type is not None else param.name
param_str = f"{param_indent}{name} (`{type_str}`"
# Add optional tag and default value if parameter is an InputParam and optional
if hasattr(param, "required"):
if not param.required:
param_str += ", *optional*"
if param.default is not None:
param_str += f", defaults to {param.default}"
param_str += "):"
# Add description on a new line with additional indentation and wrapping
if param.description:
desc = re.sub(
r'\[(.*?)\]\((https?://[^\s\)]+)\)',
r'[\1](\2)',
param.description
)
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
param_str += f"\n{desc_indent}{wrapped_desc}"
formatted_params.append(param_str)
return "\n\n".join(formatted_params)
def format_input_params(input_params, indent_level=4, max_line_length=115):
"""Format a list of InputParam objects into a readable string representation.
Args:
input_params: List of InputParam objects to format
indent_level: Number of spaces to indent each parameter line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
Returns:
A formatted string representing all input parameters
"""
return format_params(input_params, "Inputs", indent_level, max_line_length)
def format_output_params(output_params, indent_level=4, max_line_length=115):
"""Format a list of OutputParam objects into a readable string representation.
Args:
output_params: List of OutputParam objects to format
indent_level: Number of spaces to indent each parameter line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
Returns:
A formatted string representing all output parameters
"""
return format_params(output_params, "Outputs", indent_level, max_line_length)
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
"""Format a list of ComponentSpec objects into a readable string representation.
Args:
components: List of ComponentSpec objects to format
indent_level: Number of spaces to indent each component line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
add_empty_lines: Whether to add empty lines between components (default: True)
Returns:
A formatted string representing all components
"""
if not components:
return ""
base_indent = " " * indent_level
component_indent = " " * (indent_level + 4)
formatted_components = []
# Add the header
formatted_components.append(f"{base_indent}Components:")
if add_empty_lines:
formatted_components.append("")
# Add each component with optional empty lines between them
for i, component in enumerate(components):
# Get type name, handling special cases
type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
component_desc = f"{component_indent}{component.name} (`{type_name}`)"
if component.description:
component_desc += f": {component.description}"
# Get the loading fields dynamically
loading_field_values = []
for field_name in component.loading_fields():
field_value = getattr(component, field_name)
if field_value is not None:
loading_field_values.append(f"{field_name}={field_value}")
# Add loading field information if available
if loading_field_values:
component_desc += f" [{', '.join(loading_field_values)}]"
formatted_components.append(component_desc)
# Add an empty line after each component except the last one
if add_empty_lines and i < len(components) - 1:
formatted_components.append("")
return "\n".join(formatted_components)
def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True):
"""Format a list of ConfigSpec objects into a readable string representation.
Args:
configs: List of ConfigSpec objects to format
indent_level: Number of spaces to indent each config line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
add_empty_lines: Whether to add empty lines between configs (default: True)
Returns:
A formatted string representing all configs
"""
if not configs:
return ""
base_indent = " " * indent_level
config_indent = " " * (indent_level + 4)
formatted_configs = []
# Add the header
formatted_configs.append(f"{base_indent}Configs:")
if add_empty_lines:
formatted_configs.append("")
# Add each config with optional empty lines between them
for i, config in enumerate(configs):
config_desc = f"{config_indent}{config.name} (default: {config.default})"
if config.description:
config_desc += f": {config.description}"
formatted_configs.append(config_desc)
# Add an empty line after each config except the last one
if add_empty_lines and i < len(configs) - 1:
formatted_configs.append("")
return "\n".join(formatted_configs)
def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None):
"""
Generates a formatted documentation string describing the pipeline block's parameters and structure.
Args:
inputs: List of input parameters
intermediates_inputs: List of intermediate input parameters
outputs: List of output parameters
description (str, *optional*): Description of the block
class_name (str, *optional*): Name of the class to include in the documentation
expected_components (List[ComponentSpec], *optional*): List of expected components
expected_configs (List[ConfigSpec], *optional*): List of expected configurations
Returns:
str: A formatted string containing information about components, configs, call parameters,
intermediate inputs/outputs, and final outputs.
"""
output = ""
# Add class name if provided
if class_name:
output += f"class {class_name}\n\n"
# Add description
if description:
desc_lines = description.strip().split('\n')
aligned_desc = '\n'.join(' ' + line for line in desc_lines)
output += aligned_desc + "\n\n"
# Add components section if provided
if expected_components and len(expected_components) > 0:
components_str = format_components(expected_components, indent_level=2)
output += components_str + "\n\n"
# Add configs section if provided
if expected_configs and len(expected_configs) > 0:
configs_str = format_configs(expected_configs, indent_level=2)
output += configs_str + "\n\n"
# Add inputs section
output += format_input_params(inputs + intermediates_inputs, indent_level=2)
# Add outputs section
output += "\n\n"
output += format_output_params(outputs, indent_level=2)
return output
@@ -1,519 +0,0 @@
from ..configuration_utils import ConfigMixin
from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineBlocks
from .modular_pipeline_utils import InputParam, OutputParam
from ..image_processor import PipelineImageInput
from pathlib import Path
import json
import os
from typing import Union, List, Optional, Tuple
import torch
import PIL
import numpy as np
import logging
logger = logging.getLogger(__name__)
# YiYi Notes: this is actually for SDXL, put it here for now
SDXL_INPUTS_SCHEMA = {
"prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"),
"prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"),
"negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"),
"negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"),
"cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"),
"clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"),
"image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"),
"mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"),
"generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"),
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
"num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"),
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"),
"timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"),
"sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"),
"denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"),
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
"strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"),
"denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"),
"latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"),
"padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"),
"original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"),
"target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"),
"negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"),
"negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"),
"crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"),
"negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"),
"aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"),
"negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"),
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
"output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"),
"ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"),
"control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"),
"control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"),
"control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"),
"controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"),
"guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"),
"control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet")
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"),
"negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
"pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"),
"negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"),
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"),
"latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"),
"latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"),
"image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"),
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
"masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"),
"add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"),
"negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"),
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"),
"negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"),
"images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images")
}
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
DEFAULT_PARAM_MAPS = {
"prompt": {
"label": "Prompt",
"type": "string",
"default": "a bear sitting in a chair drinking a milkshake",
"display": "textarea",
},
"negative_prompt": {
"label": "Negative Prompt",
"type": "string",
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
"display": "textarea",
},
"num_inference_steps": {
"label": "Steps",
"type": "int",
"default": 25,
"min": 1,
"max": 1000,
},
"seed": {
"label": "Seed",
"type": "int",
"default": 0,
"min": 0,
"display": "random",
},
"width": {
"label": "Width",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"height": {
"label": "Height",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"images": {
"label": "Images",
"type": "image",
"display": "output",
},
"image": {
"label": "Image",
"type": "image",
"display": "input",
},
}
DEFAULT_TYPE_MAPS ={
"int": {
"type": "int",
"default": 0,
"min": 0,
},
"float": {
"type": "float",
"default": 0.0,
"min": 0.0,
},
"str": {
"type": "string",
"default": "",
},
"bool": {
"type": "boolean",
"default": False,
},
"image": {
"type": "image",
},
}
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
DEFAULT_CATEGORY = "Modular Diffusers"
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
DEFAULT_PARAMS_GROUPS_KEYS = {
"text_encoders": ["text_encoder", "tokenizer"],
"ip_adapter_embeds": ["ip_adapter_embeds"],
"prompt_embeddings": ["prompt_embeds"],
}
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
"""
Get the group name for a given parameter name, if not part of a group, return None
e.g. "prompt_embeds" -> "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
"""
if name is None:
return None
for group_name, group_keys in group_params_keys.items():
for group_key in group_keys:
if group_key in name:
return group_name
return None
class ModularNode(ConfigMixin):
config_name = "node_config.json"
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
trust_remote_code: Optional[bool] = None,
**kwargs,
):
blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
return cls(blocks, **kwargs)
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
self.blocks = blocks
if label is None:
label = self.blocks.__class__.__name__
# blocks param name -> mellon param name
self.name_mapping = {}
input_params = {}
# pass or create a default param dict for each input
# e.g. for prompt,
# prompt = {
# "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers
# "label": "Prompt",
# "type": "string",
# "default": "a bear sitting in a chair drinking a milkshake",
# "display": "textarea"}
# if type is not specified, it'll be a "custom" param of its own type
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
# it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
inputs = self.blocks.inputs + self.blocks.intermediates_inputs
for inp in inputs:
param = kwargs.pop(inp.name, None)
if param:
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
input_params[inp.name] = param
mellon_name = param.pop("name", inp.name)
if mellon_name != inp.name:
self.name_mapping[inp.name] = mellon_name
continue
if not inp.name in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
continue
if inp.name in DEFAULT_PARAM_MAPS:
# first check if it's in the default param map, if so, directly use that
param = DEFAULT_PARAM_MAPS[inp.name].copy()
elif get_group_name(inp.name):
param = get_group_name(inp.name)
if inp.name not in self.name_mapping:
self.name_mapping[inp.name] = param
else:
# if not, check if it's in the SDXL input schema, if so,
# 1. use the type hint to determine the type
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
if inp.type_hint is not None:
type_str = str(inp.type_hint).lower()
else:
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
if type_key in type_str:
param = type_param.copy()
param["label"] = inp.name
param["display"] = "input"
break
else:
param = inp.name
# add the param dict to the inp_params dict
input_params[inp.name] = param
component_params = {}
for comp in self.blocks.expected_components:
param = kwargs.pop(comp.name, None)
if param:
component_params[comp.name] = param
mellon_name = param.pop("name", comp.name)
if mellon_name != comp.name:
self.name_mapping[comp.name] = mellon_name
continue
to_exclude = False
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
if exclude_key in comp.name:
to_exclude = True
break
if to_exclude:
continue
if get_group_name(comp.name):
param = get_group_name(comp.name)
if comp.name not in self.name_mapping:
self.name_mapping[comp.name] = param
elif comp.name in DEFAULT_MODEL_KEYS:
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
else:
param = comp.name
# add the param dict to the model_params dict
component_params[comp.name] = param
output_params = {}
if isinstance(self.blocks, SequentialPipelineBlocks):
last_block_name = list(self.blocks.blocks.keys())[-1]
outputs = self.blocks.blocks[last_block_name].intermediates_outputs
else:
outputs = self.blocks.intermediates_outputs
for out in outputs:
param = kwargs.pop(out.name, None)
if param:
output_params[out.name] = param
mellon_name = param.pop("name", out.name)
if mellon_name != out.name:
self.name_mapping[out.name] = mellon_name
continue
if out.name in DEFAULT_PARAM_MAPS:
param = DEFAULT_PARAM_MAPS[out.name].copy()
param["display"] = "output"
else:
group_name = get_group_name(out.name)
if group_name:
param = group_name
if out.name not in self.name_mapping:
self.name_mapping[out.name] = param
else:
param = out.name
# add the param dict to the outputs dict
output_params[out.name] = param
if len(kwargs) > 0:
logger.warning(f"Unused kwargs: {kwargs}")
register_dict = {
"category": category,
"label": label,
"input_params": input_params,
"component_params": component_params,
"output_params": output_params,
"name_mapping": self.name_mapping,
}
self.register_to_config(**register_dict)
def setup(self, components, collection=None):
self.blocks.setup_loader(component_manager=components, collection=collection)
self._components_manager = components
@property
def mellon_config(self):
return self._convert_to_mellon_config()
def _convert_to_mellon_config(self):
node = {}
node["label"] = self.config.label
node["category"] = self.config.category
node_param = {}
for inp_name, inp_param in self.config.input_params.items():
if inp_name in self.name_mapping:
mellon_name = self.name_mapping[inp_name]
else:
mellon_name = inp_name
if isinstance(inp_param, str):
param = {
"label": inp_param,
"type": inp_param,
"display": "input",
}
else:
param = inp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
for comp_name, comp_param in self.config.component_params.items():
if comp_name in self.name_mapping:
mellon_name = self.name_mapping[comp_name]
else:
mellon_name = comp_name
if isinstance(comp_param, str):
param = {
"label": comp_param,
"type": comp_param,
"display": "input",
}
else:
param = comp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
for out_name, out_param in self.config.output_params.items():
if out_name in self.name_mapping:
mellon_name = self.name_mapping[out_name]
else:
mellon_name = out_name
if isinstance(out_param, str):
param = {
"label": out_param,
"type": out_param,
"display": "output",
}
else:
param = out_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
node["params"] = node_param
return node
def save_mellon_config(self, file_path):
"""
Save the Mellon configuration to a JSON file.
Args:
file_path (str or Path): Path where the JSON file will be saved
Returns:
Path: Path to the saved config file
"""
file_path = Path(file_path)
# Create directory if it doesn't exist
os.makedirs(file_path.parent, exist_ok=True)
# Create a combined dictionary with module definition and name mapping
config = {
"module": self.mellon_config,
"name_mapping": self.name_mapping
}
# Save the config to file
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2)
logger.info(f"Mellon config and name mapping saved to {file_path}")
return file_path
@classmethod
def load_mellon_config(cls, file_path):
"""
Load a Mellon configuration from a JSON file.
Args:
file_path (str or Path): Path to the JSON file containing Mellon config
Returns:
dict: The loaded combined configuration containing 'module' and 'name_mapping'
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"Config file not found: {file_path}")
with open(file_path, 'r', encoding='utf-8') as f:
config = json.load(f)
logger.info(f"Mellon config loaded from {file_path}")
return config
def process_inputs(self, **kwargs):
params_components = {}
for comp_name, comp_param in self.config.component_params.items():
logger.debug(f"component: {comp_name}")
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
if mellon_comp_name in kwargs:
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
comp = kwargs[mellon_comp_name].pop(comp_name)
else:
comp = kwargs.pop(mellon_comp_name)
if comp:
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
params_run = {}
for inp_name, inp_param in self.config.input_params.items():
logger.debug(f"input: {inp_name}")
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
if mellon_inp_name in kwargs:
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
inp = kwargs[mellon_inp_name].pop(inp_name)
else:
inp = kwargs.pop(mellon_inp_name)
if inp is not None:
params_run[inp_name] = inp
return_output_names = list(self.config.output_params.keys())
return params_components, params_run, return_output_names
def execute(self, **kwargs):
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
self.blocks.loader.update(**params_components)
output = self.blocks.run(**params_run, output=return_output_names)
return output
@@ -1,53 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modular_pipeline_presets"] = ["StableDiffusionXLAutoPipeline"]
_import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"]
_import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"]
_import_structure["decoders"] = ["StableDiffusionXLAutoDecodeStep"]
_import_structure["modular_block_mappings"] = ["TEXT2IMAGE_BLOCKS", "IMAGE2IMAGE_BLOCKS", "INPAINT_BLOCKS", "CONTROLNET_BLOCKS", "CONTROLNET_UNION_BLOCKS", "IP_ADAPTER_BLOCKS", "AUTO_BLOCKS", "SDXL_SUPPORTED_BLOCKS"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_pipeline_presets import StableDiffusionXLAutoPipeline
from .modular_loader import StableDiffusionXLModularLoader
from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep
from .decoders import StableDiffusionXLAutoDecodeStep
from .modular_block_mappings import SDXL_SUPPORTED_BLOCKS, TEXT2IMAGE_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, CONTROLNET_BLOCKS, CONTROLNET_UNION_BLOCKS, IP_ADAPTER_BLOCKS, AUTO_BLOCKS
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
File diff suppressed because it is too large Load Diff
@@ -1,215 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, List, Optional, Tuple, Union, Dict
import PIL
import torch
import numpy as np
from collections import OrderedDict
from ...image_processor import VaeImageProcessor, PipelineImageInput
from ...models import AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...utils import logging
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from ...configuration_utils import FrozenDict
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from ..modular_pipeline import (
AutoPipelineBlocks,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionXLDecodeStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config"),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
]
@property
def intermediates_inputs(self) -> List[str]:
return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")]
@property
def intermediates_outputs(self) -> List[str]:
return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components
@staticmethod
def upcast_vae(components):
dtype = components.vae.dtype
components.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
components.vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
components.vae.post_quant_conv.to(dtype)
components.vae.decoder.conv_in.to(dtype)
components.vae.decoder.mid_block.to(dtype)
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
if not block_state.output_type == "latent":
latents = block_state.latents
# make sure the VAE is in float32 mode, as it overflows in float16
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
if block_state.needs_upcasting:
self.upcast_vae(components)
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != components.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
components.vae = components.vae.to(latents.dtype)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
block_state.has_latents_mean = (
hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None
)
block_state.has_latents_std = (
hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None
)
if block_state.has_latents_mean and block_state.has_latents_std:
block_state.latents_mean = (
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
block_state.latents_std = (
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
else:
latents = latents / components.vae.config.scaling_factor
block_state.images = components.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed
if block_state.needs_upcasting:
components.vae.to(dtype=torch.float16)
else:
block_state.images = block_state.latents
# apply watermark if available
if hasattr(components, "watermark") and components.watermark is not None:
block_state.images = components.watermark.apply_watermark(block_state.images)
block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type)
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \
"only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("image", required=True),
InputParam("mask_image", required=True),
InputParam("padding_mask_crop"),
]
@property
def intermediates_inputs(self) -> List[str]:
return [
InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"),
InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.")
]
@property
def intermediates_outputs(self) -> List[str]:
return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
if block_state.padding_mask_crop is not None and block_state.crops_coords is not None:
block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images]
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks):
block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep]
block_names = ["decode", "mask_overlay"]
@property
def description(self):
return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \
"This is a sequential pipeline blocks:\n" + \
" - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + \
" - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image"
class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep]
block_names = ["inpaint", "non-inpaint"]
block_trigger_inputs = ["padding_mask_crop", None]
@property
def description(self):
return "Decode step that decode the denoised latents into images outputs.\n" + \
"This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \
" - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \
" - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided."
File diff suppressed because it is too large Load Diff
@@ -1,858 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, List, Optional, Tuple, Union, Dict
import PIL
import torch
from collections import OrderedDict
from ...image_processor import VaeImageProcessor, PipelineImageInput
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...models.lora import adjust_lora_scale_text_encoder
from ...utils import (
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor, unwrap_module
from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel
from ...configuration_utils import FrozenDict
from transformers import (
CLIPTextModel,
CLIPImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from ...schedulers import EulerDiscreteScheduler
from ...guiders import ClassifierFreeGuidance
from .modular_loader import StableDiffusionXLModularLoader
from ..modular_pipeline import PipelineBlock, PipelineState, AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec
import numpy as np
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionXLIPAdapterStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc"
" See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
" for more details"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("image_encoder", CLIPVisionModelWithProjection),
ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"),
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"ip_adapter_image",
PipelineImageInput,
required=True,
description="The image(s) to be used as ip adapter"
)
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings")
]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components
@staticmethod
def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(components.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = components.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = components.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = components.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds
):
image_embeds = []
if prepare_unconditional_embeds:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
components, single_ip_adapter_image, device, 1, output_hidden_state
)
image_embeds.append(single_image_embeds[None, :])
if prepare_unconditional_embeds:
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
if prepare_unconditional_embeds:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if prepare_unconditional_embeds:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
components,
ip_adapter_image=block_state.ip_adapter_image,
ip_adapter_image_embeds=None,
device=block_state.device,
num_images_per_prompt=1,
prepare_unconditional_embeds=block_state.prepare_unconditional_embeds,
)
if block_state.prepare_unconditional_embeds:
block_state.negative_ip_adapter_embeds = []
for i, image_embeds in enumerate(block_state.ip_adapter_embeds):
negative_image_embeds, image_embeds = image_embeds.chunk(2)
block_state.negative_ip_adapter_embeds.append(negative_image_embeds)
block_state.ip_adapter_embeds[i] = image_embeds
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLTextEncoderStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return(
"Text Encoder step that generate text_embeddings to guide the image generation"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", CLIPTextModel),
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
ComponentSpec("tokenizer", CLIPTokenizer),
ComponentSpec("tokenizer_2", CLIPTokenizer),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [ConfigSpec("force_zeros_for_empty_prompt", True)]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("negative_prompt"),
InputParam("negative_prompt_2"),
InputParam("cross_attention_kwargs"),
InputParam("clip_skip"),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"),
OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"),
OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"),
OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"),
]
@staticmethod
def check_inputs(block_state):
if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}")
@staticmethod
def encode_prompt(
components,
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prepare_unconditional_embeds: bool = True,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prepare_unconditional_embeds (`bool`):
whether to use prepare unconditional embeddings or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
device = device or components._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
components._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if components.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(components.text_encoder, lora_scale)
else:
scale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale)
else:
scale_lora_layers(components.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2]
text_encoders = (
[components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2]
)
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, tokenizer)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif prepare_unconditional_embeds and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
if components.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
else:
prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
if prepare_unconditional_embeds:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
if components.text_encoder_2 is not None:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
else:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if prepare_unconditional_embeds:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if components.text_encoder is not None:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
# Encode input prompt
block_state.text_encoder_lora_scale = (
block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None
)
(
block_state.prompt_embeds,
block_state.negative_prompt_embeds,
block_state.pooled_prompt_embeds,
block_state.negative_pooled_prompt_embeds,
) = self.encode_prompt(
components,
block_state.prompt,
block_state.prompt_2,
block_state.device,
1,
block_state.prepare_unconditional_embeds,
block_state.negative_prompt,
block_state.negative_prompt_2,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
lora_scale=block_state.text_encoder_lora_scale,
clip_skip=block_state.clip_skip,
)
# Add outputs
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"Vae Encoder step that encode the input image into a latent representation"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
]
@property
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs)
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.batch_size = block_state.image.shape[0]
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
)
block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator)
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config"),
ComponentSpec(
"mask_processor",
VaeImageProcessor,
config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}),
default_creation_method="from_config"),
]
@property
def description(self) -> str:
return (
"Vae encoder step that prepares the image and mask for the inpainting process"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("height"),
InputParam("width"),
InputParam("image", required=True),
InputParam("mask_image", required=True),
InputParam("padding_mask_crop"),
]
@property
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
InputParam("generator"),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"),
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"),
OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
# do not accept do_classifier_free_guidance
def prepare_mask_latents(
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image is not None and masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = None
if masked_image is not None:
if masked_image_latents is None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.device = components._execution_device
if block_state.padding_mask_crop is not None:
block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop)
block_state.resize_mode = "fill"
else:
block_state.crops_coords = None
block_state.resize_mode = "default"
block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode)
block_state.image = block_state.image.to(dtype=torch.float32)
block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords)
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
block_state.batch_size = block_state.image.shape[0]
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator)
# 7. Prepare mask latent variables
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
components,
block_state.mask,
block_state.masked_image,
block_state.batch_size,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
)
self.add_block_state(state, block_state)
return components, state
# auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file)
# Encode
class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask_image", "image"]
@property
def description(self):
return "Vae encoder step that encode the image inputs into their latent representations.\n" + \
"This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \
" - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \
" - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided."
class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin):
block_classes = [StableDiffusionXLIPAdapterStep]
block_names = ["ip_adapter"]
block_trigger_inputs = ["ip_adapter_image"]
@property
def description(self):
return "Run IP Adapter step if `ip_adapter_image` is provided."
@@ -1,121 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..modular_pipeline_utils import InsertableOrderedDict
# Import all the necessary block classes
from .denoise import (
StableDiffusionXLAutoDenoiseStep,
StableDiffusionXLControlNetDenoiseStep,
StableDiffusionXLDenoiseLoop,
StableDiffusionXLInpaintDenoiseLoop
)
from .before_denoise import (
StableDiffusionXLAutoBeforeDenoiseStep,
StableDiffusionXLInputStep,
StableDiffusionXLSetTimestepsStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLPrepareAdditionalConditioningStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLImg2ImgPrepareLatentsStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
StableDiffusionXLInpaintPrepareLatentsStep,
StableDiffusionXLControlNetInputStep,
StableDiffusionXLControlNetUnionInputStep
)
from .encoders import (
StableDiffusionXLTextEncoderStep,
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
StableDiffusionXLVaeEncoderStep,
StableDiffusionXLInpaintVaeEncoderStep,
StableDiffusionXLIPAdapterStep
)
from .decoders import (
StableDiffusionXLDecodeStep,
StableDiffusionXLInpaintDecodeStep,
StableDiffusionXLAutoDecodeStep
)
# YiYi notes: comment out for now, work on this later
# block mapping
TEXT2IMAGE_BLOCKS = InsertableOrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLSetTimestepsStep),
("prepare_latents", StableDiffusionXLPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseLoop),
("decode", StableDiffusionXLDecodeStep)
])
IMAGE2IMAGE_BLOCKS = InsertableOrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("image_encoder", StableDiffusionXLVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseLoop),
("decode", StableDiffusionXLDecodeStep)
])
INPAINT_BLOCKS = InsertableOrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("image_encoder", StableDiffusionXLInpaintVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLInpaintDenoiseLoop),
("decode", StableDiffusionXLInpaintDecodeStep)
])
CONTROLNET_BLOCKS = InsertableOrderedDict([
("controlnet_input", StableDiffusionXLControlNetInputStep),
("denoise", StableDiffusionXLControlNetDenoiseStep),
])
CONTROLNET_UNION_BLOCKS = InsertableOrderedDict([
("controlnet_input", StableDiffusionXLControlNetUnionInputStep),
("denoise", StableDiffusionXLControlNetDenoiseStep),
])
IP_ADAPTER_BLOCKS = InsertableOrderedDict([
("ip_adapter", StableDiffusionXLIPAdapterStep),
])
AUTO_BLOCKS = InsertableOrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
("denoise", StableDiffusionXLAutoDenoiseStep),
("decode", StableDiffusionXLAutoDecodeStep)
])
SDXL_SUPPORTED_BLOCKS = {
"text2img": TEXT2IMAGE_BLOCKS,
"img2img": IMAGE2IMAGE_BLOCKS,
"inpaint": INPAINT_BLOCKS,
"controlnet": CONTROLNET_BLOCKS,
"controlnet_union": CONTROLNET_UNION_BLOCKS,
"ip_adapter": IP_ADAPTER_BLOCKS,
"auto": AUTO_BLOCKS
}
@@ -1,174 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union, Dict
import PIL
import torch
import numpy as np
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
from ...image_processor import PipelineImageInput
from ...pipelines.pipeline_utils import StableDiffusionMixin
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from ...utils import logging
from ..modular_pipeline import ModularLoader
from ..modular_pipeline_utils import InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder?
# YiYi Notes: model specific components:
## (1) it should inherit from ModularLoader
## (2) acts like a container that holds components and configs
## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents
## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin)
## (5) how to use together with Components_manager?
class StableDiffusionXLModularLoader(
ModularLoader,
StableDiffusionMixin,
TextualInversionLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
ModularIPAdapterMixin,
):
@property
def default_sample_size(self):
default_sample_size = 128
if hasattr(self, "unet") and self.unet is not None:
default_sample_size = self.unet.config.sample_size
return default_sample_size
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if hasattr(self, "vae") and self.vae is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_unet(self):
num_channels_unet = 4
if hasattr(self, "unet") and self.unet is not None:
num_channels_unet = self.unet.config.in_channels
return num_channels_unet
@property
def num_channels_latents(self):
num_channels_latents = 4
if hasattr(self, "vae") and self.vae is not None:
num_channels_latents = self.vae.config.latent_channels
return num_channels_latents
# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks
SDXL_INPUTS_SCHEMA = {
"prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"),
"prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"),
"negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"),
"negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"),
"cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"),
"clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"),
"image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"),
"mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"),
"generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"),
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
"num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"),
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"),
"timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"),
"sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"),
"denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"),
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
"strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"),
"denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"),
"latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"),
"padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"),
"original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"),
"target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"),
"negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"),
"negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"),
"crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"),
"negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"),
"aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"),
"negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"),
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
"output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"),
"ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"),
"control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"),
"control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"),
"control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"),
"controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"),
"guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"),
"control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet")
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"),
"negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
"pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"),
"negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"),
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"),
"latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"),
"latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"),
"image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"),
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
"masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"),
"add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"),
"negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"),
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"),
"negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"),
"images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images")
}
SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = {
"prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"),
"negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
"pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"),
"negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"),
"batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"),
"dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"),
"mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"),
"masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"),
"crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"),
"num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"),
"latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"),
"add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"),
"negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"),
"timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"),
"noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"),
"negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"),
"images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images")
}
SDXL_OUTPUTS_SCHEMA = {
"images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images")
}
@@ -1,43 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union, Dict
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from .denoise import StableDiffusionXLAutoDenoiseStep
from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep
from .decoders import StableDiffusionXLAutoDecodeStep
from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks):
block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep]
block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decoder"]
@property
def description(self):
return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \
"- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \
"- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \
"- to run the controlnet workflow, you need to provide `control_image`\n" + \
"- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \
"- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \
"- for text-to-image generation, all you need to provide is `prompt`"
+7 -8
View File
@@ -246,15 +246,14 @@ def _get_connected_pipeline(pipeline_cls):
return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False)
def _get_model(pipeline_class_name):
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
for model_name, pipeline in task_mapping.items():
if pipeline.__name__ == pipeline_class_name:
return model_name
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
model_name = _get_model(pipeline_class_name)
def get_model(pipeline_class_name):
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
for model_name, pipeline in task_mapping.items():
if pipeline.__name__ == pipeline_class_name:
return model_name
model_name = get_model(pipeline_class_name)
if model_name is not None:
task_class = mapping.get(model_name, None)
@@ -912,6 +912,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
@@ -925,11 +931,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
init_latents = image
else:
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
image = image.float()
@@ -867,6 +867,12 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
@@ -880,11 +886,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
init_latents = image
else:
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
image = image.float()
@@ -609,6 +609,12 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
@@ -622,11 +628,6 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
init_latents = image
else:
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
image = image.float()
+2 -15
View File
@@ -75,11 +75,6 @@ class OnnxRuntimeModel:
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider"
if provider_options is None:
provider_options = []
elif not isinstance(provider_options, list):
provider_options = [provider_options]
return ort.InferenceSession(
path, providers=[provider], sess_options=sess_options, provider_options=provider_options
)
@@ -179,10 +174,7 @@ class OnnxRuntimeModel:
# load model from local directory
if os.path.isdir(model_id):
model = OnnxRuntimeModel.load_model(
Path(model_id, model_file_name).as_posix(),
provider=provider,
sess_options=sess_options,
provider_options=kwargs.pop("provider_options"),
Path(model_id, model_file_name).as_posix(), provider=provider, sess_options=sess_options
)
kwargs["model_save_dir"] = Path(model_id)
# load model from hub
@@ -198,12 +190,7 @@ class OnnxRuntimeModel:
)
kwargs["model_save_dir"] = Path(model_cache_path).parent
kwargs["latest_model_name"] = Path(model_cache_path).name
model = OnnxRuntimeModel.load_model(
model_cache_path,
provider=provider,
sess_options=sess_options,
provider_options=kwargs.pop("provider_options"),
)
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
return cls(model=model, **kwargs)
@classmethod
@@ -917,6 +917,12 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
@@ -930,11 +936,6 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
init_latents = image
else:
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
image = image.float()
@@ -707,6 +707,12 @@ class StableDiffusionXLPAGImg2ImgPipeline(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
@@ -720,11 +726,6 @@ class StableDiffusionXLPAGImg2ImgPipeline(
init_latents = image
else:
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
image = image.float()
@@ -331,20 +331,6 @@ def maybe_raise_or_warn(
)
# a simpler version of get_class_obj_and_candidates, it won't work with custom code
def simple_get_class_obj(library_name, class_name):
from diffusers import pipelines
is_pipeline_module = hasattr(pipelines, library_name)
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
else:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
return class_obj
def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):
@@ -853,10 +839,7 @@ def _fetch_class_library_tuple(module):
library = not_compiled_module.__module__
# retrieve class_name
if isinstance(not_compiled_module, type):
class_name = not_compiled_module.__name__
else:
class_name = not_compiled_module.__class__.__name__
class_name = not_compiled_module.__class__.__name__
return (library, class_name)
+4 -23
View File
@@ -58,7 +58,6 @@ from ..utils import (
_is_valid_type,
is_accelerate_available,
is_accelerate_version,
is_hpu_available,
is_torch_npu_available,
is_torch_version,
is_transformers_version,
@@ -427,7 +426,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
@@ -444,7 +443,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
)
# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and device_type in ["cuda", "xpu"]:
@@ -452,20 +450,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)
# Enable generic support for Intel Gaudi accelerator using GPU/HPU migration
if device_type == "hpu" and kwargs.pop("hpu_migration", True) and is_hpu_available():
os.environ["PT_HPU_GPU_MIGRATION"] = "1"
logger.debug("Environment variable set: PT_HPU_GPU_MIGRATION=1")
import habana_frameworks.torch # noqa: F401
# HPU hardware check
if not (hasattr(torch, "hpu") and torch.hpu.is_available()):
raise ValueError("You are trying to call `.to('hpu')` but HPU device is unavailable.")
os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
module_names, _ = self._get_signature_keys(self)
modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
@@ -1120,11 +1104,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
automatically detect the available accelerator and use.
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
@@ -1248,7 +1230,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
self.remove_all_hooks()
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
@@ -1948,10 +1930,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs
}
optional_components = pipeline._optional_components if hasattr(pipeline, "_optional_components") and pipeline._optional_components else []
missing_modules = (
set(expected_modules)
- set(optional_components)
- set(pipeline._optional_components)
- set(pipeline_kwargs.keys())
- set(true_optional_modules)
)
@@ -695,6 +695,12 @@ class StableDiffusionXLImg2ImgPipeline(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
@@ -708,11 +714,6 @@ class StableDiffusionXLImg2ImgPipeline(
init_latents = image
else:
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
image = image.float()
-1
View File
@@ -71,7 +71,6 @@ from .import_utils import (
is_gguf_version,
is_google_colab,
is_hf_hub_version,
is_hpu_available,
is_inflect_available,
is_invisible_watermark_available,
is_k_diffusion_available,
-15
View File
@@ -1388,21 +1388,6 @@ class LDMSuperResolutionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class ModularLoader(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PNDMPipeline(metaclass=DummyObject):
_backends = ["torch"]
@@ -2432,21 +2432,6 @@ class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLModularLoader(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLPAGImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
+4 -81
View File
@@ -15,16 +15,13 @@
"""Utilities to dynamically load objects from the Hub."""
import importlib
import signal
import inspect
import json
import os
import re
import shutil
import sys
import threading
from pathlib import Path
from types import ModuleType
from typing import Dict, Optional, Union
from urllib import request
@@ -40,8 +37,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15))
_HF_REMOTE_CODE_LOCK = threading.Lock()
def get_diffusers_versions():
@@ -159,87 +154,15 @@ def check_imports(filename):
return get_relative_imports(filename)
def _raise_timeout_error(signum, frame):
raise ValueError(
"Loading this model requires you to execute custom code contained in the model repository on your local "
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
)
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
if trust_remote_code is None:
if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
prev_sig_handler = None
try:
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
signal.alarm(TIME_OUT_REMOTE_CODE)
while trust_remote_code is None:
answer = input(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
f"Do you wish to run the custom code? [y/N] "
)
if answer.lower() in ["yes", "y", "1"]:
trust_remote_code = True
elif answer.lower() in ["no", "n", "0", ""]:
trust_remote_code = False
signal.alarm(0)
except Exception:
# OS which does not support signal.SIGALRM
raise ValueError(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
finally:
if prev_sig_handler is not None:
signal.signal(signal.SIGALRM, prev_sig_handler)
signal.alarm(0)
elif has_remote_code:
# For the CI which puts the timeout at 0
_raise_timeout_error(None, None)
if has_remote_code and not trust_remote_code:
raise ValueError(
f"Loading {model_name} requires you to execute the configuration file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error."
)
return trust_remote_code
def get_class_in_module(class_name, module_path, force_reload=False):
def get_class_in_module(class_name, module_path):
"""
Import a module on the cache directory for modules and extract a class from it.
"""
name = os.path.normpath(module_path)
if name.endswith(".py"):
name = name[:-3]
name = name.replace(os.path.sep, ".")
module_file: Path = Path(HF_MODULES_CACHE) / module_path
with _HF_REMOTE_CODE_LOCK:
if force_reload:
sys.modules.pop(name, None)
importlib.invalidate_caches()
cached_module: Optional[ModuleType] = sys.modules.get(name)
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
module: ModuleType
if cached_module is None:
module = importlib.util.module_from_spec(module_spec)
# insert it into sys.modules before any loading begins
sys.modules[name] = module
else:
module = cached_module
module_spec.loader.exec_module(module)
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
if class_name is None:
return find_pipeline_class(module)
return getattr(module, class_name)
@@ -531,4 +454,4 @@ def get_class_from_dynamic_module(
revision=revision,
local_files_only=local_files_only,
)
return get_class_in_module(class_name, final_module)
return get_class_in_module(class_name, final_module.replace(".py", ""))
-4
View File
@@ -353,10 +353,6 @@ def is_timm_available():
return _timm_available
def is_hpu_available():
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
-5
View File
@@ -90,11 +90,6 @@ def is_compiled_module(module) -> bool:
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
def unwrap_module(module):
"""Unwraps a module if it was compiled with torch.compile()"""
return module._orig_mod if is_compiled_module(module) else module
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
+8 -8
View File
@@ -352,7 +352,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.INFO)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -403,7 +403,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.DEBUG)
# Change the transformer config to mimic a real use case.
@@ -486,7 +486,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.DEBUG)
out_features, in_features = pipe.transformer.x_embedder.weight.shape
@@ -541,7 +541,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.DEBUG)
out_features, in_features = pipe.transformer.x_embedder.weight.shape
@@ -590,7 +590,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.DEBUG)
out_features, in_features = pipe.transformer.x_embedder.weight.shape
@@ -653,7 +653,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.INFO)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
@@ -668,7 +668,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_lora_unload_with_parameter_expanded_shapes(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.DEBUG)
# Change the transformer config to mimic a real use case.
@@ -734,7 +734,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.DEBUG)
# Change the transformer config to mimic a real use case.
+3 -3
View File
@@ -1017,7 +1017,7 @@ class PeftLoraLoaderMixinTests:
)
scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
logger = logging.get_logger("diffusers.loaders.lora_base")
logger = logging.get_logger("diffusers.loaders.lora.lora_base")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components)
@@ -1824,7 +1824,7 @@ class PeftLoraLoaderMixinTests:
elif lora_module == "text_encoder_2":
prefix = "text_encoder_2"
logger = logging.get_logger("diffusers.loaders.lora_base")
logger = logging.get_logger("diffusers.loaders.lora.lora_base")
logger.setLevel(logging.WARNING)
with CaptureLogger(logger) as cap_logger:
@@ -1925,7 +1925,7 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
logger.setLevel(logging.INFO)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -22,7 +22,6 @@ from parameterized import parameterized
from diffusers import AsymmetricAutoencoderKL
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -135,32 +134,18 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
# fmt: off
[
33,
Expectations(
{
("xpu", 3): torch.tensor([-0.0343, 0.2873, 0.1680, -0.0140, -0.3459, 0.3522, -0.1336, 0.1075]),
("cuda", 7): torch.tensor([-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205]),
("mps", None): torch.tensor(
[-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824]
),
}
),
[-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205],
[-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824],
],
[
47,
Expectations(
{
("xpu", 3): torch.tensor([0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529]),
("cuda", 7): torch.tensor([0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529]),
("mps", None): torch.tensor(
[-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089]
),
}
),
[0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529],
[-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089],
],
# fmt: on
]
)
def test_stable_diffusion(self, seed, expected_slices):
def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
model = self.get_sd_vae_model()
image = self.get_sd_image(seed)
generator = self.get_generator(seed)
@@ -171,9 +156,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
assert sample.shape == image.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
expected_slice = expected_slices.get_expectation()
assert torch_all_close(output_slice, expected_slice, atol=5e-3)
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
@parameterized.expand(
[
@@ -17,14 +17,7 @@ import unittest
import torch
from diffusers import HunyuanVideoTransformer3DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
is_torch_compile,
require_torch_2,
require_torch_gpu,
slow,
torch_device,
)
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
@@ -96,21 +89,6 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
@@ -179,21 +157,6 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.T
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
@@ -260,21 +223,6 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.Test
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
@@ -342,18 +290,3 @@ class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, u
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@@ -11,12 +11,10 @@ from diffusers import (
UNet2DModel,
)
from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache,
enable_full_determinism,
nightly,
require_torch_2,
require_torch_accelerator,
require_torch_gpu,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
@@ -170,17 +168,17 @@ class ConsistencyModelPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@nightly
@require_torch_accelerator
@require_torch_gpu
class ConsistencyModelPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
torch.cuda.empty_cache()
def get_inputs(self, seed=0, get_fixed_latents=False, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)):
generator = torch.manual_seed(seed)
@@ -266,19 +264,11 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
# Ensure usage of flash attention in torch 2.0
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
image = pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slices = Expectations(
{
("xpu", 3): np.array([0.0816, 0.0518, 0.0445, 0.0594, 0.0739, 0.0534, 0.0805, 0.0457, 0.0765]),
("cuda", 7): np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314]),
("cuda", 8): np.array([0.0816, 0.0518, 0.0445, 0.0594, 0.0739, 0.0534, 0.0805, 0.0457, 0.0765]),
}
)
expected_slice = expected_slices.get_expectation()
expected_slice = np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@@ -35,7 +35,7 @@ from diffusers.utils.testing_utils import (
enable_full_determinism,
nightly,
numpy_cosine_similarity_distance,
require_big_accelerator,
require_big_gpu_with_torch_cuda,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
@@ -210,8 +210,8 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl
@nightly
@require_big_accelerator
@pytest.mark.big_accelerator
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
class FluxControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxControlNetPipeline
@@ -33,11 +33,10 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
require_torch_accelerator,
require_torch_gpu,
slow,
torch_device,
)
@@ -396,17 +395,17 @@ class MarigoldIntrinsicsPipelineFastTests(MarigoldIntrinsicsPipelineTesterMixin,
@slow
@require_torch_accelerator
@require_torch_gpu
class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
torch.cuda.empty_cache()
def _test_marigold_intrinsics(
self,
@@ -425,7 +424,7 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
from_pretrained_kwargs["torch_dtype"] = torch.float16
pipe = MarigoldIntrinsicsPipeline.from_pretrained(model_id, **from_pretrained_kwargs)
if device in ["cuda", "xpu"]:
if device == "cuda":
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
@@ -465,10 +464,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
def test_marigold_intrinsics_einstein_f32_accelerator_G0_S1_P768_E1_B1_M1(self):
def test_marigold_intrinsics_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
self._test_marigold_intrinsics(
is_fp16=False,
device=torch_device,
device="cuda",
generator_seed=0,
expected_slice=np.array([0.62127, 0.61906, 0.61687, 0.61946, 0.61903, 0.61961, 0.61808, 0.62099, 0.62894]),
num_inference_steps=1,
@@ -478,10 +477,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E1_B1_M1(self):
def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
self._test_marigold_intrinsics(
is_fp16=True,
device=torch_device,
device="cuda",
generator_seed=0,
expected_slice=np.array([0.62109, 0.61914, 0.61719, 0.61963, 0.61914, 0.61963, 0.61816, 0.62109, 0.62891]),
num_inference_steps=1,
@@ -491,10 +490,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
def test_marigold_intrinsics_einstein_f16_accelerator_G2024_S1_P768_E1_B1_M1(self):
def test_marigold_intrinsics_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
self._test_marigold_intrinsics(
is_fp16=True,
device=torch_device,
device="cuda",
generator_seed=2024,
expected_slice=np.array([0.64111, 0.63916, 0.63623, 0.63965, 0.63916, 0.63965, 0.6377, 0.64062, 0.64941]),
num_inference_steps=1,
@@ -504,10 +503,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S2_P768_E1_B1_M1(self):
def test_marigold_intrinsics_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
self._test_marigold_intrinsics(
is_fp16=True,
device=torch_device,
device="cuda",
generator_seed=0,
expected_slice=np.array([0.60254, 0.60059, 0.59961, 0.60156, 0.60107, 0.60205, 0.60254, 0.60449, 0.61133]),
num_inference_steps=2,
@@ -517,10 +516,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P512_E1_B1_M1(self):
def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
self._test_marigold_intrinsics(
is_fp16=True,
device=torch_device,
device="cuda",
generator_seed=0,
expected_slice=np.array([0.64551, 0.64453, 0.64404, 0.64502, 0.64844, 0.65039, 0.64502, 0.65039, 0.65332]),
num_inference_steps=1,
@@ -530,10 +529,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E3_B1_M1(self):
def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
self._test_marigold_intrinsics(
is_fp16=True,
device=torch_device,
device="cuda",
generator_seed=0,
expected_slice=np.array([0.61572, 0.61377, 0.61182, 0.61426, 0.61377, 0.61426, 0.61279, 0.61572, 0.62354]),
num_inference_steps=1,
@@ -544,10 +543,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E4_B2_M1(self):
def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
self._test_marigold_intrinsics(
is_fp16=True,
device=torch_device,
device="cuda",
generator_seed=0,
expected_slice=np.array([0.61914, 0.6167, 0.61475, 0.61719, 0.61719, 0.61768, 0.61572, 0.61914, 0.62695]),
num_inference_steps=1,
@@ -558,10 +557,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P512_E1_B1_M0(self):
def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P512_E1_B1_M0(self):
self._test_marigold_intrinsics(
is_fp16=True,
device=torch_device,
device="cuda",
generator_seed=0,
expected_slice=np.array([0.65332, 0.64697, 0.64648, 0.64844, 0.64697, 0.64111, 0.64941, 0.64209, 0.65332]),
num_inference_steps=1,
@@ -24,15 +24,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline
from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache,
floats_tensor,
nightly,
require_accelerator,
require_torch_accelerator,
torch_device,
)
from diffusers.utils.testing_utils import floats_tensor, nightly, require_accelerator, require_torch_gpu, torch_device
class SafeDiffusionPipelineFastTests(unittest.TestCase):
@@ -40,13 +32,13 @@ class SafeDiffusionPipelineFastTests(unittest.TestCase):
# clean up the VRAM before each test
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
torch.cuda.empty_cache()
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
torch.cuda.empty_cache()
@property
def dummy_image(self):
@@ -270,19 +262,19 @@ class SafeDiffusionPipelineFastTests(unittest.TestCase):
@nightly
@require_torch_accelerator
@require_torch_gpu
class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
torch.cuda.empty_cache()
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
torch.cuda.empty_cache()
def test_harm_safe_stable_diffusion(self):
sd_pipe = StableDiffusionPipeline.from_pretrained(
@@ -316,14 +308,7 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
image = output.images
image_slice = image[0, -3:, -3:, -1]
expected_slices = Expectations(
{
("xpu", 3): [0.0076, 0.0058, 0.0012, 0, 0.0047, 0.0046, 0, 0, 0],
("cuda", 7): [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176],
("cuda", 8): [0.0076, 0.0058, 0.0012, 0, 0.0047, 0.0046, 0, 0, 0],
}
)
expected_slice = expected_slices.get_expectation()
expected_slice = [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176]
assert image.shape == (1, 512, 512, 3)
@@ -350,15 +335,6 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
expected_slice = [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719]
expected_slices = Expectations(
{
("xpu", 3): [0.0443, 0.0439, 0.0381, 0.0336, 0.0408, 0.0345, 0.0405, 0.0338, 0.0293],
("cuda", 7): [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719],
("cuda", 8): [0.0443, 0.0439, 0.0381, 0.0336, 0.0408, 0.0345, 0.0405, 0.0338, 0.0293],
}
)
expected_slice = expected_slices.get_expectation()
assert image.shape == (1, 512, 512, 3)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -389,14 +365,8 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
image = output.images
image_slice = image[0, -3:, -3:, -1]
expected_slices = Expectations(
{
("xpu", 3): [0.3244, 0.3355, 0.3260, 0.3123, 0.3246, 0.3426, 0.3109, 0.3471, 0.4001],
("cuda", 7): [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297],
("cuda", 8): [0.3605, 0.3684, 0.3712, 0.3624, 0.3675, 0.3726, 0.3494, 0.3748, 0.4044],
}
)
expected_slice = expected_slices.get_expectation()
expected_slice = [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297]
assert image.shape == (1, 512, 512, 3)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -419,16 +389,7 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
image = output.images
image_slice = image[0, -3:, -3:, -1]
expected_slices = Expectations(
{
("xpu", 3): [0.6178, 0.6260, 0.6194, 0.6435, 0.6265, 0.6461, 0.6567, 0.6576, 0.6444],
("cuda", 7): [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443],
("cuda", 8): [0.5892, 0.5959, 0.5914, 0.6123, 0.5982, 0.6141, 0.6180, 0.6262, 0.6171],
}
)
print(f"image_slice: {image_slice.flatten()}")
expected_slice = expected_slices.get_expectation()
expected_slice = [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443]
assert image.shape == (1, 512, 512, 3)
@@ -484,14 +445,7 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
image = output.images
image_slice = image[0, -3:, -3:, -1]
expected_slices = Expectations(
{
("xpu", 3): np.array([0.0695, 0.1244, 0.1831, 0.0527, 0.0444, 0.1660, 0.0572, 0.0677, 0.1551]),
("cuda", 7): np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561]),
("cuda", 8): np.array([0.0695, 0.1244, 0.1831, 0.0527, 0.0444, 0.1660, 0.0572, 0.0677, 0.1551]),
}
)
expected_slice = expected_slices.get_expectation()
expected_slice = np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561])
assert image.shape == (1, 512, 512, 3)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+5 -5
View File
@@ -1485,8 +1485,8 @@ class PipelineTesterMixin:
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
self.assertTrue(all(device == torch_device for device in model_devices))
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
output_cuda = pipe(**self.get_dummy_inputs(torch_device))[0]
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
@@ -1677,11 +1677,11 @@ class PipelineTesterMixin:
pipe.set_progress_bar_config(disable=None)
pipe.enable_model_cpu_offload()
pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_dummy_inputs(generator_device)
output_with_offload = pipe(**inputs)[0]
pipe.enable_model_cpu_offload()
pipe.enable_model_cpu_offload(device=torch_device)
inputs = self.get_dummy_inputs(generator_device)
output_with_offload_twice = pipe(**inputs)[0]
@@ -2226,7 +2226,7 @@ class PipelineTesterMixin:
def enable_group_offload_on_component(pipe, group_offloading_kwargs):
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
# tiling is enabled and a forward pass is run, when accelerator streams are used, the execution order of
# tiling is enabled and a forward pass is run, when cuda streams are used, the execution order of
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
# warmup forward pass (even with dummy small inputs) is recommended.
for component_name in [
@@ -22,13 +22,13 @@ from diffusers import (
UniDiffuserTextDecoder,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
nightly,
require_torch_2,
require_torch_accelerator,
require_torch_gpu,
run_test_in_subprocess,
torch_device,
)
@@ -577,24 +577,24 @@ class UniDiffuserPipelineFastTests(
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
@unittest.skip(
"Test not supported because it has a bunch of direct configs at init and also, this pipeline isn't used that much now."
"Test not supported becauseit has a bunch of direct configs at init and also, this pipeline isn't used that much now."
)
def test_encode_prompt_works_in_isolation():
pass
@nightly
@require_torch_accelerator
@require_torch_gpu
class UniDiffuserPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
torch.cuda.empty_cache()
def get_inputs(self, device, seed=0, generate_latents=False):
generator = torch.manual_seed(seed)
@@ -705,17 +705,17 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
@nightly
@require_torch_accelerator
@require_torch_gpu
class UniDiffuserPipelineNightlyTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
torch.cuda.empty_cache()
def get_inputs(self, device, seed=0, generate_latents=False):
generator = torch.manual_seed(seed)
@@ -5,7 +5,7 @@ import requests
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
@@ -18,9 +18,7 @@ import unittest
import torch
from diffusers import (
AutoencoderKL,
)
from diffusers import AutoencoderKL
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
@@ -16,9 +16,7 @@
import gc
import unittest
from diffusers import (
AutoencoderKLWan,
)
from diffusers import AutoencoderKLWan
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
@@ -18,13 +18,11 @@ import unittest
import torch
from diffusers import (
WanTransformer3DModel,
)
from diffusers import WanTransformer3DModel
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_big_accelerator,
require_big_gpu_with_torch_cuda,
require_torch_accelerator,
torch_device,
)
@@ -62,7 +60,7 @@ class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase):
)
@require_big_accelerator
@require_big_gpu_with_torch_cuda
@require_torch_accelerator
class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase):
model_class = WanTransformer3DModel

Some files were not shown because too many files have changed in this diff Show More