Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 29f15673ed | |||
| 1037287e2b | |||
| 6ea95b7a90 | |||
| 0e0db625d0 | |||
| 1f948109b8 | |||
| 37cb819df5 | |||
| f64d52dbca | |||
| 4d897aaff5 | |||
| b1105269b7 | |||
| 5d28d2217f | |||
| 73bf620dec |
@@ -41,7 +41,7 @@ Core library:
|
||||
- Schedulers: @williamberman and @patrickvonplaten
|
||||
- Pipelines: @patrickvonplaten and @sayakpaul
|
||||
- Training examples: @sayakpaul and @patrickvonplaten
|
||||
- Docs: @stevenliu and @yiyixu
|
||||
- Docs: @stevhliu and @yiyixuxu
|
||||
- JAX and MPS: @pcuenca
|
||||
- Audio: @sanchit-gandhi
|
||||
- General functionalities: @patrickvonplaten and @sayakpaul
|
||||
|
||||
@@ -1138,7 +1138,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 7.1 Apply denoising_end
|
||||
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
|
||||
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
|
||||
@@ -701,7 +701,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 10.1 Apply denoising_end
|
||||
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
|
||||
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1060,7 +1060,9 @@ def main(args):
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
t2iadapter, optimizer, lr_scheduler = accelerator.prepare(t2iadapter, optimizer, lr_scheduler)
|
||||
t2iadapter, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
t2iadapter, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
||||
@@ -53,7 +53,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -244,7 +244,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.21.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.21.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.21.0.dev0"
|
||||
__version__ = "0.21.0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
|
||||
return x_in * self.stds[key] + self.means[key]
|
||||
|
||||
def to_torch(self, x_in):
|
||||
if type(x_in) is dict:
|
||||
if isinstance(x_in, dict):
|
||||
return {k: self.to_torch(v) for k, v in x_in.items()}
|
||||
elif torch.is_tensor(x_in):
|
||||
return x_in.to(self.unet.device)
|
||||
|
||||
+116
-87
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
@@ -27,6 +26,7 @@ import torch
|
||||
from huggingface_hub import hf_hub_download, model_info
|
||||
from torch import nn
|
||||
|
||||
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HF_HUB_OFFLINE,
|
||||
@@ -46,7 +46,6 @@ if is_transformers_available():
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -137,7 +136,6 @@ class PatchedLoraProjection(nn.Module):
|
||||
self.w_down = None
|
||||
|
||||
def forward(self, input):
|
||||
# print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}")
|
||||
if self.lora_scale is None:
|
||||
self.lora_scale = 1.0
|
||||
if self.lora_linear_layer is None:
|
||||
@@ -274,6 +272,11 @@ class UNet2DConditionLoadersMixin:
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
@@ -300,6 +303,7 @@ class UNet2DConditionLoadersMixin:
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
@@ -316,6 +320,15 @@ class UNet2DConditionLoadersMixin:
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
# Let's first try to load .safetensors weights
|
||||
@@ -370,6 +383,10 @@ class UNet2DConditionLoadersMixin:
|
||||
# correct keys
|
||||
state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
|
||||
|
||||
if network_alphas is not None:
|
||||
network_alphas_keys = list(network_alphas.keys())
|
||||
used_network_alphas_keys = set()
|
||||
|
||||
lora_grouped_dict = defaultdict(dict)
|
||||
mapped_network_alphas = {}
|
||||
|
||||
@@ -381,13 +398,13 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
|
||||
if network_alphas is not None:
|
||||
network_alphas_ = copy.deepcopy(network_alphas)
|
||||
for k in network_alphas_:
|
||||
for k in network_alphas_keys:
|
||||
if k.replace(".alpha", "") in key:
|
||||
mapped_network_alphas.update({attn_processor_key: network_alphas.pop(k)})
|
||||
mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
|
||||
used_network_alphas_keys.add(k)
|
||||
|
||||
if not is_network_alphas_none:
|
||||
if len(network_alphas) > 0:
|
||||
if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
|
||||
raise ValueError(
|
||||
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
|
||||
)
|
||||
@@ -411,29 +428,38 @@ class UNet2DConditionLoadersMixin:
|
||||
out_features = attn_processor.out_channels
|
||||
kernel_size = attn_processor.kernel_size
|
||||
|
||||
lora = LoRAConv2dLayer(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
rank=rank,
|
||||
kernel_size=kernel_size,
|
||||
stride=attn_processor.stride,
|
||||
padding=attn_processor.padding,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
)
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
lora = LoRAConv2dLayer(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
rank=rank,
|
||||
kernel_size=kernel_size,
|
||||
stride=attn_processor.stride,
|
||||
padding=attn_processor.padding,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
)
|
||||
elif isinstance(attn_processor, LoRACompatibleLinear):
|
||||
lora = LoRALinearLayer(
|
||||
attn_processor.in_features,
|
||||
attn_processor.out_features,
|
||||
rank,
|
||||
mapped_network_alphas.get(key),
|
||||
)
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
lora = LoRALinearLayer(
|
||||
attn_processor.in_features,
|
||||
attn_processor.out_features,
|
||||
rank,
|
||||
mapped_network_alphas.get(key),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
|
||||
|
||||
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
|
||||
lora.load_state_dict(value_dict)
|
||||
lora_layers_list.append((attn_processor, lora))
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(value_dict.values())).device
|
||||
dtype = next(iter(value_dict.values())).dtype
|
||||
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
|
||||
else:
|
||||
lora.load_state_dict(value_dict)
|
||||
elif is_custom_diffusion:
|
||||
attn_processors = {}
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
@@ -470,13 +496,12 @@ class UNet2DConditionLoadersMixin:
|
||||
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
||||
)
|
||||
|
||||
# set correct dtype & device
|
||||
lora_layers_list = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in lora_layers_list]
|
||||
|
||||
# set lora layers
|
||||
for target_module, lora_layer in lora_layers_list:
|
||||
target_module.set_lora_layer(lora_layer)
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
|
||||
is_new_lora_format = all(
|
||||
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
||||
@@ -999,13 +1024,18 @@ class LoraLoaderMixin:
|
||||
recurive = is_sequential_cpu_offload
|
||||
remove_hook_from_module(component, recurse=recurive)
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
|
||||
self.load_lora_into_unet(
|
||||
state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
lora_scale=self.lora_scale,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
# Offload back.
|
||||
@@ -1065,6 +1095,11 @@ class LoraLoaderMixin:
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
@@ -1305,7 +1340,7 @@ class LoraLoaderMixin:
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet):
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
@@ -1318,7 +1353,13 @@ class LoraLoaderMixin:
|
||||
See `LoRALinearLayer` for more details.
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
@@ -1343,11 +1384,12 @@ class LoraLoaderMixin:
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
warnings.warn(warn_message)
|
||||
|
||||
# load loras into unet
|
||||
unet.load_attn_procs(state_dict, network_alphas=network_alphas)
|
||||
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0):
|
||||
def load_lora_into_text_encoder(
|
||||
cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
|
||||
@@ -1364,7 +1406,13 @@ class LoraLoaderMixin:
|
||||
lora_scale (`float`):
|
||||
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
||||
lora layer.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
@@ -1447,6 +1495,7 @@ class LoraLoaderMixin:
|
||||
network_alphas,
|
||||
rank=rank,
|
||||
patch_mlp=patch_mlp,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
# set correct dtype & device
|
||||
@@ -1454,12 +1503,23 @@ class LoraLoaderMixin:
|
||||
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
for k, v in text_encoder_lora_state_dict.items()
|
||||
}
|
||||
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
|
||||
if len(load_state_dict_results.unexpected_keys) != 0:
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(text_encoder_lora_state_dict.values())).device
|
||||
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
|
||||
unexpected_keys = load_state_dict_results.unexpected_keys
|
||||
|
||||
if len(unexpected_keys) != 0:
|
||||
raise ValueError(
|
||||
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
||||
)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
@property
|
||||
def lora_scale(self) -> float:
|
||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||
@@ -1492,11 +1552,21 @@ class LoraLoaderMixin:
|
||||
rank: Union[Dict[str, int], int] = 4,
|
||||
dtype=None,
|
||||
patch_mlp=False,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
r"""
|
||||
Monkey-patches the forward passes of attention modules of the text encoder.
|
||||
"""
|
||||
|
||||
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
|
||||
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype)
|
||||
|
||||
lora_parameters.extend(model.lora_linear_layer.parameters())
|
||||
return model
|
||||
|
||||
# First, remove any monkey-patch that might have been applied before
|
||||
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
|
||||
|
||||
@@ -1515,45 +1585,18 @@ class LoraLoaderMixin:
|
||||
else:
|
||||
current_rank = rank
|
||||
|
||||
q_linear_layer = (
|
||||
attn_module.q_proj.regular_linear_layer
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection)
|
||||
else attn_module.q_proj
|
||||
attn_module.q_proj = create_patched_linear_lora(
|
||||
attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
attn_module.q_proj = PatchedLoraProjection(
|
||||
q_linear_layer, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
|
||||
attn_module.k_proj = create_patched_linear_lora(
|
||||
attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
|
||||
|
||||
k_linear_layer = (
|
||||
attn_module.k_proj.regular_linear_layer
|
||||
if isinstance(attn_module.k_proj, PatchedLoraProjection)
|
||||
else attn_module.k_proj
|
||||
attn_module.v_proj = create_patched_linear_lora(
|
||||
attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
attn_module.k_proj = PatchedLoraProjection(
|
||||
k_linear_layer, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
|
||||
attn_module.out_proj = create_patched_linear_lora(
|
||||
attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
|
||||
|
||||
v_linear_layer = (
|
||||
attn_module.v_proj.regular_linear_layer
|
||||
if isinstance(attn_module.v_proj, PatchedLoraProjection)
|
||||
else attn_module.v_proj
|
||||
)
|
||||
attn_module.v_proj = PatchedLoraProjection(
|
||||
v_linear_layer, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
|
||||
|
||||
out_linear_layer = (
|
||||
attn_module.out_proj.regular_linear_layer
|
||||
if isinstance(attn_module.out_proj, PatchedLoraProjection)
|
||||
else attn_module.out_proj
|
||||
)
|
||||
attn_module.out_proj = PatchedLoraProjection(
|
||||
out_linear_layer, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
|
||||
|
||||
if patch_mlp:
|
||||
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
@@ -1563,25 +1606,12 @@ class LoraLoaderMixin:
|
||||
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
|
||||
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
|
||||
|
||||
fc1_linear_layer = (
|
||||
mlp_module.fc1.regular_linear_layer
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection)
|
||||
else mlp_module.fc1
|
||||
mlp_module.fc1 = create_patched_linear_lora(
|
||||
mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters
|
||||
)
|
||||
mlp_module.fc1 = PatchedLoraProjection(
|
||||
fc1_linear_layer, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
|
||||
mlp_module.fc2 = create_patched_linear_lora(
|
||||
mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters
|
||||
)
|
||||
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
|
||||
|
||||
fc2_linear_layer = (
|
||||
mlp_module.fc2.regular_linear_layer
|
||||
if isinstance(mlp_module.fc2, PatchedLoraProjection)
|
||||
else mlp_module.fc2
|
||||
)
|
||||
mlp_module.fc2 = PatchedLoraProjection(
|
||||
fc2_linear_layer, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
|
||||
|
||||
if is_network_alphas_populated and len(network_alphas) > 0:
|
||||
raise ValueError(
|
||||
@@ -2375,8 +2405,7 @@ class FromOriginalVAEMixin:
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in converted_vae_checkpoint.items():
|
||||
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
|
||||
load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu")
|
||||
else:
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
|
||||
@@ -128,6 +128,31 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
)
|
||||
|
||||
|
||||
def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
unexpected_keys = []
|
||||
empty_state_dict = model.state_dict()
|
||||
for param_name, param in state_dict.items():
|
||||
if param_name not in empty_state_dict:
|
||||
unexpected_keys.append(param_name)
|
||||
continue
|
||||
|
||||
if empty_state_dict[param_name].shape != param.shape:
|
||||
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
||||
raise ValueError(
|
||||
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
)
|
||||
|
||||
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param)
|
||||
return unexpected_keys
|
||||
|
||||
|
||||
def _load_state_dict_into_model(model_to_load, state_dict):
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
@@ -624,29 +649,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
||||
" those weights or else make sure your checkpoint file is correct."
|
||||
)
|
||||
unexpected_keys = []
|
||||
|
||||
empty_state_dict = model.state_dict()
|
||||
for param_name, param in state_dict.items():
|
||||
accepts_dtype = "dtype" in set(
|
||||
inspect.signature(set_module_tensor_to_device).parameters.keys()
|
||||
)
|
||||
|
||||
if param_name not in empty_state_dict:
|
||||
unexpected_keys.append(param_name)
|
||||
continue
|
||||
|
||||
if empty_state_dict[param_name].shape != param.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
)
|
||||
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(
|
||||
model, param_name, param_device, value=param, dtype=torch_dtype
|
||||
)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict,
|
||||
device=param_device,
|
||||
dtype=torch_dtype,
|
||||
model_name_or_path=pretrained_model_name_or_path,
|
||||
)
|
||||
|
||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
|
||||
@@ -100,6 +100,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -127,6 +127,7 @@ class AltDiffusionImg2ImgPipeline(
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -178,7 +178,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
||||
self.scheduler.set_timesteps(steps)
|
||||
step_generator = step_generator or generator
|
||||
# For backwards compatibility
|
||||
if type(self.unet.config.sample_size) == int:
|
||||
if isinstance(self.unet.config.sample_size, int):
|
||||
self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
|
||||
if noise is None:
|
||||
noise = randn_tensor(
|
||||
|
||||
@@ -125,6 +125,7 @@ class StableDiffusionControlNetPipeline(
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -149,6 +149,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -273,6 +273,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1293,10 +1293,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id: int = 0, device: Union[torch.device, str] = "cuda"):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state
|
||||
dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU
|
||||
and then moved to `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward`
|
||||
method called. Offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
|
||||
|
||||
@@ -101,6 +101,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
+1
@@ -191,6 +191,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -272,6 +272,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -124,6 +124,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
+1
@@ -182,6 +182,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -66,6 +66,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
# we should give a descriptive message if the pipeline doesn't have one.
|
||||
_optional_components = ["safety_checker"]
|
||||
model_cpu_offload_seq = "image_encoder->unet->vae"
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -129,6 +129,7 @@ class StableDiffusionImg2ImgPipeline(
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -194,6 +194,7 @@ class StableDiffusionInpaintPipeline(
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -117,6 +117,7 @@ class StableDiffusionInpaintPipelineLegacy(
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -91,6 +91,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -82,6 +82,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -114,6 +114,7 @@ class StableDiffusionLDM3DPipeline(
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -68,6 +68,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -80,6 +80,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -95,6 +95,7 @@ class StableDiffusionParadigmsPipeline(
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -315,6 +315,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
"caption_processor",
|
||||
"inverse_scheduler",
|
||||
]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -119,6 +119,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -92,6 +92,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
"""
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["watermarker", "safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -810,7 +810,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 7.1 Apply denoising_end
|
||||
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
|
||||
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
|
||||
@@ -885,7 +885,7 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
|
||||
# 5. Prepare timesteps
|
||||
def denoising_value_valid(dnv):
|
||||
return type(denoising_end) == float and 0 < dnv < 1
|
||||
return isinstance(denoising_end, float) and 0 < dnv < 1
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(
|
||||
|
||||
@@ -1120,7 +1120,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
|
||||
# 4. set timesteps
|
||||
def denoising_value_valid(dnv):
|
||||
return type(denoising_end) == float and 0 < dnv < 1
|
||||
return isinstance(denoising_end, float) and 0 < dnv < 1
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(
|
||||
|
||||
+1
-1
@@ -837,7 +837,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
|
||||
# 11. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
|
||||
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
|
||||
@@ -886,7 +886,7 @@ class StableDiffusionXLAdapterPipeline(
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 7.1 Apply denoising_end
|
||||
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
|
||||
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
|
||||
@@ -330,7 +330,11 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
||||
|
||||
# 2. Encode caption
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
prompt,
|
||||
device,
|
||||
image_embeddings.size(0) * num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
)
|
||||
text_encoder_hidden_states = (
|
||||
torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds
|
||||
|
||||
@@ -154,6 +154,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
||||
decoder_timesteps: Optional[List[float]] = None,
|
||||
decoder_guidance_scale: float = 0.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
@@ -165,10 +167,17 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
The prompt or prompts to guide the image generation for the prior and decoder.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
|
||||
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
@@ -221,13 +230,15 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
||||
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
prior_outputs = self.prior_pipe(
|
||||
prompt=prompt,
|
||||
prompt=prompt if prompt_embeds is None else None,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=prior_num_inference_steps,
|
||||
timesteps=prior_timesteps,
|
||||
guidance_scale=prior_guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
|
||||
@@ -150,41 +150,57 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
prompt=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
if prompt_embeds is None:
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))
|
||||
text_encoder_hidden_states = text_encoder_output.last_hidden_state
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
uncond_text_encoder_hidden_states = None
|
||||
if do_classifier_free_guidance:
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
|
||||
|
||||
text_encoder_output = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=attention_mask.to(device)
|
||||
)
|
||||
prompt_embeds = text_encoder_output.last_hidden_state
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if negative_prompt_embeds is None and do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
@@ -215,17 +231,17 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
|
||||
uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device)
|
||||
)
|
||||
|
||||
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
|
||||
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.last_hidden_state
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_text_encoder_hidden_states.shape[1]
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
# done duplicates
|
||||
|
||||
return text_encoder_hidden_states, uncond_text_encoder_hidden_states
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
@@ -264,13 +280,15 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
num_inference_steps: int = 60,
|
||||
timesteps: List[float] = None,
|
||||
guidance_scale: float = 8.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
@@ -304,6 +322,13 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `decoder_guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -345,7 +370,13 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
|
||||
|
||||
# 2. Encode caption
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
|
||||
@@ -785,8 +785,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.to(torch_device)
|
||||
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
new_sample = new_model(**inputs_dict).sample
|
||||
|
||||
@@ -193,7 +193,7 @@ class ConsistencyModelPipelineSlowTests(unittest.TestCase):
|
||||
return inputs
|
||||
|
||||
def get_fixed_latents(self, seed=0, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)):
|
||||
if type(device) == str:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
@@ -923,9 +923,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing(1)
|
||||
|
||||
@@ -35,6 +35,8 @@ from diffusers.utils.testing_utils import (
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
print_tensor_test,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -182,7 +184,7 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
@@ -193,13 +195,17 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_inputs(generator_device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.84491, 0.90789, 0.75708, 0.78734, 0.83485, 0.70099, 0.66938, 0.68727, 0.61379])
|
||||
assert np.abs(image_slice - expected_slice).max() < 6e-3
|
||||
expected_slice = np.array([0.8449, 0.9079, 0.7571, 0.7873, 0.8348, 0.7010, 0.6694, 0.6873, 0.6138])
|
||||
print_tensor_test(image_slice)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_stable_diffusion_img_variation_intermediate_state(self):
|
||||
number_of_steps = 0
|
||||
@@ -212,31 +218,36 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[-0.1621, 0.2837, -0.7979, -0.1221, -1.3057, 0.7681, -2.1191, 0.0464, 1.6309]
|
||||
)
|
||||
expected_slice = np.array([-0.7974, -0.4343, -1.087, 0.04785, -1.327, 0.855, -2.148, -0.1725, 1.439])
|
||||
max_diff = numpy_cosine_similarity_distance(latents_slice.flatten(), expected_slice)
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||
elif step == 2:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.6299, 1.7500, 1.1992, -2.1582, -1.8994, 0.7334, -0.7090, 1.0137, 1.5273])
|
||||
expected_slice = np.array([0.3232, 0.004883, 0.913, -1.084, 0.6143, -1.6875, -2.463, -0.439, -0.419])
|
||||
max_diff = numpy_cosine_similarity_distance(latents_slice.flatten(), expected_slice)
|
||||
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||
assert max_diff < 1e-3
|
||||
|
||||
callback_fn.has_been_called = False
|
||||
|
||||
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
|
||||
"fusing/sd-image-variations-diffusers",
|
||||
"lambdalabs/sd-image-variations-diffusers",
|
||||
safety_checker=None,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
|
||||
inputs = self.get_inputs(torch_device, dtype=torch.float16)
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_inputs(generator_device, dtype=torch.float16)
|
||||
pipe(**inputs, callback=callback_fn, callback_steps=1)
|
||||
assert callback_fn.has_been_called
|
||||
assert number_of_steps == inputs["num_inference_steps"]
|
||||
@@ -246,9 +257,8 @@ class StableDiffusionImageVariationPipelineSlowTests(unittest.TestCase):
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
model_id = "fusing/sd-image-variations-diffusers"
|
||||
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
|
||||
model_id, safety_checker=None, torch_dtype=torch.float16
|
||||
"lambdalabs/sd-image-variations-diffusers", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -109,7 +109,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
return inputs
|
||||
|
||||
def get_fixed_latents(self, device, seed=0):
|
||||
if type(device) == str:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
# Hardcode the shapes for now.
|
||||
@@ -545,7 +545,7 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
|
||||
return inputs
|
||||
|
||||
def get_fixed_latents(self, device, seed=0):
|
||||
if type(device) == str:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
latent_device = torch.device("cpu")
|
||||
generator = torch.Generator(device=latent_device).manual_seed(seed)
|
||||
@@ -648,7 +648,7 @@ class UniDiffuserPipelineNightlyTests(unittest.TestCase):
|
||||
return inputs
|
||||
|
||||
def get_fixed_latents(self, device, seed=0):
|
||||
if type(device) == str:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
latent_device = torch.device("cpu")
|
||||
generator = torch.Generator(device=latent_device).manual_seed(seed)
|
||||
|
||||
Reference in New Issue
Block a user