Compare commits

..

11 Commits

Author SHA1 Message Date
Patrick von Platen 29f15673ed Release: v0.21.0 2023-09-13 15:58:24 +02:00
Patrick von Platen 1037287e2b examples fix t2i training (#5001)
* examples fix t2i training

* make style
2023-09-12 23:52:41 +02:00
Steven Liu 6ea95b7a90 Fix PR template (#4984)
fix template
2023-09-12 19:36:38 +02:00
Patrick von Platen 0e0db625d0 Fix safety checker seq offload (#4998)
* fix safety checker

* fix safety checker

* fix safety checker
2023-09-12 18:56:35 +02:00
dg845 1f948109b8 [docs] Fix DiffusionPipeline.enable_sequential_cpu_offload docstring (#4952)
* Fix an unmatched backtick and make description more general for DiffusionPipeline.enable_sequential_cpu_offload.

* make style

* _exclude_from_cpu_offload -> self._exclude_from_cpu_offload

* make style

* apply suggestions from review

* make style
2023-09-12 08:58:47 -07:00
Patrick von Platen 37cb819df5 [Lora] Speed up lora loading (#4994)
* speed up lora loading

* Apply suggestions from code review

* up

* up

* Fix more

* Correct more

* Apply suggestions from code review

* up

* Fix more

* Fix more -

* up

* up
2023-09-12 17:51:15 +02:00
Dhruv Nair f64d52dbca fix custom diffusion tests (#4996) 2023-09-12 17:50:47 +02:00
Dhruv Nair 4d897aaff5 fix image variation slow test (#4995)
fix image variation tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-09-12 17:45:47 +02:00
Patrick von Platen b1105269b7 make style 2023-09-12 14:55:27 +00:00
Kashif Rasul 5d28d2217f [Wuerstchen] fix combined pipeline's num_images_per_prompt (#4989)
* fix encode_prompt

* added prompt_embeds and negative_prompt_embeds

* prompt_embeds for the prior only
2023-09-12 16:55:13 +02:00
Kashif Rasul 73bf620dec fix E721 Do not compare types, use isinstance() (#4992) 2023-09-12 16:52:25 +02:00
65 changed files with 325 additions and 207 deletions
+1 -1
View File
@@ -41,7 +41,7 @@ Core library:
- Schedulers: @williamberman and @patrickvonplaten
- Pipelines: @patrickvonplaten and @sayakpaul
- Training examples: @sayakpaul and @patrickvonplaten
- Docs: @stevenliu and @yiyixu
- Docs: @stevhliu and @yiyixuxu
- JAX and MPS: @pcuenca
- Audio: @sanchit-gandhi
- General functionalities: @patrickvonplaten and @sayakpaul
@@ -1138,7 +1138,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 7.1 Apply denoising_end
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
@@ -701,7 +701,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 10.1 Apply denoising_end
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
+1 -1
View File
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = logging.getLogger(__name__)
+1 -1
View File
@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)
+1 -1
View File
@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
+1 -1
View File
@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO")
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO")
@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)
@@ -1060,7 +1060,9 @@ def main(args):
)
# Prepare everything with our `accelerator`.
t2iadapter, optimizer, lr_scheduler = accelerator.prepare(t2iadapter, optimizer, lr_scheduler)
t2iadapter, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
t2iadapter, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -53,7 +53,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO")
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = logging.getLogger(__name__)
@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO")
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = logging.getLogger(__name__)
@@ -30,7 +30,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO")
+1 -1
View File
@@ -244,7 +244,7 @@ install_requires = [
setup(
name="diffusers",
version="0.21.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.21.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
+1 -1
View File
@@ -1,4 +1,4 @@
__version__ = "0.21.0.dev0"
__version__ = "0.21.0"
from typing import TYPE_CHECKING
@@ -76,7 +76,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
return x_in * self.stds[key] + self.means[key]
def to_torch(self, x_in):
if type(x_in) is dict:
if isinstance(x_in, dict):
return {k: self.to_torch(v) for k, v in x_in.items()}
elif torch.is_tensor(x_in):
return x_in.to(self.unet.device)
+116 -87
View File
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import re
import warnings
@@ -27,6 +26,7 @@ import torch
from huggingface_hub import hf_hub_download, model_info
from torch import nn
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
@@ -46,7 +46,6 @@ if is_transformers_available():
if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
from accelerate.utils import set_module_tensor_to_device
logger = logging.get_logger(__name__)
@@ -137,7 +136,6 @@ class PatchedLoraProjection(nn.Module):
self.w_down = None
def forward(self, input):
# print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}")
if self.lora_scale is None:
self.lora_scale = 1.0
if self.lora_linear_layer is None:
@@ -274,6 +272,11 @@ class UNet2DConditionLoadersMixin:
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
@@ -300,6 +303,7 @@ class UNet2DConditionLoadersMixin:
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alphas = kwargs.pop("network_alphas", None)
@@ -316,6 +320,15 @@ class UNet2DConditionLoadersMixin:
"framework": "pytorch",
}
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
@@ -370,6 +383,10 @@ class UNet2DConditionLoadersMixin:
# correct keys
state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
if network_alphas is not None:
network_alphas_keys = list(network_alphas.keys())
used_network_alphas_keys = set()
lora_grouped_dict = defaultdict(dict)
mapped_network_alphas = {}
@@ -381,13 +398,13 @@ class UNet2DConditionLoadersMixin:
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
if network_alphas is not None:
network_alphas_ = copy.deepcopy(network_alphas)
for k in network_alphas_:
for k in network_alphas_keys:
if k.replace(".alpha", "") in key:
mapped_network_alphas.update({attn_processor_key: network_alphas.pop(k)})
mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
used_network_alphas_keys.add(k)
if not is_network_alphas_none:
if len(network_alphas) > 0:
if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
raise ValueError(
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
)
@@ -411,29 +428,38 @@ class UNet2DConditionLoadersMixin:
out_features = attn_processor.out_channels
kernel_size = attn_processor.kernel_size
lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
network_alpha=mapped_network_alphas.get(key),
)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
network_alpha=mapped_network_alphas.get(key),
)
elif isinstance(attn_processor, LoRACompatibleLinear):
lora = LoRALinearLayer(
attn_processor.in_features,
attn_processor.out_features,
rank,
mapped_network_alphas.get(key),
)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
lora = LoRALinearLayer(
attn_processor.in_features,
attn_processor.out_features,
rank,
mapped_network_alphas.get(key),
)
else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
lora.load_state_dict(value_dict)
lora_layers_list.append((attn_processor, lora))
if low_cpu_mem_usage:
device = next(iter(value_dict.values())).device
dtype = next(iter(value_dict.values())).dtype
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
else:
lora.load_state_dict(value_dict)
elif is_custom_diffusion:
attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict)
@@ -470,13 +496,12 @@ class UNet2DConditionLoadersMixin:
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
)
# set correct dtype & device
lora_layers_list = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in lora_layers_list]
# set lora layers
for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer)
self.to(dtype=self.dtype, device=self.device)
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
is_new_lora_format = all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
@@ -999,13 +1024,18 @@ class LoraLoaderMixin:
recurive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recurive)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
self.load_lora_into_unet(
state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage
)
self.load_lora_into_text_encoder(
state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
)
# Offload back.
@@ -1065,6 +1095,11 @@ class LoraLoaderMixin:
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
mirror (`str`, *optional*):
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
@@ -1305,7 +1340,7 @@ class LoraLoaderMixin:
return new_state_dict
@classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet):
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -1318,7 +1353,13 @@ class LoraLoaderMixin:
See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
@@ -1343,11 +1384,12 @@ class LoraLoaderMixin:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message)
# load loras into unet
unet.load_attn_procs(state_dict, network_alphas=network_alphas)
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage)
@classmethod
def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0):
def load_lora_into_text_encoder(
cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1364,7 +1406,13 @@ class LoraLoaderMixin:
lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
@@ -1447,6 +1495,7 @@ class LoraLoaderMixin:
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
)
# set correct dtype & device
@@ -1454,12 +1503,23 @@ class LoraLoaderMixin:
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
for k, v in text_encoder_lora_state_dict.items()
}
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
if len(load_state_dict_results.unexpected_keys) != 0:
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
)
else:
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
unexpected_keys = load_state_dict_results.unexpected_keys
if len(unexpected_keys) != 0:
raise ValueError(
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
@property
def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline.
@@ -1492,11 +1552,21 @@ class LoraLoaderMixin:
rank: Union[Dict[str, int], int] = 4,
dtype=None,
patch_mlp=False,
low_cpu_mem_usage=False,
):
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype)
lora_parameters.extend(model.lora_linear_layer.parameters())
return model
# First, remove any monkey-patch that might have been applied before
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
@@ -1515,45 +1585,18 @@ class LoraLoaderMixin:
else:
current_rank = rank
q_linear_layer = (
attn_module.q_proj.regular_linear_layer
if isinstance(attn_module.q_proj, PatchedLoraProjection)
else attn_module.q_proj
attn_module.q_proj = create_patched_linear_lora(
attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters
)
attn_module.q_proj = PatchedLoraProjection(
q_linear_layer, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
attn_module.k_proj = create_patched_linear_lora(
attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters
)
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
k_linear_layer = (
attn_module.k_proj.regular_linear_layer
if isinstance(attn_module.k_proj, PatchedLoraProjection)
else attn_module.k_proj
attn_module.v_proj = create_patched_linear_lora(
attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters
)
attn_module.k_proj = PatchedLoraProjection(
k_linear_layer, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
attn_module.out_proj = create_patched_linear_lora(
attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters
)
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
v_linear_layer = (
attn_module.v_proj.regular_linear_layer
if isinstance(attn_module.v_proj, PatchedLoraProjection)
else attn_module.v_proj
)
attn_module.v_proj = PatchedLoraProjection(
v_linear_layer, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype
)
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
out_linear_layer = (
attn_module.out_proj.regular_linear_layer
if isinstance(attn_module.out_proj, PatchedLoraProjection)
else attn_module.out_proj
)
attn_module.out_proj = PatchedLoraProjection(
out_linear_layer, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype
)
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
if patch_mlp:
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
@@ -1563,25 +1606,12 @@ class LoraLoaderMixin:
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
fc1_linear_layer = (
mlp_module.fc1.regular_linear_layer
if isinstance(mlp_module.fc1, PatchedLoraProjection)
else mlp_module.fc1
mlp_module.fc1 = create_patched_linear_lora(
mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters
)
mlp_module.fc1 = PatchedLoraProjection(
fc1_linear_layer, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
mlp_module.fc2 = create_patched_linear_lora(
mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters
)
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
fc2_linear_layer = (
mlp_module.fc2.regular_linear_layer
if isinstance(mlp_module.fc2, PatchedLoraProjection)
else mlp_module.fc2
)
mlp_module.fc2 = PatchedLoraProjection(
fc2_linear_layer, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
)
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
if is_network_alphas_populated and len(network_alphas) > 0:
raise ValueError(
@@ -2375,8 +2405,7 @@ class FromOriginalVAEMixin:
vae = AutoencoderKL(**vae_config)
if is_accelerate_available():
for param_name, param in converted_vae_checkpoint.items():
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu")
else:
vae.load_state_dict(converted_vae_checkpoint)
+32 -22
View File
@@ -128,6 +128,31 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
)
def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
unexpected_keys = []
empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
if param_name not in empty_state_dict:
unexpected_keys.append(param_name)
continue
if empty_state_dict[param_name].shape != param.shape:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
return unexpected_keys
def _load_state_dict_into_model(model_to_load, state_dict):
# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
@@ -624,29 +649,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct."
)
unexpected_keys = []
empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
accepts_dtype = "dtype" in set(
inspect.signature(set_module_tensor_to_device).parameters.keys()
)
if param_name not in empty_state_dict:
unexpected_keys.append(param_name)
continue
if empty_state_dict[param_name].shape != param.shape:
raise ValueError(
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
if accepts_dtype:
set_module_tensor_to_device(
model, param_name, param_device, value=param, dtype=torch_dtype
)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param)
unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
device=param_device,
dtype=torch_dtype,
model_name_or_path=pretrained_model_name_or_path,
)
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
@@ -100,6 +100,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -127,6 +127,7 @@ class AltDiffusionImg2ImgPipeline(
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -178,7 +178,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(steps)
step_generator = step_generator or generator
# For backwards compatibility
if type(self.unet.config.sample_size) == int:
if isinstance(self.unet.config.sample_size, int):
self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
if noise is None:
noise = randn_tensor(
@@ -125,6 +125,7 @@ class StableDiffusionControlNetPipeline(
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -149,6 +149,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -273,6 +273,7 @@ class StableDiffusionControlNetInpaintPipeline(
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
+4 -4
View File
@@ -1293,10 +1293,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
def enable_sequential_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device, str] = "cuda"):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
Note that offloading happens on a submodule basis. Memory savings are higher than with
Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state
dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU
and then moved to `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward`
method called. Offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
@@ -101,6 +101,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -191,6 +191,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -272,6 +272,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -124,6 +124,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
"""
_optional_components = ["safety_checker", "feature_extractor"]
model_cpu_offload_seq = "text_encoder->unet->vae"
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -182,6 +182,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -66,6 +66,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
# we should give a descriptive message if the pipeline doesn't have one.
_optional_components = ["safety_checker"]
model_cpu_offload_seq = "image_encoder->unet->vae"
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -129,6 +129,7 @@ class StableDiffusionImg2ImgPipeline(
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -194,6 +194,7 @@ class StableDiffusionInpaintPipeline(
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -117,6 +117,7 @@ class StableDiffusionInpaintPipelineLegacy(
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -91,6 +91,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -82,6 +82,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -114,6 +114,7 @@ class StableDiffusionLDM3DPipeline(
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -68,6 +68,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -80,6 +80,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -95,6 +95,7 @@ class StableDiffusionParadigmsPipeline(
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -315,6 +315,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
"caption_processor",
"inverse_scheduler",
]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -119,6 +119,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -92,6 +92,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["watermarker", "safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
self,
@@ -810,7 +810,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 7.1 Apply denoising_end
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
@@ -885,7 +885,7 @@ class StableDiffusionXLImg2ImgPipeline(
# 5. Prepare timesteps
def denoising_value_valid(dnv):
return type(denoising_end) == float and 0 < dnv < 1
return isinstance(denoising_end, float) and 0 < dnv < 1
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(
@@ -1120,7 +1120,7 @@ class StableDiffusionXLInpaintPipeline(
# 4. set timesteps
def denoising_value_valid(dnv):
return type(denoising_end) == float and 0 < dnv < 1
return isinstance(denoising_end, float) and 0 < dnv < 1
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(
@@ -837,7 +837,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
# 11. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
@@ -886,7 +886,7 @@ class StableDiffusionXLAdapterPipeline(
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 7.1 Apply denoising_end
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
@@ -330,7 +330,11 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
# 2. Encode caption
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
prompt,
device,
image_embeddings.size(0) * num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
)
text_encoder_hidden_states = (
torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds
@@ -154,6 +154,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
decoder_timesteps: Optional[List[float]] = None,
decoder_guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
@@ -165,10 +167,17 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
The prompt or prompts to guide the image generation for the prior and decoder.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
input argument.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to 512):
@@ -221,13 +230,15 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
"""
prior_outputs = self.prior_pipe(
prompt=prompt,
prompt=prompt if prompt_embeds is None else None,
height=height,
width=width,
num_inference_steps=prior_num_inference_steps,
timesteps=prior_timesteps,
guidance_scale=prior_guidance_scale,
negative_prompt=negative_prompt,
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
latents=latents,
@@ -150,41 +150,57 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
prompt=None,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
if prompt_embeds is None:
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))
text_encoder_hidden_states = text_encoder_output.last_hidden_state
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
uncond_text_encoder_hidden_states = None
if do_classifier_free_guidance:
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
text_encoder_output = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask.to(device)
)
prompt_embeds = text_encoder_output.last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
if negative_prompt_embeds is None and do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
@@ -215,17 +231,17 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device)
)
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.last_hidden_state
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_text_encoder_hidden_states.shape[1]
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# done duplicates
return text_encoder_hidden_states, uncond_text_encoder_hidden_states
return prompt_embeds, negative_prompt_embeds
def check_inputs(
self,
@@ -264,13 +280,15 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt: Optional[Union[str, List[str]]] = None,
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 60,
timesteps: List[float] = None,
guidance_scale: float = 8.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
@@ -304,6 +322,13 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `decoder_guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -345,7 +370,13 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
# 2. Encode caption
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# For classifier free guidance, we need to do two forward passes.
@@ -785,8 +785,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
new_model.to(torch_device)
with torch.no_grad():
new_sample = new_model(**inputs_dict).sample
@@ -193,7 +193,7 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
return inputs
def get_fixed_latents(self, seed=0, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)):
if type(device) == str:
if isinstance(device, str):
device = torch.device(device)
generator = torch.Generator(device=device).manual_seed(seed)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -923,9 +923,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
)
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
@@ -35,6 +35,8 @@ from diffusers.utils.testing_utils import (
load_image,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
print_tensor_test,
require_torch_gpu,
slow,
torch_device,
@@ -182,7 +184,7 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 7.5,
"output_type": "numpy",
"output_type": "np",
}
return inputs
@@ -193,13 +195,17 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
generator_device = "cpu"
inputs = self.get_inputs(generator_device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.84491, 0.90789, 0.75708, 0.78734, 0.83485, 0.70099, 0.66938, 0.68727, 0.61379])
assert np.abs(image_slice - expected_slice).max() < 6e-3
expected_slice = np.array([0.8449, 0.9079, 0.7571, 0.7873, 0.8348, 0.7010, 0.6694, 0.6873, 0.6138])
print_tensor_test(image_slice)
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 1e-4
def test_stable_diffusion_img_variation_intermediate_state(self):
number_of_steps = 0
@@ -212,31 +218,36 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[-0.1621, 0.2837, -0.7979, -0.1221, -1.3057, 0.7681, -2.1191, 0.0464, 1.6309]
)
expected_slice = np.array([-0.7974, -0.4343, -1.087, 0.04785, -1.327, 0.855, -2.148, -0.1725, 1.439])
max_diff = numpy_cosine_similarity_distance(latents_slice.flatten(), expected_slice)
assert max_diff < 1e-3
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
elif step == 2:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array([0.6299, 1.7500, 1.1992, -2.1582, -1.8994, 0.7334, -0.7090, 1.0137, 1.5273])
expected_slice = np.array([0.3232, 0.004883, 0.913, -1.084, 0.6143, -1.6875, -2.463, -0.439, -0.419])
max_diff = numpy_cosine_similarity_distance(latents_slice.flatten(), expected_slice)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
assert max_diff < 1e-3
callback_fn.has_been_called = False
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
"fusing/sd-image-variations-diffusers",
"lambdalabs/sd-image-variations-diffusers",
safety_checker=None,
torch_dtype=torch.float16,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
inputs = self.get_inputs(torch_device, dtype=torch.float16)
generator_device = "cpu"
inputs = self.get_inputs(generator_device, dtype=torch.float16)
pipe(**inputs, callback=callback_fn, callback_steps=1)
assert callback_fn.has_been_called
assert number_of_steps == inputs["num_inference_steps"]
@@ -246,9 +257,8 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
model_id = "fusing/sd-image-variations-diffusers"
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
model_id, safety_checker=None, torch_dtype=torch.float16
"lambdalabs/sd-image-variations-diffusers", safety_checker=None, torch_dtype=torch.float16
)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -109,7 +109,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return inputs
def get_fixed_latents(self, device, seed=0):
if type(device) == str:
if isinstance(device, str):
device = torch.device(device)
generator = torch.Generator(device=device).manual_seed(seed)
# Hardcode the shapes for now.
@@ -545,7 +545,7 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
return inputs
def get_fixed_latents(self, device, seed=0):
if type(device) == str:
if isinstance(device, str):
device = torch.device(device)
latent_device = torch.device("cpu")
generator = torch.Generator(device=latent_device).manual_seed(seed)
@@ -648,7 +648,7 @@ class UniDiffuserPipelineNightlyTests(unittest.TestCase):
return inputs
def get_fixed_latents(self, device, seed=0):
if type(device) == str:
if isinstance(device, str):
device = torch.device(device)
latent_device = torch.device("cpu")
generator = torch.Generator(device=latent_device).manual_seed(seed)