Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8d8621ec72 | |||
| 0c37895440 | |||
| 9bebdf225d | |||
| c05114d5ec | |||
| a57a5ab4c0 | |||
| 4b1c7dc81a | |||
| 1590325a60 | |||
| e4dd7c5333 | |||
| d6430c79a3 | |||
| 1597ae6ac9 | |||
| 11a23d11fe | |||
| 6b8b225aca | |||
| 27d2401e59 | |||
| 1ddfe14220 | |||
| 0e8d1d25eb | |||
| 546446ae21 | |||
| ea3f0b8d68 | |||
| f0ea9ff2e2 | |||
| 1b7c286974 | |||
| 6138cc1720 | |||
| ea0ce4bfab | |||
| f2aa2f91dc | |||
| 4faac73219 | |||
| d870e3c9a6 | |||
| 178b884673 | |||
| 2da3cb4a8c | |||
| ea3ba4f431 | |||
| 21b2566933 | |||
| a71334b861 | |||
| eb47a67d50 | |||
| 8267677a24 |
@@ -22,11 +22,11 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
|
||||
|
||||
## IPAdapterMixin
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
|
||||
[[autodoc]] loaders.ip_adapter.ip_adapter.IPAdapterMixin
|
||||
|
||||
## SD3IPAdapterMixin
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin
|
||||
[[autodoc]] loaders.ip_adapter.ip_adapter.SD3IPAdapterMixin
|
||||
- all
|
||||
- is_ip_adapter_active
|
||||
|
||||
|
||||
@@ -39,58 +39,66 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
## StableDiffusionLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin
|
||||
|
||||
## StableDiffusionXLLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin
|
||||
|
||||
## SD3LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.SD3LoraLoaderMixin
|
||||
|
||||
## FluxLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.FluxLoraLoaderMixin
|
||||
|
||||
## CogVideoXLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin
|
||||
|
||||
## Mochi1LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.Mochi1LoraLoaderMixin
|
||||
## AuraFlowLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.AuraFlowLoraLoaderMixin
|
||||
|
||||
## LTXVideoLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.LTXVideoLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.LTXVideoLoraLoaderMixin
|
||||
|
||||
## SanaLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.SanaLoraLoaderMixin
|
||||
|
||||
## HunyuanVideoLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.HunyuanVideoLoraLoaderMixin
|
||||
|
||||
## Lumina2LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Lumina2LoraLoaderMixin
|
||||
|
||||
## CogView4LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.CogView4LoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.Lumina2LoraLoaderMixin
|
||||
|
||||
## WanLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.WanLoraLoaderMixin
|
||||
|
||||
## CogView4LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora.lora_pipeline.CogView4LoraLoaderMixin
|
||||
|
||||
## CogView4LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora.lora_pipeline.CogView4LoraLoaderMixin
|
||||
|
||||
## WanLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora.lora_pipeline.WanLoraLoaderMixin
|
||||
|
||||
## AmusedLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.AmusedLoraLoaderMixin
|
||||
|
||||
## HiDreamImageLoraLoaderMixin
|
||||
|
||||
@@ -98,4 +106,4 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
## LoraBaseMixin
|
||||
|
||||
[[autodoc]] loaders.lora_base.LoraBaseMixin
|
||||
[[autodoc]] loaders.lora.lora_base.LoraBaseMixin
|
||||
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# SD3Transformer2D
|
||||
|
||||
This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
|
||||
This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and [SD3Transformer2DModel], check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
|
||||
|
||||
The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
|
||||
|
||||
@@ -24,6 +24,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
## SD3Transformer2DLoadersMixin
|
||||
|
||||
[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin
|
||||
[[autodoc]] loaders.ip_adapter.transformer_sd3.SD3Transformer2DLoadersMixin
|
||||
- all
|
||||
- _load_ip_adapter_weights
|
||||
@@ -86,7 +86,6 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
|
||||
| Perturbed-Attention Guidance |StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).|[Perturbed-Attention Guidance](#perturbed-attention-guidance)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/perturbed_attention_guidance.ipynb)|[Hyoungwon Cho](https://github.com/HyoungwonCho)|
|
||||
| CogVideoX DDIM Inversion Pipeline | Implementation of DDIM inversion and guided attention-based editing denoising process on CogVideoX. | [CogVideoX DDIM Inversion Pipeline](#cogvideox-ddim-inversion-pipeline) | - | [LittleNyima](https://github.com/LittleNyima) |
|
||||
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://arxiv.org/abs/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
|
||||
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
```py
|
||||
@@ -5433,50 +5432,4 @@ cropped_image = gen_image.crop((0, 0, width_init, height_init))
|
||||
cropped_image.save("data/result.png")
|
||||
````
|
||||
### Result
|
||||
[<img src="https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/faithdiff_restored.PNG" width="512px" height="512px"/>](https://imgsli.com/MzY1NzE2)
|
||||
|
||||
|
||||
# Stable Diffusion 3 InstructPix2Pix Pipeline
|
||||
This the implementation of the Stable Diffusion 3 InstructPix2Pix Pipeline, based on the HuggingFace Diffusers.
|
||||
|
||||
## Example Usage
|
||||
This pipeline aims to edit image based on user's instruction by using SD3
|
||||
````py
|
||||
import torch
|
||||
from diffusers import SD3Transformer2DModel
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
resolution = 512
|
||||
image = load_image("https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png").resize(
|
||||
(resolution, resolution)
|
||||
)
|
||||
edit_instruction = "Turn sky into a sunny one"
|
||||
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers", custom_pipeline="pipeline_stable_diffusion_3_instruct_pix2pix", torch_dtype=torch.float16).to('cuda')
|
||||
|
||||
pipe.transformer = SD3Transformer2DModel.from_pretrained("CaptainZZZ/sd3-instructpix2pix",torch_dtype=torch.float16).to('cuda')
|
||||
|
||||
edited_image = pipe(
|
||||
prompt=edit_instruction,
|
||||
image=image,
|
||||
height=resolution,
|
||||
width=resolution,
|
||||
guidance_scale=7.5,
|
||||
image_guidance_scale=1.5,
|
||||
num_inference_steps=30,
|
||||
).images[0]
|
||||
|
||||
edited_image.save("edited_image.png")
|
||||
````
|
||||
|Original|Edited|
|
||||
|---|---|
|
||||
||
|
||||
|
||||
### Note
|
||||
This model is trained on 512x512, so input size is better on 512x512.
|
||||
For better editing performance, please refer to this powerful model https://huggingface.co/BleachNick/SD3_UltraEdit_freeform and Paper "UltraEdit: Instruction-based Fine-Grained Image
|
||||
Editing at Scale", many thanks to their contribution!
|
||||
[<img src="https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/faithdiff_restored.PNG" width="512px" height="512px"/>](https://imgsli.com/MzY1NzE2)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -639,15 +639,6 @@ def parse_args(input_args=None):
|
||||
action="store_true",
|
||||
help="Enable model cpu offload and save memory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
default="lanczos",
|
||||
choices=[
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
],
|
||||
help="The image interpolation method to use for resizing images.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -745,13 +736,9 @@ def get_train_dataset(args, accelerator):
|
||||
|
||||
|
||||
def prepare_train_dataset(dataset, accelerator):
|
||||
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
|
||||
if interpolation is None:
|
||||
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
|
||||
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=interpolation),
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
@@ -760,7 +747,7 @@ def prepare_train_dataset(dataset, accelerator):
|
||||
|
||||
conditioning_image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=interpolation),
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
|
||||
@@ -134,25 +134,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
|
||||
|
||||
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
|
||||
validation_image = Image.open(validation_image).convert("RGB")
|
||||
|
||||
try:
|
||||
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
|
||||
except (AttributeError, KeyError):
|
||||
supported_interpolation_modes = [
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
]
|
||||
raise ValueError(
|
||||
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
|
||||
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
|
||||
)
|
||||
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=interpolation),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
]
|
||||
)
|
||||
validation_image = transform(validation_image)
|
||||
validation_image = validation_image.resize((args.resolution, args.resolution))
|
||||
|
||||
images = []
|
||||
|
||||
@@ -605,15 +587,6 @@ def parse_args(input_args=None):
|
||||
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
default="lanczos",
|
||||
choices=[
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
],
|
||||
help="The image interpolation method to use for resizing images.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -759,20 +732,9 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom
|
||||
|
||||
|
||||
def prepare_train_dataset(dataset, accelerator):
|
||||
try:
|
||||
interpolation_mode = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
|
||||
except (AttributeError, KeyError):
|
||||
supported_interpolation_modes = [
|
||||
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
||||
]
|
||||
raise ValueError(
|
||||
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
|
||||
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
|
||||
)
|
||||
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=interpolation_mode),
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
@@ -781,7 +743,7 @@ def prepare_train_dataset(dataset, accelerator):
|
||||
|
||||
conditioning_image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=interpolation_mode),
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
|
||||
@@ -239,7 +239,6 @@ else:
|
||||
"KarrasVePipeline",
|
||||
"LDMPipeline",
|
||||
"LDMSuperResolutionPipeline",
|
||||
"ModularPipeline",
|
||||
"PNDMPipeline",
|
||||
"RePaintPipeline",
|
||||
"ScoreSdeVePipeline",
|
||||
@@ -494,12 +493,10 @@ else:
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"StableDiffusionXLPAGImg2ImgPipeline",
|
||||
"StableDiffusionXLPAGInpaintPipeline",
|
||||
"StableDiffusionXLPAGPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
"StableDiffusionXLAutoPipeline",
|
||||
"StableUnCLIPImg2ImgPipeline",
|
||||
"StableUnCLIPPipeline",
|
||||
"StableVideoDiffusionPipeline",
|
||||
@@ -837,7 +834,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KarrasVePipeline,
|
||||
LDMPipeline,
|
||||
LDMSuperResolutionPipeline,
|
||||
ModularPipeline,
|
||||
PNDMPipeline,
|
||||
RePaintPipeline,
|
||||
ScoreSdeVePipeline,
|
||||
@@ -1070,12 +1066,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLModularPipeline,
|
||||
StableDiffusionXLPAGImg2ImgPipeline,
|
||||
StableDiffusionXLPAGInpaintPipeline,
|
||||
StableDiffusionXLPAGPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableDiffusionXLAutoPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
StableVideoDiffusionPipeline,
|
||||
|
||||
@@ -1,745 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
PAGCFGIdentitySelfAttnProcessor2_0,
|
||||
PAGIdentitySelfAttnProcessor2_0,
|
||||
)
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
The predicted noise tensor for the guided diffusion process.
|
||||
noise_pred_text (`torch.Tensor`):
|
||||
The predicted noise tensor for the text-guided diffusion process.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescale factor applied to the noise predictions.
|
||||
|
||||
Returns:
|
||||
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
"""
|
||||
This class is used to guide the pipeline with CFG (Classifier-Free Guidance).
|
||||
"""
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
|
||||
# a flag to disable CFG, e.g. we disable it for LCM and use a guidance scale embedding instead
|
||||
disable_guidance = guider_kwargs.get("disable_guidance", False)
|
||||
guidance_scale = guider_kwargs.get("guidance_scale", None)
|
||||
if guidance_scale is None:
|
||||
raise ValueError("guidance_scale is not provided in guider_kwargs")
|
||||
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
|
||||
batch_size = guider_kwargs.get("batch_size", None)
|
||||
if batch_size is None:
|
||||
raise ValueError("batch_size is not provided in guider_kwargs")
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._batch_size = batch_size
|
||||
self._disable_guidance = disable_guidance
|
||||
|
||||
def reset_guider(self, pipeline):
|
||||
pass
|
||||
|
||||
def maybe_update_guider(self, pipeline, timestep):
|
||||
pass
|
||||
|
||||
def maybe_update_input(self, pipeline, cond_input):
|
||||
pass
|
||||
|
||||
def _maybe_split_prepared_input(self, cond):
|
||||
"""
|
||||
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
|
||||
|
||||
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
|
||||
It determines whether to split the input based on its batch size relative to the expected batch size.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to process.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The negative conditional input (uncond_input)
|
||||
- The positive conditional input (cond_input)
|
||||
"""
|
||||
if cond.shape[0] == self.batch_size * 2:
|
||||
neg_cond = cond[0 : self.batch_size]
|
||||
cond = cond[self.batch_size :]
|
||||
return neg_cond, cond
|
||||
elif cond.shape[0] == self.batch_size:
|
||||
return cond, cond
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape: {cond.shape}")
|
||||
|
||||
def _is_prepared_input(self, cond):
|
||||
"""
|
||||
Check if the input is already prepared for Classifier-Free Guidance (CFG).
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the input is already prepared, False otherwise.
|
||||
"""
|
||||
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
|
||||
|
||||
return cond_tensor.shape[0] == self.batch_size * 2
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
cond_input: Union[torch.Tensor, List[torch.Tensor]],
|
||||
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Prepare the input for CFG.
|
||||
|
||||
Args:
|
||||
cond_input (Union[torch.Tensor, List[torch.Tensor]]):
|
||||
The conditional input. It can be a single tensor or a
|
||||
list of tensors. It must have the same length as `negative_cond_input`.
|
||||
negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a
|
||||
single tensor or a list of tensors. It must have the same length as `cond_input`.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, List[torch.Tensor]]: The prepared input.
|
||||
"""
|
||||
|
||||
# we check if cond_input already has CFG applied, and split if it is the case.
|
||||
if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
|
||||
if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance:
|
||||
if isinstance(cond_input, list):
|
||||
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
|
||||
else:
|
||||
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
|
||||
|
||||
if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None:
|
||||
raise ValueError(
|
||||
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
|
||||
)
|
||||
|
||||
if isinstance(cond_input, (list, tuple)):
|
||||
if not self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
|
||||
if len(negative_cond_input) != len(cond_input):
|
||||
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
|
||||
prepared_input = []
|
||||
for neg_cond, cond in zip(negative_cond_input, cond_input):
|
||||
if neg_cond.shape[0] != cond.shape[0]:
|
||||
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
|
||||
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
|
||||
return prepared_input
|
||||
|
||||
elif isinstance(cond_input, torch.Tensor):
|
||||
if not self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
else:
|
||||
return torch.cat([negative_cond_input, cond_input], dim=0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(cond_input)}")
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if not self.do_classifier_free_guidance:
|
||||
return model_output
|
||||
|
||||
noise_pred_uncond, noise_pred_text = model_output.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
||||
return noise_pred
|
||||
|
||||
|
||||
class PAGGuider:
|
||||
"""
|
||||
This class is used to guide the pipeline with CFG (Classifier-Free Guidance).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pag_applied_layers: Union[str, List[str]],
|
||||
pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = (
|
||||
PAGCFGIdentitySelfAttnProcessor2_0(),
|
||||
PAGIdentitySelfAttnProcessor2_0(),
|
||||
),
|
||||
):
|
||||
r"""
|
||||
Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
|
||||
|
||||
Args:
|
||||
pag_applied_layers (`str` or `List[str]`):
|
||||
One or more strings identifying the layer names, or a simple regex for matching multiple layers, where
|
||||
PAG is to be applied. A few ways of expected usage are as follows:
|
||||
- Single layers specified as - "blocks.{layer_index}"
|
||||
- Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...]
|
||||
- Multiple layers as a block name - "mid"
|
||||
- Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})"
|
||||
pag_attn_processors:
|
||||
(`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(),
|
||||
PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention
|
||||
processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second
|
||||
attention processor is for PAG with CFG disabled (unconditional only).
|
||||
"""
|
||||
|
||||
if not isinstance(pag_applied_layers, list):
|
||||
pag_applied_layers = [pag_applied_layers]
|
||||
if pag_attn_processors is not None:
|
||||
if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2:
|
||||
raise ValueError("Expected a tuple of two attention processors")
|
||||
|
||||
for i in range(len(pag_applied_layers)):
|
||||
if not isinstance(pag_applied_layers[i], str):
|
||||
raise ValueError(
|
||||
f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}"
|
||||
)
|
||||
|
||||
self.pag_applied_layers = pag_applied_layers
|
||||
self._pag_attn_processors = pag_attn_processors
|
||||
|
||||
def _set_pag_attn_processor(self, model, pag_applied_layers, do_classifier_free_guidance):
|
||||
r"""
|
||||
Set the attention processor for the PAG layers.
|
||||
"""
|
||||
pag_attn_processors = self._pag_attn_processors
|
||||
pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1]
|
||||
|
||||
def is_self_attn(module: nn.Module) -> bool:
|
||||
r"""
|
||||
Check if the module is self-attention module based on its name.
|
||||
"""
|
||||
return isinstance(module, Attention) and not module.is_cross_attention
|
||||
|
||||
def is_fake_integral_match(layer_id, name):
|
||||
layer_id = layer_id.split(".")[-1]
|
||||
name = name.split(".")[-1]
|
||||
return layer_id.isnumeric() and name.isnumeric() and layer_id == name
|
||||
|
||||
for layer_id in pag_applied_layers:
|
||||
# for each PAG layer input, we find corresponding self-attention layers in the unet model
|
||||
target_modules = []
|
||||
|
||||
for name, module in model.named_modules():
|
||||
# Identify the following simple cases:
|
||||
# (1) Self Attention layer existing
|
||||
# (2) Whether the module name matches pag layer id even partially
|
||||
# (3) Make sure it's not a fake integral match if the layer_id ends with a number
|
||||
# For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1"
|
||||
if (
|
||||
is_self_attn(module)
|
||||
and re.search(layer_id, name) is not None
|
||||
and not is_fake_integral_match(layer_id, name)
|
||||
):
|
||||
logger.debug(f"Applying PAG to layer: {name}")
|
||||
target_modules.append(module)
|
||||
|
||||
if len(target_modules) == 0:
|
||||
raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}")
|
||||
|
||||
for module in target_modules:
|
||||
module.processor = pag_attn_proc
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def do_perturbed_attention_guidance(self):
|
||||
return self._pag_scale > 0 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def do_pag_adaptive_scaling(self):
|
||||
return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def pag_scale(self):
|
||||
return self._pag_scale
|
||||
|
||||
@property
|
||||
def pag_adaptive_scale(self):
|
||||
return self._pag_adaptive_scale
|
||||
|
||||
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
|
||||
pag_scale = guider_kwargs.get("pag_scale", 3.0)
|
||||
pag_adaptive_scale = guider_kwargs.get("pag_adaptive_scale", 0.0)
|
||||
|
||||
batch_size = guider_kwargs.get("batch_size", None)
|
||||
if batch_size is None:
|
||||
raise ValueError("batch_size is a required argument for PAGGuider")
|
||||
|
||||
guidance_scale = guider_kwargs.get("guidance_scale", None)
|
||||
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
|
||||
disable_guidance = guider_kwargs.get("disable_guidance", False)
|
||||
|
||||
if guidance_scale is None:
|
||||
raise ValueError("guidance_scale is a required argument for PAGGuider")
|
||||
|
||||
self._pag_scale = pag_scale
|
||||
self._pag_adaptive_scale = pag_adaptive_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._disable_guidance = disable_guidance
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._batch_size = batch_size
|
||||
if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None:
|
||||
pipeline.original_attn_proc = pipeline.unet.attn_processors
|
||||
self._set_pag_attn_processor(
|
||||
model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer,
|
||||
pag_applied_layers=self.pag_applied_layers,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
def reset_guider(self, pipeline):
|
||||
if (
|
||||
self.do_perturbed_attention_guidance
|
||||
and hasattr(pipeline, "original_attn_proc")
|
||||
and pipeline.original_attn_proc is not None
|
||||
):
|
||||
pipeline.unet.set_attn_processor(pipeline.original_attn_proc)
|
||||
pipeline.original_attn_proc = None
|
||||
|
||||
def maybe_update_guider(self, pipeline, timestep):
|
||||
pass
|
||||
|
||||
def maybe_update_input(self, pipeline, cond_input):
|
||||
pass
|
||||
|
||||
def _is_prepared_input(self, cond):
|
||||
"""
|
||||
Check if the input is already prepared for Perturbed Attention Guidance (PAG).
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the input is already prepared, False otherwise.
|
||||
"""
|
||||
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
|
||||
|
||||
return cond_tensor.shape[0] == self.batch_size * 3
|
||||
|
||||
def _maybe_split_prepared_input(self, cond):
|
||||
"""
|
||||
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
|
||||
|
||||
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
|
||||
It determines whether to split the input based on its batch size relative to the expected batch size.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to process.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The negative conditional input (uncond_input)
|
||||
- The positive conditional input (cond_input)
|
||||
"""
|
||||
if cond.shape[0] == self.batch_size * 3:
|
||||
neg_cond = cond[0 : self.batch_size]
|
||||
cond = cond[self.batch_size : self.batch_size * 2]
|
||||
return neg_cond, cond
|
||||
elif cond.shape[0] == self.batch_size:
|
||||
return cond, cond
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape: {cond.shape}")
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
cond_input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
|
||||
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]] = None,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]:
|
||||
"""
|
||||
Prepare the input for CFG.
|
||||
|
||||
Args:
|
||||
cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]):
|
||||
The conditional input. It can be a single tensor or a
|
||||
list of tensors. It must have the same length as `negative_cond_input`.
|
||||
negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]):
|
||||
The negative conditional input. It can be a single tensor or a list of tensors. It must have the same
|
||||
length as `cond_input`.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: The prepared input.
|
||||
"""
|
||||
|
||||
# we check if cond_input already has CFG applied, and split if it is the case.
|
||||
|
||||
if self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance:
|
||||
return cond_input
|
||||
|
||||
if self._is_prepared_input(cond_input) and not self.do_perturbed_attention_guidance:
|
||||
if isinstance(cond_input, list):
|
||||
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
|
||||
else:
|
||||
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
|
||||
|
||||
if not self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance and negative_cond_input is None:
|
||||
raise ValueError(
|
||||
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
|
||||
)
|
||||
|
||||
if isinstance(cond_input, (list, tuple)):
|
||||
if not self.do_perturbed_attention_guidance:
|
||||
return cond_input
|
||||
|
||||
if len(negative_cond_input) != len(cond_input):
|
||||
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
|
||||
|
||||
prepared_input = []
|
||||
for neg_cond, cond in zip(negative_cond_input, cond_input):
|
||||
if neg_cond.shape[0] != cond.shape[0]:
|
||||
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
|
||||
|
||||
cond = torch.cat([cond] * 2, dim=0)
|
||||
if self.do_classifier_free_guidance:
|
||||
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
|
||||
else:
|
||||
prepared_input.append(cond)
|
||||
|
||||
return prepared_input
|
||||
|
||||
elif isinstance(cond_input, torch.Tensor):
|
||||
if not self.do_perturbed_attention_guidance:
|
||||
return cond_input
|
||||
|
||||
cond_input = torch.cat([cond_input] * 2, dim=0)
|
||||
if self.do_classifier_free_guidance:
|
||||
return torch.cat([negative_cond_input, cond_input], dim=0)
|
||||
else:
|
||||
return cond_input
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}")
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if not self.do_perturbed_attention_guidance:
|
||||
return model_output
|
||||
|
||||
if self.do_pag_adaptive_scaling:
|
||||
pag_scale = max(self._pag_scale - self._pag_adaptive_scale * (1000 - timestep), 0)
|
||||
else:
|
||||
pag_scale = self._pag_scale
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = model_output.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ pag_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
else:
|
||||
noise_pred_text, noise_pred_perturb = model_output.chunk(2)
|
||||
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
||||
|
||||
return noise_pred
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
class APGGuider:
|
||||
"""
|
||||
This class is used to guide the pipeline with APG (Adaptive Projected Guidance).
|
||||
"""
|
||||
|
||||
def normalized_guidance(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer: MomentumBuffer = None,
|
||||
norm_threshold: float = 0.0,
|
||||
eta: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion
|
||||
Models](https://arxiv.org/pdf/2410.02416)
|
||||
"""
|
||||
diff = pred_cond - pred_uncond
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
|
||||
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
|
||||
return pred_guided
|
||||
|
||||
@property
|
||||
def adaptive_projected_guidance_momentum(self):
|
||||
return self._adaptive_projected_guidance_momentum
|
||||
|
||||
@property
|
||||
def adaptive_projected_guidance_rescale_factor(self):
|
||||
return self._adaptive_projected_guidance_rescale_factor
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
|
||||
disable_guidance = guider_kwargs.get("disable_guidance", False)
|
||||
guidance_scale = guider_kwargs.get("guidance_scale", None)
|
||||
if guidance_scale is None:
|
||||
raise ValueError("guidance_scale is not provided in guider_kwargs")
|
||||
adaptive_projected_guidance_momentum = guider_kwargs.get("adaptive_projected_guidance_momentum", None)
|
||||
adaptive_projected_guidance_rescale_factor = guider_kwargs.get(
|
||||
"adaptive_projected_guidance_rescale_factor", 15.0
|
||||
)
|
||||
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
|
||||
batch_size = guider_kwargs.get("batch_size", None)
|
||||
if batch_size is None:
|
||||
raise ValueError("batch_size is not provided in guider_kwargs")
|
||||
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._batch_size = batch_size
|
||||
self._disable_guidance = disable_guidance
|
||||
if adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
|
||||
else:
|
||||
self.momentum_buffer = None
|
||||
self.scheduler = pipeline.scheduler
|
||||
|
||||
def reset_guider(self, pipeline):
|
||||
pass
|
||||
|
||||
def maybe_update_guider(self, pipeline, timestep):
|
||||
pass
|
||||
|
||||
def maybe_update_input(self, pipeline, cond_input):
|
||||
pass
|
||||
|
||||
def _maybe_split_prepared_input(self, cond):
|
||||
"""
|
||||
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
|
||||
|
||||
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
|
||||
It determines whether to split the input based on its batch size relative to the expected batch size.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to process.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The negative conditional input (uncond_input)
|
||||
- The positive conditional input (cond_input)
|
||||
"""
|
||||
if cond.shape[0] == self.batch_size * 2:
|
||||
neg_cond = cond[0 : self.batch_size]
|
||||
cond = cond[self.batch_size :]
|
||||
return neg_cond, cond
|
||||
elif cond.shape[0] == self.batch_size:
|
||||
return cond, cond
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape: {cond.shape}")
|
||||
|
||||
def _is_prepared_input(self, cond):
|
||||
"""
|
||||
Check if the input is already prepared for Classifier-Free Guidance (CFG).
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the input is already prepared, False otherwise.
|
||||
"""
|
||||
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
|
||||
|
||||
return cond_tensor.shape[0] == self.batch_size * 2
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
cond_input: Union[torch.Tensor, List[torch.Tensor]],
|
||||
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Prepare the input for CFG.
|
||||
|
||||
Args:
|
||||
cond_input (Union[torch.Tensor, List[torch.Tensor]]):
|
||||
The conditional input. It can be a single tensor or a
|
||||
list of tensors. It must have the same length as `negative_cond_input`.
|
||||
negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a
|
||||
single tensor or a list of tensors. It must have the same length as `cond_input`.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, List[torch.Tensor]]: The prepared input.
|
||||
"""
|
||||
|
||||
# we check if cond_input already has CFG applied, and split if it is the case.
|
||||
if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
|
||||
if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance:
|
||||
if isinstance(cond_input, list):
|
||||
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
|
||||
else:
|
||||
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
|
||||
|
||||
if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None:
|
||||
raise ValueError(
|
||||
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
|
||||
)
|
||||
|
||||
if isinstance(cond_input, (list, tuple)):
|
||||
if not self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
|
||||
if len(negative_cond_input) != len(cond_input):
|
||||
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
|
||||
prepared_input = []
|
||||
for neg_cond, cond in zip(negative_cond_input, cond_input):
|
||||
if neg_cond.shape[0] != cond.shape[0]:
|
||||
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
|
||||
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
|
||||
return prepared_input
|
||||
|
||||
elif isinstance(cond_input, torch.Tensor):
|
||||
if not self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
else:
|
||||
return torch.cat([negative_cond_input, cond_input], dim=0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(cond_input)}")
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if not self.do_classifier_free_guidance:
|
||||
return model_output
|
||||
|
||||
if latents is None:
|
||||
raise ValueError("APG requires `latents` to convert model output to denoised prediction (x0).")
|
||||
|
||||
sigma = self.scheduler.sigmas[self.scheduler.step_index]
|
||||
noise_pred = latents - sigma * model_output
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = self.normalized_guidance(
|
||||
noise_pred_text,
|
||||
noise_pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.adaptive_projected_guidance_rescale_factor,
|
||||
)
|
||||
noise_pred = (latents - noise_pred) / sigma
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
||||
return noise_pred
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -55,7 +55,7 @@ class ModuleGroup:
|
||||
parameters: Optional[List[torch.nn.Parameter]] = None,
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
onload_self: bool = True,
|
||||
@@ -115,13 +115,8 @@ class ModuleGroup:
|
||||
|
||||
def onload_(self):
|
||||
r"""Onloads the group of modules to the onload_device."""
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
else torch.cuda
|
||||
)
|
||||
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
|
||||
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
|
||||
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
|
||||
current_stream = torch.cuda.current_stream() if self.record_stream else None
|
||||
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
@@ -167,15 +162,9 @@ class ModuleGroup:
|
||||
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
else torch.cuda
|
||||
)
|
||||
if self.stream is not None:
|
||||
if not self.record_stream:
|
||||
torch_accelerator_module.current_stream().synchronize()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
@@ -440,10 +429,8 @@ def apply_group_offloading(
|
||||
if use_stream:
|
||||
if torch.cuda.is_available():
|
||||
stream = torch.cuda.Stream()
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
stream = torch.Stream()
|
||||
else:
|
||||
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
|
||||
raise ValueError("Using streams for data transfer requires a CUDA device.")
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
@@ -481,7 +468,7 @@ def _apply_group_offloading_block_level(
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
@@ -499,7 +486,7 @@ def _apply_group_offloading_block_level(
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
|
||||
stream (`torch.cuda.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
@@ -512,10 +499,7 @@ def _apply_group_offloading_block_level(
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
"""
|
||||
if stream is not None and num_blocks_per_group != 1:
|
||||
logger.warning(
|
||||
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
|
||||
)
|
||||
num_blocks_per_group = 1
|
||||
raise ValueError(f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}.")
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -585,7 +569,7 @@ def _apply_group_offloading_leaf_level(
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
@@ -605,7 +589,7 @@ def _apply_group_offloading_leaf_level(
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
|
||||
stream (`torch.cuda.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
|
||||
@@ -54,14 +54,14 @@ if is_transformers_available():
|
||||
_import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
|
||||
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
|
||||
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
|
||||
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
||||
_import_structure["ip_adapter.transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
|
||||
_import_structure["ip_adapter.transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
|
||||
_import_structure["single_file.single_file_model"] = ["FromOriginalModelMixin"]
|
||||
_import_structure["unet.unet"] = ["UNet2DConditionLoadersMixin"]
|
||||
_import_structure["utils"] = ["AttnProcsLayers"]
|
||||
if is_transformers_available():
|
||||
_import_structure["single_file"] = ["FromSingleFileMixin"]
|
||||
_import_structure["lora_pipeline"] = [
|
||||
_import_structure["single_file.single_file"] = ["FromSingleFileMixin"]
|
||||
_import_structure["lora.lora_pipeline"] = [
|
||||
"AmusedLoraLoaderMixin",
|
||||
"StableDiffusionLoraLoaderMixin",
|
||||
"SD3LoraLoaderMixin",
|
||||
@@ -80,11 +80,10 @@ if is_torch_available():
|
||||
"HiDreamImageLoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = [
|
||||
_import_structure["ip_adapter.ip_adapter"] = [
|
||||
"IPAdapterMixin",
|
||||
"FluxIPAdapterMixin",
|
||||
"SD3IPAdapterMixin",
|
||||
"ModularIPAdapterMixin",
|
||||
]
|
||||
|
||||
_import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
@@ -92,20 +91,14 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .single_file_model import FromOriginalModelMixin
|
||||
from .transformer_flux import FluxTransformer2DLoadersMixin
|
||||
from .transformer_sd3 import SD3Transformer2DLoadersMixin
|
||||
from .ip_adapter import FluxTransformer2DLoadersMixin, SD3Transformer2DLoadersMixin
|
||||
from .single_file import FromOriginalModelMixin
|
||||
from .unet import UNet2DConditionLoadersMixin
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
if is_transformers_available():
|
||||
from .ip_adapter import (
|
||||
FluxIPAdapterMixin,
|
||||
IPAdapterMixin,
|
||||
SD3IPAdapterMixin,
|
||||
ModularIPAdapterMixin,
|
||||
)
|
||||
from .lora_pipeline import (
|
||||
from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin
|
||||
from .lora import (
|
||||
AmusedLoraLoaderMixin,
|
||||
AuraFlowLoraLoaderMixin,
|
||||
CogVideoXLoraLoaderMixin,
|
||||
@@ -113,6 +106,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxLoraLoaderMixin,
|
||||
HiDreamImageLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
LoraBaseMixin,
|
||||
LoraLoaderMixin,
|
||||
LTXVideoLoraLoaderMixin,
|
||||
Lumina2LoraLoaderMixin,
|
||||
|
||||
+17
-1117
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,9 @@
|
||||
from ...utils.import_utils import is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .transformer_flux import FluxTransformer2DLoadersMixin
|
||||
from .transformer_sd3 import SD3Transformer2DLoadersMixin
|
||||
|
||||
if is_transformers_available():
|
||||
from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin
|
||||
@@ -0,0 +1,879 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from safetensors import safe_open
|
||||
|
||||
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_detailed_type,
|
||||
_get_model_file,
|
||||
_is_valid_type,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
from ..unet.unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
JointAttnProcessor2_0,
|
||||
SD3IPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class IPAdapterMixin:
|
||||
"""Mixin for handling IP Adapters."""
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
subfolder: Union[str, List[str]],
|
||||
weight_name: Union[str, List[str]],
|
||||
image_encoder_folder: Optional[str] = "image_encoder",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
||||
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
||||
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
||||
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
||||
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
||||
`image_encoder_folder="different_subfolder/image_encoder"`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
):
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_folder is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||
if image_encoder_folder.count("/") == 0:
|
||||
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
|
||||
else:
|
||||
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
|
||||
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
subfolder=image_encoder_subfolder,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
torch_dtype=self.dtype,
|
||||
).to(self.device)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
||||
default_clip_size = 224
|
||||
clip_image_size = (
|
||||
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
||||
)
|
||||
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
||||
self.register_modules(feature_extractor=feature_extractor)
|
||||
|
||||
# load ip-adapter into unet
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
extra_loras = unet._load_ip_adapter_loras(state_dicts)
|
||||
if extra_loras != {}:
|
||||
if not USE_PEFT_BACKEND:
|
||||
logger.warning("PEFT backend is required to load these weights.")
|
||||
else:
|
||||
# apply the IP Adapter Face ID LoRA weights
|
||||
peft_config = getattr(unet, "peft_config", {})
|
||||
for k, lora in extra_loras.items():
|
||||
if f"faceid_{k}" not in peft_config:
|
||||
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
|
||||
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
# To use original IP-Adapter
|
||||
scale = 1.0
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style block only
|
||||
scale = {
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style+layout blocks
|
||||
scale = {
|
||||
"down": {"block_2": [0.0, 1.0]},
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style and layout from 2 reference images
|
||||
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
|
||||
pipeline.set_ip_adapter_scale(scales)
|
||||
```
|
||||
"""
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale]
|
||||
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
|
||||
|
||||
for attn_name, attn_processor in unet.attn_processors.items():
|
||||
if isinstance(
|
||||
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
for i, scale_config in enumerate(scale_configs):
|
||||
if isinstance(scale_config, dict):
|
||||
for k, s in scale_config.items():
|
||||
if attn_name.startswith(k):
|
||||
attn_processor.scale[i] = s
|
||||
else:
|
||||
attn_processor.scale[i] = scale_config
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# remove CLIP image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=[None, None])
|
||||
|
||||
# remove feature extractor only when safety_checker is None as safety_checker uses
|
||||
# the feature_extractor later
|
||||
if not hasattr(self, "safety_checker"):
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=[None, None])
|
||||
|
||||
# remove hidden encoder
|
||||
self.unet.encoder_hid_proj = None
|
||||
self.unet.config.encoder_hid_dim_type = None
|
||||
|
||||
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
|
||||
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
|
||||
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
|
||||
self.unet.text_encoder_hid_proj = None
|
||||
self.unet.config.encoder_hid_dim_type = "text_proj"
|
||||
|
||||
# restore original Unet attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.unet.attn_processors.items():
|
||||
attn_processor_class = (
|
||||
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
||||
)
|
||||
attn_procs[name] = (
|
||||
attn_processor_class
|
||||
if isinstance(
|
||||
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
)
|
||||
else value.__class__()
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class FluxIPAdapterMixin:
|
||||
"""Mixin for handling Flux IP Adapters."""
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
weight_name: Union[str, List[str]],
|
||||
subfolder: Optional[Union[str, List[str]]] = "",
|
||||
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
|
||||
image_encoder_subfolder: Optional[str] = "",
|
||||
image_encoder_dtype: torch.dtype = torch.float16,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`weight_name`.
|
||||
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
):
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
|
||||
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
|
||||
for key in f.keys():
|
||||
if any(key.startswith(prefix) for prefix in image_proj_keys):
|
||||
diffusers_name = ".".join(key.split(".")[1:])
|
||||
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
|
||||
elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
|
||||
diffusers_name = (
|
||||
".".join(key.split(".")[1:])
|
||||
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
|
||||
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
|
||||
.replace("processor.", "")
|
||||
)
|
||||
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if keys != ["image_proj", "ip_adapter"]:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_pretrained_model_name_or_path is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
|
||||
image_encoder = (
|
||||
CLIPVisionModelWithProjection.from_pretrained(
|
||||
image_encoder_pretrained_model_name_or_path,
|
||||
subfolder=image_encoder_subfolder,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
dtype=image_encoder_dtype,
|
||||
)
|
||||
.to(self.device)
|
||||
.eval()
|
||||
)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
||||
default_clip_size = 224
|
||||
clip_image_size = (
|
||||
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
||||
)
|
||||
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
||||
self.register_modules(feature_extractor=feature_extractor)
|
||||
|
||||
# load ip-adapter into transformer
|
||||
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a list.
|
||||
|
||||
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
|
||||
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
|
||||
number of IP adapters and each must match the number of blocks.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
# To use original IP-Adapter
|
||||
scale = 1.0
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
|
||||
def LinearStrengthModel(start, finish, size):
|
||||
return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
|
||||
|
||||
|
||||
ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
|
||||
pipeline.set_ip_adapter_scale(ip_strengths)
|
||||
```
|
||||
"""
|
||||
|
||||
scale_type = Union[int, float]
|
||||
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
num_layers = self.transformer.config.num_layers
|
||||
|
||||
# Single value for all layers of all IP-Adapters
|
||||
if isinstance(scale, scale_type):
|
||||
scale = [scale for _ in range(num_ip_adapters)]
|
||||
# List of per-layer scales for a single IP-Adapter
|
||||
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
|
||||
scale = [scale]
|
||||
# Invalid scale type
|
||||
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
|
||||
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
|
||||
|
||||
if len(scale) != num_ip_adapters:
|
||||
raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
|
||||
|
||||
if any(len(s) != num_layers for s in scale if isinstance(s, list)):
|
||||
invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
|
||||
raise ValueError(
|
||||
f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
|
||||
)
|
||||
|
||||
# Scalars are transformed to lists with length num_layers
|
||||
scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
|
||||
|
||||
# Set scales. zip over scale_configs prevents going into single transformer layers
|
||||
for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
|
||||
attn_processor.scale = scale
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# remove CLIP image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=[None, None])
|
||||
|
||||
# remove feature extractor only when safety_checker is None as safety_checker uses
|
||||
# the feature_extractor later
|
||||
if not hasattr(self, "safety_checker"):
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=[None, None])
|
||||
|
||||
# remove hidden encoder
|
||||
self.transformer.encoder_hid_proj = None
|
||||
self.transformer.config.encoder_hid_dim_type = None
|
||||
|
||||
# restore original Transformer attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.transformer.attn_processors.items():
|
||||
attn_processor_class = FluxAttnProcessor2_0()
|
||||
attn_procs[name] = (
|
||||
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
|
||||
)
|
||||
self.transformer.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class SD3IPAdapterMixin:
|
||||
"""Mixin for handling StableDiffusion 3 IP Adapters."""
|
||||
|
||||
@property
|
||||
def is_ip_adapter_active(self) -> bool:
|
||||
"""Checks if IP-Adapter is loaded and scale > 0.
|
||||
|
||||
IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
|
||||
the image context is irrelevant.
|
||||
|
||||
Returns:
|
||||
`bool`: True when IP-Adapter is loaded and any layer has scale > 0.
|
||||
"""
|
||||
scales = [
|
||||
attn_proc.scale
|
||||
for attn_proc in self.transformer.attn_processors.values()
|
||||
if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
|
||||
]
|
||||
|
||||
return len(scales) > 0 and any(scale > 0 for scale in scales)
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
weight_name: str = "ip-adapter.safetensors",
|
||||
subfolder: Optional[str] = None,
|
||||
image_encoder_folder: Optional[str] = "image_encoder",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
weight_name (`str`, defaults to "ip-adapter.safetensors"):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
subfolder (`str`, *optional*):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
||||
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
||||
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
||||
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
||||
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
||||
`image_encoder_folder="different_subfolder/image_encoder"`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
# Load the main state dict first
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
# Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_folder is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||
if image_encoder_folder.count("/") == 0:
|
||||
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
|
||||
else:
|
||||
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
|
||||
|
||||
# Commons args for loading image encoder and image processor
|
||||
kwargs = {
|
||||
"low_cpu_mem_usage": low_cpu_mem_usage,
|
||||
"cache_dir": cache_dir,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
|
||||
self.register_modules(
|
||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
|
||||
image_encoder=SiglipVisionModel.from_pretrained(
|
||||
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
|
||||
).to(self.device),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# Load IP-Adapter into transformer
|
||||
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale: float) -> None:
|
||||
"""
|
||||
Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
|
||||
conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
|
||||
the model to produce more diverse images, but they may not be as aligned with the image prompt.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.set_ip_adapter_scale(0.6)
|
||||
>>> ...
|
||||
```
|
||||
|
||||
Args:
|
||||
scale (float):
|
||||
IP-Adapter scale to be set.
|
||||
|
||||
"""
|
||||
for attn_processor in self.transformer.attn_processors.values():
|
||||
if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
|
||||
attn_processor.scale = scale
|
||||
|
||||
def unload_ip_adapter(self) -> None:
|
||||
"""
|
||||
Unloads the IP Adapter weights.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# Remove image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=None)
|
||||
|
||||
# Remove feature extractor
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=None)
|
||||
|
||||
# Remove image projection
|
||||
self.transformer.image_proj = None
|
||||
|
||||
# Restore original attention processors layers
|
||||
attn_procs = {
|
||||
name: (
|
||||
JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
|
||||
)
|
||||
for name, value in self.transformer.attn_processors.items()
|
||||
}
|
||||
self.transformer.set_attn_processor(attn_procs)
|
||||
@@ -0,0 +1,168 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
|
||||
from ...models.embeddings import ImageProjection, MultiIPAdapterImageProjection
|
||||
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ...utils import is_accelerate_available, is_torch_version, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FluxTransformer2DLoadersMixin:
|
||||
"""
|
||||
Load layers into a [`FluxTransformer2DModel`].
|
||||
"""
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
updated_state_dict = {}
|
||||
image_projection = None
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
|
||||
if "proj.weight" in state_dict:
|
||||
# IP-Adapter
|
||||
num_image_text_embeds = 4
|
||||
if state_dict["proj.weight"].shape[0] == 65536:
|
||||
num_image_text_embeds = 16
|
||||
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
|
||||
|
||||
with init_context():
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
image_embed_dim=clip_embeddings_dim,
|
||||
num_image_text_embeds=num_image_text_embeds,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("proj", "image_embeds")
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
image_projection.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
|
||||
return image_projection
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
from ...models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
# set ip-adapter cross-attention processors & load state_dict
|
||||
attn_procs = {}
|
||||
key_id = 0
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
for name in self.attn_processors.keys():
|
||||
if name.startswith("single_transformer_blocks"):
|
||||
attn_processor_class = self.attn_processors[name].__class__
|
||||
attn_procs[name] = attn_processor_class()
|
||||
else:
|
||||
cross_attention_dim = self.config.joint_attention_dim
|
||||
hidden_size = self.inner_dim
|
||||
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
|
||||
num_image_text_embeds = []
|
||||
for state_dict in state_dicts:
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
num_image_text_embed = 4
|
||||
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
|
||||
num_image_text_embed = 16
|
||||
# IP-Adapter
|
||||
num_image_text_embeds += [num_image_text_embed]
|
||||
|
||||
with init_context():
|
||||
attn_procs[name] = attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=num_image_text_embeds,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
value_dict = {}
|
||||
for i, state_dict in enumerate(state_dicts):
|
||||
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
||||
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
||||
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
|
||||
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(value_dict)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
dtype = self.dtype
|
||||
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
|
||||
|
||||
key_id += 1
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
if not isinstance(state_dicts, list):
|
||||
state_dicts = [state_dicts]
|
||||
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
image_projection_layers = []
|
||||
for state_dict in state_dicts:
|
||||
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
|
||||
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
|
||||
)
|
||||
image_projection_layers.append(image_projection_layer)
|
||||
|
||||
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
@@ -0,0 +1,170 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict
|
||||
|
||||
from ...models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
||||
from ...models.embeddings import IPAdapterTimeImageProjection
|
||||
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ...utils import is_accelerate_available, is_torch_version, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class SD3Transformer2DLoadersMixin:
|
||||
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(
|
||||
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
) -> Dict:
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
# IP-Adapter cross attention parameters
|
||||
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
|
||||
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
|
||||
timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1]
|
||||
|
||||
# Dict where key is transformer layer index, value is attention processor's state dict
|
||||
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
|
||||
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
|
||||
for key, weights in state_dict.items():
|
||||
idx, name = key.split(".", maxsplit=1)
|
||||
layer_state_dict[int(idx)][name] = weights
|
||||
|
||||
# Create IP-Adapter attention processor & load state_dict
|
||||
attn_procs = {}
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
for idx, name in enumerate(self.attn_processors.keys()):
|
||||
with init_context():
|
||||
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
|
||||
hidden_size=hidden_size,
|
||||
ip_hidden_states_dim=ip_hidden_states_dim,
|
||||
head_dim=self.config.attention_head_dim,
|
||||
timesteps_emb_dim=timesteps_emb_dim,
|
||||
)
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(
|
||||
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
|
||||
)
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(
|
||||
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
) -> IPAdapterTimeImageProjection:
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
|
||||
# Convert to diffusers
|
||||
updated_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
# InstantX/SD3.5-Large-IP-Adapter
|
||||
if key.startswith("layers."):
|
||||
idx = key.split(".")[1]
|
||||
key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
|
||||
key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
|
||||
key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
|
||||
key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
|
||||
key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
|
||||
key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
|
||||
key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
|
||||
key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
|
||||
key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
|
||||
updated_state_dict[key] = value
|
||||
|
||||
# Image projetion parameters
|
||||
embed_dim = updated_state_dict["proj_in.weight"].shape[1]
|
||||
output_dim = updated_state_dict["proj_out.weight"].shape[0]
|
||||
hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
|
||||
heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
|
||||
num_queries = updated_state_dict["latents"].shape[1]
|
||||
timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
# Image projection
|
||||
with init_context():
|
||||
image_proj = IPAdapterTimeImageProjection(
|
||||
embed_dim=embed_dim,
|
||||
output_dim=output_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
heads=heads,
|
||||
num_queries=num_queries,
|
||||
timestep_in_dim=timestep_in_dim,
|
||||
)
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
image_proj.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
|
||||
return image_proj
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
|
||||
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
|
||||
|
||||
Args:
|
||||
state_dict (`Dict`):
|
||||
State dict with keys "ip_adapter", which contains parameters for attention processors, and
|
||||
"image_proj", which contains parameters for image projection net.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage)
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage)
|
||||
@@ -0,0 +1,25 @@
|
||||
from ...utils import is_peft_available, is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .lora_base import LoraBaseMixin
|
||||
|
||||
if is_transformers_available():
|
||||
from .lora_pipeline import (
|
||||
AmusedLoraLoaderMixin,
|
||||
AuraFlowLoraLoaderMixin,
|
||||
CogVideoXLoraLoaderMixin,
|
||||
CogView4LoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HiDreamImageLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
LTXVideoLoraLoaderMixin,
|
||||
Lumina2LoraLoaderMixin,
|
||||
Mochi1LoraLoaderMixin,
|
||||
SanaLoraLoaderMixin,
|
||||
SD3LoraLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
WanLoraLoaderMixin,
|
||||
)
|
||||
@@ -0,0 +1,935 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
|
||||
from ...models.modeling_utils import ModelMixin, load_state_dict
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ...models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
if is_peft_available():
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
"""
|
||||
Fuses LoRAs for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
"""
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
"""
|
||||
Unfuses LoRAs for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
|
||||
def set_adapters_for_text_encoder(
|
||||
adapter_names: Union[List[str], str],
|
||||
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the adapter layers for the text encoder.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
text_encoder_weights (`List[float]`, *optional*):
|
||||
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
|
||||
"""
|
||||
if text_encoder is None:
|
||||
raise ValueError(
|
||||
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
|
||||
)
|
||||
|
||||
def process_weights(adapter_names, weights):
|
||||
# Expand weights into a list, one entry per adapter
|
||||
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
|
||||
if not isinstance(weights, list):
|
||||
weights = [weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
||||
)
|
||||
|
||||
# Set None values to default of 1.0
|
||||
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
|
||||
weights = [w if w is not None else 1.0 for w in weights]
|
||||
|
||||
return weights
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
|
||||
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
|
||||
|
||||
|
||||
def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
|
||||
"""
|
||||
Disables the LoRA layers for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
if text_encoder is None:
|
||||
raise ValueError("Text Encoder not found.")
|
||||
set_adapter_layers(text_encoder, enabled=False)
|
||||
|
||||
|
||||
def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
|
||||
"""
|
||||
Enables the LoRA layers for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
if text_encoder is None:
|
||||
raise ValueError("Text Encoder not found.")
|
||||
set_adapter_layers(text_encoder, enabled=True)
|
||||
|
||||
|
||||
def _remove_text_encoder_monkey_patch(text_encoder):
|
||||
recurse_remove_peft_layers(text_encoder)
|
||||
if getattr(text_encoder, "peft_config", None) is not None:
|
||||
del text_encoder.peft_config
|
||||
text_encoder._hf_peft_config_loaded = None
|
||||
|
||||
|
||||
def _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weight_name,
|
||||
use_safetensors,
|
||||
local_files_only,
|
||||
cache_dir,
|
||||
force_download,
|
||||
proxies,
|
||||
token,
|
||||
revision,
|
||||
subfolder,
|
||||
user_agent,
|
||||
allow_pickle,
|
||||
):
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
# Here we're relaxing the loading check to enable more Inference API
|
||||
# friendliness where sometimes, it's not at all possible to automatically
|
||||
# determine `weight_name`.
|
||||
if weight_name is None:
|
||||
weight_name = _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
file_extension=".safetensors",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except (IOError, safetensors.SafetensorError) as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
# try loading non-safetensors weights
|
||||
model_file = None
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
if weight_name is None:
|
||||
weight_name = _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
||||
):
|
||||
if local_files_only or HF_HUB_OFFLINE:
|
||||
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
||||
|
||||
targeted_files = []
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
||||
return
|
||||
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
||||
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
|
||||
else:
|
||||
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
||||
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
||||
if len(targeted_files) == 0:
|
||||
return
|
||||
|
||||
# "scheduler" does not correspond to a LoRA checkpoint.
|
||||
# "optimizer" does not correspond to a LoRA checkpoint
|
||||
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
||||
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
||||
targeted_files = list(
|
||||
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
||||
)
|
||||
|
||||
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
|
||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
|
||||
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
|
||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
||||
|
||||
if len(targeted_files) > 1:
|
||||
raise ValueError(
|
||||
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
||||
)
|
||||
weight_name = targeted_files[0]
|
||||
return weight_name
|
||||
|
||||
|
||||
def _load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
text_encoder_name="text_encoder",
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
||||
# their prefixes.
|
||||
prefix = text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
|
||||
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
|
||||
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
if prefix is not None:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
state_dict = convert_state_dict_to_diffusers(state_dict)
|
||||
|
||||
# convert state dict
|
||||
state_dict = convert_state_dict_to_peft(state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
logger.warning(
|
||||
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
|
||||
"This is safe to ignore if LoRA state dict didn't originally have any "
|
||||
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
|
||||
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
||||
"https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
|
||||
def _func_optionally_disable_offloading(_pipeline):
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
|
||||
class LoraBaseMixin:
|
||||
"""Utility class for handling LoRAs."""
|
||||
|
||||
_lora_loadable_modules = []
|
||||
num_fused_loras = 0
|
||||
|
||||
def load_lora_weights(self, **kwargs):
|
||||
raise NotImplementedError("`load_lora_weights()` is not implemented.")
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(cls, **kwargs):
|
||||
raise NotImplementedError("`save_lora_weights()` not implemented.")
|
||||
|
||||
@classmethod
|
||||
def lora_state_dict(cls, **kwargs):
|
||||
raise NotImplementedError("`lora_state_dict()` is not implemented.")
|
||||
|
||||
@classmethod
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@classmethod
|
||||
def _fetch_state_dict(cls, *args, **kwargs):
|
||||
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
|
||||
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
|
||||
return _fetch_state_dict(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _best_guess_weight_name(cls, *args, **kwargs):
|
||||
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
|
||||
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
|
||||
return _best_guess_weight_name(*args, **kwargs)
|
||||
|
||||
def unload_lora_weights(self):
|
||||
"""
|
||||
Unloads the LoRA parameters.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
|
||||
>>> pipeline.unload_lora_weights()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.unload_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
_remove_text_encoder_monkey_patch(model)
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = [],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
if "fuse_unet" in kwargs:
|
||||
depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
|
||||
deprecate(
|
||||
"fuse_unet",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "fuse_transformer" in kwargs:
|
||||
depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
|
||||
deprecate(
|
||||
"fuse_transformer",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "fuse_text_encoder" in kwargs:
|
||||
depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
|
||||
deprecate(
|
||||
"fuse_text_encoder",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
|
||||
if len(components) == 0:
|
||||
raise ValueError("`components` cannot be an empty list.")
|
||||
|
||||
for fuse_component in components:
|
||||
if fuse_component not in self._lora_loadable_modules:
|
||||
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
||||
|
||||
model = getattr(self, fuse_component, None)
|
||||
if model is not None:
|
||||
# check if diffusers model
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||
# handle transformers models.
|
||||
if issubclass(model.__class__, PreTrainedModel):
|
||||
fuse_text_encoder_lora(
|
||||
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
self.num_fused_loras += 1
|
||||
|
||||
def unfuse_lora(self, components: List[str] = [], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
if "unfuse_unet" in kwargs:
|
||||
depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
|
||||
deprecate(
|
||||
"unfuse_unet",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "unfuse_transformer" in kwargs:
|
||||
depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
|
||||
deprecate(
|
||||
"unfuse_transformer",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "unfuse_text_encoder" in kwargs:
|
||||
depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
|
||||
deprecate(
|
||||
"unfuse_text_encoder",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
|
||||
if len(components) == 0:
|
||||
raise ValueError("`components` cannot be an empty list.")
|
||||
|
||||
for fuse_component in components:
|
||||
if fuse_component not in self._lora_loadable_modules:
|
||||
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
||||
|
||||
model = getattr(self, fuse_component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
self.num_fused_loras -= 1
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
||||
):
|
||||
if isinstance(adapter_weights, dict):
|
||||
components_passed = set(adapter_weights.keys())
|
||||
lora_components = set(self._lora_loadable_modules)
|
||||
|
||||
invalid_components = sorted(components_passed - lora_components)
|
||||
if invalid_components:
|
||||
logger.warning(
|
||||
f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
|
||||
f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
|
||||
"to the invalid components will be removed and ignored."
|
||||
)
|
||||
adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
adapter_weights = copy.deepcopy(adapter_weights)
|
||||
|
||||
# Expand weights into a list, one entry per adapter
|
||||
if not isinstance(adapter_weights, list):
|
||||
adapter_weights = [adapter_weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(adapter_weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
|
||||
)
|
||||
|
||||
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
||||
# eg ["adapter1", "adapter2"]
|
||||
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
|
||||
missing_adapters = set(adapter_names) - all_adapters
|
||||
if len(missing_adapters) > 0:
|
||||
raise ValueError(
|
||||
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
|
||||
)
|
||||
|
||||
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
||||
invert_list_adapters = {
|
||||
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
|
||||
for adapter in all_adapters
|
||||
}
|
||||
|
||||
# Decompose weights into weights for denoiser and text encoders.
|
||||
_component_adapter_weights = {}
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component)
|
||||
|
||||
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
||||
if isinstance(weights, dict):
|
||||
component_adapter_weights = weights.pop(component, None)
|
||||
if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
|
||||
logger.warning(
|
||||
(
|
||||
f"Lora weight dict for adapter '{adapter_name}' contains {component},"
|
||||
f"but this will be ignored because {adapter_name} does not contain weights for {component}."
|
||||
f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
component_adapter_weights = weights
|
||||
|
||||
_component_adapter_weights.setdefault(component, [])
|
||||
_component_adapter_weights[component].append(component_adapter_weights)
|
||||
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.set_adapters(adapter_names, _component_adapter_weights[component])
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
|
||||
|
||||
def disable_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.disable_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
disable_lora_for_text_encoder(model)
|
||||
|
||||
def enable_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.enable_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
enable_lora_for_text_encoder(model)
|
||||
|
||||
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
||||
"""
|
||||
Args:
|
||||
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
|
||||
adapter_names (`Union[List[str], str]`):
|
||||
The names of the adapter to delete. Can be a single string or a list of strings
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if isinstance(adapter_names, str):
|
||||
adapter_names = [adapter_names]
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.delete_adapters(adapter_names)
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
for adapter_name in adapter_names:
|
||||
delete_adapter_layers(model, adapter_name)
|
||||
|
||||
def get_active_adapters(self) -> List[str]:
|
||||
"""
|
||||
Gets the list of the current active adapters.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
pipeline.get_active_adapters()
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError(
|
||||
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
||||
)
|
||||
|
||||
active_adapters = []
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None and issubclass(model.__class__, ModelMixin):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
active_adapters = module.active_adapters
|
||||
break
|
||||
|
||||
return active_adapters
|
||||
|
||||
def get_list_adapters(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Gets the current list of all available adapters in the pipeline.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError(
|
||||
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
||||
)
|
||||
|
||||
set_adapters = {}
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if (
|
||||
model is not None
|
||||
and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
|
||||
and hasattr(model, "peft_config")
|
||||
):
|
||||
set_adapters[component] = list(model.peft_config.keys())
|
||||
|
||||
return set_adapters
|
||||
|
||||
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
|
||||
"""
|
||||
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
||||
you want to load multiple adapters and free some GPU memory.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]`):
|
||||
List of adapters to send device to.
|
||||
device (`Union[torch.device, str, int]`):
|
||||
Device to send the adapters to. Can be either a torch device, a str or an integer.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
for adapter_name in adapter_names:
|
||||
module.lora_A[adapter_name].to(device)
|
||||
module.lora_B[adapter_name].to(device)
|
||||
# this is a param, not a module, so device placement is not in-place -> re-assign
|
||||
if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
|
||||
if adapter_name in module.lora_magnitude_vector:
|
||||
module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
|
||||
adapter_name
|
||||
].to(device)
|
||||
|
||||
@staticmethod
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
@staticmethod
|
||||
def write_lora_layers(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
save_directory: str,
|
||||
is_main_process: bool,
|
||||
weight_name: str,
|
||||
save_function: Callable,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
|
||||
save_path = Path(save_directory, weight_name).as_posix()
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
@property
|
||||
def lora_scale(self) -> float:
|
||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||
# if _lora_scale has not been set, return 1
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
def enable_lora_hotswap(self, **kwargs) -> None:
|
||||
"""Enables the possibility to hotswap LoRA adapters.
|
||||
|
||||
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
|
||||
the loaded adapters differ.
|
||||
|
||||
Args:
|
||||
target_rank (`int`):
|
||||
The highest rank among all the adapters that will be loaded.
|
||||
check_compiled (`str`, *optional*, defaults to `"error"`):
|
||||
How to handle the case when the model is already compiled, which should generally be avoided. The
|
||||
options are:
|
||||
- "error" (default): raise an error
|
||||
- "warn": issue a warning
|
||||
- "ignore": do nothing
|
||||
"""
|
||||
for key, component in self.components.items():
|
||||
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
|
||||
component.enable_lora_hotswap(**kwargs)
|
||||
+1
-1
@@ -17,7 +17,7 @@ from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import is_peft_version, logging, state_dict_all_zero
|
||||
from ...utils import is_peft_version, logging, state_dict_all_zero
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,927 +12,66 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
|
||||
from ..models.modeling_utils import ModelMixin, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
if is_peft_available():
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
from ..utils import deprecate
|
||||
from .lora.lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin # noqa: F401
|
||||
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
"""
|
||||
Fuses LoRAs for the text encoder.
|
||||
from .lora.lora_base import fuse_text_encoder_lora
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
"""
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
deprecation_message = "Importing `fuse_text_encoder_lora()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import fuse_text_encoder_lora` instead."
|
||||
deprecate("diffusers.loaders.lora_base.fuse_text_encoder_lora", "0.36", deprecation_message)
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
return fuse_text_encoder_lora(
|
||||
text_encoder, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
"""
|
||||
Unfuses LoRAs for the text encoder.
|
||||
from .lora.lora_base import unfuse_text_encoder_lora
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
deprecation_message = "Importing `unfuse_text_encoder_lora()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import unfuse_text_encoder_lora` instead."
|
||||
deprecate("diffusers.loaders.lora_base.unfuse_text_encoder_lora", "0.36", deprecation_message)
|
||||
|
||||
return unfuse_text_encoder_lora(text_encoder)
|
||||
|
||||
|
||||
def set_adapters_for_text_encoder(
|
||||
adapter_names: Union[List[str], str],
|
||||
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
|
||||
adapter_names,
|
||||
text_encoder=None,
|
||||
text_encoder_weights=None,
|
||||
):
|
||||
"""
|
||||
Sets the adapter layers for the text encoder.
|
||||
from .lora.lora_base import set_adapters_for_text_encoder
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
text_encoder_weights (`List[float]`, *optional*):
|
||||
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
|
||||
"""
|
||||
if text_encoder is None:
|
||||
raise ValueError(
|
||||
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
|
||||
)
|
||||
deprecation_message = "Importing `set_adapters_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import set_adapters_for_text_encoder` instead."
|
||||
deprecate("diffusers.loaders.lora_base.set_adapters_for_text_encoder", "0.36", deprecation_message)
|
||||
|
||||
def process_weights(adapter_names, weights):
|
||||
# Expand weights into a list, one entry per adapter
|
||||
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
|
||||
if not isinstance(weights, list):
|
||||
weights = [weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
||||
)
|
||||
|
||||
# Set None values to default of 1.0
|
||||
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
|
||||
weights = [w if w is not None else 1.0 for w in weights]
|
||||
|
||||
return weights
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
|
||||
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
|
||||
|
||||
|
||||
def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
|
||||
"""
|
||||
Disables the LoRA layers for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
if text_encoder is None:
|
||||
raise ValueError("Text Encoder not found.")
|
||||
set_adapter_layers(text_encoder, enabled=False)
|
||||
|
||||
|
||||
def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
|
||||
"""
|
||||
Enables the LoRA layers for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
if text_encoder is None:
|
||||
raise ValueError("Text Encoder not found.")
|
||||
set_adapter_layers(text_encoder, enabled=True)
|
||||
|
||||
|
||||
def _remove_text_encoder_monkey_patch(text_encoder):
|
||||
recurse_remove_peft_layers(text_encoder)
|
||||
if getattr(text_encoder, "peft_config", None) is not None:
|
||||
del text_encoder.peft_config
|
||||
text_encoder._hf_peft_config_loaded = None
|
||||
|
||||
|
||||
def _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weight_name,
|
||||
use_safetensors,
|
||||
local_files_only,
|
||||
cache_dir,
|
||||
force_download,
|
||||
proxies,
|
||||
token,
|
||||
revision,
|
||||
subfolder,
|
||||
user_agent,
|
||||
allow_pickle,
|
||||
):
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
# Here we're relaxing the loading check to enable more Inference API
|
||||
# friendliness where sometimes, it's not at all possible to automatically
|
||||
# determine `weight_name`.
|
||||
if weight_name is None:
|
||||
weight_name = _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
file_extension=".safetensors",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except (IOError, safetensors.SafetensorError) as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
# try loading non-safetensors weights
|
||||
model_file = None
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
if weight_name is None:
|
||||
weight_name = _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
||||
):
|
||||
if local_files_only or HF_HUB_OFFLINE:
|
||||
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
||||
|
||||
targeted_files = []
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
||||
return
|
||||
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
||||
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
|
||||
else:
|
||||
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
||||
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
||||
if len(targeted_files) == 0:
|
||||
return
|
||||
|
||||
# "scheduler" does not correspond to a LoRA checkpoint.
|
||||
# "optimizer" does not correspond to a LoRA checkpoint
|
||||
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
||||
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
||||
targeted_files = list(
|
||||
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
||||
return set_adapters_for_text_encoder(
|
||||
adapter_names=adapter_names, text_encoder=text_encoder, text_encoder_weights=text_encoder_weights
|
||||
)
|
||||
|
||||
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
|
||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
|
||||
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
|
||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
||||
|
||||
if len(targeted_files) > 1:
|
||||
raise ValueError(
|
||||
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
||||
)
|
||||
weight_name = targeted_files[0]
|
||||
return weight_name
|
||||
def disable_lora_for_text_encoder(text_encoder=None):
|
||||
from .lora.lora_base import disable_lora_for_text_encoder
|
||||
|
||||
deprecation_message = "Importing `disable_lora_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import disable_lora_for_text_encoder` instead."
|
||||
deprecate("diffusers.loaders.lora_base.disable_lora_for_text_encoder", "0.36", deprecation_message)
|
||||
|
||||
def _load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
text_encoder_name="text_encoder",
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
return disable_lora_for_text_encoder(text_encoder=text_encoder)
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
def enable_lora_for_text_encoder(text_encoder=None):
|
||||
from .lora.lora_base import enable_lora_for_text_encoder
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
||||
# their prefixes.
|
||||
prefix = text_encoder_name if prefix is None else prefix
|
||||
deprecation_message = "Importing `enable_lora_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import enable_lora_for_text_encoder` instead."
|
||||
deprecate("diffusers.loaders.lora_base.enable_lora_for_text_encoder", "0.36", deprecation_message)
|
||||
|
||||
# Safe prefix to check with.
|
||||
if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
|
||||
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
|
||||
return enable_lora_for_text_encoder(text_encoder=text_encoder)
|
||||
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
if prefix is not None:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
state_dict = convert_state_dict_to_diffusers(state_dict)
|
||||
|
||||
# convert state dict
|
||||
state_dict = convert_state_dict_to_peft(state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
logger.warning(
|
||||
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
|
||||
"This is safe to ignore if LoRA state dict didn't originally have any "
|
||||
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
|
||||
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
||||
"https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
|
||||
def _func_optionally_disable_offloading(_pipeline):
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and hasattr(_pipeline, "hf_device_map") and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
|
||||
class LoraBaseMixin:
|
||||
"""Utility class for handling LoRAs."""
|
||||
|
||||
_lora_loadable_modules = []
|
||||
num_fused_loras = 0
|
||||
|
||||
def load_lora_weights(self, **kwargs):
|
||||
raise NotImplementedError("`load_lora_weights()` is not implemented.")
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(cls, **kwargs):
|
||||
raise NotImplementedError("`save_lora_weights()` not implemented.")
|
||||
|
||||
@classmethod
|
||||
def lora_state_dict(cls, **kwargs):
|
||||
raise NotImplementedError("`lora_state_dict()` is not implemented.")
|
||||
|
||||
@classmethod
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@classmethod
|
||||
def _fetch_state_dict(cls, *args, **kwargs):
|
||||
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
|
||||
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
|
||||
return _fetch_state_dict(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _best_guess_weight_name(cls, *args, **kwargs):
|
||||
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
|
||||
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
|
||||
return _best_guess_weight_name(*args, **kwargs)
|
||||
|
||||
def unload_lora_weights(self):
|
||||
"""
|
||||
Unloads the LoRA parameters.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
|
||||
>>> pipeline.unload_lora_weights()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.unload_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
_remove_text_encoder_monkey_patch(model)
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = [],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
if "fuse_unet" in kwargs:
|
||||
depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
|
||||
deprecate(
|
||||
"fuse_unet",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "fuse_transformer" in kwargs:
|
||||
depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
|
||||
deprecate(
|
||||
"fuse_transformer",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "fuse_text_encoder" in kwargs:
|
||||
depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
|
||||
deprecate(
|
||||
"fuse_text_encoder",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
|
||||
if len(components) == 0:
|
||||
raise ValueError("`components` cannot be an empty list.")
|
||||
|
||||
for fuse_component in components:
|
||||
if fuse_component not in self._lora_loadable_modules:
|
||||
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
||||
|
||||
model = getattr(self, fuse_component, None)
|
||||
if model is not None:
|
||||
# check if diffusers model
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||
# handle transformers models.
|
||||
if issubclass(model.__class__, PreTrainedModel):
|
||||
fuse_text_encoder_lora(
|
||||
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
self.num_fused_loras += 1
|
||||
|
||||
def unfuse_lora(self, components: List[str] = [], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
if "unfuse_unet" in kwargs:
|
||||
depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
|
||||
deprecate(
|
||||
"unfuse_unet",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "unfuse_transformer" in kwargs:
|
||||
depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
|
||||
deprecate(
|
||||
"unfuse_transformer",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "unfuse_text_encoder" in kwargs:
|
||||
depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
|
||||
deprecate(
|
||||
"unfuse_text_encoder",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
|
||||
if len(components) == 0:
|
||||
raise ValueError("`components` cannot be an empty list.")
|
||||
|
||||
for fuse_component in components:
|
||||
if fuse_component not in self._lora_loadable_modules:
|
||||
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
||||
|
||||
model = getattr(self, fuse_component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
self.num_fused_loras -= 1
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
||||
):
|
||||
if isinstance(adapter_weights, dict):
|
||||
components_passed = set(adapter_weights.keys())
|
||||
lora_components = set(self._lora_loadable_modules)
|
||||
|
||||
invalid_components = sorted(components_passed - lora_components)
|
||||
if invalid_components:
|
||||
logger.warning(
|
||||
f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
|
||||
f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
|
||||
"to the invalid components will be removed and ignored."
|
||||
)
|
||||
adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
adapter_weights = copy.deepcopy(adapter_weights)
|
||||
|
||||
# Expand weights into a list, one entry per adapter
|
||||
if not isinstance(adapter_weights, list):
|
||||
adapter_weights = [adapter_weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(adapter_weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
|
||||
)
|
||||
|
||||
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
||||
# eg ["adapter1", "adapter2"]
|
||||
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
|
||||
missing_adapters = set(adapter_names) - all_adapters
|
||||
if len(missing_adapters) > 0:
|
||||
raise ValueError(
|
||||
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
|
||||
)
|
||||
|
||||
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
||||
invert_list_adapters = {
|
||||
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
|
||||
for adapter in all_adapters
|
||||
}
|
||||
|
||||
# Decompose weights into weights for denoiser and text encoders.
|
||||
_component_adapter_weights = {}
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is None:
|
||||
logger.warning(f"Model {component} not found in pipeline.")
|
||||
continue
|
||||
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
||||
if isinstance(weights, dict):
|
||||
component_adapter_weights = weights.pop(component, None)
|
||||
if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
|
||||
logger.warning(
|
||||
(
|
||||
f"Lora weight dict for adapter '{adapter_name}' contains {component},"
|
||||
f"but this will be ignored because {adapter_name} does not contain weights for {component}."
|
||||
f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
component_adapter_weights = weights
|
||||
|
||||
_component_adapter_weights.setdefault(component, [])
|
||||
_component_adapter_weights[component].append(component_adapter_weights)
|
||||
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.set_adapters(adapter_names, _component_adapter_weights[component])
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
|
||||
|
||||
def disable_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.disable_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
disable_lora_for_text_encoder(model)
|
||||
|
||||
def enable_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.enable_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
enable_lora_for_text_encoder(model)
|
||||
|
||||
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
||||
"""
|
||||
Args:
|
||||
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
|
||||
adapter_names (`Union[List[str], str]`):
|
||||
The names of the adapter to delete. Can be a single string or a list of strings
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if isinstance(adapter_names, str):
|
||||
adapter_names = [adapter_names]
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.delete_adapters(adapter_names)
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
for adapter_name in adapter_names:
|
||||
delete_adapter_layers(model, adapter_name)
|
||||
|
||||
def get_active_adapters(self) -> List[str]:
|
||||
"""
|
||||
Gets the list of the current active adapters.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
pipeline.get_active_adapters()
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError(
|
||||
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
||||
)
|
||||
|
||||
active_adapters = []
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None and issubclass(model.__class__, ModelMixin):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
active_adapters = module.active_adapters
|
||||
break
|
||||
|
||||
return active_adapters
|
||||
|
||||
def get_list_adapters(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Gets the current list of all available adapters in the pipeline.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError(
|
||||
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
||||
)
|
||||
|
||||
set_adapters = {}
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if (
|
||||
model is not None
|
||||
and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
|
||||
and hasattr(model, "peft_config")
|
||||
):
|
||||
set_adapters[component] = list(model.peft_config.keys())
|
||||
|
||||
return set_adapters
|
||||
|
||||
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
|
||||
"""
|
||||
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
||||
you want to load multiple adapters and free some GPU memory.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]`):
|
||||
List of adapters to send device to.
|
||||
device (`Union[torch.device, str, int]`):
|
||||
Device to send the adapters to. Can be either a torch device, a str or an integer.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
for adapter_name in adapter_names:
|
||||
module.lora_A[adapter_name].to(device)
|
||||
module.lora_B[adapter_name].to(device)
|
||||
# this is a param, not a module, so device placement is not in-place -> re-assign
|
||||
if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
|
||||
if adapter_name in module.lora_magnitude_vector:
|
||||
module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
|
||||
adapter_name
|
||||
].to(device)
|
||||
|
||||
@staticmethod
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
@staticmethod
|
||||
def write_lora_layers(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
save_directory: str,
|
||||
is_main_process: bool,
|
||||
weight_name: str,
|
||||
save_function: Callable,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
|
||||
save_path = Path(save_directory, weight_name).as_posix()
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
@property
|
||||
def lora_scale(self) -> float:
|
||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||
# if _lora_scale has not been set, return 1
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
def enable_lora_hotswap(self, **kwargs) -> None:
|
||||
"""Enables the possibility to hotswap LoRA adapters.
|
||||
|
||||
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
|
||||
the loaded adapters differ.
|
||||
|
||||
Args:
|
||||
target_rank (`int`):
|
||||
The highest rank among all the adapters that will be loaded.
|
||||
check_compiled (`str`, *optional*, defaults to `"error"`):
|
||||
How to handle the case when the model is already compiled, which should generally be avoided. The
|
||||
options are:
|
||||
- "error" (default): raise an error
|
||||
- "warn": issue a warning
|
||||
- "ignore": do nothing
|
||||
"""
|
||||
for key, component in self.components.items():
|
||||
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
|
||||
component.enable_lora_hotswap(**kwargs)
|
||||
class LoraBaseMixin(LoraBaseMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `LoraBaseMixin` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import LoraBaseMixin` instead."
|
||||
deprecate("diffusers.loaders.lora_base.LoraBaseMixin", "0.36", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -35,8 +35,8 @@ from ..utils import (
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
|
||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
from .lora.lora_base import _fetch_state_dict, _func_optionally_disable_offloading
|
||||
from .unet.unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -99,7 +99,7 @@ class PeftAdapterMixin:
|
||||
_prepare_lora_hotswap_kwargs: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
# Copied from diffusers.loaders.lora.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
@@ -11,42 +11,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
|
||||
from packaging import version
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..utils import deprecate, is_transformers_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
_is_legacy_scheduler_kwargs,
|
||||
_is_model_weights_in_cached_folder,
|
||||
_legacy_load_clip_tokenizer,
|
||||
_legacy_load_safety_checker,
|
||||
_legacy_load_scheduler,
|
||||
create_diffusers_clip_model_from_ldm,
|
||||
create_diffusers_t5_model_from_checkpoint,
|
||||
fetch_diffusers_config,
|
||||
fetch_original_config,
|
||||
is_clip_model_in_single_file,
|
||||
is_t5_in_single_file,
|
||||
load_single_file_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
|
||||
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
|
||||
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from ..utils import deprecate
|
||||
from .single_file.single_file import FromSingleFileMixin
|
||||
|
||||
|
||||
def load_single_file_sub_model(
|
||||
@@ -64,502 +30,30 @@ def load_single_file_sub_model(
|
||||
disable_mmap=False,
|
||||
**kwargs,
|
||||
):
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
from .single_file.single_file import load_single_file_sub_model
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
deprecation_message = "Importing `load_single_file_sub_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file import load_single_file_sub_model` instead."
|
||||
deprecate("diffusers.loaders.single_file.load_single_file_sub_model", "0.36", deprecation_message)
|
||||
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
is_tokenizer = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedTokenizer)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
return load_single_file_sub_model(
|
||||
library_name,
|
||||
class_name,
|
||||
name,
|
||||
checkpoint,
|
||||
pipelines,
|
||||
is_pipeline_module,
|
||||
cached_model_config_path,
|
||||
original_config,
|
||||
local_files_only,
|
||||
torch_dtype,
|
||||
is_legacy_loading,
|
||||
disable_mmap,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
|
||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||
is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
|
||||
|
||||
if is_diffusers_single_file_model:
|
||||
load_method = getattr(class_obj, "from_single_file")
|
||||
|
||||
# We cannot provide two different config options to the `from_single_file` method
|
||||
# Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
|
||||
if original_config:
|
||||
cached_model_config_path = None
|
||||
|
||||
loaded_sub_model = load_method(
|
||||
pretrained_model_link_or_path_or_dict=checkpoint,
|
||||
original_config=original_config,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
disable_mmap=disable_mmap,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
|
||||
loaded_sub_model = create_diffusers_clip_model_from_ldm(
|
||||
class_obj,
|
||||
checkpoint=checkpoint,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
is_legacy_loading=is_legacy_loading,
|
||||
)
|
||||
|
||||
elif is_transformers_model and is_t5_in_single_file(checkpoint):
|
||||
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
|
||||
class_obj,
|
||||
checkpoint=checkpoint,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
elif is_tokenizer and is_legacy_loading:
|
||||
loaded_sub_model = _legacy_load_clip_tokenizer(
|
||||
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
||||
)
|
||||
|
||||
elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
|
||||
loaded_sub_model = _legacy_load_scheduler(
|
||||
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
if not hasattr(class_obj, "from_pretrained"):
|
||||
raise ValueError(
|
||||
(
|
||||
f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
|
||||
" a supported loading method."
|
||||
)
|
||||
)
|
||||
|
||||
loading_kwargs = {}
|
||||
loading_kwargs.update(
|
||||
{
|
||||
"pretrained_model_name_or_path": cached_model_config_path,
|
||||
"subfolder": name,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
)
|
||||
|
||||
# Schedulers and Tokenizers don't make use of torch_dtype
|
||||
# Skip passing it to those objects
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
loading_kwargs.update({"torch_dtype": torch_dtype})
|
||||
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
load_method = getattr(class_obj, "from_pretrained")
|
||||
loaded_sub_model = load_method(**loading_kwargs)
|
||||
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
def _map_component_types_to_config_dict(component_types):
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
config_dict = {}
|
||||
component_types.pop("self", None)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
for component_name, component_value in component_types.items():
|
||||
is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
|
||||
is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
|
||||
is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
|
||||
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(component_value[0], PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
is_transformers_tokenizer = (
|
||||
is_transformers_available()
|
||||
and issubclass(component_value[0], PreTrainedTokenizer)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
config_dict[component_name] = ["diffusers", component_value[0].__name__]
|
||||
|
||||
elif is_scheduler_enum or is_scheduler:
|
||||
if is_scheduler_enum:
|
||||
# Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
|
||||
# if the type hint is a KarrassDiffusionSchedulers enum
|
||||
config_dict[component_name] = ["diffusers", "DDIMScheduler"]
|
||||
|
||||
elif is_scheduler:
|
||||
config_dict[component_name] = ["diffusers", component_value[0].__name__]
|
||||
|
||||
elif (
|
||||
is_transformers_model or is_transformers_tokenizer
|
||||
) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
config_dict[component_name] = ["transformers", component_value[0].__name__]
|
||||
|
||||
else:
|
||||
config_dict[component_name] = [None, None]
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def _infer_pipeline_config_dict(pipeline_class):
|
||||
parameters = inspect.signature(pipeline_class.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
component_types = pipeline_class._get_signature_types()
|
||||
|
||||
# Ignore parameters that are not required for the pipeline
|
||||
component_types = {k: v for k, v in component_types.items() if k in required_parameters}
|
||||
config_dict = _map_component_types_to_config_dict(component_types)
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def _download_diffusers_model_config_from_hub(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir,
|
||||
revision,
|
||||
proxies,
|
||||
force_download=None,
|
||||
local_files_only=None,
|
||||
token=None,
|
||||
):
|
||||
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
|
||||
cached_model_path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
|
||||
return cached_model_path
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
|
||||
r"""
|
||||
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
|
||||
format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
original_config_file (`str`, *optional*):
|
||||
The path to the original config file that was used to train the model. If not provided, the config file
|
||||
will be inferred from the checkpoint file.
|
||||
config (`str`, *optional*):
|
||||
Can be either:
|
||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
|
||||
component configs in Diffusers format.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import StableDiffusionPipeline
|
||||
|
||||
>>> # Download pipeline from huggingface.co and cache.
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
||||
... )
|
||||
|
||||
>>> # Download pipeline from local file
|
||||
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
|
||||
|
||||
>>> # Enable float16 and move to GPU
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||
... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
||||
... torch_dtype=torch.float16,
|
||||
... )
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
|
||||
"""
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config = kwargs.pop("config", None)
|
||||
original_config = kwargs.pop("original_config", None)
|
||||
|
||||
if original_config_file is not None:
|
||||
deprecation_message = (
|
||||
"`original_config_file` argument is deprecated and will be removed in future versions."
|
||||
"please use the `original_config` argument instead."
|
||||
)
|
||||
deprecate("original_config_file", "1.0.0", deprecation_message)
|
||||
original_config = original_config_file
|
||||
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
is_legacy_loading = False
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
# We shouldn't allow configuring individual models components through a Pipeline creation method
|
||||
# These model kwargs should be deprecated
|
||||
scaling_factor = kwargs.get("scaling_factor", None)
|
||||
if scaling_factor is not None:
|
||||
deprecation_message = (
|
||||
"Passing the `scaling_factor` argument to `from_single_file is deprecated "
|
||||
"and will be ignored in future versions."
|
||||
)
|
||||
deprecate("scaling_factor", "1.0.0", deprecation_message)
|
||||
|
||||
if original_config is not None:
|
||||
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
||||
|
||||
from ..pipelines.pipeline_utils import _get_pipeline_class
|
||||
|
||||
pipeline_class = _get_pipeline_class(cls, config=None)
|
||||
|
||||
checkpoint = load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
config = fetch_diffusers_config(checkpoint)
|
||||
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
||||
else:
|
||||
default_pretrained_model_config_name = config
|
||||
|
||||
if not os.path.isdir(default_pretrained_model_config_name):
|
||||
# Provided config is a repo_id
|
||||
if default_pretrained_model_config_name.count("/") > 1:
|
||||
raise ValueError(
|
||||
f'The provided config "{config}"'
|
||||
" is neither a valid local path nor a valid repo id. Please check the parameter."
|
||||
)
|
||||
try:
|
||||
# Attempt to download the config files for the pipeline
|
||||
cached_model_config_path = _download_diffusers_model_config_from_hub(
|
||||
default_pretrained_model_config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
)
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
except LocalEntryNotFoundError:
|
||||
# `local_files_only=True` but a local diffusers format model config is not available in the cache
|
||||
# If `original_config` is not provided, we need override `local_files_only` to False
|
||||
# to fetch the config files from the hub so that we have a way
|
||||
# to configure the pipeline components.
|
||||
|
||||
if original_config is None:
|
||||
logger.warning(
|
||||
"`local_files_only` is True but no local configs were found for this checkpoint.\n"
|
||||
"Attempting to download the necessary config files for this pipeline.\n"
|
||||
)
|
||||
cached_model_config_path = _download_diffusers_model_config_from_hub(
|
||||
default_pretrained_model_config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
local_files_only=False,
|
||||
token=token,
|
||||
)
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
else:
|
||||
# For backwards compatibility
|
||||
# If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
|
||||
logger.warning(
|
||||
"Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
|
||||
"This may lead to errors if the model components are not correctly inferred. \n"
|
||||
"To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
|
||||
"e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
|
||||
"or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
|
||||
"the necessary config files.\n"
|
||||
)
|
||||
is_legacy_loading = True
|
||||
cached_model_config_path = None
|
||||
|
||||
config_dict = _infer_pipeline_config_dict(pipeline_class)
|
||||
config_dict["_class_name"] = pipeline_class.__name__
|
||||
|
||||
else:
|
||||
# Provided config is a path to a local directory attempt to load directly.
|
||||
cached_model_config_path = default_pretrained_model_config_name
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
# pop out "_ignore_files" as it is only needed for download
|
||||
config_dict.pop("_ignore_files", None)
|
||||
|
||||
expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||
|
||||
from diffusers import pipelines
|
||||
|
||||
# remove `null` components
|
||||
def load_module(name, value):
|
||||
if value[0] is None:
|
||||
return False
|
||||
if name in passed_class_obj and passed_class_obj[name] is None:
|
||||
return False
|
||||
if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||
|
||||
for name, (library_name, class_name) in logging.tqdm(
|
||||
sorted(init_dict.items()), desc="Loading pipeline components..."
|
||||
):
|
||||
loaded_sub_model = None
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
|
||||
if name in passed_class_obj:
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
|
||||
else:
|
||||
try:
|
||||
loaded_sub_model = load_single_file_sub_model(
|
||||
library_name=library_name,
|
||||
class_name=class_name,
|
||||
name=name,
|
||||
checkpoint=checkpoint,
|
||||
is_pipeline_module=is_pipeline_module,
|
||||
cached_model_config_path=cached_model_config_path,
|
||||
pipelines=pipelines,
|
||||
torch_dtype=torch_dtype,
|
||||
original_config=original_config,
|
||||
local_files_only=local_files_only,
|
||||
is_legacy_loading=is_legacy_loading,
|
||||
disable_mmap=disable_mmap,
|
||||
**kwargs,
|
||||
)
|
||||
except SingleFileComponentError as e:
|
||||
raise SingleFileComponentError(
|
||||
(
|
||||
f"{e.message}\n"
|
||||
f"Please load the component before passing it in as an argument to `from_single_file`.\n"
|
||||
f"\n"
|
||||
f"{name} = {class_name}.from_pretrained('...')\n"
|
||||
f"pipe = {pipeline_class.__name__}.from_single_file(<checkpoint path>, {name}={name})\n"
|
||||
f"\n"
|
||||
)
|
||||
)
|
||||
|
||||
init_kwargs[name] = loaded_sub_model
|
||||
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
optional_modules = pipeline_class._optional_components
|
||||
|
||||
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
||||
for module in missing_modules:
|
||||
init_kwargs[module] = passed_class_obj.get(module, None)
|
||||
elif len(missing_modules) > 0:
|
||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
||||
raise ValueError(
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
|
||||
# deprecated kwargs
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", None)
|
||||
if load_safety_checker is not None:
|
||||
deprecation_message = (
|
||||
"Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
|
||||
"using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
|
||||
)
|
||||
deprecate("load_safety_checker", "1.0.0", deprecation_message)
|
||||
|
||||
safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
|
||||
init_kwargs.update(safety_checker_components)
|
||||
|
||||
pipe = pipeline_class(**init_kwargs)
|
||||
|
||||
return pipe
|
||||
class FromSingleFileMixin(FromSingleFileMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FromSingleFileMixin` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file import FromSingleFileMixin` instead."
|
||||
deprecate("diffusers.loaders.single_file.FromSingleFileMixin", "0.36", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
from ...utils import is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .single_file_model import FromOriginalModelMixin
|
||||
|
||||
if is_transformers_available():
|
||||
from .single_file import FromSingleFileMixin
|
||||
@@ -0,0 +1,565 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
|
||||
from packaging import version
|
||||
from typing_extensions import Self
|
||||
|
||||
from ...utils import deprecate, is_transformers_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
_is_legacy_scheduler_kwargs,
|
||||
_is_model_weights_in_cached_folder,
|
||||
_legacy_load_clip_tokenizer,
|
||||
_legacy_load_safety_checker,
|
||||
_legacy_load_scheduler,
|
||||
create_diffusers_clip_model_from_ldm,
|
||||
create_diffusers_t5_model_from_checkpoint,
|
||||
fetch_diffusers_config,
|
||||
fetch_original_config,
|
||||
is_clip_model_in_single_file,
|
||||
is_t5_in_single_file,
|
||||
load_single_file_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
|
||||
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
|
||||
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
def load_single_file_sub_model(
|
||||
library_name,
|
||||
class_name,
|
||||
name,
|
||||
checkpoint,
|
||||
pipelines,
|
||||
is_pipeline_module,
|
||||
cached_model_config_path,
|
||||
original_config=None,
|
||||
local_files_only=False,
|
||||
torch_dtype=None,
|
||||
is_legacy_loading=False,
|
||||
disable_mmap=False,
|
||||
**kwargs,
|
||||
):
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
is_tokenizer = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedTokenizer)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
|
||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||
is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
|
||||
|
||||
if is_diffusers_single_file_model:
|
||||
load_method = getattr(class_obj, "from_single_file")
|
||||
|
||||
# We cannot provide two different config options to the `from_single_file` method
|
||||
# Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
|
||||
if original_config:
|
||||
cached_model_config_path = None
|
||||
|
||||
loaded_sub_model = load_method(
|
||||
pretrained_model_link_or_path_or_dict=checkpoint,
|
||||
original_config=original_config,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
disable_mmap=disable_mmap,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
|
||||
loaded_sub_model = create_diffusers_clip_model_from_ldm(
|
||||
class_obj,
|
||||
checkpoint=checkpoint,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
is_legacy_loading=is_legacy_loading,
|
||||
)
|
||||
|
||||
elif is_transformers_model and is_t5_in_single_file(checkpoint):
|
||||
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
|
||||
class_obj,
|
||||
checkpoint=checkpoint,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
elif is_tokenizer and is_legacy_loading:
|
||||
loaded_sub_model = _legacy_load_clip_tokenizer(
|
||||
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
||||
)
|
||||
|
||||
elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
|
||||
loaded_sub_model = _legacy_load_scheduler(
|
||||
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
if not hasattr(class_obj, "from_pretrained"):
|
||||
raise ValueError(
|
||||
(
|
||||
f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
|
||||
" a supported loading method."
|
||||
)
|
||||
)
|
||||
|
||||
loading_kwargs = {}
|
||||
loading_kwargs.update(
|
||||
{
|
||||
"pretrained_model_name_or_path": cached_model_config_path,
|
||||
"subfolder": name,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
)
|
||||
|
||||
# Schedulers and Tokenizers don't make use of torch_dtype
|
||||
# Skip passing it to those objects
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
loading_kwargs.update({"torch_dtype": torch_dtype})
|
||||
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
load_method = getattr(class_obj, "from_pretrained")
|
||||
loaded_sub_model = load_method(**loading_kwargs)
|
||||
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
def _map_component_types_to_config_dict(component_types):
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
config_dict = {}
|
||||
component_types.pop("self", None)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
for component_name, component_value in component_types.items():
|
||||
is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
|
||||
is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
|
||||
is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
|
||||
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(component_value[0], PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
is_transformers_tokenizer = (
|
||||
is_transformers_available()
|
||||
and issubclass(component_value[0], PreTrainedTokenizer)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
config_dict[component_name] = ["diffusers", component_value[0].__name__]
|
||||
|
||||
elif is_scheduler_enum or is_scheduler:
|
||||
if is_scheduler_enum:
|
||||
# Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
|
||||
# if the type hint is a KarrassDiffusionSchedulers enum
|
||||
config_dict[component_name] = ["diffusers", "DDIMScheduler"]
|
||||
|
||||
elif is_scheduler:
|
||||
config_dict[component_name] = ["diffusers", component_value[0].__name__]
|
||||
|
||||
elif (
|
||||
is_transformers_model or is_transformers_tokenizer
|
||||
) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
config_dict[component_name] = ["transformers", component_value[0].__name__]
|
||||
|
||||
else:
|
||||
config_dict[component_name] = [None, None]
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def _infer_pipeline_config_dict(pipeline_class):
|
||||
parameters = inspect.signature(pipeline_class.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
component_types = pipeline_class._get_signature_types()
|
||||
|
||||
# Ignore parameters that are not required for the pipeline
|
||||
component_types = {k: v for k, v in component_types.items() if k in required_parameters}
|
||||
config_dict = _map_component_types_to_config_dict(component_types)
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def _download_diffusers_model_config_from_hub(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir,
|
||||
revision,
|
||||
proxies,
|
||||
force_download=None,
|
||||
local_files_only=None,
|
||||
token=None,
|
||||
):
|
||||
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
|
||||
cached_model_path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
|
||||
return cached_model_path
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
|
||||
r"""
|
||||
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
|
||||
format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
original_config_file (`str`, *optional*):
|
||||
The path to the original config file that was used to train the model. If not provided, the config file
|
||||
will be inferred from the checkpoint file.
|
||||
config (`str`, *optional*):
|
||||
Can be either:
|
||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
|
||||
component configs in Diffusers format.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import StableDiffusionPipeline
|
||||
|
||||
>>> # Download pipeline from huggingface.co and cache.
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
||||
... )
|
||||
|
||||
>>> # Download pipeline from local file
|
||||
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
|
||||
|
||||
>>> # Enable float16 and move to GPU
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||
... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
||||
... torch_dtype=torch.float16,
|
||||
... )
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
|
||||
"""
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config = kwargs.pop("config", None)
|
||||
original_config = kwargs.pop("original_config", None)
|
||||
|
||||
if original_config_file is not None:
|
||||
deprecation_message = (
|
||||
"`original_config_file` argument is deprecated and will be removed in future versions."
|
||||
"please use the `original_config` argument instead."
|
||||
)
|
||||
deprecate("original_config_file", "1.0.0", deprecation_message)
|
||||
original_config = original_config_file
|
||||
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
is_legacy_loading = False
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
# We shouldn't allow configuring individual models components through a Pipeline creation method
|
||||
# These model kwargs should be deprecated
|
||||
scaling_factor = kwargs.get("scaling_factor", None)
|
||||
if scaling_factor is not None:
|
||||
deprecation_message = (
|
||||
"Passing the `scaling_factor` argument to `from_single_file is deprecated "
|
||||
"and will be ignored in future versions."
|
||||
)
|
||||
deprecate("scaling_factor", "1.0.0", deprecation_message)
|
||||
|
||||
if original_config is not None:
|
||||
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
||||
|
||||
from ..pipelines.pipeline_utils import _get_pipeline_class
|
||||
|
||||
pipeline_class = _get_pipeline_class(cls, config=None)
|
||||
|
||||
checkpoint = load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
config = fetch_diffusers_config(checkpoint)
|
||||
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
||||
else:
|
||||
default_pretrained_model_config_name = config
|
||||
|
||||
if not os.path.isdir(default_pretrained_model_config_name):
|
||||
# Provided config is a repo_id
|
||||
if default_pretrained_model_config_name.count("/") > 1:
|
||||
raise ValueError(
|
||||
f'The provided config "{config}"'
|
||||
" is neither a valid local path nor a valid repo id. Please check the parameter."
|
||||
)
|
||||
try:
|
||||
# Attempt to download the config files for the pipeline
|
||||
cached_model_config_path = _download_diffusers_model_config_from_hub(
|
||||
default_pretrained_model_config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
)
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
except LocalEntryNotFoundError:
|
||||
# `local_files_only=True` but a local diffusers format model config is not available in the cache
|
||||
# If `original_config` is not provided, we need override `local_files_only` to False
|
||||
# to fetch the config files from the hub so that we have a way
|
||||
# to configure the pipeline components.
|
||||
|
||||
if original_config is None:
|
||||
logger.warning(
|
||||
"`local_files_only` is True but no local configs were found for this checkpoint.\n"
|
||||
"Attempting to download the necessary config files for this pipeline.\n"
|
||||
)
|
||||
cached_model_config_path = _download_diffusers_model_config_from_hub(
|
||||
default_pretrained_model_config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
local_files_only=False,
|
||||
token=token,
|
||||
)
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
else:
|
||||
# For backwards compatibility
|
||||
# If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
|
||||
logger.warning(
|
||||
"Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
|
||||
"This may lead to errors if the model components are not correctly inferred. \n"
|
||||
"To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
|
||||
"e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
|
||||
"or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
|
||||
"the necessary config files.\n"
|
||||
)
|
||||
is_legacy_loading = True
|
||||
cached_model_config_path = None
|
||||
|
||||
config_dict = _infer_pipeline_config_dict(pipeline_class)
|
||||
config_dict["_class_name"] = pipeline_class.__name__
|
||||
|
||||
else:
|
||||
# Provided config is a path to a local directory attempt to load directly.
|
||||
cached_model_config_path = default_pretrained_model_config_name
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
# pop out "_ignore_files" as it is only needed for download
|
||||
config_dict.pop("_ignore_files", None)
|
||||
|
||||
expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||
|
||||
from diffusers import pipelines
|
||||
|
||||
# remove `null` components
|
||||
def load_module(name, value):
|
||||
if value[0] is None:
|
||||
return False
|
||||
if name in passed_class_obj and passed_class_obj[name] is None:
|
||||
return False
|
||||
if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||
|
||||
for name, (library_name, class_name) in logging.tqdm(
|
||||
sorted(init_dict.items()), desc="Loading pipeline components..."
|
||||
):
|
||||
loaded_sub_model = None
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
|
||||
if name in passed_class_obj:
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
|
||||
else:
|
||||
try:
|
||||
loaded_sub_model = load_single_file_sub_model(
|
||||
library_name=library_name,
|
||||
class_name=class_name,
|
||||
name=name,
|
||||
checkpoint=checkpoint,
|
||||
is_pipeline_module=is_pipeline_module,
|
||||
cached_model_config_path=cached_model_config_path,
|
||||
pipelines=pipelines,
|
||||
torch_dtype=torch_dtype,
|
||||
original_config=original_config,
|
||||
local_files_only=local_files_only,
|
||||
is_legacy_loading=is_legacy_loading,
|
||||
disable_mmap=disable_mmap,
|
||||
**kwargs,
|
||||
)
|
||||
except SingleFileComponentError as e:
|
||||
raise SingleFileComponentError(
|
||||
(
|
||||
f"{e.message}\n"
|
||||
f"Please load the component before passing it in as an argument to `from_single_file`.\n"
|
||||
f"\n"
|
||||
f"{name} = {class_name}.from_pretrained('...')\n"
|
||||
f"pipe = {pipeline_class.__name__}.from_single_file(<checkpoint path>, {name}={name})\n"
|
||||
f"\n"
|
||||
)
|
||||
)
|
||||
|
||||
init_kwargs[name] = loaded_sub_model
|
||||
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
optional_modules = pipeline_class._optional_components
|
||||
|
||||
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
||||
for module in missing_modules:
|
||||
init_kwargs[module] = passed_class_obj.get(module, None)
|
||||
elif len(missing_modules) > 0:
|
||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
||||
raise ValueError(
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
|
||||
# deprecated kwargs
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", None)
|
||||
if load_safety_checker is not None:
|
||||
deprecation_message = (
|
||||
"Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
|
||||
"using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
|
||||
)
|
||||
deprecate("load_safety_checker", "1.0.0", deprecation_message)
|
||||
|
||||
safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
|
||||
init_kwargs.update(safety_checker_components)
|
||||
|
||||
pipe = pipeline_class(**init_kwargs)
|
||||
|
||||
return pipe
|
||||
@@ -0,0 +1,440 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from typing_extensions import Self
|
||||
|
||||
from ... import __version__
|
||||
from ...quantizers import DiffusersAutoQuantizer
|
||||
from ...utils import deprecate, is_accelerate_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_lumina2_to_diffusers,
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_sana_transformer_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
convert_wan_transformer_to_diffusers,
|
||||
convert_wan_vae_to_diffusers,
|
||||
create_controlnet_diffusers_config_from_ldm,
|
||||
create_unet_diffusers_config_from_ldm,
|
||||
create_vae_diffusers_config_from_ldm,
|
||||
fetch_diffusers_config,
|
||||
fetch_original_config,
|
||||
load_single_file_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import dispatch_model, init_empty_weights
|
||||
|
||||
from ...models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
|
||||
SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"StableCascadeUNet": {
|
||||
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
},
|
||||
"UNet2DConditionModel": {
|
||||
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
|
||||
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
|
||||
"default_subfolder": "unet",
|
||||
"legacy_kwargs": {
|
||||
"num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
|
||||
},
|
||||
},
|
||||
"AutoencoderKL": {
|
||||
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
|
||||
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"ControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
||||
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
||||
},
|
||||
"SD3Transformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"MotionAdapter": {
|
||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||
},
|
||||
"SparseControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||
},
|
||||
"FluxTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"LTXVideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLLTXVideo": {
|
||||
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
|
||||
"MochiTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"HunyuanVideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AuraFlowTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"Lumina2Transformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"SanaTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"WanTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLWan": {
|
||||
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_single_file_loadable_mapping_class(cls):
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
loadable_class = getattr(diffusers_module, loadable_class_str)
|
||||
|
||||
if issubclass(cls, loadable_class):
|
||||
return loadable_class_str
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
|
||||
parameters = inspect.signature(mapping_fn).parameters
|
||||
|
||||
mapping_kwargs = {}
|
||||
for parameter in parameters:
|
||||
if parameter in kwargs:
|
||||
mapping_kwargs[parameter] = kwargs[parameter]
|
||||
|
||||
return mapping_kwargs
|
||||
|
||||
|
||||
class FromOriginalModelMixin:
|
||||
"""
|
||||
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
|
||||
r"""
|
||||
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
|
||||
is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path_or_dict (`str`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.safetensors` or `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.
|
||||
- A path to a local *file* containing the weights of the component model.
|
||||
- A state dict containing the component model weights.
|
||||
config (`str`, *optional*):
|
||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
|
||||
on the Hub.
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
|
||||
configs in Diffusers format.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
original_config (`str`, *optional*):
|
||||
Dict or path to a yaml file containing the configuration for the model in its original format.
|
||||
If a dict is provided, it will be used to initialize the model configuration.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
```py
|
||||
>>> from diffusers import StableCascadeUNet
|
||||
|
||||
>>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
|
||||
>>> model = StableCascadeUNet.from_single_file(ckpt_path)
|
||||
```
|
||||
"""
|
||||
|
||||
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
|
||||
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
if mapping_class_name is None:
|
||||
raise ValueError(
|
||||
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
|
||||
)
|
||||
|
||||
pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
|
||||
if pretrained_model_link_or_path is not None:
|
||||
deprecation_message = (
|
||||
"Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
|
||||
)
|
||||
deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
|
||||
pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
original_config = kwargs.pop("original_config", None)
|
||||
|
||||
if config is not None and original_config is not None:
|
||||
raise ValueError(
|
||||
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
|
||||
)
|
||||
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
config_revision = kwargs.pop("config_revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
|
||||
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
||||
if quantization_config is not None:
|
||||
user_agent["quant"] = quantization_config.quant_method.value
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
||||
checkpoint = pretrained_model_link_or_path_or_dict
|
||||
else:
|
||||
checkpoint = load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path_or_dict,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
disable_mmap=disable_mmap,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if quantization_config is not None:
|
||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
||||
hf_quantizer.validate_environment()
|
||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
||||
|
||||
else:
|
||||
hf_quantizer = None
|
||||
|
||||
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
|
||||
|
||||
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
|
||||
if original_config is not None:
|
||||
if "config_mapping_fn" in mapping_functions:
|
||||
config_mapping_fn = mapping_functions["config_mapping_fn"]
|
||||
else:
|
||||
config_mapping_fn = None
|
||||
|
||||
if config_mapping_fn is None:
|
||||
raise ValueError(
|
||||
(
|
||||
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
|
||||
"was found to convert the original config to a Diffusers config in"
|
||||
"`diffusers.loaders.single_file_utils`"
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(original_config, str):
|
||||
# If original_config is a URL or filepath fetch the original_config dict
|
||||
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
||||
|
||||
config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
|
||||
diffusers_model_config = config_mapping_fn(
|
||||
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
|
||||
)
|
||||
else:
|
||||
if config is not None:
|
||||
if isinstance(config, str):
|
||||
default_pretrained_model_config_name = config
|
||||
else:
|
||||
raise ValueError(
|
||||
(
|
||||
"Invalid `config` argument. Please provide a string representing a repo id"
|
||||
"or path to a local Diffusers model repo."
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
config = fetch_diffusers_config(checkpoint)
|
||||
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
||||
|
||||
if "default_subfolder" in mapping_functions:
|
||||
subfolder = mapping_functions["default_subfolder"]
|
||||
|
||||
subfolder = subfolder or config.pop(
|
||||
"subfolder", None
|
||||
) # some configs contain a subfolder key, e.g. StableCascadeUNet
|
||||
|
||||
diffusers_model_config = cls.load_config(
|
||||
pretrained_model_name_or_path=default_pretrained_model_config_name,
|
||||
subfolder=subfolder,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=config_revision,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
||||
|
||||
# Map legacy kwargs to new kwargs
|
||||
if "legacy_kwargs" in mapping_functions:
|
||||
legacy_kwargs = mapping_functions["legacy_kwargs"]
|
||||
for legacy_key, new_key in legacy_kwargs.items():
|
||||
if legacy_key in kwargs:
|
||||
kwargs[new_key] = kwargs.pop(legacy_key)
|
||||
|
||||
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
||||
diffusers_model_config.update(model_kwargs)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model = cls.from_config(diffusers_model_config)
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||
)
|
||||
if use_keep_in_fp32_modules:
|
||||
keep_in_fp32_modules = cls._keep_in_fp32_modules
|
||||
if not isinstance(keep_in_fp32_modules, list):
|
||||
keep_in_fp32_modules = [keep_in_fp32_modules]
|
||||
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model,
|
||||
device_map=None,
|
||||
state_dict=diffusers_format_checkpoint,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
|
||||
device_map = None
|
||||
if is_accelerate_available():
|
||||
param_device = torch.device(device) if device else torch.device("cpu")
|
||||
empty_state_dict = model.state_dict()
|
||||
unexpected_keys = [
|
||||
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
|
||||
]
|
||||
device_map = {"": param_device}
|
||||
load_model_dict_into_meta(
|
||||
model,
|
||||
diffusers_format_checkpoint,
|
||||
dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
)
|
||||
else:
|
||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
if model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in model._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.postprocess_model(model)
|
||||
model.hf_quantizer = hf_quantizer
|
||||
|
||||
if torch_dtype is not None and hf_quantizer is None:
|
||||
model.to(torch_dtype)
|
||||
|
||||
model.eval()
|
||||
|
||||
if device_map is not None:
|
||||
device_map_kwargs = {"device_map": device_map}
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
|
||||
return model
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,430 +11,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from typing_extensions import Self
|
||||
|
||||
from .. import __version__
|
||||
from ..quantizers import DiffusersAutoQuantizer
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_lumina2_to_diffusers,
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_sana_transformer_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
convert_wan_transformer_to_diffusers,
|
||||
convert_wan_vae_to_diffusers,
|
||||
create_controlnet_diffusers_config_from_ldm,
|
||||
create_unet_diffusers_config_from_ldm,
|
||||
create_vae_diffusers_config_from_ldm,
|
||||
fetch_diffusers_config,
|
||||
fetch_original_config,
|
||||
load_single_file_checkpoint,
|
||||
from ..utils import deprecate
|
||||
from .single_file.single_file_model import (
|
||||
SINGLE_FILE_LOADABLE_CLASSES, # noqa: F401
|
||||
FromOriginalModelMixin,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import dispatch_model, init_empty_weights
|
||||
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
|
||||
SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"StableCascadeUNet": {
|
||||
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
},
|
||||
"UNet2DConditionModel": {
|
||||
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
|
||||
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
|
||||
"default_subfolder": "unet",
|
||||
"legacy_kwargs": {
|
||||
"num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
|
||||
},
|
||||
},
|
||||
"AutoencoderKL": {
|
||||
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
|
||||
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"ControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
||||
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
||||
},
|
||||
"SD3Transformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"MotionAdapter": {
|
||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||
},
|
||||
"SparseControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||
},
|
||||
"FluxTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"LTXVideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLLTXVideo": {
|
||||
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
|
||||
"MochiTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"HunyuanVideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AuraFlowTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"Lumina2Transformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"SanaTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"WanTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLWan": {
|
||||
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_single_file_loadable_mapping_class(cls):
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
loadable_class = getattr(diffusers_module, loadable_class_str)
|
||||
|
||||
if issubclass(cls, loadable_class):
|
||||
return loadable_class_str
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
|
||||
parameters = inspect.signature(mapping_fn).parameters
|
||||
|
||||
mapping_kwargs = {}
|
||||
for parameter in parameters:
|
||||
if parameter in kwargs:
|
||||
mapping_kwargs[parameter] = kwargs[parameter]
|
||||
|
||||
return mapping_kwargs
|
||||
|
||||
|
||||
class FromOriginalModelMixin:
|
||||
"""
|
||||
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
|
||||
r"""
|
||||
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
|
||||
is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path_or_dict (`str`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.safetensors` or `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.
|
||||
- A path to a local *file* containing the weights of the component model.
|
||||
- A state dict containing the component model weights.
|
||||
config (`str`, *optional*):
|
||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
|
||||
on the Hub.
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
|
||||
configs in Diffusers format.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
original_config (`str`, *optional*):
|
||||
Dict or path to a yaml file containing the configuration for the model in its original format.
|
||||
If a dict is provided, it will be used to initialize the model configuration.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
```py
|
||||
>>> from diffusers import StableCascadeUNet
|
||||
|
||||
>>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
|
||||
>>> model = StableCascadeUNet.from_single_file(ckpt_path)
|
||||
```
|
||||
"""
|
||||
|
||||
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
|
||||
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
if mapping_class_name is None:
|
||||
raise ValueError(
|
||||
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
|
||||
)
|
||||
|
||||
pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
|
||||
if pretrained_model_link_or_path is not None:
|
||||
deprecation_message = (
|
||||
"Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
|
||||
)
|
||||
deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
|
||||
pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
original_config = kwargs.pop("original_config", None)
|
||||
|
||||
if config is not None and original_config is not None:
|
||||
raise ValueError(
|
||||
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
|
||||
)
|
||||
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
config_revision = kwargs.pop("config_revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
|
||||
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
||||
if quantization_config is not None:
|
||||
user_agent["quant"] = quantization_config.quant_method.value
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
||||
checkpoint = pretrained_model_link_or_path_or_dict
|
||||
else:
|
||||
checkpoint = load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path_or_dict,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
disable_mmap=disable_mmap,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if quantization_config is not None:
|
||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
||||
hf_quantizer.validate_environment()
|
||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
||||
|
||||
else:
|
||||
hf_quantizer = None
|
||||
|
||||
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
|
||||
|
||||
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
|
||||
if original_config is not None:
|
||||
if "config_mapping_fn" in mapping_functions:
|
||||
config_mapping_fn = mapping_functions["config_mapping_fn"]
|
||||
else:
|
||||
config_mapping_fn = None
|
||||
|
||||
if config_mapping_fn is None:
|
||||
raise ValueError(
|
||||
(
|
||||
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
|
||||
"was found to convert the original config to a Diffusers config in"
|
||||
"`diffusers.loaders.single_file_utils`"
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(original_config, str):
|
||||
# If original_config is a URL or filepath fetch the original_config dict
|
||||
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
||||
|
||||
config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
|
||||
diffusers_model_config = config_mapping_fn(
|
||||
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
|
||||
)
|
||||
else:
|
||||
if config is not None:
|
||||
if isinstance(config, str):
|
||||
default_pretrained_model_config_name = config
|
||||
else:
|
||||
raise ValueError(
|
||||
(
|
||||
"Invalid `config` argument. Please provide a string representing a repo id"
|
||||
"or path to a local Diffusers model repo."
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
config = fetch_diffusers_config(checkpoint)
|
||||
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
||||
|
||||
if "default_subfolder" in mapping_functions:
|
||||
subfolder = mapping_functions["default_subfolder"]
|
||||
|
||||
subfolder = subfolder or config.pop(
|
||||
"subfolder", None
|
||||
) # some configs contain a subfolder key, e.g. StableCascadeUNet
|
||||
|
||||
diffusers_model_config = cls.load_config(
|
||||
pretrained_model_name_or_path=default_pretrained_model_config_name,
|
||||
subfolder=subfolder,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=config_revision,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
||||
|
||||
# Map legacy kwargs to new kwargs
|
||||
if "legacy_kwargs" in mapping_functions:
|
||||
legacy_kwargs = mapping_functions["legacy_kwargs"]
|
||||
for legacy_key, new_key in legacy_kwargs.items():
|
||||
if legacy_key in kwargs:
|
||||
kwargs[new_key] = kwargs.pop(legacy_key)
|
||||
|
||||
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
||||
diffusers_model_config.update(model_kwargs)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model = cls.from_config(diffusers_model_config)
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||
)
|
||||
if use_keep_in_fp32_modules:
|
||||
keep_in_fp32_modules = cls._keep_in_fp32_modules
|
||||
if not isinstance(keep_in_fp32_modules, list):
|
||||
keep_in_fp32_modules = [keep_in_fp32_modules]
|
||||
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model,
|
||||
device_map=None,
|
||||
state_dict=diffusers_format_checkpoint,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
|
||||
device_map = None
|
||||
if is_accelerate_available():
|
||||
param_device = torch.device(device) if device else torch.device("cpu")
|
||||
empty_state_dict = model.state_dict()
|
||||
unexpected_keys = [
|
||||
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
|
||||
]
|
||||
device_map = {"": param_device}
|
||||
load_model_dict_into_meta(
|
||||
model,
|
||||
diffusers_format_checkpoint,
|
||||
dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
)
|
||||
else:
|
||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
if model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in model._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.postprocess_model(model)
|
||||
model.hf_quantizer = hf_quantizer
|
||||
|
||||
if torch_dtype is not None and hf_quantizer is None:
|
||||
model.to(torch_dtype)
|
||||
|
||||
model.eval()
|
||||
|
||||
if device_map is not None:
|
||||
device_map_kwargs = {"device_map": device_map}
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
|
||||
return model
|
||||
class FromOriginalModelMixin(FromOriginalModelMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FromOriginalModelMixin` from diffusers.loaders.single_file_model has been deprecated. Please use `from diffusers.loaders.single_file.single_file_model import FromOriginalModelMixin` instead."
|
||||
deprecate("diffusers.loaders.single_file_model.FromOriginalModelMixin", "0.36", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,170 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
MultiIPAdapterImageProjection,
|
||||
)
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils import deprecate
|
||||
from .ip_adapter.transformer_flux import FluxTransformer2DLoadersMixin
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
pass
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FluxTransformer2DLoadersMixin:
|
||||
"""
|
||||
Load layers into a [`FluxTransformer2DModel`].
|
||||
"""
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
updated_state_dict = {}
|
||||
image_projection = None
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
|
||||
if "proj.weight" in state_dict:
|
||||
# IP-Adapter
|
||||
num_image_text_embeds = 4
|
||||
if state_dict["proj.weight"].shape[0] == 65536:
|
||||
num_image_text_embeds = 16
|
||||
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
|
||||
|
||||
with init_context():
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
image_embed_dim=clip_embeddings_dim,
|
||||
num_image_text_embeds=num_image_text_embeds,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("proj", "image_embeds")
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
image_projection.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
|
||||
return image_projection
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
from ..models.attention_processor import (
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
# set ip-adapter cross-attention processors & load state_dict
|
||||
attn_procs = {}
|
||||
key_id = 0
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
for name in self.attn_processors.keys():
|
||||
if name.startswith("single_transformer_blocks"):
|
||||
attn_processor_class = self.attn_processors[name].__class__
|
||||
attn_procs[name] = attn_processor_class()
|
||||
else:
|
||||
cross_attention_dim = self.config.joint_attention_dim
|
||||
hidden_size = self.inner_dim
|
||||
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
|
||||
num_image_text_embeds = []
|
||||
for state_dict in state_dicts:
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
num_image_text_embed = 4
|
||||
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
|
||||
num_image_text_embed = 16
|
||||
# IP-Adapter
|
||||
num_image_text_embeds += [num_image_text_embed]
|
||||
|
||||
with init_context():
|
||||
attn_procs[name] = attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=num_image_text_embeds,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
value_dict = {}
|
||||
for i, state_dict in enumerate(state_dicts):
|
||||
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
||||
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
||||
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
|
||||
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(value_dict)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
dtype = self.dtype
|
||||
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
|
||||
|
||||
key_id += 1
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
if not isinstance(state_dicts, list):
|
||||
state_dicts = [state_dicts]
|
||||
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
image_projection_layers = []
|
||||
for state_dict in state_dicts:
|
||||
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
|
||||
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
|
||||
)
|
||||
image_projection_layers.append(image_projection_layer)
|
||||
|
||||
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
class FluxTransformer2DLoadersMixin(FluxTransformer2DLoadersMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FluxTransformer2DLoadersMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.transformer_flux import FluxTransformer2DLoadersMixin` instead."
|
||||
deprecate("diffusers.loaders.ip_adapter.FluxTransformer2DLoadersMixin", "0.36", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -11,160 +11,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict
|
||||
|
||||
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
||||
from ..models.embeddings import IPAdapterTimeImageProjection
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import is_accelerate_available, is_torch_version, logging
|
||||
from ..utils import deprecate
|
||||
from .ip_adapter.transformer_sd3 import SD3Transformer2DLoadersMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class SD3Transformer2DLoadersMixin:
|
||||
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(
|
||||
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
) -> Dict:
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
# IP-Adapter cross attention parameters
|
||||
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
|
||||
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
|
||||
timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1]
|
||||
|
||||
# Dict where key is transformer layer index, value is attention processor's state dict
|
||||
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
|
||||
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
|
||||
for key, weights in state_dict.items():
|
||||
idx, name = key.split(".", maxsplit=1)
|
||||
layer_state_dict[int(idx)][name] = weights
|
||||
|
||||
# Create IP-Adapter attention processor & load state_dict
|
||||
attn_procs = {}
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
for idx, name in enumerate(self.attn_processors.keys()):
|
||||
with init_context():
|
||||
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
|
||||
hidden_size=hidden_size,
|
||||
ip_hidden_states_dim=ip_hidden_states_dim,
|
||||
head_dim=self.config.attention_head_dim,
|
||||
timesteps_emb_dim=timesteps_emb_dim,
|
||||
)
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(
|
||||
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
|
||||
)
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(
|
||||
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
) -> IPAdapterTimeImageProjection:
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
|
||||
# Convert to diffusers
|
||||
updated_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
# InstantX/SD3.5-Large-IP-Adapter
|
||||
if key.startswith("layers."):
|
||||
idx = key.split(".")[1]
|
||||
key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
|
||||
key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
|
||||
key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
|
||||
key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
|
||||
key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
|
||||
key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
|
||||
key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
|
||||
key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
|
||||
key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
|
||||
updated_state_dict[key] = value
|
||||
|
||||
# Image projetion parameters
|
||||
embed_dim = updated_state_dict["proj_in.weight"].shape[1]
|
||||
output_dim = updated_state_dict["proj_out.weight"].shape[0]
|
||||
hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
|
||||
heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
|
||||
num_queries = updated_state_dict["latents"].shape[1]
|
||||
timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
# Image projection
|
||||
with init_context():
|
||||
image_proj = IPAdapterTimeImageProjection(
|
||||
embed_dim=embed_dim,
|
||||
output_dim=output_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
heads=heads,
|
||||
num_queries=num_queries,
|
||||
timestep_in_dim=timestep_in_dim,
|
||||
)
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
image_proj.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
|
||||
return image_proj
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
|
||||
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
|
||||
|
||||
Args:
|
||||
state_dict (`Dict`):
|
||||
State dict with keys "ip_adapter", which contains parameters for attention processors, and
|
||||
"image_proj", which contains parameters for image projection net.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage)
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage)
|
||||
class SD3Transformer2DLoadersMixin(SD3Transformer2DLoadersMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SD3Transformer2DLoadersMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.transformer_sd3 import SD3Transformer2DLoadersMixin` instead."
|
||||
deprecate("diffusers.loaders.ip_adapter.SD3Transformer2DLoadersMixin", "0.36", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .unet import UNet2DConditionLoadersMixin
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..models.embeddings import (
|
||||
from ...models.embeddings import (
|
||||
ImageProjection,
|
||||
IPAdapterFaceIDImageProjection,
|
||||
IPAdapterFaceIDPlusImageProjection,
|
||||
@@ -30,8 +30,8 @@ from ..models.embeddings import (
|
||||
IPAdapterPlusImageProjection,
|
||||
MultiIPAdapterImageProjection,
|
||||
)
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
|
||||
from ..utils import (
|
||||
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_unet_state_dict_to_peft,
|
||||
@@ -43,9 +43,9 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from .lora_base import _func_optionally_disable_offloading
|
||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||
from .utils import AttnProcsLayers
|
||||
from ..lora.lora_base import _func_optionally_disable_offloading
|
||||
from ..lora.lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||
from ..utils import AttnProcsLayers
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -247,7 +247,7 @@ class UNet2DConditionLoadersMixin:
|
||||
# Unsafe code />
|
||||
|
||||
def _process_custom_diffusion(self, state_dict):
|
||||
from ..models.attention_processor import CustomDiffusionAttnProcessor
|
||||
from ...models.attention_processor import CustomDiffusionAttnProcessor
|
||||
|
||||
attn_processors = {}
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
@@ -395,7 +395,7 @@ class UNet2DConditionLoadersMixin:
|
||||
return is_model_cpu_offload, is_sequential_cpu_offload
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
# Copied from diffusers.loaders.lora.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
@@ -408,7 +408,6 @@ class UNet2DConditionLoadersMixin:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
def save_attn_procs(
|
||||
@@ -452,7 +451,7 @@ class UNet2DConditionLoadersMixin:
|
||||
pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
```
|
||||
"""
|
||||
from ..models.attention_processor import (
|
||||
from ...models.attention_processor import (
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
@@ -514,7 +513,7 @@ class UNet2DConditionLoadersMixin:
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
def _get_custom_diffusion_state_dict(self):
|
||||
from ..models.attention_processor import (
|
||||
from ...models.attention_processor import (
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
@@ -760,7 +759,7 @@ class UNet2DConditionLoadersMixin:
|
||||
return image_projection
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
from ..models.attention_processor import (
|
||||
from ...models.attention_processor import (
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
+2
-2
@@ -14,12 +14,12 @@
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Dict, List, Union
|
||||
|
||||
from ..utils import logging
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# import here to avoid circular imports
|
||||
from ..models import UNet2DConditionModel
|
||||
from ...models import UNet2DConditionModel
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -17,8 +17,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import deprecate
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import (
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
|
||||
@@ -17,7 +17,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
|
||||
@@ -1068,15 +1068,17 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
latent_sequence_length = hidden_states.shape[1]
|
||||
condition_sequence_length = encoder_hidden_states.shape[1]
|
||||
sequence_length = latent_sequence_length + condition_sequence_length
|
||||
attention_mask = torch.ones(
|
||||
attention_mask = torch.zeros(
|
||||
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
|
||||
) # [B, N]
|
||||
|
||||
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
|
||||
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
|
||||
indices = torch.arange(sequence_length, device=hidden_states.device).unsqueeze(0) # [1, N]
|
||||
mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N]
|
||||
attention_mask = attention_mask.masked_fill(mask_indices, False)
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N]
|
||||
|
||||
for i in range(batch_size):
|
||||
attention_mask[i, : effective_sequence_length[i]] = True
|
||||
# [B, 1, 1, N], for broadcasting across attention heads
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
# 4. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -20,8 +20,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import LuminaFeedForward
|
||||
from ..attention_processor import Attention
|
||||
|
||||
@@ -19,8 +19,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
|
||||
@@ -19,8 +19,7 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import (
|
||||
|
||||
@@ -47,7 +47,6 @@ else:
|
||||
"AutoPipelineForInpainting",
|
||||
"AutoPipelineForText2Image",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["ModularPipeline"]
|
||||
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
|
||||
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
|
||||
_import_structure["ddim"] = ["DDIMPipeline"]
|
||||
@@ -330,8 +329,6 @@ else:
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"StableDiffusionXLAutoPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
|
||||
@@ -481,7 +478,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
|
||||
from .dit import DiTPipeline
|
||||
from .latent_diffusion import LDMSuperResolutionPipeline
|
||||
from .modular_pipeline import ModularPipeline
|
||||
from .pipeline_utils import (
|
||||
AudioPipelineOutput,
|
||||
DiffusionPipeline,
|
||||
@@ -706,9 +702,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLModularPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableDiffusionXLAutoPipeline,
|
||||
)
|
||||
from .stable_video_diffusion import StableVideoDiffusionPipeline
|
||||
from .t2i_adapter import (
|
||||
|
||||
@@ -246,15 +246,14 @@ def _get_connected_pipeline(pipeline_cls):
|
||||
return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False)
|
||||
|
||||
|
||||
def _get_model(pipeline_class_name):
|
||||
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
|
||||
for model_name, pipeline in task_mapping.items():
|
||||
if pipeline.__name__ == pipeline_class_name:
|
||||
return model_name
|
||||
|
||||
|
||||
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
|
||||
model_name = _get_model(pipeline_class_name)
|
||||
def get_model(pipeline_class_name):
|
||||
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
|
||||
for model_name, pipeline in task_mapping.items():
|
||||
if pipeline.__name__ == pipeline_class_name:
|
||||
return model_name
|
||||
|
||||
model_name = get_model(pipeline_class_name)
|
||||
|
||||
if model_name is not None:
|
||||
task_class = mapping.get(model_name, None)
|
||||
|
||||
@@ -1,609 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict
|
||||
from itertools import combinations
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
)
|
||||
from ..models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module
|
||||
from accelerate.state import PartialState
|
||||
from accelerate.utils import send_to_device
|
||||
from accelerate.utils.memory import clear_device_cache
|
||||
from accelerate.utils.modeling import convert_file_size_to_int
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# YiYi Notes: copied from modeling_utils.py (decide later where to put this)
|
||||
def get_memory_footprint(self, return_buffers=True):
|
||||
r"""
|
||||
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to
|
||||
benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch
|
||||
discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
|
||||
|
||||
Arguments:
|
||||
return_buffers (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are
|
||||
tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm
|
||||
layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
|
||||
"""
|
||||
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
|
||||
if return_buffers:
|
||||
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
|
||||
mem = mem + mem_bufs
|
||||
return mem
|
||||
|
||||
|
||||
class CustomOffloadHook(ModelHook):
|
||||
"""
|
||||
A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are
|
||||
on the given device. Optionally offloads other models to the CPU before the forward pass is called.
|
||||
|
||||
Args:
|
||||
execution_device(`str`, `int` or `torch.device`, *optional*):
|
||||
The device on which the model should be executed. Will default to the MPS device if it's available, then
|
||||
GPU 0 if there is a GPU, and finally to the CPU.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
execution_device: Optional[Union[str, int, torch.device]] = None,
|
||||
other_hooks: Optional[List["UserCustomOffloadHook"]] = None,
|
||||
offload_strategy: Optional["AutoOffloadStrategy"] = None,
|
||||
):
|
||||
self.execution_device = execution_device if execution_device is not None else PartialState().default_device
|
||||
self.other_hooks = other_hooks
|
||||
self.offload_strategy = offload_strategy
|
||||
self.model_id = None
|
||||
|
||||
def set_strategy(self, offload_strategy: "AutoOffloadStrategy"):
|
||||
self.offload_strategy = offload_strategy
|
||||
|
||||
def add_other_hook(self, hook: "UserCustomOffloadHook"):
|
||||
"""
|
||||
Add a hook to the list of hooks to consider for offloading.
|
||||
"""
|
||||
if self.other_hooks is None:
|
||||
self.other_hooks = []
|
||||
self.other_hooks.append(hook)
|
||||
|
||||
def init_hook(self, module):
|
||||
return module.to("cpu")
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
if module.device != self.execution_device:
|
||||
if self.other_hooks is not None:
|
||||
hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device]
|
||||
# offload all other hooks
|
||||
start_time = time.perf_counter()
|
||||
if self.offload_strategy is not None:
|
||||
hooks_to_offload = self.offload_strategy(
|
||||
hooks=hooks_to_offload,
|
||||
model_id=self.model_id,
|
||||
model=module,
|
||||
execution_device=self.execution_device,
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds"
|
||||
)
|
||||
|
||||
for hook in hooks_to_offload:
|
||||
logger.info(
|
||||
f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu"
|
||||
)
|
||||
hook.offload()
|
||||
|
||||
if hooks_to_offload:
|
||||
clear_device_cache()
|
||||
module.to(self.execution_device)
|
||||
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
|
||||
|
||||
|
||||
class UserCustomOffloadHook:
|
||||
"""
|
||||
A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of
|
||||
the hook or remove it entirely.
|
||||
"""
|
||||
|
||||
def __init__(self, model_id, model, hook):
|
||||
self.model_id = model_id
|
||||
self.model = model
|
||||
self.hook = hook
|
||||
|
||||
def offload(self):
|
||||
self.hook.init_hook(self.model)
|
||||
|
||||
def attach(self):
|
||||
add_hook_to_module(self.model, self.hook)
|
||||
self.hook.model_id = self.model_id
|
||||
|
||||
def remove(self):
|
||||
remove_hook_from_module(self.model)
|
||||
self.hook.model_id = None
|
||||
|
||||
def add_other_hook(self, hook: "UserCustomOffloadHook"):
|
||||
self.hook.add_other_hook(hook)
|
||||
|
||||
|
||||
def custom_offload_with_hook(
|
||||
model_id: str,
|
||||
model: torch.nn.Module,
|
||||
execution_device: Union[str, int, torch.device] = None,
|
||||
offload_strategy: Optional["AutoOffloadStrategy"] = None,
|
||||
):
|
||||
hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy)
|
||||
user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook)
|
||||
user_hook.attach()
|
||||
return user_hook
|
||||
|
||||
|
||||
class AutoOffloadStrategy:
|
||||
"""
|
||||
Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on
|
||||
the available memory on the device.
|
||||
"""
|
||||
|
||||
def __init__(self, memory_reserve_margin="3GB"):
|
||||
self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin)
|
||||
|
||||
def __call__(self, hooks, model_id, model, execution_device):
|
||||
if len(hooks) == 0:
|
||||
return []
|
||||
|
||||
current_module_size = get_memory_footprint(model)
|
||||
|
||||
mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
|
||||
mem_on_device = mem_on_device - self.memory_reserve_margin
|
||||
if current_module_size < mem_on_device:
|
||||
return []
|
||||
|
||||
min_memory_offload = current_module_size - mem_on_device
|
||||
logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory")
|
||||
|
||||
# exlucde models that's not currently loaded on the device
|
||||
module_sizes = dict(
|
||||
sorted(
|
||||
{hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
def search_best_candidate(module_sizes, min_memory_offload):
|
||||
"""
|
||||
search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a
|
||||
minimum memory offload size. the combination of models should add up to the smallest modulesize that is
|
||||
larger than `min_memory_offload`
|
||||
"""
|
||||
model_ids = list(module_sizes.keys())
|
||||
best_candidate = None
|
||||
best_size = float("inf")
|
||||
for r in range(1, len(model_ids) + 1):
|
||||
for candidate_model_ids in combinations(model_ids, r):
|
||||
candidate_size = sum(
|
||||
module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids
|
||||
)
|
||||
if candidate_size < min_memory_offload:
|
||||
continue
|
||||
else:
|
||||
if best_candidate is None or candidate_size < best_size:
|
||||
best_candidate = candidate_model_ids
|
||||
best_size = candidate_size
|
||||
|
||||
return best_candidate
|
||||
|
||||
best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload)
|
||||
|
||||
if best_offload_model_ids is None:
|
||||
# if no combination is found, meaning that we cannot meet the memory requirement, offload all models
|
||||
logger.warning("no combination of models to offload to cpu is found, offloading all models")
|
||||
hooks_to_offload = hooks
|
||||
else:
|
||||
hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids]
|
||||
|
||||
return hooks_to_offload
|
||||
|
||||
|
||||
class ComponentsManager:
|
||||
def __init__(self):
|
||||
self.components = OrderedDict()
|
||||
self.added_time = OrderedDict() # Store when components were added
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
def add(self, name, component):
|
||||
if name in self.components:
|
||||
logger.warning(f"Overriding existing component '{name}' in ComponentsManager")
|
||||
self.components[name] = component
|
||||
self.added_time[name] = time.time()
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
def remove(self, name):
|
||||
if name not in self.components:
|
||||
logger.warning(f"Component '{name}' not found in ComponentsManager")
|
||||
return
|
||||
|
||||
self.components.pop(name)
|
||||
self.added_time.pop(name)
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
# YiYi TODO: looking into improving the search pattern
|
||||
def get(self, names: Union[str, List[str]]):
|
||||
"""
|
||||
Get components by name with simple pattern matching.
|
||||
|
||||
Args:
|
||||
names: Component name(s) or pattern(s)
|
||||
Patterns:
|
||||
- "unet" : exact match
|
||||
- "!unet" : everything except exact match "unet"
|
||||
- "base_*" : everything starting with "base_"
|
||||
- "!base_*" : everything NOT starting with "base_"
|
||||
- "*unet*" : anything containing "unet"
|
||||
- "!*unet*" : anything NOT containing "unet"
|
||||
- "refiner|vae|unet" : anything containing any of these terms
|
||||
- "!refiner|vae|unet" : anything NOT containing any of these terms
|
||||
|
||||
Returns:
|
||||
Single component if names is str and matches one component,
|
||||
dict of components if names matches multiple components or is a list
|
||||
"""
|
||||
if isinstance(names, str):
|
||||
# Check if this is a "not" pattern
|
||||
is_not_pattern = names.startswith('!')
|
||||
if is_not_pattern:
|
||||
names = names[1:] # Remove the ! prefix
|
||||
|
||||
# Handle OR patterns (containing |)
|
||||
if '|' in names:
|
||||
terms = names.split('|')
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if any((term in name) != is_not_pattern for term in terms) # Flip condition if not pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing any of {terms}: {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing any of {terms}: {list(matches.keys())}")
|
||||
|
||||
# Exact match
|
||||
elif names in self.components:
|
||||
if is_not_pattern:
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if name != names
|
||||
}
|
||||
logger.info(f"Getting all components except '{names}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting component: {names}")
|
||||
return self.components[names]
|
||||
|
||||
# Prefix match (ends with *)
|
||||
elif names.endswith('*'):
|
||||
prefix = names[:-1]
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if name.startswith(prefix) != is_not_pattern # Flip condition if not pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
|
||||
|
||||
# Contains match (starts with *)
|
||||
elif names.startswith('*'):
|
||||
search = names[1:-1] if names.endswith('*') else names[1:]
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if (search in name) != is_not_pattern # Flip condition if not pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Component '{names}' not found in ComponentsManager")
|
||||
|
||||
if not matches:
|
||||
raise ValueError(f"No components found matching pattern '{names}'")
|
||||
return matches if len(matches) > 1 else next(iter(matches.values()))
|
||||
|
||||
elif isinstance(names, list):
|
||||
results = {}
|
||||
for name in names:
|
||||
result = self.get(name)
|
||||
if isinstance(result, dict):
|
||||
results.update(result)
|
||||
else:
|
||||
results[name] = result
|
||||
logger.info(f"Getting multiple components: {list(results.keys())}")
|
||||
return results
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid type for names: {type(names)}")
|
||||
|
||||
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device]="cuda", memory_reserve_margin="3GB"):
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
|
||||
remove_hook_from_module(component, recurse=True)
|
||||
|
||||
self.disable_auto_cpu_offload()
|
||||
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
device = torch.device(f"{device.type}:{0}")
|
||||
all_hooks = []
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy)
|
||||
all_hooks.append(hook)
|
||||
|
||||
for hook in all_hooks:
|
||||
other_hooks = [h for h in all_hooks if h is not hook]
|
||||
for other_hook in other_hooks:
|
||||
if other_hook.hook.execution_device == hook.hook.execution_device:
|
||||
hook.add_other_hook(other_hook)
|
||||
|
||||
self.model_hooks = all_hooks
|
||||
self._auto_offload_enabled = True
|
||||
self._auto_offload_device = device
|
||||
|
||||
def disable_auto_cpu_offload(self):
|
||||
if self.model_hooks is None:
|
||||
self._auto_offload_enabled = False
|
||||
return
|
||||
|
||||
for hook in self.model_hooks:
|
||||
hook.offload()
|
||||
hook.remove()
|
||||
if self.model_hooks:
|
||||
clear_device_cache()
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Get comprehensive information about a component.
|
||||
|
||||
Args:
|
||||
name: Name of the component to get info for
|
||||
fields: Optional field(s) to return. Can be a string for single field or list of fields.
|
||||
If None, returns all fields.
|
||||
|
||||
Returns:
|
||||
Dictionary containing requested component metadata.
|
||||
If fields is specified, returns only those fields.
|
||||
If a single field is requested as string, returns just that field's value.
|
||||
"""
|
||||
if name not in self.components:
|
||||
raise ValueError(f"Component '{name}' not found in ComponentsManager")
|
||||
|
||||
component = self.components[name]
|
||||
|
||||
# Build complete info dict first
|
||||
info = {
|
||||
"model_id": name,
|
||||
"added_time": self.added_time[name],
|
||||
}
|
||||
|
||||
# Additional info for torch.nn.Module components
|
||||
if isinstance(component, torch.nn.Module):
|
||||
info.update({
|
||||
"class_name": component.__class__.__name__,
|
||||
"size_gb": get_memory_footprint(component) / (1024**3),
|
||||
"adapters": None, # Default to None
|
||||
})
|
||||
|
||||
# Get adapters if applicable
|
||||
if hasattr(component, "peft_config"):
|
||||
info["adapters"] = list(component.peft_config.keys())
|
||||
|
||||
# Check for IP-Adapter scales
|
||||
if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"):
|
||||
processors = copy.deepcopy(component.attn_processors)
|
||||
# First check if any processor is an IP-Adapter
|
||||
processor_types = [v.__class__.__name__ for v in processors.values()]
|
||||
if any("IPAdapter" in ptype for ptype in processor_types):
|
||||
# Then get scales only from IP-Adapter processors
|
||||
scales = {
|
||||
k: v.scale
|
||||
for k, v in processors.items()
|
||||
if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
|
||||
}
|
||||
if scales:
|
||||
info["ip_adapter"] = summarize_dict_by_value_and_parts(scales)
|
||||
|
||||
# If fields specified, filter info
|
||||
if fields is not None:
|
||||
if isinstance(fields, str):
|
||||
# Single field requested, return just that value
|
||||
return {fields: info.get(fields)}
|
||||
else:
|
||||
# List of fields requested, return dict with just those fields
|
||||
return {k: v for k, v in info.items() if k in fields}
|
||||
|
||||
return info
|
||||
|
||||
def __repr__(self):
|
||||
col_widths = {
|
||||
"id": max(15, max(len(id) for id in self.components.keys())),
|
||||
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
|
||||
"device": 10,
|
||||
"dtype": 15,
|
||||
"size": 10,
|
||||
}
|
||||
|
||||
# Create the header lines
|
||||
sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
|
||||
dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
|
||||
|
||||
output = "Components:\n" + sep_line
|
||||
|
||||
# Separate components into models and others
|
||||
models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
|
||||
others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)}
|
||||
|
||||
# Models section
|
||||
if models:
|
||||
output += "Models:\n" + dash_line
|
||||
# Column headers
|
||||
output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | "
|
||||
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB)\n"
|
||||
output += dash_line
|
||||
|
||||
# Model entries
|
||||
for name, component in models.items():
|
||||
info = self.get_model_info(name)
|
||||
device = str(getattr(component, "device", "N/A"))
|
||||
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
|
||||
output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | "
|
||||
output += f"{device:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | {info['size_gb']:.2f}\n"
|
||||
output += dash_line
|
||||
|
||||
# Other components section
|
||||
if others:
|
||||
if models: # Add extra newline if we had models section
|
||||
output += "\n"
|
||||
output += "Other Components:\n" + dash_line
|
||||
# Column headers for other components
|
||||
output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n"
|
||||
output += dash_line
|
||||
|
||||
# Other component entries
|
||||
for name, component in others.items():
|
||||
output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n"
|
||||
output += dash_line
|
||||
|
||||
# Add additional component info
|
||||
output += "\nAdditional Component Info:\n" + "=" * 50 + "\n"
|
||||
for name in self.components:
|
||||
info = self.get_model_info(name)
|
||||
if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
|
||||
output += f"\n{name}:\n"
|
||||
if info.get("adapters") is not None:
|
||||
output += f" Adapters: {info['adapters']}\n"
|
||||
if info.get("ip_adapter"):
|
||||
output += f" IP-Adapter: Enabled\n"
|
||||
output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n"
|
||||
|
||||
return output
|
||||
|
||||
def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Load components from a pretrained model and add them to the manager.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (str): The path or identifier of the pretrained model
|
||||
prefix (str, optional): Prefix to add to all component names loaded from this model.
|
||||
If provided, components will be named as "{prefix}_{component_name}"
|
||||
**kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained()
|
||||
"""
|
||||
from ..pipelines.pipeline_utils import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
for name, component in pipe.components.items():
|
||||
|
||||
if component is None:
|
||||
continue
|
||||
|
||||
# Add prefix if specified
|
||||
component_name = f"{prefix}_{name}" if prefix else name
|
||||
|
||||
if component_name not in self.components:
|
||||
self.add(component_name, component)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
|
||||
f"1. remove the existing component with remove('{component_name}')\n"
|
||||
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
|
||||
)
|
||||
|
||||
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Summarizes a dictionary by finding common prefixes that share the same value.
|
||||
|
||||
For a dictionary with dot-separated keys like:
|
||||
{
|
||||
'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
|
||||
'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
|
||||
'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
|
||||
}
|
||||
|
||||
Returns a dictionary where keys are the shortest common prefixes and values are their shared values:
|
||||
{
|
||||
'down_blocks': [0.6],
|
||||
'up_blocks': [0.3]
|
||||
}
|
||||
"""
|
||||
# First group by values - convert lists to tuples to make them hashable
|
||||
value_to_keys = {}
|
||||
for key, value in d.items():
|
||||
value_tuple = tuple(value) if isinstance(value, list) else value
|
||||
if value_tuple not in value_to_keys:
|
||||
value_to_keys[value_tuple] = []
|
||||
value_to_keys[value_tuple].append(key)
|
||||
|
||||
def find_common_prefix(keys: List[str]) -> str:
|
||||
"""Find the shortest common prefix among a list of dot-separated keys."""
|
||||
if not keys:
|
||||
return ""
|
||||
if len(keys) == 1:
|
||||
return keys[0]
|
||||
|
||||
# Split all keys into parts
|
||||
key_parts = [k.split('.') for k in keys]
|
||||
|
||||
# Find how many initial parts are common
|
||||
common_length = 0
|
||||
for parts in zip(*key_parts):
|
||||
if len(set(parts)) == 1: # All parts at this position are the same
|
||||
common_length += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if common_length == 0:
|
||||
return ""
|
||||
|
||||
# Return the common prefix
|
||||
return '.'.join(key_parts[0][:common_length])
|
||||
|
||||
# Create summary by finding common prefixes for each value group
|
||||
summary = {}
|
||||
for value_tuple, keys in value_to_keys.items():
|
||||
prefix = find_common_prefix(keys)
|
||||
if prefix: # Only add if we found a common prefix
|
||||
# Convert tuple back to list if it was originally a list
|
||||
value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple
|
||||
summary[prefix] = value
|
||||
else:
|
||||
summary[""] = value # Use empty string if no common prefix
|
||||
|
||||
return summary
|
||||
@@ -912,6 +912,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -925,11 +931,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
@@ -867,6 +867,12 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -880,11 +886,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
@@ -609,6 +609,12 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -622,11 +628,6 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -75,11 +75,6 @@ class OnnxRuntimeModel:
|
||||
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
|
||||
provider = "CPUExecutionProvider"
|
||||
|
||||
if provider_options is None:
|
||||
provider_options = []
|
||||
elif not isinstance(provider_options, list):
|
||||
provider_options = [provider_options]
|
||||
|
||||
return ort.InferenceSession(
|
||||
path, providers=[provider], sess_options=sess_options, provider_options=provider_options
|
||||
)
|
||||
@@ -179,10 +174,7 @@ class OnnxRuntimeModel:
|
||||
# load model from local directory
|
||||
if os.path.isdir(model_id):
|
||||
model = OnnxRuntimeModel.load_model(
|
||||
Path(model_id, model_file_name).as_posix(),
|
||||
provider=provider,
|
||||
sess_options=sess_options,
|
||||
provider_options=kwargs.pop("provider_options"),
|
||||
Path(model_id, model_file_name).as_posix(), provider=provider, sess_options=sess_options
|
||||
)
|
||||
kwargs["model_save_dir"] = Path(model_id)
|
||||
# load model from hub
|
||||
@@ -198,12 +190,7 @@ class OnnxRuntimeModel:
|
||||
)
|
||||
kwargs["model_save_dir"] = Path(model_cache_path).parent
|
||||
kwargs["latest_model_name"] = Path(model_cache_path).name
|
||||
model = OnnxRuntimeModel.load_model(
|
||||
model_cache_path,
|
||||
provider=provider,
|
||||
sess_options=sess_options,
|
||||
provider_options=kwargs.pop("provider_options"),
|
||||
)
|
||||
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
|
||||
return cls(model=model, **kwargs)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -917,6 +917,12 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -930,11 +936,6 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
@@ -707,6 +707,12 @@ class StableDiffusionXLPAGImg2ImgPipeline(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -720,11 +726,6 @@ class StableDiffusionXLPAGImg2ImgPipeline(
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
@@ -412,7 +412,7 @@ def _get_pipeline_class(
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline":
|
||||
if class_obj.__name__ != "DiffusionPipeline":
|
||||
return class_obj
|
||||
|
||||
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
|
||||
|
||||
@@ -58,7 +58,6 @@ from ..utils import (
|
||||
_is_valid_type,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_hpu_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_version,
|
||||
is_transformers_version,
|
||||
@@ -427,7 +426,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
module_is_sequentially_offloaded(module) for _, module in self.components.items()
|
||||
)
|
||||
|
||||
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
|
||||
@@ -444,7 +443,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
|
||||
)
|
||||
|
||||
|
||||
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
||||
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
|
||||
if pipeline_is_offloaded and device_type in ["cuda", "xpu"]:
|
||||
@@ -452,20 +450,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
|
||||
)
|
||||
|
||||
# Enable generic support for Intel Gaudi accelerator using GPU/HPU migration
|
||||
if device_type == "hpu" and kwargs.pop("hpu_migration", True) and is_hpu_available():
|
||||
os.environ["PT_HPU_GPU_MIGRATION"] = "1"
|
||||
logger.debug("Environment variable set: PT_HPU_GPU_MIGRATION=1")
|
||||
|
||||
import habana_frameworks.torch # noqa: F401
|
||||
|
||||
# HPU hardware check
|
||||
if not (hasattr(torch, "hpu") and torch.hpu.is_available()):
|
||||
raise ValueError("You are trying to call `.to('hpu')` but HPU device is unavailable.")
|
||||
|
||||
os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
|
||||
logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
|
||||
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
modules = [getattr(self, n, None) for n in module_names]
|
||||
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
|
||||
@@ -1120,11 +1104,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
||||
automatically detect the available accelerator and use.
|
||||
"""
|
||||
|
||||
self._maybe_raise_error_if_group_offload_active(raise_error=True)
|
||||
|
||||
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
|
||||
@@ -1248,7 +1230,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
||||
self.remove_all_hooks()
|
||||
|
||||
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
|
||||
|
||||
@@ -29,18 +29,6 @@ else:
|
||||
_import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_xl_modular"] = [
|
||||
"StableDiffusionXLControlNetDenoiseStep",
|
||||
"StableDiffusionXLDecodeLatentsStep",
|
||||
"StableDiffusionXLDenoiseStep",
|
||||
"StableDiffusionXLInputStep",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"StableDiffusionXLPrepareAdditionalConditioningStep",
|
||||
"StableDiffusionXLPrepareLatentsStep",
|
||||
"StableDiffusionXLSetTimestepsStep",
|
||||
"StableDiffusionXLTextEncoderStep",
|
||||
"StableDiffusionXLAutoPipeline",
|
||||
]
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
|
||||
@@ -60,18 +48,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline
|
||||
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
|
||||
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
|
||||
from .pipeline_stable_diffusion_xl_modular import (
|
||||
StableDiffusionXLControlNetDenoiseStep,
|
||||
StableDiffusionXLDecodeLatentsStep,
|
||||
StableDiffusionXLDenoiseStep,
|
||||
StableDiffusionXLInputStep,
|
||||
StableDiffusionXLModularPipeline,
|
||||
StableDiffusionXLPrepareAdditionalConditioningStep,
|
||||
StableDiffusionXLPrepareLatentsStep,
|
||||
StableDiffusionXLSetTimestepsStep,
|
||||
StableDiffusionXLTextEncoderStep,
|
||||
StableDiffusionXLAutoPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
|
||||
@@ -695,6 +695,12 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -708,11 +714,6 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -71,7 +71,6 @@ from .import_utils import (
|
||||
is_gguf_version,
|
||||
is_google_colab,
|
||||
is_hf_hub_version,
|
||||
is_hpu_available,
|
||||
is_inflect_available,
|
||||
is_invisible_watermark_available,
|
||||
is_k_diffusion_available,
|
||||
|
||||
@@ -1388,21 +1388,6 @@ class LDMSuperResolutionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PNDMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -2432,21 +2432,6 @@ class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLPAGImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -353,10 +353,6 @@ def is_timm_available():
|
||||
return _timm_available
|
||||
|
||||
|
||||
def is_hpu_available():
|
||||
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
|
||||
@@ -352,7 +352,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
@@ -403,7 +403,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
@@ -486,7 +486,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
@@ -541,7 +541,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
@@ -590,7 +590,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
@@ -653,7 +653,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
@@ -668,7 +668,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_lora_unload_with_parameter_expanded_shapes(self):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
@@ -734,7 +734,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
|
||||
+3
-3
@@ -1017,7 +1017,7 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
|
||||
scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
|
||||
logger = logging.get_logger("diffusers.loaders.lora_base")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_base")
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components)
|
||||
@@ -1824,7 +1824,7 @@ class PeftLoraLoaderMixinTests:
|
||||
elif lora_module == "text_encoder_2":
|
||||
prefix = "text_encoder_2"
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_base")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_base")
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
@@ -1925,7 +1925,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -22,7 +22,6 @@ from parameterized import parameterized
|
||||
from diffusers import AsymmetricAutoencoderKL
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
Expectations,
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
@@ -135,32 +134,18 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
# fmt: off
|
||||
[
|
||||
33,
|
||||
Expectations(
|
||||
{
|
||||
("xpu", 3): torch.tensor([-0.0343, 0.2873, 0.1680, -0.0140, -0.3459, 0.3522, -0.1336, 0.1075]),
|
||||
("cuda", 7): torch.tensor([-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205]),
|
||||
("mps", None): torch.tensor(
|
||||
[-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824]
|
||||
),
|
||||
}
|
||||
),
|
||||
[-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205],
|
||||
[-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824],
|
||||
],
|
||||
[
|
||||
47,
|
||||
Expectations(
|
||||
{
|
||||
("xpu", 3): torch.tensor([0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529]),
|
||||
("cuda", 7): torch.tensor([0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529]),
|
||||
("mps", None): torch.tensor(
|
||||
[-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089]
|
||||
),
|
||||
}
|
||||
),
|
||||
[0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529],
|
||||
[-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089],
|
||||
],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion(self, seed, expected_slices):
|
||||
def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed)
|
||||
generator = self.get_generator(seed)
|
||||
@@ -171,9 +156,9 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
||||
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
assert torch_all_close(output_slice, expected_slice, atol=5e-3)
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
|
||||
@@ -17,14 +17,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import HunyuanVideoTransformer3DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
is_torch_compile,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
@@ -96,21 +89,6 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
@is_torch_compile
|
||||
@slow
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
torch._dynamo.reset()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
|
||||
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
@@ -179,21 +157,6 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.T
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
@is_torch_compile
|
||||
@slow
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
torch._dynamo.reset()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
|
||||
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
@@ -260,21 +223,6 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.Test
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
@is_torch_compile
|
||||
@slow
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
torch._dynamo.reset()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
@@ -342,18 +290,3 @@ class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, u
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
@is_torch_compile
|
||||
@slow
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
torch._dynamo.reset()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@@ -11,12 +11,10 @@ from diffusers import (
|
||||
UNet2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
Expectations,
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
nightly,
|
||||
require_torch_2,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
@@ -170,17 +168,17 @@ class ConsistencyModelPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class ConsistencyModelPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, seed=0, get_fixed_latents=False, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)):
|
||||
generator = torch.manual_seed(seed)
|
||||
@@ -266,19 +264,11 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
|
||||
# Ensure usage of flash attention in torch 2.0
|
||||
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||
image = pipe(**inputs).images
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("xpu", 3): np.array([0.0816, 0.0518, 0.0445, 0.0594, 0.0739, 0.0534, 0.0805, 0.0457, 0.0765]),
|
||||
("cuda", 7): np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314]),
|
||||
("cuda", 8): np.array([0.0816, 0.0518, 0.0445, 0.0594, 0.0739, 0.0534, 0.0805, 0.0457, 0.0765]),
|
||||
}
|
||||
)
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
expected_slice = np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_accelerator,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
@@ -210,8 +210,8 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl
|
||||
|
||||
|
||||
@nightly
|
||||
@require_big_accelerator
|
||||
@pytest.mark.big_accelerator
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
class FluxControlNetPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = FluxControlNetPipeline
|
||||
|
||||
|
||||
@@ -33,11 +33,10 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -396,17 +395,17 @@ class MarigoldIntrinsicsPipelineFastTests(MarigoldIntrinsicsPipelineTesterMixin,
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _test_marigold_intrinsics(
|
||||
self,
|
||||
@@ -425,7 +424,7 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
|
||||
from_pretrained_kwargs["torch_dtype"] = torch.float16
|
||||
|
||||
pipe = MarigoldIntrinsicsPipeline.from_pretrained(model_id, **from_pretrained_kwargs)
|
||||
if device in ["cuda", "xpu"]:
|
||||
if device == "cuda":
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -465,10 +464,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
|
||||
match_input_resolution=True,
|
||||
)
|
||||
|
||||
def test_marigold_intrinsics_einstein_f32_accelerator_G0_S1_P768_E1_B1_M1(self):
|
||||
def test_marigold_intrinsics_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
|
||||
self._test_marigold_intrinsics(
|
||||
is_fp16=False,
|
||||
device=torch_device,
|
||||
device="cuda",
|
||||
generator_seed=0,
|
||||
expected_slice=np.array([0.62127, 0.61906, 0.61687, 0.61946, 0.61903, 0.61961, 0.61808, 0.62099, 0.62894]),
|
||||
num_inference_steps=1,
|
||||
@@ -478,10 +477,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
|
||||
match_input_resolution=True,
|
||||
)
|
||||
|
||||
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E1_B1_M1(self):
|
||||
def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
|
||||
self._test_marigold_intrinsics(
|
||||
is_fp16=True,
|
||||
device=torch_device,
|
||||
device="cuda",
|
||||
generator_seed=0,
|
||||
expected_slice=np.array([0.62109, 0.61914, 0.61719, 0.61963, 0.61914, 0.61963, 0.61816, 0.62109, 0.62891]),
|
||||
num_inference_steps=1,
|
||||
@@ -491,10 +490,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
|
||||
match_input_resolution=True,
|
||||
)
|
||||
|
||||
def test_marigold_intrinsics_einstein_f16_accelerator_G2024_S1_P768_E1_B1_M1(self):
|
||||
def test_marigold_intrinsics_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
|
||||
self._test_marigold_intrinsics(
|
||||
is_fp16=True,
|
||||
device=torch_device,
|
||||
device="cuda",
|
||||
generator_seed=2024,
|
||||
expected_slice=np.array([0.64111, 0.63916, 0.63623, 0.63965, 0.63916, 0.63965, 0.6377, 0.64062, 0.64941]),
|
||||
num_inference_steps=1,
|
||||
@@ -504,10 +503,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
|
||||
match_input_resolution=True,
|
||||
)
|
||||
|
||||
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S2_P768_E1_B1_M1(self):
|
||||
def test_marigold_intrinsics_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
|
||||
self._test_marigold_intrinsics(
|
||||
is_fp16=True,
|
||||
device=torch_device,
|
||||
device="cuda",
|
||||
generator_seed=0,
|
||||
expected_slice=np.array([0.60254, 0.60059, 0.59961, 0.60156, 0.60107, 0.60205, 0.60254, 0.60449, 0.61133]),
|
||||
num_inference_steps=2,
|
||||
@@ -517,10 +516,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
|
||||
match_input_resolution=True,
|
||||
)
|
||||
|
||||
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P512_E1_B1_M1(self):
|
||||
def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
|
||||
self._test_marigold_intrinsics(
|
||||
is_fp16=True,
|
||||
device=torch_device,
|
||||
device="cuda",
|
||||
generator_seed=0,
|
||||
expected_slice=np.array([0.64551, 0.64453, 0.64404, 0.64502, 0.64844, 0.65039, 0.64502, 0.65039, 0.65332]),
|
||||
num_inference_steps=1,
|
||||
@@ -530,10 +529,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
|
||||
match_input_resolution=True,
|
||||
)
|
||||
|
||||
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E3_B1_M1(self):
|
||||
def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
|
||||
self._test_marigold_intrinsics(
|
||||
is_fp16=True,
|
||||
device=torch_device,
|
||||
device="cuda",
|
||||
generator_seed=0,
|
||||
expected_slice=np.array([0.61572, 0.61377, 0.61182, 0.61426, 0.61377, 0.61426, 0.61279, 0.61572, 0.62354]),
|
||||
num_inference_steps=1,
|
||||
@@ -544,10 +543,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
|
||||
match_input_resolution=True,
|
||||
)
|
||||
|
||||
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E4_B2_M1(self):
|
||||
def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
|
||||
self._test_marigold_intrinsics(
|
||||
is_fp16=True,
|
||||
device=torch_device,
|
||||
device="cuda",
|
||||
generator_seed=0,
|
||||
expected_slice=np.array([0.61914, 0.6167, 0.61475, 0.61719, 0.61719, 0.61768, 0.61572, 0.61914, 0.62695]),
|
||||
num_inference_steps=1,
|
||||
@@ -558,10 +557,10 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
|
||||
match_input_resolution=True,
|
||||
)
|
||||
|
||||
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P512_E1_B1_M0(self):
|
||||
def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P512_E1_B1_M0(self):
|
||||
self._test_marigold_intrinsics(
|
||||
is_fp16=True,
|
||||
device=torch_device,
|
||||
device="cuda",
|
||||
generator_seed=0,
|
||||
expected_slice=np.array([0.65332, 0.64697, 0.64648, 0.64844, 0.64697, 0.64111, 0.64941, 0.64209, 0.65332]),
|
||||
num_inference_steps=1,
|
||||
|
||||
@@ -24,15 +24,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline
|
||||
from diffusers.utils.testing_utils import (
|
||||
Expectations,
|
||||
backend_empty_cache,
|
||||
floats_tensor,
|
||||
nightly,
|
||||
require_accelerator,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, nightly, require_accelerator, require_torch_gpu, torch_device
|
||||
|
||||
|
||||
class SafeDiffusionPipelineFastTests(unittest.TestCase):
|
||||
@@ -40,13 +32,13 @@ class SafeDiffusionPipelineFastTests(unittest.TestCase):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def dummy_image(self):
|
||||
@@ -270,19 +262,19 @@ class SafeDiffusionPipelineFastTests(unittest.TestCase):
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_harm_safe_stable_diffusion(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained(
|
||||
@@ -316,14 +308,7 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("xpu", 3): [0.0076, 0.0058, 0.0012, 0, 0.0047, 0.0046, 0, 0, 0],
|
||||
("cuda", 7): [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176],
|
||||
("cuda", 8): [0.0076, 0.0058, 0.0012, 0, 0.0047, 0.0046, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
expected_slice = [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
@@ -350,15 +335,6 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719]
|
||||
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("xpu", 3): [0.0443, 0.0439, 0.0381, 0.0336, 0.0408, 0.0345, 0.0405, 0.0338, 0.0293],
|
||||
("cuda", 7): [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719],
|
||||
("cuda", 8): [0.0443, 0.0439, 0.0381, 0.0336, 0.0408, 0.0345, 0.0405, 0.0338, 0.0293],
|
||||
}
|
||||
)
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
@@ -389,14 +365,8 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("xpu", 3): [0.3244, 0.3355, 0.3260, 0.3123, 0.3246, 0.3426, 0.3109, 0.3471, 0.4001],
|
||||
("cuda", 7): [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297],
|
||||
("cuda", 8): [0.3605, 0.3684, 0.3712, 0.3624, 0.3675, 0.3726, 0.3494, 0.3748, 0.4044],
|
||||
}
|
||||
)
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
expected_slice = [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
@@ -419,16 +389,7 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("xpu", 3): [0.6178, 0.6260, 0.6194, 0.6435, 0.6265, 0.6461, 0.6567, 0.6576, 0.6444],
|
||||
("cuda", 7): [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443],
|
||||
("cuda", 8): [0.5892, 0.5959, 0.5914, 0.6123, 0.5982, 0.6141, 0.6180, 0.6262, 0.6171],
|
||||
}
|
||||
)
|
||||
|
||||
print(f"image_slice: {image_slice.flatten()}")
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
expected_slice = [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
@@ -484,14 +445,7 @@ class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("xpu", 3): np.array([0.0695, 0.1244, 0.1831, 0.0527, 0.0444, 0.1660, 0.0572, 0.0677, 0.1551]),
|
||||
("cuda", 7): np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561]),
|
||||
("cuda", 8): np.array([0.0695, 0.1244, 0.1831, 0.0527, 0.0444, 0.1660, 0.0572, 0.0677, 0.1551]),
|
||||
}
|
||||
)
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
|
||||
expected_slice = np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561])
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -1485,8 +1485,8 @@ class PipelineTesterMixin:
|
||||
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
|
||||
self.assertTrue(all(device == torch_device for device in model_devices))
|
||||
|
||||
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
|
||||
self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
|
||||
output_cuda = pipe(**self.get_dummy_inputs(torch_device))[0]
|
||||
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
|
||||
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
@@ -1677,11 +1677,11 @@ class PipelineTesterMixin:
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_offload = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_offload_twice = pipe(**inputs)[0]
|
||||
|
||||
@@ -2226,7 +2226,7 @@ class PipelineTesterMixin:
|
||||
|
||||
def enable_group_offload_on_component(pipe, group_offloading_kwargs):
|
||||
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
|
||||
# tiling is enabled and a forward pass is run, when accelerator streams are used, the execution order of
|
||||
# tiling is enabled and a forward pass is run, when cuda streams are used, the execution order of
|
||||
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
|
||||
# warmup forward pass (even with dummy small inputs) is recommended.
|
||||
for component_name in [
|
||||
|
||||
@@ -22,13 +22,13 @@ from diffusers import (
|
||||
UniDiffuserTextDecoder,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
nightly,
|
||||
require_torch_2,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
torch_device,
|
||||
)
|
||||
@@ -577,24 +577,24 @@ class UniDiffuserPipelineFastTests(
|
||||
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
|
||||
|
||||
@unittest.skip(
|
||||
"Test not supported because it has a bunch of direct configs at init and also, this pipeline isn't used that much now."
|
||||
"Test not supported becauseit has a bunch of direct configs at init and also, this pipeline isn't used that much now."
|
||||
)
|
||||
def test_encode_prompt_works_in_isolation():
|
||||
pass
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class UniDiffuserPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, seed=0, generate_latents=False):
|
||||
generator = torch.manual_seed(seed)
|
||||
@@ -705,17 +705,17 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_torch_gpu
|
||||
class UniDiffuserPipelineNightlyTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, seed=0, generate_latents=False):
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
@@ -5,7 +5,7 @@ import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.utils.testing_utils import (
|
||||
numpy_cosine_similarity_distance,
|
||||
|
||||
@@ -18,9 +18,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
)
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
|
||||
@@ -16,9 +16,7 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
)
|
||||
from diffusers import AutoencoderKLWan
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
|
||||
@@ -18,13 +18,11 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
WanTransformer3DModel,
|
||||
)
|
||||
from diffusers import WanTransformer3DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_big_accelerator,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
@@ -62,7 +60,7 @@ class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
@require_big_accelerator
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_torch_accelerator
|
||||
class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase):
|
||||
model_class = WanTransformer3DModel
|
||||
|
||||
@@ -3,9 +3,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
SanaTransformer2DModel,
|
||||
)
|
||||
from diffusers import SanaTransformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
|
||||
@@ -8,7 +8,7 @@ from diffusers import (
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
T2IAdapter,
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
|
||||
@@ -98,7 +98,7 @@ if __name__ == "__main__":
|
||||
},
|
||||
"LoRA Mixins": {
|
||||
"doc_path": "docs/source/en/api/loaders/lora.md",
|
||||
"src_path": "src/diffusers/loaders/lora_pipeline.py",
|
||||
"src_path": "src/diffusers/loaders/lora/lora_pipeline.py",
|
||||
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
|
||||
"src_regex": r"class\s+(\w+LoraLoaderMixin(?:\d*_?\d*))[:(]",
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user