Compare commits

..

1 Commits

Author SHA1 Message Date
Patrick von Platen a99b0eb517 Update pr_test_peft_backend.yml to use 1 process for testing 2024-01-17 15:23:53 +02:00
24 changed files with 428 additions and 1891 deletions
@@ -33,9 +33,6 @@ model = AutoencoderKL.from_single_file(url)
## AutoencoderKL
[[autodoc]] AutoencoderKL
- decode
- encode
- all
## AutoencoderKLOutput
@@ -1279,7 +1279,7 @@ def main(args):
for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32)
param = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_one.append(param)
else:
@@ -1288,7 +1288,7 @@ def main(args):
for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32)
param = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_two.append(param)
else:
@@ -1725,19 +1725,19 @@ def main(args):
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)
# flag used for textual inversion
pivoted = False
for epoch in range(first_epoch, args.num_train_epochs):
# if performing any kind of optimization of text_encoder params
if args.train_text_encoder or args.train_text_encoder_ti:
if epoch == num_train_epochs_text_encoder:
print("PIVOT HALFWAY", epoch)
# stopping optimization of text_encoder params
# this flag is used to reset the optimizer to optimize only on unet params
pivoted = True
# re setting the optimizer to optimize only on unet params
optimizer.param_groups[1]["lr"] = 0.0
optimizer.param_groups[2]["lr"] = 0.0
else:
# still optimizing the text encoder
# still optimizng the text encoder
text_encoder_one.train()
text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works
@@ -1747,12 +1747,6 @@ def main(args):
unet.train()
for step, batch in enumerate(train_dataloader):
if pivoted:
# stopping optimization of text_encoder params
# re setting the optimizer to optimize only on unet params
optimizer.param_groups[1]["lr"] = 0.0
optimizer.param_groups[2]["lr"] = 0.0
with accelerator.accumulate(unet):
prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image -
@@ -1891,7 +1885,8 @@ def main(args):
# every step, we reset the embeddings to the original embeddings.
if args.train_text_encoder_ti:
embedding_handler.retract_embeddings()
for idx, text_encoder in enumerate(text_encoders):
embedding_handler.retract_embeddings()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
+1 -28
View File
@@ -58,8 +58,6 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
| Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) |
| Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#Rerender_A_Video) | - | [Yifan Zhou](https://github.com/SingleZombie) |
| StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
| AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -2992,7 +2990,7 @@ pipe = DiffusionPipeline.from_pretrained(
custom_pipeline="pipeline_animatediff_controlnet",
).to(device="cuda", dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", beta_schedule="linear", clip_sample=False, timestep_spacing="linspace", steps_offset=1
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
)
pipe.enable_vae_slicing()
@@ -3411,32 +3409,7 @@ images = pipe(
pipe.disable_style_aligned()
```
### AnimateDiff Image-To-Video Pipeline
This pipeline adds experimental support for the image-to-video task using AnimateDiff. Refer to [this](https://github.com/huggingface/diffusers/pull/6328) PR for more examples and results.
```py
import torch
from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
from diffusers.utils import export_to_gif, load_image
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
image = load_image("snail.png")
output = pipe(
image=image,
prompt="A snail moving on the ground",
strength=0.8,
latent_interpolation_method="slerp", # can be lerp, slerp, or your own callback
)
frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```
### IP Adapter Face ID
IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.
You need to install `insightface` and all its requirements to use this model.
You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`.
@@ -1,989 +0,0 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from types import FunctionType
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models.unet_motion_model import MotionAdapter
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
>>> from diffusers.utils import export_to_gif, load_image
>>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
>>> pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
>>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
>>> image = load_image("snail.png")
>>> output = pipe(image=image, prompt="A snail moving on the ground", strength=0.8, latent_interpolation_method="slerp")
>>> frames = output.frames[0]
>>> export_to_gif(frames, "animation.gif")
```
"""
def lerp(
v0: torch.Tensor,
v1: torch.Tensor,
t: Union[float, torch.Tensor],
) -> torch.Tensor:
r"""
Linear Interpolation between two tensors.
Args:
v0 (`torch.Tensor`): First tensor.
v1 (`torch.Tensor`): Second tensor.
t: (`float` or `torch.Tensor`): Interpolation factor.
"""
t_is_float = False
input_device = v0.device
v0 = v0.cpu().numpy()
v1 = v1.cpu().numpy()
if isinstance(t, torch.Tensor):
t = t.cpu().numpy()
else:
t_is_float = True
t = np.array([t], dtype=v0.dtype)
t = t[..., None]
v0 = v0[None, ...]
v1 = v1[None, ...]
v2 = (1 - t) * v0 + t * v1
if t_is_float and v0.ndim > 1:
assert v2.shape[0] == 1
v2 = np.squeeze(v2, axis=0)
v2 = torch.from_numpy(v2).to(input_device)
return v2
def slerp(
v0: torch.Tensor,
v1: torch.Tensor,
t: Union[float, torch.Tensor],
DOT_THRESHOLD: float = 0.9995,
) -> torch.Tensor:
r"""
Spherical Linear Interpolation between two tensors.
Args:
v0 (`torch.Tensor`): First tensor.
v1 (`torch.Tensor`): Second tensor.
t: (`float` or `torch.Tensor`): Interpolation factor.
DOT_THRESHOLD (`float`):
Dot product threshold exceeding which linear interpolation will be used
because input tensors are close to parallel.
"""
t_is_float = False
input_device = v0.device
v0 = v0.cpu().numpy()
v1 = v1.cpu().numpy()
if isinstance(t, torch.Tensor):
t = t.cpu().numpy()
else:
t_is_float = True
t = np.array([t], dtype=v0.dtype)
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
# v0 and v1 are close to parallel, so use linear interpolation instead
v2 = lerp(v0, v1, t)
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
s0 = s0[..., None]
s1 = s1[..., None]
v0 = v0[None, ...]
v1 = v1[None, ...]
v2 = s0 * v0 + s1 * v1
if t_is_float and v0.ndim > 1:
assert v2.shape[0] == 1
v2 = np.squeeze(v2, axis=0)
v2 = torch.from_numpy(v2).to(input_device)
return v2
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
@dataclass
class AnimateDiffImgToVideoPipelineOutput(BaseOutput):
frames: Union[torch.Tensor, np.ndarray]
class AnimateDiffImgToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-video generation.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer (`CLIPTokenizer`):
A [`~transformers.CLIPTokenizer`] to tokenize text.
unet ([`UNet2DConditionModel`]):
A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
motion_adapter ([`MotionAdapter`]):
A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["feature_extractor", "image_encoder"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
motion_adapter: MotionAdapter,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
feature_extractor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModelWithProjection = None,
):
super().__init__()
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
motion_adapter=motion_adapter,
scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
image = self.vae.decode(latents).sample
video = (
image[None, :]
.reshape(
(
batch_size,
num_frames,
-1,
)
+ image.shape[2:]
)
.permute(0, 2, 1, 3, 4)
)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()
return video
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if not hasattr(self, "unet"):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
latent_interpolation_method=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if latent_interpolation_method is not None:
if latent_interpolation_method not in ["lerp", "slerp"] and not isinstance(
latent_interpolation_method, FunctionType
):
raise ValueError(
"`latent_interpolation_method` must be one of `lerp`, `slerp` or a Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]"
)
def prepare_latents(
self,
image,
strength,
batch_size,
num_channels_latents,
num_frames,
height,
width,
dtype,
device,
generator,
latents=None,
latent_interpolation_method="slerp",
):
shape = (
batch_size,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
image = image.to(device=device, dtype=dtype)
if image.shape[1] == 4:
latents = image
else:
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
image = image.float()
self.vae.to(dtype=torch.float32)
if isinstance(generator, list):
if len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
if self.vae.config.force_upcast:
self.vae.to(dtype)
init_latents = init_latents.to(dtype)
init_latents = self.vae.config.scaling_factor * init_latents
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = latents * self.scheduler.init_noise_sigma
if latent_interpolation_method == "lerp":
def latent_cls(v0, v1, index):
return lerp(v0, v1, index / num_frames * (1 - strength))
elif latent_interpolation_method == "slerp":
def latent_cls(v0, v1, index):
return slerp(v0, v1, index / num_frames * (1 - strength))
else:
latent_cls = latent_interpolation_method
for i in range(num_frames):
latents[:, :, i, :, :] = latent_cls(latents[:, :, i, :, :], init_latents, i)
else:
if shape != latents.shape:
# [B, C, F, H, W]
raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
latents = latents.to(device, dtype=dtype)
return latents
@torch.no_grad()
def __call__(
self,
image: PipelineImageInput,
prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: int = 16,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 7.5,
strength: float = 0.8,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
latent_interpolation_method: Union[str, Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]] = "slerp",
):
r"""
The call function to the pipeline for generation.
Args:
image (`PipelineImageInput`):
The input image to condition the generation on.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated video.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated video.
num_frames (`int`, *optional*, defaults to 16):
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
amounts to 2 seconds of video.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
expense of slower inference.
strength (`float`, *optional*, defaults to 0.8):
Higher strength leads to more differences between original image and generated video.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
`(batch_size, num_channel, num_frames, height, width)`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`AnimateDiffImgToVideoPipelineOutput`] instead
of a plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
latent_interpolation_method (`str` or `Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]]`, *optional*):
Must be one of "lerp", "slerp" or a callable that takes in a random noisy latent, image latent and a frame index
as input and returns an initial latent for sampling.
Examples:
Returns:
[`AnimateDiffImgToVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`AnimateDiffImgToVideoPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
callback_steps=callback_steps,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
latent_interpolation_method=latent_interpolation_method,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
)
if do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Preprocess image
image = self.image_processor.preprocess(image, height=height, width=width)
# 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
image=image,
strength=strength,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents,
num_frames=num_frames,
height=height,
width=width,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=latents,
latent_interpolation_method=latent_interpolation_method,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if output_type == "latent":
return AnimateDiffImgToVideoPipelineOutput(frames=latents)
# 10. Post-processing
video_tensor = self.decode_latents(latents)
if output_type == "pt":
video = video_tensor
else:
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# 11. Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return AnimateDiffImgToVideoPipelineOutput(frames=video)
@@ -55,9 +55,6 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0")
@@ -70,57 +67,6 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
def log_validation(
pipeline,
args,
accelerator,
generator,
global_step,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
val_save_dir = os.path.join(args.output_dir, "validation_images")
if not os.path.exists(val_save_dir):
os.makedirs(val_save_dir)
original_image = (
lambda image_url_or_path: load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path)
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
edited_images = []
# Run inference
for val_img_idx in range(args.num_validation_images):
a_val_img = pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
edited_images.append(a_val_img)
# Save validation images
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
logger_name = "test" if is_final_validation else "validation"
tracker.log({logger_name: wandb_table})
def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
@@ -501,6 +447,11 @@ def main():
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -1160,6 +1111,11 @@ def main():
### BEGIN: Perform validation every `validation_epochs` steps
if global_step % args.validation_steps == 0:
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
@@ -1179,16 +1135,44 @@ def main():
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
log_validation(
pipeline,
args,
accelerator,
generator,
global_step,
is_final_validation=False,
)
# run inference
# Save validation images
val_save_dir = os.path.join(args.output_dir, "validation_images")
if not os.path.exists(val_save_dir):
os.makedirs(val_save_dir)
original_image = (
lambda image_url_or_path: load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path)
with torch.autocast(
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
):
edited_images = []
for val_img_idx in range(args.num_validation_images):
a_val_img = pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
edited_images.append(a_val_img)
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"validation": wandb_table})
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
@@ -1203,6 +1187,7 @@ def main():
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())
@@ -1213,11 +1198,10 @@ def main():
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
vae=vae,
unet=unwrap_model(unet),
unet=unet,
revision=args.revision,
variant=args.variant,
)
pipeline.save_pretrained(args.output_dir)
if args.push_to_hub:
@@ -1228,15 +1212,30 @@ def main():
ignore_patterns=["step_*", "epoch_*"],
)
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
log_validation(
pipeline,
args,
accelerator,
generator,
global_step,
is_final_validation=True,
)
if args.validation_prompt is not None:
edited_images = []
pipeline = pipeline.to(accelerator.device)
with torch.autocast(str(accelerator.device).replace(":0", "")):
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"test": wandb_table})
accelerator.end_training()
-60
View File
@@ -183,66 +183,6 @@ The above command will also run inference as fine-tuning progresses and log the
* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
### Using DeepSpeed
Using DeepSpeed one can reduce the consumption of GPU memory, enabling the training of models on GPUs with smaller memory sizes. DeepSpeed is capable of offloading model parameters to the machine's memory, or it can distribute parameters, gradients, and optimizer states across multiple GPUs. This allows for the training of larger models under the same hardware configuration.
First, you need to use the `accelerate config` command to choose to use DeepSpeed, or manually use the accelerate config file to set up DeepSpeed.
Here is an example of a config file for using DeepSpeed. For more detailed explanations of the configuration, you can refer to this [link](https://huggingface.co/docs/accelerate/usage_guides/deepspeed).
```yaml
compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
You need to save the mentioned configuration as an `accelerate_config.yaml` file. Then, you need to input the path of your `accelerate_config.yaml` file into the `ACCELERATE_CONFIG_FILE` parameter. This way you can use DeepSpeed to train your SDXL model in LoRA. Additionally, you can use DeepSpeed to train other SD models in this way.
```shell
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
export ACCELERATE_CONFIG_FILE="your accelerate_config.yaml"
accelerate launch --config_file $ACCELERATE_CONFIG_FILE train_text_to_image_lora_sdxl.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--pretrained_vae_model_name_or_path=$VAE_NAME \
--dataset_name=$DATASET_NAME --caption_column="text" \
--resolution=1024 \
--train_batch_size=1 \
--num_train_epochs=2 \
--checkpointing_steps=2 \
--learning_rate=1e-04 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--mixed_precision="fp16" \
--max_train_steps=20 \
--validation_epochs=20 \
--seed=1234 \
--output_dir="sd-pokemon-model-lora-sdxl" \
--validation_prompt="cute dragon creature"
```
### Finetuning the text encoder and UNet
The script also allows you to finetune the `text_encoder` along with the `unet`.
@@ -652,13 +652,13 @@ def main(args):
text_encoder_two_lora_layers_to_save = None
for model in models:
if isinstance(unwrap_model(model), type(unwrap_model(unet))):
if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
@@ -666,8 +666,7 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
weights.pop()
StableDiffusionXLPipeline.save_lora_weights(
output_dir,
-75
View File
@@ -249,81 +249,6 @@ def get_down_block(
raise ValueError(f"{down_block_type} does not exist.")
def get_mid_block(
mid_block_type: str,
temb_channels: int,
in_channels: int,
resnet_eps: float,
resnet_act_fn: str,
resnet_groups: int,
output_scale_factor: float = 1.0,
transformer_layers_per_block: int = 1,
num_attention_heads: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
mid_block_only_cross_attention: bool = False,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
attention_type: str = "default",
resnet_skip_time_act: bool = False,
cross_attention_norm: Optional[str] = None,
attention_head_dim: Optional[int] = 1,
dropout: float = 0.0,
):
if mid_block_type == "UNetMidBlock2DCrossAttn":
return UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
temb_channels=temb_channels,
dropout=dropout,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
output_scale_factor=output_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
resnet_groups=resnet_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
return UNetMidBlock2DSimpleCrossAttn(
in_channels=in_channels,
temb_channels=temb_channels,
dropout=dropout,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
output_scale_factor=output_scale_factor,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif mid_block_type == "UNetMidBlock2D":
return UNetMidBlock2D(
in_channels=in_channels,
temb_channels=temb_channels,
dropout=dropout,
num_layers=0,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
output_scale_factor=output_scale_factor,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
add_attention=False,
)
elif mid_block_type is None:
return None
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
def get_up_block(
up_block_type: str,
num_layers: int,
+308 -402
View File
@@ -44,8 +44,10 @@ from .embeddings import (
)
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
UNetMidBlock2D,
UNetMidBlock2DCrossAttn,
UNetMidBlock2DSimpleCrossAttn,
get_down_block,
get_mid_block,
get_up_block,
)
@@ -237,18 +239,44 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
self._check_config(
down_block_types=down_block_types,
up_block_types=up_block_types,
only_cross_attention=only_cross_attention,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
cross_attention_dim=cross_attention_dim,
transformer_layers_per_block=transformer_layers_per_block,
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
)
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
# input
conv_in_padding = (conv_in_kernel - 1) // 2
@@ -257,13 +285,23 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
)
# time
time_embed_dim, timestep_input_dim = self._set_time_proj(
time_embedding_type,
block_out_channels=block_out_channels,
flip_sin_to_cos=flip_sin_to_cos,
freq_shift=freq_shift,
time_embedding_dim=time_embedding_dim,
)
if time_embedding_type == "fourier":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
if time_embed_dim % 2 != 0:
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
self.time_proj = GaussianFourierProjection(
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
)
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
@@ -273,33 +311,96 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
cond_proj_dim=time_cond_proj_dim,
)
self._set_encoder_hid_proj(
encoder_hid_dim_type,
cross_attention_dim=cross_attention_dim,
encoder_hid_dim=encoder_hid_dim,
)
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type == "image_proj":
# Kandinsky 2.2
self.encoder_hid_proj = ImageProjection(
image_embed_dim=encoder_hid_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
# class embedding
self._set_class_embedding(
class_embed_type,
act_fn=act_fn,
num_class_embeds=num_class_embeds,
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
time_embed_dim=time_embed_dim,
timestep_input_dim=timestep_input_dim,
)
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif class_embed_type == "simple_projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
)
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
self._set_add_embedding(
addition_embed_type,
addition_embed_type_num_heads=addition_embed_type_num_heads,
addition_time_embed_dim=addition_time_embed_dim,
cross_attention_dim=cross_attention_dim,
encoder_hid_dim=encoder_hid_dim,
flip_sin_to_cos=flip_sin_to_cos,
freq_shift=freq_shift,
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
time_embed_dim=time_embed_dim,
)
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type == "image":
# Kandinsky 2.2
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
elif addition_embed_type == "image_hint":
# Kandinsky 2.2 ControlNet
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
if time_embedding_act_fn is None:
self.time_embed_act = None
@@ -309,7 +410,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
# set or unroll configs
if isinstance(only_cross_attention, bool):
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = only_cross_attention
@@ -378,28 +478,57 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
self.down_blocks.append(down_block)
# mid
self.mid_block = get_mid_block(
mid_block_type,
temb_channels=blocks_time_embed_dim,
in_channels=block_out_channels[-1],
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
output_scale_factor=mid_block_scale_factor,
transformer_layers_per_block=transformer_layers_per_block[-1],
num_attention_heads=num_attention_heads[-1],
cross_attention_dim=cross_attention_dim[-1],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
mid_block_only_cross_attention=mid_block_only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
resnet_skip_time_act=resnet_skip_time_act,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[-1],
dropout=dropout,
)
if mid_block_type == "UNetMidBlock2DCrossAttn":
self.mid_block = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim[-1],
attention_head_dim=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif mid_block_type == "UNetMidBlock2D":
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
dropout=dropout,
num_layers=0,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
add_attention=False,
)
elif mid_block_type is None:
self.mid_block = None
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
# count how many layers upsample the images
self.num_upsamplers = 0
@@ -466,7 +595,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = get_activation(act_fn)
else:
self.conv_norm_out = None
self.conv_act = None
@@ -476,206 +607,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
def _check_config(
self,
down_block_types: Tuple[str],
up_block_types: Tuple[str],
only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int],
layers_per_block: [int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]],
reverse_transformer_layers_per_block: bool,
attention_head_dim: int,
num_attention_heads: Optional[Union[int, Tuple[int]]],
):
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
def _set_time_proj(
self,
time_embedding_type: str,
block_out_channels: int,
flip_sin_to_cos: bool,
freq_shift: float,
time_embedding_dim: int,
) -> Tuple[int, int]:
if time_embedding_type == "fourier":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
if time_embed_dim % 2 != 0:
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
self.time_proj = GaussianFourierProjection(
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
)
return time_embed_dim, timestep_input_dim
def _set_encoder_hid_proj(
self,
encoder_hid_dim_type: Optional[str],
cross_attention_dim: Union[int, Tuple[int]],
encoder_hid_dim: Optional[int],
):
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type == "image_proj":
# Kandinsky 2.2
self.encoder_hid_proj = ImageProjection(
image_embed_dim=encoder_hid_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
def _set_class_embedding(
self,
class_embed_type: Optional[str],
act_fn: str,
num_class_embeds: Optional[int],
projection_class_embeddings_input_dim: Optional[int],
time_embed_dim: int,
timestep_input_dim: int,
):
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif class_embed_type == "simple_projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
)
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
def _set_add_embedding(
self,
addition_embed_type: str,
addition_embed_type_num_heads: int,
addition_time_embed_dim: Optional[int],
flip_sin_to_cos: bool,
freq_shift: float,
cross_attention_dim: Optional[int],
encoder_hid_dim: Optional[int],
projection_class_embeddings_input_dim: Optional[int],
time_embed_dim: int,
):
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type == "image":
# Kandinsky 2.2
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
elif addition_embed_type == "image_hint":
# Kandinsky 2.2 ControlNet
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
if attention_type in ["gated", "gated-text-image"]:
positive_len = 768
if isinstance(cross_attention_dim, int):
@@ -909,130 +840,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
def get_time_embed(
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
) -> Optional[torch.Tensor]:
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
return t_emb
def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
class_emb = None
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype)
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
return class_emb
def get_aug_embed(
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict
) -> Optional[torch.Tensor]:
aug_emb = None
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_image":
# Kandinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
aug_emb = self.add_embedding(text_embs, image_embs)
elif self.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
elif self.config.addition_embed_type == "image":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
aug_emb = self.add_embedding(image_embs)
elif self.config.addition_embed_type == "image_hint":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
hint = added_cond_kwargs.get("hint")
aug_emb = self.add_embedding(image_embs, hint)
return aug_emb
def process_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor, added_cond_kwargs) -> torch.Tensor:
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
# Kadinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
return encoder_hidden_states
def forward(
self,
sample: torch.FloatTensor,
@@ -1145,31 +952,130 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
sample = 2 * sample - 1.0
# 1. time
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
if class_emb is not None:
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype)
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
else:
emb = emb + class_emb
aug_emb = self.get_aug_embed(
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
)
if self.config.addition_embed_type == "image_hint":
aug_emb, hint = aug_emb
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_image":
# Kandinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
aug_emb = self.add_embedding(text_embs, image_embs)
elif self.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
elif self.config.addition_embed_type == "image":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
aug_emb = self.add_embedding(image_embs)
elif self.config.addition_embed_type == "image_hint":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
hint = added_cond_kwargs.get("hint")
aug_emb, hint = self.add_embedding(image_embs, hint)
sample = torch.cat([sample, hint], dim=1)
emb = emb + aug_emb if aug_emb is not None else emb
if self.time_embed_act is not None:
emb = self.time_embed_act(emb)
encoder_hidden_states = self.process_encoder_hidden_states(
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
)
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
# Kadinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
# 2. pre-process
sample = self.conv_in(sample)
@@ -683,11 +683,9 @@ class StableDiffusionControlNetInpaintPipeline(
self,
prompt,
image,
mask_image,
height,
width,
callback_steps,
output_type,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
@@ -695,7 +693,6 @@ class StableDiffusionControlNetInpaintPipeline(
control_guidance_start=0.0,
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
padding_mask_crop=None,
):
if height is not None and height % 8 != 0 or width is not None and width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -739,19 +736,6 @@ class StableDiffusionControlNetInpaintPipeline(
f" {negative_prompt_embeds.shape}."
)
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
f" {type(mask_image)}."
)
if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
if isinstance(self.controlnet, MultiControlNetModel):
@@ -878,6 +862,7 @@ class StableDiffusionControlNetInpaintPipeline(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
def prepare_control_image(
self,
image,
@@ -887,14 +872,10 @@ class StableDiffusionControlNetInpaintPipeline(
num_images_per_prompt,
device,
dtype,
crops_coords,
resize_mode,
do_classifier_free_guidance=False,
guess_mode=False,
):
image = self.control_image_processor.preprocess(
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
).to(dtype=torch.float32)
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
@@ -1093,7 +1074,6 @@ class StableDiffusionControlNetInpaintPipeline(
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
padding_mask_crop: Optional[int] = None,
strength: float = 1.0,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
@@ -1150,12 +1130,6 @@ class StableDiffusionControlNetInpaintPipeline(
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
padding_mask_crop (`int`, *optional*, defaults to `None`):
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
and contain information inreleant for inpainging, such as background.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
@@ -1266,11 +1240,9 @@ class StableDiffusionControlNetInpaintPipeline(
self.check_inputs(
prompt,
control_image,
mask_image,
height,
width,
callback_steps,
output_type,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
@@ -1278,7 +1250,6 @@ class StableDiffusionControlNetInpaintPipeline(
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
padding_mask_crop,
)
self._guidance_scale = guidance_scale
@@ -1293,14 +1264,6 @@ class StableDiffusionControlNetInpaintPipeline(
else:
batch_size = prompt_embeds.shape[0]
if padding_mask_crop is not None:
height, width = self.image_processor.get_default_height_width(image, height, width)
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
resize_mode = "fill"
else:
crops_coords = None
resize_mode = "default"
device = self._execution_device
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
@@ -1352,8 +1315,6 @@ class StableDiffusionControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1369,8 +1330,6 @@ class StableDiffusionControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1382,15 +1341,10 @@ class StableDiffusionControlNetInpaintPipeline(
assert False
# 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
original_image = image
init_image = self.image_processor.preprocess(
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
)
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)
mask = self.mask_processor.preprocess(
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
masked_image = init_image * (mask < 0.5)
_, _, height, width = init_image.shape
@@ -1580,9 +1534,6 @@ class StableDiffusionControlNetInpaintPipeline(
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if padding_mask_crop is not None:
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
# Offload all models
self.maybe_free_model_hooks()
@@ -557,11 +557,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
prompt,
prompt_2,
image,
mask_image,
strength,
num_inference_steps,
callback_steps,
output_type,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
@@ -572,7 +570,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
control_guidance_start=0.0,
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
padding_mask_crop=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -635,19 +632,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
f" {negative_prompt_embeds.shape}."
)
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
f" {type(mask_image)}."
)
if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
@@ -761,14 +745,10 @@ class StableDiffusionXLControlNetInpaintPipeline(
num_images_per_prompt,
device,
dtype,
crops_coords,
resize_mode,
do_classifier_free_guidance=False,
guess_mode=False,
):
image = self.control_image_processor.preprocess(
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
).to(dtype=torch.float32)
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
@@ -1086,7 +1066,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
] = None,
height: Optional[int] = None,
width: Optional[int] = None,
padding_mask_crop: Optional[int] = None,
strength: float = 0.9999,
num_inference_steps: int = 50,
denoising_start: Optional[float] = None,
@@ -1142,12 +1121,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
padding_mask_crop (`int`, *optional*, defaults to `None`):
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
and contain information inreleant for inpainging, such as background.
strength (`float`, *optional*, defaults to 0.9999):
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
@@ -1317,11 +1290,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
prompt,
prompt_2,
control_image,
mask_image,
strength,
num_inference_steps,
callback_steps,
output_type,
negative_prompt,
negative_prompt_2,
prompt_embeds,
@@ -1332,7 +1303,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
padding_mask_crop,
)
self._guidance_scale = guidance_scale
@@ -1400,18 +1370,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
# 5. Preprocess mask and image - resizes image and mask w.r.t height and width
# 5.1 Prepare init image
if padding_mask_crop is not None:
height, width = self.image_processor.get_default_height_width(image, height, width)
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
resize_mode = "fill"
else:
crops_coords = None
resize_mode = "default"
original_image = image
init_image = self.image_processor.preprocess(
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
)
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)
# 5.2 Prepare control images
@@ -1424,8 +1383,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1441,8 +1398,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1454,9 +1409,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
raise ValueError(f"{controlnet.__class__} is not supported.")
# 5.3 Prepare mask
mask = self.mask_processor.preprocess(
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
masked_image = init_image * (mask < 0.5)
_, _, height, width = init_image.shape
@@ -1731,9 +1684,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
image = self.image_processor.postprocess(image, output_type=output_type)
if padding_mask_crop is not None:
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
# Offload all models
self.maybe_free_model_hooks()
@@ -268,6 +268,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
return objs
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
@@ -1317,9 +1317,6 @@ def download_from_original_stable_diffusion_ckpt(
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)
else:
with open(original_config_file, "r") as f:
original_config_file = f.read()
original_config = yaml.safe_load(original_config_file)
@@ -642,7 +642,6 @@ class StableDiffusionInpaintPipeline(
width,
strength,
callback_steps,
output_type,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
@@ -694,6 +693,11 @@ class StableDiffusionInpaintPipeline(
f" {negative_prompt_embeds.shape}."
)
if padding_mask_crop is not None:
if self.unet.config.in_channels != 4:
raise ValueError(
f"The UNet should have 4 input channels for inpainting mask crop, but has"
f" {self.unet.config.in_channels} input channels."
)
if not isinstance(image, PIL.Image.Image):
raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
@@ -703,8 +707,6 @@ class StableDiffusionInpaintPipeline(
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
f" {type(mask_image)}."
)
if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
def prepare_latents(
self,
@@ -1164,7 +1166,6 @@ class StableDiffusionInpaintPipeline(
width,
strength,
callback_steps,
output_type,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
@@ -744,19 +744,15 @@ class StableDiffusionXLInpaintPipeline(
self,
prompt,
prompt_2,
image,
mask_image,
height,
width,
strength,
callback_steps,
output_type,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
padding_mask_crop=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -814,18 +810,6 @@ class StableDiffusionXLInpaintPipeline(
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
f" {type(mask_image)}."
)
if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
def prepare_latents(
self,
@@ -1241,7 +1225,6 @@ class StableDiffusionXLInpaintPipeline(
masked_image_latents: torch.FloatTensor = None,
height: Optional[int] = None,
width: Optional[int] = None,
padding_mask_crop: Optional[int] = None,
strength: float = 0.9999,
num_inference_steps: int = 50,
timesteps: List[int] = None,
@@ -1304,12 +1287,6 @@ class StableDiffusionXLInpaintPipeline(
Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
and checkpoints that are not specifically fine-tuned on low resolutions.
padding_mask_crop (`int`, *optional*, defaults to `None`):
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
and contain information inreleant for inpainging, such as background.
strength (`float`, *optional*, defaults to 0.9999):
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
@@ -1472,19 +1449,15 @@ class StableDiffusionXLInpaintPipeline(
self.check_inputs(
prompt,
prompt_2,
image,
mask_image,
height,
width,
strength,
callback_steps,
output_type,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
padding_mask_crop,
)
self._guidance_scale = guidance_scale
@@ -1554,22 +1527,10 @@ class StableDiffusionXLInpaintPipeline(
is_strength_max = strength == 1.0
# 5. Preprocess mask and image
if padding_mask_crop is not None:
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
resize_mode = "fill"
else:
crops_coords = None
resize_mode = "default"
original_image = image
init_image = self.image_processor.preprocess(
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
)
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)
mask = self.mask_processor.preprocess(
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
if masked_image_latents is not None:
masked_image = masked_image_latents
@@ -1830,9 +1791,6 @@ class StableDiffusionXLInpaintPipeline(
image = self.image_processor.postprocess(image, output_type=output_type)
if padding_mask_crop is not None:
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
# Offload all models
self.maybe_free_model_hooks()
@@ -128,9 +128,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
`lambda(t)`.
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
@@ -168,16 +165,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
use_lu_lambdas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
@@ -215,11 +207,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
raise ValueError(
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
)
# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
@@ -280,24 +267,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_lu_lambdas:
lambdas = np.flip(log_sigmas.copy())
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
sigmas = np.exp(lambdas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
@@ -851,9 +831,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Improve numerical stability for small number of steps
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final
or (self.config.lower_order_final and len(self.timesteps) < 15)
or self.config.final_sigmas_type == "zero"
self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15)
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
@@ -165,10 +165,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
@@ -787,6 +783,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self._step_index = step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
def step(
self,
model_output: torch.FloatTensor,
@@ -108,7 +108,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`.
algorithm_type (`str`, defaults to `dpmsolver++`):
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
paper, and the `dpmsolver++` type implements the algorithms in the
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
@@ -122,9 +122,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
@@ -153,14 +150,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
solver_type: str = "midpoint",
lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
):
if algorithm_type == "dpmsolver":
deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message)
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
@@ -197,11 +189,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
raise ValueError(
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
)
# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
@@ -280,18 +267,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
@@ -305,12 +285,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
)
self.register_to_config(lower_order_final=True)
if not self.config.lower_order_final and self.config.final_sigmas_type == "zero":
logger.warn(
" `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True."
)
self.register_to_config(lower_order_final=True)
self.order_list = self.get_order_list(num_inference_steps)
# add an index counter for schedulers that allow duplicated timesteps
+4 -4
View File
@@ -519,10 +519,10 @@ def export_to_video(video_frames: List[np.ndarray], output_video_path: str = Non
def load_hf_numpy(path) -> np.ndarray:
base_url = "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main"
if not path.startswith("http://") and not path.startswith("https://"):
path = os.path.join(base_url, urllib.parse.quote(path))
if not path.startswith("http://") or path.startswith("https://"):
path = Path(
"https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path)
).as_posix()
return load_numpy(path)
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import os
import random
import tempfile
@@ -1663,11 +1662,6 @@ class UNet3DConditionLoRAModelTests(unittest.TestCase):
@deprecate_after_peft_backend
@require_torch_gpu
class LoraIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_dreambooth_old_format(self):
generator = torch.Generator("cpu").manual_seed(0)
+6 -13
View File
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import importlib
import os
import tempfile
@@ -1206,11 +1205,6 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"latent_channels": 4,
}
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@slow
@require_torch_gpu
def test_integration_move_lora_cpu(self):
@@ -1440,11 +1434,6 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
"sample_size": 128,
}
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@slow
@require_torch_gpu
@@ -1479,9 +1468,11 @@ class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
}
def tearDown(self):
super().tearDown()
import gc
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_dreambooth_old_format(self):
generator = torch.Generator("cpu").manual_seed(0)
@@ -1766,9 +1757,11 @@ class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
}
def tearDown(self):
super().tearDown()
import gc
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_sdxl_0_9_lora_one(self):
generator = torch.Generator().manual_seed(0)
@@ -97,7 +97,7 @@ class StableDiffusionXLKPipelineIntegrationTests(unittest.TestCase):
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array(
[0.9506951, 0.9527786, 0.95309967, 0.9511477, 0.952523, 0.9515326, 0.9511933, 0.9480397, 0.94930184]
[0.9389532, 0.9408587, 0.9394901, 0.939082, 0.9402114, 0.9382007, 0.93737566, 0.9346897, 0.9324472]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -32,7 +32,6 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
"euler_at_final": False,
"lambda_min_clipped": -float("inf"),
"variance_type": None,
"final_sigmas_type": "sigma_min",
}
config.update(**kwargs)
@@ -30,7 +30,6 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
"solver_type": "midpoint",
"lambda_min_clipped": -float("inf"),
"variance_type": None,
"final_sigmas_type": "sigma_min",
}
config.update(**kwargs)