Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cb8c90b683 | |||
| 278c16aa05 | |||
| cfbfce4e90 | |||
| b8a61d6fb9 |
@@ -111,21 +111,3 @@ jobs:
|
||||
-s -v \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/lora/
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
--make-reports=tests_models_lora_${{ matrix.config.report }} \
|
||||
tests/models/ -k "lora"
|
||||
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
cat reports/tests_models_lora_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
@@ -189,17 +189,12 @@ jobs:
|
||||
-s -v -k "not Flax and not Onnx and not PEFTLoRALoading" \
|
||||
--make-reports=tests_peft_cuda \
|
||||
tests/lora/
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "lora and not Flax and not Onnx and not PEFTLoRALoading" \
|
||||
--make-reports=tests_peft_cuda_models_lora \
|
||||
tests/models/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_peft_cuda_stats.txt
|
||||
cat reports/tests_peft_cuda_failures_short.txt
|
||||
cat reports/tests_peft_cuda_models_lora_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
nvidia-smi
|
||||
|
||||
- name: Tailscale # In order to be able to SSH when a test fails
|
||||
uses: huggingface/tailscale-action@main
|
||||
uses: huggingface/tailscale-action@v1
|
||||
with:
|
||||
authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }}
|
||||
slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}
|
||||
|
||||
@@ -237,19 +237,17 @@
|
||||
- local: api/models/consistency_decoder_vae
|
||||
title: ConsistencyDecoderVAE
|
||||
- local: api/models/transformer2d
|
||||
title: Transformer2DModel
|
||||
title: Transformer2D
|
||||
- local: api/models/pixart_transformer2d
|
||||
title: PixArtTransformer2DModel
|
||||
title: PixArtTransformer2D
|
||||
- local: api/models/dit_transformer2d
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
title: HunyuanDiT2DModel
|
||||
title: DiTTransformer2D
|
||||
- local: api/models/transformer_temporal
|
||||
title: TransformerTemporalModel
|
||||
title: Transformer Temporal
|
||||
- local: api/models/prior_transformer
|
||||
title: PriorTransformer
|
||||
title: Prior Transformer
|
||||
- local: api/models/controlnet
|
||||
title: ControlNetModel
|
||||
title: ControlNet
|
||||
title: Models
|
||||
isExpanded: false
|
||||
- sections:
|
||||
@@ -291,8 +289,6 @@
|
||||
title: DiffEdit
|
||||
- local: api/pipelines/dit
|
||||
title: DiT
|
||||
- local: api/pipelines/hunyuandit
|
||||
title: Hunyuan-DiT
|
||||
- local: api/pipelines/i2vgenxl
|
||||
title: I2VGen-XL
|
||||
- local: api/pipelines/pix2pix
|
||||
@@ -461,4 +457,4 @@
|
||||
title: Video Processor
|
||||
title: Internal classes
|
||||
isExpanded: false
|
||||
title: API
|
||||
title: API
|
||||
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNetModel
|
||||
# ControlNet
|
||||
|
||||
The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# DiTTransformer2DModel
|
||||
# DiTTransformer2D
|
||||
|
||||
A Transformer model for image-like data from [DiT](https://huggingface.co/papers/2212.09748).
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# HunyuanDiT2DModel
|
||||
|
||||
A Diffusion Transformer model for 2D data from [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT).
|
||||
|
||||
## HunyuanDiT2DModel
|
||||
|
||||
[[autodoc]] HunyuanDiT2DModel
|
||||
|
||||
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# PixArtTransformer2DModel
|
||||
# PixArtTransformer2D
|
||||
|
||||
A Transformer model for image-like data from [PixArt-Alpha](https://huggingface.co/papers/2310.00426) and [PixArt-Sigma](https://huggingface.co/papers/2403.04692).
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# PriorTransformer
|
||||
# Prior Transformer
|
||||
|
||||
The Prior Transformer was originally introduced in [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://huggingface.co/papers/2204.06125) by Ramesh et al. It is used to predict CLIP image embeddings from CLIP text embeddings; image embeddings are predicted through a denoising diffusion process.
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Transformer2DModel
|
||||
# Transformer2D
|
||||
|
||||
A Transformer model for image-like data from [CompVis](https://huggingface.co/CompVis) that is based on the [Vision Transformer](https://huggingface.co/papers/2010.11929) introduced by Dosovitskiy et al. The [`Transformer2DModel`] accepts discrete (classes of vector embeddings) or continuous (actual embeddings) inputs.
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# TransformerTemporalModel
|
||||
# Transformer Temporal
|
||||
|
||||
A Transformer model for video-like data.
|
||||
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Hunyuan-DiT
|
||||

|
||||
|
||||
[Hunyuan-DiT : A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding](https://arxiv.org/abs/2405.08748) from Tencent Hunyuan.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We present Hunyuan-DiT, a text-to-image diffusion transformer with fine-grained understanding of both English and Chinese. To construct Hunyuan-DiT, we carefully design the transformer structure, text encoder, and positional encoding. We also build from scratch a whole data pipeline to update and evaluate data for iterative model optimization. For fine-grained language understanding, we train a Multimodal Large Language Model to refine the captions of the images. Finally, Hunyuan-DiT can perform multi-turn multimodal dialogue with users, generating and refining images according to the context. Through our holistic human evaluation protocol with more than 50 professional human evaluators, Hunyuan-DiT sets a new state-of-the-art in Chinese-to-image generation compared with other open-source models.*
|
||||
|
||||
|
||||
You can find the original codebase at [Tencent/HunyuanDiT](https://github.com/Tencent/HunyuanDiT) and all the available checkpoints at [Tencent-Hunyuan](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT).
|
||||
|
||||
**Highlights**: HunyuanDiT supports Chinese/English-to-image, multi-resolution generation.
|
||||
|
||||
HunyuanDiT has the following components:
|
||||
* It uses a diffusion transformer as the backbone
|
||||
* It combines two text encoders, a bilingual CLIP and a multilingual T5 encoder
|
||||
|
||||
|
||||
## Memory optimization
|
||||
|
||||
By loading the T5 text encoder in 8 bits, you can run the pipeline in just under 6 GBs of GPU VRAM. Refer to [this script](https://gist.github.com/sayakpaul/3154605f6af05b98a41081aaba5ca43e) for details.
|
||||
|
||||
## HunyuanDiTPipeline
|
||||
|
||||
[[autodoc]] HunyuanDiTPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
+141
-17
@@ -22,14 +22,17 @@ import torch
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from ..models.modeling_utils import load_state_dict
|
||||
from .. import __version__
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
@@ -116,10 +119,13 @@ class LoraLoaderMixin:
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
)
|
||||
@@ -130,6 +136,7 @@ class LoraLoaderMixin:
|
||||
if not hasattr(self, "text_encoder")
|
||||
else self.text_encoder,
|
||||
lora_scale=self.lora_scale,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
)
|
||||
@@ -186,8 +193,16 @@ class LoraLoaderMixin:
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
weight_name (`str`, *optional*, defaults to None):
|
||||
Name of the serialized state dict file.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# UNet and text encoder or both.
|
||||
@@ -368,7 +383,9 @@ class LoraLoaderMixin:
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
|
||||
def load_lora_into_unet(
|
||||
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
@@ -378,11 +395,14 @@ class LoraLoaderMixin:
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
See `LoRALinearLayer` for more details.
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
@@ -390,18 +410,94 @@ class LoraLoaderMixin:
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
|
||||
|
||||
if any(key.startswith(cls.unet_name) for key in keys) and not only_text_encoder:
|
||||
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_attn_procs(
|
||||
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
|
||||
)
|
||||
|
||||
unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
|
||||
state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)]
|
||||
network_alphas = {
|
||||
k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
else:
|
||||
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
||||
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
||||
if not USE_PEFT_BACKEND:
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
logger.warning(warn_message)
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
if adapter_name in getattr(unet, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
|
||||
)
|
||||
|
||||
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||
|
||||
if network_alphas is not None:
|
||||
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
|
||||
# `convert_unet_state_dict_to_peft` method.
|
||||
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(unet)
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
unet.load_attn_procs(
|
||||
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_text_encoder(
|
||||
@@ -411,6 +507,7 @@ class LoraLoaderMixin:
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
low_cpu_mem_usage=None,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
):
|
||||
@@ -430,6 +527,11 @@ class LoraLoaderMixin:
|
||||
lora_scale (`float`):
|
||||
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
||||
lora layer.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
@@ -439,6 +541,8 @@ class LoraLoaderMixin:
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
@@ -521,7 +625,9 @@ class LoraLoaderMixin:
|
||||
# Unsafe code />
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
@@ -534,12 +640,19 @@ class LoraLoaderMixin:
|
||||
See `LoRALinearLayer` for more details.
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
||||
@@ -733,11 +846,22 @@ class LoraLoaderMixin:
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.unload_lora()
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
logger.warning(
|
||||
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
|
||||
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
|
||||
)
|
||||
|
||||
for _, module in unet.named_modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
else:
|
||||
recurse_remove_peft_layers(unet)
|
||||
if hasattr(unet, "peft_config"):
|
||||
del unet.peft_config
|
||||
|
||||
# Safe to call the following regardless of LoRA.
|
||||
self._remove_text_encoder_monkey_patch()
|
||||
|
||||
+269
-272
@@ -33,32 +33,34 @@ from ..models.embeddings import (
|
||||
IPAdapterPlusImageProjection,
|
||||
MultiIPAdapterImageProjection,
|
||||
)
|
||||
from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_unet_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_version,
|
||||
is_torch_version,
|
||||
logging,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
TEXT_ENCODER_NAME = "text_encoder"
|
||||
UNET_NAME = "unet"
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
||||
|
||||
@@ -77,8 +79,7 @@ class UNet2DConditionLoadersMixin:
|
||||
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
|
||||
defined in
|
||||
[`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
|
||||
and be a `torch.nn.Module` class. Currently supported: LoRA, Custom Diffusion. For LoRA, one must install
|
||||
`peft`: `pip install -U peft`.
|
||||
and be a `torch.nn.Module` class.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
@@ -109,20 +110,20 @@ class UNet2DConditionLoadersMixin:
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
adapter_name (`str`, *optional*, defaults to None):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
weight_name (`str`, *optional*, defaults to None):
|
||||
Name of the serialized state dict file.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -138,6 +139,9 @@ class UNet2DConditionLoadersMixin:
|
||||
)
|
||||
```
|
||||
"""
|
||||
from ..models.attention_processor import CustomDiffusionAttnProcessor
|
||||
from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
@@ -148,9 +152,15 @@ class UNet2DConditionLoadersMixin:
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
adapter_name = kwargs.pop("adapter_name", None)
|
||||
_pipeline = kwargs.pop("_pipeline", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
|
||||
_pipeline = kwargs.pop("_pipeline", None)
|
||||
|
||||
is_network_alphas_none = network_alphas is None
|
||||
|
||||
allow_pickle = False
|
||||
|
||||
if use_safetensors is None:
|
||||
@@ -206,196 +216,198 @@ class UNet2DConditionLoadersMixin:
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
||||
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
# fill attn processors
|
||||
lora_layers_list = []
|
||||
|
||||
if is_custom_diffusion:
|
||||
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
|
||||
elif is_lora:
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
|
||||
state_dict=state_dict,
|
||||
unet_identifier_key=self.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
)
|
||||
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
|
||||
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
||||
|
||||
if is_lora:
|
||||
# correct keys
|
||||
state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
|
||||
|
||||
if network_alphas is not None:
|
||||
network_alphas_keys = list(network_alphas.keys())
|
||||
used_network_alphas_keys = set()
|
||||
|
||||
lora_grouped_dict = defaultdict(dict)
|
||||
mapped_network_alphas = {}
|
||||
|
||||
all_keys = list(state_dict.keys())
|
||||
for key in all_keys:
|
||||
value = state_dict.pop(key)
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
||||
lora_grouped_dict[attn_processor_key][sub_key] = value
|
||||
|
||||
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
|
||||
if network_alphas is not None:
|
||||
for k in network_alphas_keys:
|
||||
if k.replace(".alpha", "") in key:
|
||||
mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
|
||||
used_network_alphas_keys.add(k)
|
||||
|
||||
if not is_network_alphas_none:
|
||||
if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
|
||||
raise ValueError(
|
||||
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
|
||||
)
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(
|
||||
f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
|
||||
)
|
||||
|
||||
for key, value_dict in lora_grouped_dict.items():
|
||||
attn_processor = self
|
||||
for sub_key in key.split("."):
|
||||
attn_processor = getattr(attn_processor, sub_key)
|
||||
|
||||
# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
|
||||
# or add_{k,v,q,out_proj}_proj_lora layers.
|
||||
rank = value_dict["lora.down.weight"].shape[0]
|
||||
|
||||
if isinstance(attn_processor, LoRACompatibleConv):
|
||||
in_features = attn_processor.in_channels
|
||||
out_features = attn_processor.out_channels
|
||||
kernel_size = attn_processor.kernel_size
|
||||
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
lora = LoRAConv2dLayer(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
rank=rank,
|
||||
kernel_size=kernel_size,
|
||||
stride=attn_processor.stride,
|
||||
padding=attn_processor.padding,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
)
|
||||
elif isinstance(attn_processor, LoRACompatibleLinear):
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
lora = LoRALinearLayer(
|
||||
attn_processor.in_features,
|
||||
attn_processor.out_features,
|
||||
rank,
|
||||
mapped_network_alphas.get(key),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
|
||||
|
||||
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
|
||||
lora_layers_list.append((attn_processor, lora))
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(value_dict.values())).device
|
||||
dtype = next(iter(value_dict.values())).dtype
|
||||
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
|
||||
else:
|
||||
lora.load_state_dict(value_dict)
|
||||
|
||||
elif is_custom_diffusion:
|
||||
attn_processors = {}
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
if len(value) == 0:
|
||||
custom_diffusion_grouped_dict[key] = {}
|
||||
else:
|
||||
if "to_out" in key:
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
||||
else:
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
|
||||
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
|
||||
|
||||
for key, value_dict in custom_diffusion_grouped_dict.items():
|
||||
if len(value_dict) == 0:
|
||||
attn_processors[key] = CustomDiffusionAttnProcessor(
|
||||
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
|
||||
)
|
||||
else:
|
||||
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
|
||||
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
|
||||
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
|
||||
attn_processors[key] = CustomDiffusionAttnProcessor(
|
||||
train_kv=True,
|
||||
train_q_out=train_q_out,
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
elif USE_PEFT_BACKEND:
|
||||
# In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict`
|
||||
# on the Unet
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{model_file} does not seem to be in the correct format expected by Custom Diffusion training."
|
||||
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
||||
# Now we remove any existing hooks to `_pipeline`.
|
||||
# Now we remove any existing hooks to
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
|
||||
if is_custom_diffusion and _pipeline is not None:
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
# only custom diffusion needs to set attn processors
|
||||
self.set_attn_processor(attn_processors)
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
def _process_custom_diffusion(self, state_dict):
|
||||
from ..models.attention_processor import CustomDiffusionAttnProcessor
|
||||
|
||||
attn_processors = {}
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
if len(value) == 0:
|
||||
custom_diffusion_grouped_dict[key] = {}
|
||||
else:
|
||||
if "to_out" in key:
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
||||
else:
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
|
||||
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
|
||||
|
||||
for key, value_dict in custom_diffusion_grouped_dict.items():
|
||||
if len(value_dict) == 0:
|
||||
attn_processors[key] = CustomDiffusionAttnProcessor(
|
||||
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
|
||||
)
|
||||
else:
|
||||
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
|
||||
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
|
||||
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
|
||||
attn_processors[key] = CustomDiffusionAttnProcessor(
|
||||
train_kv=True,
|
||||
train_q_out=train_q_out,
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
|
||||
return attn_processors
|
||||
|
||||
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
|
||||
# This method does the following things:
|
||||
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
|
||||
# format. For legacy format no filtering is applied.
|
||||
# 2. Converts the `state_dict` to the `peft` compatible format.
|
||||
# 3. Creates a `LoraConfig` and then injects the converted `state_dict` into the UNet per the
|
||||
# `LoraConfig` specs.
|
||||
# 4. It also reports if the underlying `_pipeline` has any kind of offloading inside of it.
|
||||
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet`
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
unet_keys = [k for k in keys if k.startswith(unet_identifier_key)]
|
||||
unet_state_dict = {
|
||||
k.replace(f"{unet_identifier_key}.", ""): v for k, v in state_dict.items() if k in unet_keys
|
||||
}
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(unet_identifier_key)]
|
||||
network_alphas = {
|
||||
k.replace(f"{unet_identifier_key}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
|
||||
|
||||
if len(state_dict_to_be_used) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
|
||||
)
|
||||
|
||||
state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)
|
||||
|
||||
if network_alphas is not None:
|
||||
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
|
||||
# `convert_unet_state_dict_to_peft` method.
|
||||
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
return is_model_cpu_offload, is_sequential_cpu_offload
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
# only custom diffusion needs to set attn processors
|
||||
if is_custom_diffusion:
|
||||
self.set_attn_processor(attn_processors)
|
||||
|
||||
# set lora layers
|
||||
for target_module, lora_layer in lora_layers_list:
|
||||
target_module.set_lora_layer(lora_layer)
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
|
||||
is_new_lora_format = all(
|
||||
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
||||
)
|
||||
if is_new_lora_format:
|
||||
# Strip the `"unet"` prefix.
|
||||
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
|
||||
if is_text_encoder_present:
|
||||
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
|
||||
logger.warning(warn_message)
|
||||
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
|
||||
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
||||
|
||||
# change processor format to 'pure' LoRACompatibleLinear format
|
||||
if any("processor" in k.split(".") for k in state_dict.keys()):
|
||||
|
||||
def format_to_lora_compatible(key):
|
||||
if "processor" not in key.split("."):
|
||||
return key
|
||||
return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora")
|
||||
|
||||
state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()}
|
||||
|
||||
if network_alphas is not None:
|
||||
network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}
|
||||
return state_dict, network_alphas
|
||||
|
||||
def save_attn_procs(
|
||||
self,
|
||||
@@ -448,23 +460,6 @@ class UNet2DConditionLoadersMixin:
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
is_custom_diffusion = any(
|
||||
isinstance(
|
||||
x,
|
||||
(CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
|
||||
)
|
||||
for (_, x) in self.attn_processors.items()
|
||||
)
|
||||
if is_custom_diffusion:
|
||||
state_dict = self._get_custom_diffusion_state_dict()
|
||||
else:
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
|
||||
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
state_dict = get_peft_model_state_dict(self)
|
||||
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
@@ -476,6 +471,36 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
is_custom_diffusion = any(
|
||||
isinstance(
|
||||
x,
|
||||
(CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
|
||||
)
|
||||
for (_, x) in self.attn_processors.items()
|
||||
)
|
||||
if is_custom_diffusion:
|
||||
model_to_save = AttnProcsLayers(
|
||||
{
|
||||
y: x
|
||||
for (y, x) in self.attn_processors.items()
|
||||
if isinstance(
|
||||
x,
|
||||
(
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
),
|
||||
)
|
||||
}
|
||||
)
|
||||
state_dict = model_to_save.state_dict()
|
||||
for name, attn in self.attn_processors.items():
|
||||
if len(attn.state_dict()) == 0:
|
||||
state_dict[name] = {}
|
||||
else:
|
||||
model_to_save = AttnProcsLayers(self.attn_processors)
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
|
||||
@@ -487,84 +512,56 @@ class UNet2DConditionLoadersMixin:
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
def _get_custom_diffusion_state_dict(self):
|
||||
from ..models.attention_processor import (
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
)
|
||||
|
||||
model_to_save = AttnProcsLayers(
|
||||
{
|
||||
y: x
|
||||
for (y, x) in self.attn_processors.items()
|
||||
if isinstance(
|
||||
x,
|
||||
(
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
),
|
||||
)
|
||||
}
|
||||
)
|
||||
state_dict = model_to_save.state_dict()
|
||||
for name, attn in self.attn_processors.items():
|
||||
if len(attn.state_dict()) == 0:
|
||||
state_dict[name] = {}
|
||||
|
||||
return state_dict
|
||||
|
||||
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `fuse_lora()`.")
|
||||
|
||||
self.lora_scale = lora_scale
|
||||
self._safe_fusing = safe_fusing
|
||||
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
|
||||
|
||||
def _fuse_lora_apply(self, module, adapter_names=None):
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
if not USE_PEFT_BACKEND:
|
||||
if hasattr(module, "_fuse_lora"):
|
||||
module._fuse_lora(self.lora_scale, self._safe_fusing)
|
||||
|
||||
merge_kwargs = {"safe_merge": self._safe_fusing}
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if self.lora_scale != 1.0:
|
||||
module.scale_layer(self.lora_scale)
|
||||
|
||||
# For BC with prevous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
if adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
|
||||
" to the latest version of PEFT. `pip install -U peft`"
|
||||
"The `adapter_names` argument is not supported in your environment. Please switch"
|
||||
" to PEFT backend to use this argument by installing latest PEFT and transformers."
|
||||
" `pip install -U peft transformers`"
|
||||
)
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
merge_kwargs = {"safe_merge": self._safe_fusing}
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if self.lora_scale != 1.0:
|
||||
module.scale_layer(self.lora_scale)
|
||||
|
||||
# For BC with prevous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
|
||||
" to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
def unfuse_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `unfuse_lora()`.")
|
||||
self.apply(self._unfuse_lora_apply)
|
||||
|
||||
def _unfuse_lora_apply(self, module):
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
def unload_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `unload_lora()`.")
|
||||
if hasattr(module, "_unfuse_lora"):
|
||||
module._unfuse_lora()
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
from ..utils import recurse_remove_peft_layers
|
||||
|
||||
recurse_remove_peft_layers(self)
|
||||
if hasattr(self, "peft_config"):
|
||||
del self.peft_config
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
|
||||
@@ -176,7 +176,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, generator, sample, mask).sample
|
||||
dec = self.decode(z, sample, mask).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
@@ -71,22 +71,18 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
|
||||
|
||||
def _fetch_remapped_cls_from_config(config, old_class):
|
||||
previous_class_name = old_class.__name__
|
||||
remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None)
|
||||
remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"])
|
||||
|
||||
# Details:
|
||||
# https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818
|
||||
if remapped_class_name:
|
||||
# load diffusers library to import compatible and original scheduler
|
||||
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
||||
remapped_class = getattr(diffusers_library, remapped_class_name)
|
||||
logger.info(
|
||||
f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
|
||||
f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this"
|
||||
" DOESN'T affect the final results."
|
||||
)
|
||||
return remapped_class
|
||||
else:
|
||||
return old_class
|
||||
# load diffusers library to import compatible and original scheduler
|
||||
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
||||
remapped_class = getattr(diffusers_library, remapped_class_name)
|
||||
logger.info(
|
||||
f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
|
||||
"This is because `previous_class_name` is scheduled to be deprecated in a future version. Note that this"
|
||||
" DOESN'T affect the final results."
|
||||
)
|
||||
|
||||
return remapped_class
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor, HunyuanAttnProcessor2_0
|
||||
from ..attention_processor import Attention, HunyuanAttnProcessor2_0
|
||||
from ..embeddings import (
|
||||
HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
||||
PatchEmbed,
|
||||
@@ -166,7 +166,6 @@ class HunyuanDiTBlock(nn.Module):
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
@@ -322,110 +321,6 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
self.set_attn_processor(HunyuanAttnProcessor2_0())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@@ -530,45 +425,3 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
||||
|
||||
Parameters:
|
||||
chunk_size (`int`, *optional*):
|
||||
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
||||
over each tensor of dim=`dim`.
|
||||
dim (`int`, *optional*, defaults to `0`):
|
||||
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
||||
or dim=1 (sequence length).
|
||||
"""
|
||||
if dim not in [0, 1]:
|
||||
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
||||
|
||||
# By default chunk size is 1
|
||||
chunk_size = chunk_size or 1
|
||||
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
||||
def disable_forward_chunking(self):
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, None, 0)
|
||||
|
||||
@@ -903,6 +903,17 @@ class UNet2DConditionModel(
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def unload_lora(self):
|
||||
"""Unloads LoRA weights."""
|
||||
deprecate(
|
||||
"unload_lora",
|
||||
"0.28.0",
|
||||
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
|
||||
)
|
||||
for module in self.modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
|
||||
def get_time_embed(
|
||||
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput, logging
|
||||
from ...utils import BaseOutput, deprecate, logging
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -546,6 +546,18 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora
|
||||
def unload_lora(self):
|
||||
"""Unloads LoRA weights."""
|
||||
deprecate(
|
||||
"unload_lora",
|
||||
"0.28.0",
|
||||
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
|
||||
)
|
||||
for module in self.modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
|
||||
@@ -52,9 +52,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import torch
|
||||
>>> from diffusers import HunyuanDiTPipeline
|
||||
|
||||
>>> pipe = HunyuanDiTPipeline.from_pretrained(
|
||||
... "Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT", torch_dtype=torch.float16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> # You may also use English prompt as HunyuanDiT supports both English and Chinese
|
||||
@@ -228,22 +226,16 @@ class HunyuanDiTPipeline(DiffusionPipeline):
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
self.default_sample_size = (
|
||||
self.transformer.config.sample_size
|
||||
if hasattr(self, "transformer") and self.transformer is not None
|
||||
else 128
|
||||
)
|
||||
self.default_sample_size = self.transformer.config.sample_size
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
device: torch.device = None,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[str] = None,
|
||||
@@ -287,17 +279,6 @@ class HunyuanDiTPipeline(DiffusionPipeline):
|
||||
text_encoder_index (`int`, *optional*):
|
||||
Index of the text encoder to use. `0` for clip and `1` for T5.
|
||||
"""
|
||||
if dtype is None:
|
||||
if self.text_encoder_2 is not None:
|
||||
dtype = self.text_encoder_2.dtype
|
||||
elif self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
tokenizers = [self.tokenizer, self.tokenizer_2]
|
||||
text_encoders = [self.text_encoder, self.text_encoder_2]
|
||||
|
||||
|
||||
@@ -51,8 +51,8 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import StableUnCLIPImg2ImgPipeline
|
||||
|
||||
>>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-2-1-unclip-small", torch_dtype=torch.float16
|
||||
... )
|
||||
... "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16
|
||||
... ) # TODO update model path
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
@@ -63,7 +63,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
>>> prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
>>> images = pipe(init_image, prompt).images
|
||||
>>> images = pipe(prompt, init_image).images
|
||||
>>> images[0].save("fantasy_landscape.png")
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -243,13 +243,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
ramp = torch.linspace(0, 1, self.num_inference_steps)
|
||||
ramp = np.linspace(0, 1, self.num_inference_steps)
|
||||
if self.config.sigma_schedule == "karras":
|
||||
sigmas = self._compute_karras_sigmas(ramp)
|
||||
elif self.config.sigma_schedule == "exponential":
|
||||
sigmas = self._compute_exponential_sigmas(ramp)
|
||||
|
||||
sigmas = sigmas.to(dtype=torch.float32, device=device)
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
self.timesteps = self.precondition_noise(sigmas)
|
||||
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
@@ -283,6 +283,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
|
||||
|
||||
@@ -16,6 +16,7 @@ import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
@@ -209,13 +210,13 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
ramp = torch.linspace(0, 1, self.num_inference_steps)
|
||||
ramp = np.linspace(0, 1, self.num_inference_steps)
|
||||
if self.config.sigma_schedule == "karras":
|
||||
sigmas = self._compute_karras_sigmas(ramp)
|
||||
elif self.config.sigma_schedule == "exponential":
|
||||
sigmas = self._compute_exponential_sigmas(ramp)
|
||||
|
||||
sigmas = sigmas.to(dtype=torch.float32, device=device)
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
self.timesteps = self.precondition_noise(sigmas)
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
@@ -233,6 +234,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
|
||||
return sigmas
|
||||
|
||||
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
||||
|
||||
@@ -37,9 +37,7 @@ from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
is_peft_available,
|
||||
load_hf_numpy,
|
||||
require_peft_backend,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_fp16,
|
||||
require_torch_accelerator_with_training,
|
||||
@@ -53,38 +51,11 @@ from diffusers.utils.testing_utils import (
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
def get_unet_lora_config():
|
||||
rank = 4
|
||||
unet_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=rank,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
return unet_lora_config
|
||||
|
||||
|
||||
def check_if_lora_correctly_set(model) -> bool:
|
||||
"""
|
||||
Checks if the LoRA layers are correctly set with peft
|
||||
"""
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def create_ip_adapter_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
ip_cross_attn_state_dict = {}
|
||||
@@ -1034,65 +1005,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@require_peft_backend
|
||||
def test_lora(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without LoRA
|
||||
with torch.no_grad():
|
||||
non_lora_sample = model(**inputs_dict).sample
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
# forward pass with LoRA
|
||||
with torch.no_grad():
|
||||
lora_sample = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(
|
||||
non_lora_sample, lora_sample, atol=1e-4, rtol=1e-4
|
||||
), "LoRA injected UNet should produce different results."
|
||||
|
||||
@require_peft_backend
|
||||
def test_lora_serialization(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without LoRA
|
||||
with torch.no_grad():
|
||||
non_lora_sample = model(**inputs_dict).sample
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
# forward pass with LoRA
|
||||
with torch.no_grad():
|
||||
lora_sample_1 = model(**inputs_dict).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
model.unload_lora()
|
||||
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with torch.no_grad():
|
||||
lora_sample_2 = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(
|
||||
non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4
|
||||
), "LoRA injected UNet should produce different results."
|
||||
assert torch.allclose(
|
||||
lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
|
||||
), "Loading from a saved checkpoint should produce identical results."
|
||||
|
||||
|
||||
@slow
|
||||
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
@@ -228,60 +228,6 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_no_chunking = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_chunking = image[0, -3:, -3:, -1]
|
||||
|
||||
max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max()
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["return_dict"] = False
|
||||
image = pipe(**inputs)[0]
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["return_dict"] = False
|
||||
image_fused = pipe(**inputs)[0]
|
||||
image_slice_fused = image_fused[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["return_dict"] = False
|
||||
image_disabled = pipe(**inputs)[0]
|
||||
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user