Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2461933857 | |||
| 325f6c53ed | |||
| 43979c2890 | |||
| 9ea6ac1b07 | |||
| 2c34c7d6dd | |||
| bffadde126 | |||
| 11190ed09a | |||
| 35a969d297 | |||
| c5ff469d0e | |||
| bcecfbc873 | |||
| 6269045c5b | |||
| 6ca9c4af05 | |||
| 0532cece97 | |||
| 22b45304bf |
@@ -1,6 +1,6 @@
|
||||
diffusers==0.20.1
|
||||
accelerate==0.23.0
|
||||
transformers==4.34.0
|
||||
transformers==4.36.0
|
||||
peft==0.5.0
|
||||
torch==2.0.1
|
||||
torchvision>=0.16
|
||||
|
||||
@@ -22,7 +22,6 @@ import os
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -436,22 +435,6 @@ DATASET_NAME_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
|
||||
"""
|
||||
Returns:
|
||||
a state dict containing just the attention processor parameters.
|
||||
"""
|
||||
attn_processors = unet.attn_processors
|
||||
|
||||
attn_processors_state_dict = {}
|
||||
|
||||
for attn_processor_key, attn_processor in attn_processors.items():
|
||||
for parameter_key, parameter in attn_processor.state_dict().items():
|
||||
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
|
||||
|
||||
return attn_processors_state_dict
|
||||
|
||||
|
||||
def tokenize_prompt(tokenizer, prompt):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
@@ -640,6 +623,17 @@ def main(args):
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
for model in models:
|
||||
for param in model.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
@@ -1187,6 +1181,9 @@ def main(args):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Final inference
|
||||
# Make sure vae.dtype is consistent with the unet.dtype
|
||||
if args.mixed_precision == "fp16":
|
||||
vae.to(weight_dtype)
|
||||
# Load previous pipeline
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
|
||||
@@ -32,7 +32,6 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
||||
_import_structure["controlnetxs"] = ["ControlNetXSModel"]
|
||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
|
||||
@@ -1,703 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_condition import UNet2DConditionModel, UNetMidBlock2DCrossAttn
|
||||
from .unet_3d_blocks import (
|
||||
CrossAttnDownBlockMotion,
|
||||
DownBlockMotion,
|
||||
get_down_block,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseControlNetOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`ControlNetModel`].
|
||||
|
||||
Args:
|
||||
down_block_res_samples (`tuple[torch.Tensor]`):
|
||||
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
||||
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||
used to condition the original UNet's downsampling activations.
|
||||
mid_down_block_re_sample (`torch.Tensor`):
|
||||
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
||||
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
||||
Output can be used to condition the original UNet's middle block activation.
|
||||
"""
|
||||
|
||||
down_block_res_samples: Tuple[torch.Tensor]
|
||||
mid_block_res_sample: torch.Tensor
|
||||
|
||||
|
||||
class SparseControlNetConditioningEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_embedding_channels: int,
|
||||
conditioning_channels: int = 3,
|
||||
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
self.blocks = nn.ModuleList([])
|
||||
|
||||
for i in range(len(block_out_channels) - 1):
|
||||
channel_in = block_out_channels[i]
|
||||
channel_out = block_out_channels[i + 1]
|
||||
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
||||
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
||||
|
||||
self.conv_out = zero_module(
|
||||
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
def forward(self, conditioning):
|
||||
batch_size, channels, num_frames, height, width = conditioning.shape
|
||||
conditioning = conditioning.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
embedding = self.conv_in(conditioning)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
for block in self.blocks:
|
||||
embedding = block(embedding)
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
embedding = self.conv_out(embedding)
|
||||
embedding = embedding.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
class SparseControlNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, defaults to 0):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
The number of layers per block.
|
||||
downsample_padding (`int`, defaults to 1):
|
||||
The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, defaults to 1):
|
||||
The scale factor to use for the mid block.
|
||||
act_fn (`str`, defaults to "silu"):
|
||||
The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||
in post-processing.
|
||||
norm_eps (`float`, defaults to 1e-5):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||
`class_embed_type="projection"`.
|
||||
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
TODO(Patrick) - unused parameter.
|
||||
addition_embed_type_num_heads (`int`, defaults to 64):
|
||||
The number of heads to use for the `TextTimeEmbedding` layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 4,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"DownBlockMotion",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 768,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
motion_max_seq_length: int = 32,
|
||||
motion_num_attention_heads: int = 8,
|
||||
concate_conditioning_mask: bool = True,
|
||||
use_simplified_condition_embedding: bool = True,
|
||||
set_noisy_sample_input_to_zero: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
||||
)
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||
encoder_hid_dim_type = "text_proj"
|
||||
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
||||
|
||||
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type == "text_proj":
|
||||
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
||||
elif encoder_hid_dim_type == "text_image_proj":
|
||||
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
||||
self.encoder_hid_proj = TextImageProjection(
|
||||
text_embed_dim=encoder_hid_dim,
|
||||
image_embed_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
elif encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
||||
)
|
||||
else:
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
if concate_conditioning_mask:
|
||||
conditioning_channels = conditioning_channels + 1
|
||||
|
||||
self.concate_conditioning_mask = concate_conditioning_mask
|
||||
self.controlnet_cond_embedding = nn.Conv2d(
|
||||
conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.controlnet_down_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
dual_cross_attention=False,
|
||||
temporal_num_attention_heads=motion_num_attention_heads,
|
||||
temporal_max_seq_length=motion_max_seq_length,
|
||||
temporal_double_self_attention=False,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
for _ in range(layers_per_block):
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
# mid
|
||||
mid_block_channel = block_out_channels[-1]
|
||||
|
||||
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_mid_block = controlnet_block
|
||||
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 3,
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
||||
|
||||
Parameters:
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
||||
where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||
)
|
||||
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
||||
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
||||
addition_time_embed_dim = (
|
||||
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
||||
)
|
||||
|
||||
controlnet = cls(
|
||||
encoder_hid_dim=encoder_hid_dim,
|
||||
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||
addition_embed_type=addition_embed_type,
|
||||
addition_time_embed_dim=addition_time_embed_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
in_channels=unet.config.in_channels,
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
down_block_types=unet.config.down_block_types,
|
||||
only_cross_attention=unet.config.only_cross_attention,
|
||||
block_out_channels=unet.config.block_out_channels,
|
||||
layers_per_block=unet.config.layers_per_block,
|
||||
downsample_padding=unet.config.downsample_padding,
|
||||
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||
act_fn=unet.config.act_fn,
|
||||
norm_num_groups=unet.config.norm_num_groups,
|
||||
norm_eps=unet.config.norm_eps,
|
||||
cross_attention_dim=unet.config.cross_attention_dim,
|
||||
attention_head_dim=unet.config.attention_head_dim,
|
||||
num_attention_heads=unet.config.num_attention_heads,
|
||||
use_linear_projection=unet.config.use_linear_projection,
|
||||
class_embed_type=unet.config.class_embed_type,
|
||||
num_class_embeds=unet.config.num_class_embeds,
|
||||
upcast_attention=unet.config.upcast_attention,
|
||||
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
||||
mid_block_type=unet.config.mid_block_type,
|
||||
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
)
|
||||
|
||||
if load_weights_from_unet:
|
||||
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
||||
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||
|
||||
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
||||
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
||||
|
||||
return controlnet
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.FloatTensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
conditioning_mask: Optional[torch.FloatTensor] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
||||
"""
|
||||
The [`ControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor.
|
||||
timestep (`Union[torch.Tensor, float, int]`):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
controlnet_cond (`torch.FloatTensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||
embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape
|
||||
sample = torch.zeros_like(sample).to(sample.device)
|
||||
|
||||
# check channel order
|
||||
channel_order = self.config.controlnet_conditioning_channel_order
|
||||
|
||||
if channel_order == "rgb":
|
||||
# in rgb order by default
|
||||
...
|
||||
elif channel_order == "bgr":
|
||||
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
t_emb = self.time_proj(timesteps)
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
# 2. pre-process
|
||||
batch_size, channels, num_frames, height, width = sample.shape
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(sample_num_frames, dim=0)
|
||||
emb = emb.repeat_interleave(sample_num_frames, dim=0)
|
||||
|
||||
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
batch_frames, channels, height, width = sample.shape
|
||||
sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width)
|
||||
|
||||
if self.concate_conditioning_mask:
|
||||
controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1)
|
||||
|
||||
batch_size, channels, num_frames, height, width = controlnet_cond.shape
|
||||
controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape(
|
||||
batch_size * num_frames, channels, height, width
|
||||
)
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
batch_frames, channels, height, width = controlnet_cond.shape
|
||||
controlnet_cond = controlnet_cond[:, None].reshape(batch_size, num_frames, channels, height, width)
|
||||
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
batch_size, num_frames, channels, height, width = sample.shape
|
||||
sample = sample.reshape(sample_batch_size * sample_num_frames, channels, height, width)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
num_frames=sample_num_frames,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=sample_num_frames)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. Control net blocks
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = controlnet_down_block_res_samples
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
return (down_block_res_samples, mid_block_res_sample)
|
||||
|
||||
return SparseControlNetOutput(
|
||||
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||
)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
@@ -0,0 +1,318 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .lora import LoRACompatibleConv
|
||||
from .upsampling import upfirdn2d_native
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
"""A 1D downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
return self.conv(inputs)
|
||||
|
||||
|
||||
class Downsample2D(nn.Module):
|
||||
"""A 2D downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
if use_conv:
|
||||
conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.Conv2d_0 = conv
|
||||
self.conv = conv
|
||||
elif name == "Conv2d_0":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv and self.padding == 0:
|
||||
pad = (0, 1, 0, 1)
|
||||
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
if isinstance(self.conv, LoRACompatibleConv):
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FirDownsample2D(nn.Module):
|
||||
"""A 2D FIR downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.fir_kernel = fir_kernel
|
||||
self.use_conv = use_conv
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _downsample_2d(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
weight: Optional[torch.FloatTensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight (`torch.FloatTensor`, *optional*):
|
||||
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
||||
performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to average pooling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer downsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude.
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
|
||||
datatype as `x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
# setup kernel
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * gain
|
||||
|
||||
if self.use_conv:
|
||||
_, _, convH, convW = weight.shape
|
||||
pad_value = (kernel.shape[0] - factor) + (convW - 1)
|
||||
stride_value = [factor, factor]
|
||||
upfirdn_input = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
|
||||
else:
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
down=factor,
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.use_conv:
|
||||
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
|
||||
class KDownsample2D(nn.Module):
|
||||
r"""A 2D K-downsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros(
|
||||
[
|
||||
inputs.shape[1],
|
||||
inputs.shape[1],
|
||||
self.kernel.shape[0],
|
||||
self.kernel.shape[1],
|
||||
]
|
||||
)
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv2d(inputs, weight, stride=2)
|
||||
|
||||
|
||||
def downsample_2d(
|
||||
hidden_states: torch.FloatTensor,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
r"""Downsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
||||
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
||||
shape is a multiple of the downsampling factor.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`)
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to average pooling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer downsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude.
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]`
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * gain
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
kernel.to(device=hidden_states.device),
|
||||
down=factor,
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
return output
|
||||
+15
-699
@@ -23,562 +23,23 @@ import torch.nn.functional as F
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .activations import get_activation
|
||||
from .attention_processor import SpatialNorm
|
||||
from .downsampling import ( # noqa
|
||||
Downsample1D,
|
||||
Downsample2D,
|
||||
FirDownsample2D,
|
||||
KDownsample2D,
|
||||
downsample_2d,
|
||||
)
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from .normalization import AdaGroupNorm
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
"""A 1D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
self.conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(inputs)
|
||||
|
||||
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
outputs = self.conv(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
"""A 1D downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
return self.conv(inputs)
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
"""A 2D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
conv = None
|
||||
if use_conv_transpose:
|
||||
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
output_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(hidden_states)
|
||||
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
||||
# https://github.com/pytorch/pytorch/issues/86679
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# if `output_size` is passed we force the interpolation output
|
||||
# size and do not make use of `scale_factor=2`
|
||||
if output_size is None:
|
||||
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if self.use_conv:
|
||||
if self.name == "conv":
|
||||
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
||||
hidden_states = self.Conv2d_0(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.Conv2d_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Downsample2D(nn.Module):
|
||||
"""A 2D downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
if use_conv:
|
||||
conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.Conv2d_0 = conv
|
||||
self.conv = conv
|
||||
elif name == "Conv2d_0":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv and self.padding == 0:
|
||||
pad = (0, 1, 0, 1)
|
||||
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
if isinstance(self.conv, LoRACompatibleConv):
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FirUpsample2D(nn.Module):
|
||||
"""A 2D FIR upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`, optional):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.use_conv = use_conv
|
||||
self.fir_kernel = fir_kernel
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _upsample_2d(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
weight: Optional[torch.FloatTensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight (`torch.FloatTensor`, *optional*):
|
||||
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
||||
performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to nearest-neighbor upsampling.
|
||||
factor (`int`, *optional*): Integer upsampling factor (default: 2).
|
||||
gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
||||
datatype as `hidden_states`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
|
||||
# Setup filter kernel.
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
# setup kernel
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * (gain * (factor**2))
|
||||
|
||||
if self.use_conv:
|
||||
convH = weight.shape[2]
|
||||
convW = weight.shape[3]
|
||||
inC = weight.shape[1]
|
||||
|
||||
pad_value = (kernel.shape[0] - factor) - (convW - 1)
|
||||
|
||||
stride = (factor, factor)
|
||||
# Determine data dimensions.
|
||||
output_shape = (
|
||||
(hidden_states.shape[2] - 1) * factor + convH,
|
||||
(hidden_states.shape[3] - 1) * factor + convW,
|
||||
)
|
||||
output_padding = (
|
||||
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
|
||||
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
|
||||
)
|
||||
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
||||
num_groups = hidden_states.shape[1] // inC
|
||||
|
||||
# Transpose weights.
|
||||
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
|
||||
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
|
||||
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
||||
|
||||
inverse_conv = F.conv_transpose2d(
|
||||
hidden_states,
|
||||
weight,
|
||||
stride=stride,
|
||||
output_padding=output_padding,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
output = upfirdn2d_native(
|
||||
inverse_conv,
|
||||
torch.tensor(kernel, device=inverse_conv.device),
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
|
||||
)
|
||||
else:
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
up=factor,
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.use_conv:
|
||||
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||
|
||||
return height
|
||||
|
||||
|
||||
class FirDownsample2D(nn.Module):
|
||||
"""A 2D FIR downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.fir_kernel = fir_kernel
|
||||
self.use_conv = use_conv
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _downsample_2d(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
weight: Optional[torch.FloatTensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight (`torch.FloatTensor`, *optional*):
|
||||
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
||||
performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to average pooling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer downsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude.
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
|
||||
datatype as `x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
# setup kernel
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * gain
|
||||
|
||||
if self.use_conv:
|
||||
_, _, convH, convW = weight.shape
|
||||
pad_value = (kernel.shape[0] - factor) + (convW - 1)
|
||||
stride_value = [factor, factor]
|
||||
upfirdn_input = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
|
||||
else:
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
down=factor,
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.use_conv:
|
||||
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
|
||||
class KDownsample2D(nn.Module):
|
||||
r"""A 2D K-downsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros(
|
||||
[
|
||||
inputs.shape[1],
|
||||
inputs.shape[1],
|
||||
self.kernel.shape[0],
|
||||
self.kernel.shape[1],
|
||||
]
|
||||
)
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv2d(inputs, weight, stride=2)
|
||||
|
||||
|
||||
class KUpsample2D(nn.Module):
|
||||
r"""A 2D K-upsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros(
|
||||
[
|
||||
inputs.shape[1],
|
||||
inputs.shape[1],
|
||||
self.kernel.shape[0],
|
||||
self.kernel.shape[1],
|
||||
]
|
||||
)
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
from .upsampling import ( # noqa
|
||||
FirUpsample2D,
|
||||
KUpsample2D,
|
||||
Upsample1D,
|
||||
Upsample2D,
|
||||
upfirdn2d_native,
|
||||
upsample_2d,
|
||||
)
|
||||
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
@@ -894,151 +355,6 @@ class ResidualTemporalBlock1D(nn.Module):
|
||||
return out + self.residual_conv(inputs)
|
||||
|
||||
|
||||
def upsample_2d(
|
||||
hidden_states: torch.FloatTensor,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
r"""Upsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
||||
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
|
||||
a: multiple of the upsampling factor.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to nearest-neighbor upsampling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer upsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]`
|
||||
"""
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * (gain * (factor**2))
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
kernel.to(device=hidden_states.device),
|
||||
up=factor,
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def downsample_2d(
|
||||
hidden_states: torch.FloatTensor,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
r"""Downsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
||||
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
||||
shape is a multiple of the downsampling factor.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`)
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to average pooling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer downsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude.
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]`
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * gain
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
kernel.to(device=hidden_states.device),
|
||||
down=factor,
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def upfirdn2d_native(
|
||||
tensor: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
up: int = 1,
|
||||
down: int = 1,
|
||||
pad: Tuple[int, int] = (0, 0),
|
||||
) -> torch.Tensor:
|
||||
up_x = up_y = up
|
||||
down_x = down_y = down
|
||||
pad_x0 = pad_y0 = pad[0]
|
||||
pad_x1 = pad_y1 = pad[1]
|
||||
|
||||
_, channel, in_h, in_w = tensor.shape
|
||||
tensor = tensor.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
_, in_h, in_w, minor = tensor.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out.to(tensor.device) # Move back to mps if necessary
|
||||
out = out[
|
||||
:,
|
||||
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
||||
:,
|
||||
]
|
||||
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
-1,
|
||||
minor,
|
||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
||||
)
|
||||
out = out.permute(0, 2, 3, 1)
|
||||
out = out[:, ::down_y, ::down_x, :]
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
||||
|
||||
return out.view(-1, channel, out_h, out_w)
|
||||
|
||||
|
||||
class TemporalConvLayer(nn.Module):
|
||||
"""
|
||||
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
|
||||
|
||||
@@ -54,7 +54,6 @@ def get_down_block(
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
temporal_num_attention_heads: int = 8,
|
||||
temporal_double_self_attention=True,
|
||||
temporal_max_seq_length: int = 32,
|
||||
transformer_layers_per_block: int = 1,
|
||||
) -> Union[
|
||||
@@ -113,7 +112,6 @@ def get_down_block(
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
temporal_num_attention_heads=temporal_num_attention_heads,
|
||||
temporal_max_seq_length=temporal_max_seq_length,
|
||||
temporal_double_self_attention=temporal_double_self_attention,
|
||||
)
|
||||
elif down_block_type == "CrossAttnDownBlockMotion":
|
||||
if cross_attention_dim is None:
|
||||
@@ -137,7 +135,6 @@ def get_down_block(
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
temporal_num_attention_heads=temporal_num_attention_heads,
|
||||
temporal_max_seq_length=temporal_max_seq_length,
|
||||
temporal_double_self_attention=temporal_double_self_attention,
|
||||
)
|
||||
elif down_block_type == "DownBlockSpatioTemporal":
|
||||
# added for SDV
|
||||
@@ -949,7 +946,6 @@ class DownBlockMotion(nn.Module):
|
||||
temporal_num_attention_heads: int = 1,
|
||||
temporal_cross_attention_dim: Optional[int] = None,
|
||||
temporal_max_seq_length: int = 32,
|
||||
temporal_double_self_attention: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -982,7 +978,6 @@ class DownBlockMotion(nn.Module):
|
||||
positional_embeddings="sinusoidal",
|
||||
num_positional_embeddings=temporal_max_seq_length,
|
||||
attention_head_dim=out_channels // temporal_num_attention_heads,
|
||||
double_self_attention=temporal_double_self_attention,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1085,7 +1080,6 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
temporal_cross_attention_dim: Optional[int] = None,
|
||||
temporal_num_attention_heads: int = 8,
|
||||
temporal_max_seq_length: int = 32,
|
||||
temporal_double_self_attention: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -1150,7 +1144,6 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
positional_embeddings="sinusoidal",
|
||||
num_positional_embeddings=temporal_max_seq_length,
|
||||
attention_head_dim=out_channels // temporal_num_attention_heads,
|
||||
double_self_attention=temporal_double_self_attention,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,426 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .lora import LoRACompatibleConv
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
"""A 1D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
self.conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(inputs)
|
||||
|
||||
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
outputs = self.conv(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
"""A 2D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
conv = None
|
||||
if use_conv_transpose:
|
||||
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
output_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(hidden_states)
|
||||
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
||||
# https://github.com/pytorch/pytorch/issues/86679
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# if `output_size` is passed we force the interpolation output
|
||||
# size and do not make use of `scale_factor=2`
|
||||
if output_size is None:
|
||||
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if self.use_conv:
|
||||
if self.name == "conv":
|
||||
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
||||
hidden_states = self.Conv2d_0(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.Conv2d_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FirUpsample2D(nn.Module):
|
||||
"""A 2D FIR upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`, optional):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.use_conv = use_conv
|
||||
self.fir_kernel = fir_kernel
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _upsample_2d(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
weight: Optional[torch.FloatTensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight (`torch.FloatTensor`, *optional*):
|
||||
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
||||
performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to nearest-neighbor upsampling.
|
||||
factor (`int`, *optional*): Integer upsampling factor (default: 2).
|
||||
gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
||||
datatype as `hidden_states`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
|
||||
# Setup filter kernel.
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
# setup kernel
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * (gain * (factor**2))
|
||||
|
||||
if self.use_conv:
|
||||
convH = weight.shape[2]
|
||||
convW = weight.shape[3]
|
||||
inC = weight.shape[1]
|
||||
|
||||
pad_value = (kernel.shape[0] - factor) - (convW - 1)
|
||||
|
||||
stride = (factor, factor)
|
||||
# Determine data dimensions.
|
||||
output_shape = (
|
||||
(hidden_states.shape[2] - 1) * factor + convH,
|
||||
(hidden_states.shape[3] - 1) * factor + convW,
|
||||
)
|
||||
output_padding = (
|
||||
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
|
||||
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
|
||||
)
|
||||
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
||||
num_groups = hidden_states.shape[1] // inC
|
||||
|
||||
# Transpose weights.
|
||||
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
|
||||
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
|
||||
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
||||
|
||||
inverse_conv = F.conv_transpose2d(
|
||||
hidden_states,
|
||||
weight,
|
||||
stride=stride,
|
||||
output_padding=output_padding,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
output = upfirdn2d_native(
|
||||
inverse_conv,
|
||||
torch.tensor(kernel, device=inverse_conv.device),
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
|
||||
)
|
||||
else:
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
up=factor,
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.use_conv:
|
||||
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||
|
||||
return height
|
||||
|
||||
|
||||
class KUpsample2D(nn.Module):
|
||||
r"""A 2D K-upsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros(
|
||||
[
|
||||
inputs.shape[1],
|
||||
inputs.shape[1],
|
||||
self.kernel.shape[0],
|
||||
self.kernel.shape[1],
|
||||
]
|
||||
)
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
|
||||
|
||||
def upfirdn2d_native(
|
||||
tensor: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
up: int = 1,
|
||||
down: int = 1,
|
||||
pad: Tuple[int, int] = (0, 0),
|
||||
) -> torch.Tensor:
|
||||
up_x = up_y = up
|
||||
down_x = down_y = down
|
||||
pad_x0 = pad_y0 = pad[0]
|
||||
pad_x1 = pad_y1 = pad[1]
|
||||
|
||||
_, channel, in_h, in_w = tensor.shape
|
||||
tensor = tensor.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
_, in_h, in_w, minor = tensor.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out.to(tensor.device) # Move back to mps if necessary
|
||||
out = out[
|
||||
:,
|
||||
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
||||
:,
|
||||
]
|
||||
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
-1,
|
||||
minor,
|
||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
||||
)
|
||||
out = out.permute(0, 2, 3, 1)
|
||||
out = out[:, ::down_y, ::down_x, :]
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
||||
|
||||
return out.view(-1, channel, out_h, out_w)
|
||||
|
||||
|
||||
def upsample_2d(
|
||||
hidden_states: torch.FloatTensor,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
r"""Upsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
||||
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
|
||||
a: multiple of the upsampling factor.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to nearest-neighbor upsampling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer upsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]`
|
||||
"""
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * (gain * (factor**2))
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
kernel.to(device=hidden_states.device),
|
||||
up=factor,
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
||||
)
|
||||
return output
|
||||
@@ -179,12 +179,7 @@ else:
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"CLIPImageProjection",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
"StableDiffusionDiffEditPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENTextImagePipeline",
|
||||
"StableDiffusionImageVariationPipeline",
|
||||
"StableDiffusionImg2ImgPipeline",
|
||||
"StableDiffusionInpaintPipeline",
|
||||
@@ -193,13 +188,18 @@ else:
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
"StableDiffusionPanoramaPipeline",
|
||||
"StableDiffusionPipeline",
|
||||
"StableDiffusionSAGPipeline",
|
||||
"StableDiffusionUpscalePipeline",
|
||||
"StableUnCLIPImg2ImgPipeline",
|
||||
"StableUnCLIPPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
|
||||
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
|
||||
_import_structure["stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
|
||||
_import_structure["stable_diffusion_gligen"] = [
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENTextImagePipeline",
|
||||
]
|
||||
_import_structure["stable_video_diffusion"] = ["StableVideoDiffusionPipeline"]
|
||||
_import_structure["stable_diffusion_xl"].extend(
|
||||
[
|
||||
@@ -209,6 +209,7 @@ else:
|
||||
"StableDiffusionXLPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
|
||||
_import_structure["t2i_adapter"] = [
|
||||
"StableDiffusionAdapterPipeline",
|
||||
"StableDiffusionXLAdapterPipeline",
|
||||
@@ -268,7 +269,7 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
|
||||
else:
|
||||
_import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"])
|
||||
_import_structure["stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"]
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -420,11 +421,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_diffusion import (
|
||||
CLIPImageProjection,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionGLIGENPipeline,
|
||||
StableDiffusionGLIGENTextImagePipeline,
|
||||
StableDiffusionImageVariationPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
@@ -433,12 +430,15 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionLDM3DPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
|
||||
from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
|
||||
from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .stable_diffusion_sag import StableDiffusionSAGPipeline
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
@@ -498,7 +498,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
|
||||
else:
|
||||
from .stable_diffusion import StableDiffusionKDiffusionPipeline
|
||||
from .stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
def convert_motion_module(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
for k, v in original_state_dict.items():
|
||||
if "pos_encoder" in k:
|
||||
continue
|
||||
|
||||
else:
|
||||
converted_state_dict[
|
||||
k.replace(".norms.0", ".norm1")
|
||||
.replace(".norms.1", ".norm2")
|
||||
.replace(".ff_norm", ".norm3")
|
||||
.replace(".attention_blocks.0", ".attn1")
|
||||
.replace(".attention_blocks.1", ".attn2")
|
||||
.replace(".temporal_transformer", "")
|
||||
] = v
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
conv_state_dict = convert_motion_module(state_dict)
|
||||
|
||||
# convert to new format
|
||||
output_dict = {}
|
||||
for module_name, params in conv_state_dict.items():
|
||||
if type(params) is not torch.Tensor:
|
||||
continue
|
||||
output_dict.update({f"unet.{module_name}": params})
|
||||
|
||||
save_file(output_dict, f"{args.output_path}/diffusion_pytorch_model.safetensors")
|
||||
@@ -1,51 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import MotionAdapter
|
||||
|
||||
|
||||
def convert_motion_module(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
for k, v in original_state_dict.items():
|
||||
if "pos_encoder" in k:
|
||||
continue
|
||||
|
||||
else:
|
||||
converted_state_dict[
|
||||
k.replace(".norms.0", ".norm1")
|
||||
.replace(".norms.1", ".norm2")
|
||||
.replace(".ff_norm", ".norm3")
|
||||
.replace(".attention_blocks.0", ".attn1")
|
||||
.replace(".attention_blocks.1", ".attn2")
|
||||
.replace(".temporal_transformer", "")
|
||||
] = v
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
parser.add_argument("--use_motion_mid_block", action="store_true")
|
||||
parser.add_argument("--motion_max_seq_length", type=int, default=32)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
conv_state_dict = convert_motion_module(state_dict)
|
||||
adapter = MotionAdapter(
|
||||
use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length
|
||||
)
|
||||
# skip loading position embeddings
|
||||
adapter.load_state_dict(conv_state_dict, strict=False)
|
||||
adapter.save_pretrained(args.output_path)
|
||||
adapter.save_pretrained(args.output_path, variant="fp16")
|
||||
@@ -1,49 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.models import SparseControlNetModel
|
||||
|
||||
|
||||
def convert_sparse_cntrl_module(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
for k, v in original_state_dict.items():
|
||||
if "pos_encoder" in k:
|
||||
continue
|
||||
|
||||
else:
|
||||
converted_state_dict[
|
||||
k.replace(".norms.0", ".norm1")
|
||||
.replace(".norms.1", ".norm2")
|
||||
.replace(".ff_norm", ".norm3")
|
||||
.replace(".attention_blocks.0", ".attn1")
|
||||
.replace(".attention_blocks.1", ".attn2")
|
||||
.replace(".temporal_transformer", "")
|
||||
] = v
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
parser.add_argument("--motion_max_seq_length", type=int, default=32)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
conv_state_dict = convert_sparse_cntrl_module(state_dict)
|
||||
controlnet = SparseControlNetModel()
|
||||
|
||||
# skip loading position embeddings
|
||||
controlnet.load_state_dict(conv_state_dict, strict=False)
|
||||
controlnet.save_pretrained(args.output_path)
|
||||
controlnet.save_pretrained(args.output_path, variant="fp16")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -44,7 +44,6 @@ else:
|
||||
_import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"]
|
||||
_import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"]
|
||||
_import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"]
|
||||
@@ -67,37 +66,19 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
)
|
||||
|
||||
_dummy_objects.update(
|
||||
{
|
||||
"StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline,
|
||||
"StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline,
|
||||
"StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline,
|
||||
}
|
||||
)
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"]
|
||||
try:
|
||||
if not (
|
||||
is_torch_available()
|
||||
and is_transformers_available()
|
||||
and is_k_diffusion_available()
|
||||
and is_k_diffusion_version(">=", "0.0.12")
|
||||
):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import (
|
||||
dummy_torch_and_transformers_and_k_diffusion_objects,
|
||||
)
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"]
|
||||
try:
|
||||
if not (is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -139,13 +120,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionPipelineOutput,
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from .pipeline_stable_diffusion_attend_and_excite import (
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
)
|
||||
from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline
|
||||
from .pipeline_stable_diffusion_gligen_text_image import (
|
||||
StableDiffusionGLIGENTextImagePipeline,
|
||||
)
|
||||
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
from .pipeline_stable_diffusion_instruct_pix2pix import (
|
||||
@@ -156,7 +130,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
|
||||
from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
|
||||
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
|
||||
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
|
||||
from .pipeline_stable_unclip import StableUnCLIPPipeline
|
||||
from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
|
||||
@@ -181,29 +154,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
)
|
||||
else:
|
||||
from .pipeline_stable_diffusion_depth2img import (
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
)
|
||||
from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
|
||||
|
||||
try:
|
||||
if not (
|
||||
is_torch_available()
|
||||
and is_transformers_available()
|
||||
and is_k_diffusion_available()
|
||||
and is_k_diffusion_version(">=", "0.0.12")
|
||||
):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_k_diffusion import (
|
||||
StableDiffusionKDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_onnx_available()):
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
+2
-2
@@ -37,8 +37,8 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
+2
-2
@@ -40,8 +40,8 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline
|
||||
from .pipeline_stable_diffusion_gligen_text_image import StableDiffusionGLIGENTextImagePipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
+2
-2
@@ -36,8 +36,8 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
+3
-3
@@ -35,9 +35,9 @@ from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .clip_image_project_model import CLIPImageProjection
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.clip_image_project_model import CLIPImageProjection
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -0,0 +1,60 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_k_diffusion_available,
|
||||
is_k_diffusion_version,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (
|
||||
is_transformers_available()
|
||||
and is_torch_available()
|
||||
and is_k_diffusion_available()
|
||||
and is_k_diffusion_version(">=", "0.0.12")
|
||||
):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (
|
||||
is_transformers_available()
|
||||
and is_torch_available()
|
||||
and is_k_diffusion_available()
|
||||
and is_k_diffusion_version(">=", "0.0.12")
|
||||
):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
+1
-1
@@ -27,7 +27,7 @@ from ...schedulers import LMSDiscreteScheduler
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
+2
-2
@@ -34,8 +34,8 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import importlib
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
@@ -24,6 +25,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from packaging import version
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
@@ -1983,10 +1985,26 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
fused_te_2_state_dict = pipe.text_encoder_2.state_dict()
|
||||
unet_state_dict = pipe.unet.state_dict()
|
||||
|
||||
peft_ge_070 = version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0")
|
||||
|
||||
def remap_key(key, sd):
|
||||
# some keys have moved around for PEFT >= 0.7.0, but they should still be loaded correctly
|
||||
if (key in sd) or (not peft_ge_070):
|
||||
return key
|
||||
|
||||
# instead of linear.weight, we now have linear.base_layer.weight, etc.
|
||||
if key.endswith(".weight"):
|
||||
key = key[:-7] + ".base_layer.weight"
|
||||
elif key.endswith(".bias"):
|
||||
key = key[:-5] + ".base_layer.bias"
|
||||
return key
|
||||
|
||||
for key, value in text_encoder_1_sd.items():
|
||||
key = remap_key(key, fused_te_state_dict)
|
||||
self.assertTrue(torch.allclose(fused_te_state_dict[key], value))
|
||||
|
||||
for key, value in text_encoder_2_sd.items():
|
||||
key = remap_key(key, fused_te_2_state_dict)
|
||||
self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value))
|
||||
|
||||
for key, value in unet_state_dict.items():
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from diffusers import OnnxStableDiffusionInpaintPipelineLegacy
|
||||
from diffusers.utils.testing_utils import (
|
||||
is_onnx_available,
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
require_onnxruntime,
|
||||
require_torch_gpu,
|
||||
)
|
||||
|
||||
|
||||
if is_onnx_available():
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
@nightly
|
||||
@require_onnxruntime
|
||||
@require_torch_gpu
|
||||
class StableDiffusionOnnxInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||
@property
|
||||
def gpu_provider(self):
|
||||
return (
|
||||
"CUDAExecutionProvider",
|
||||
{
|
||||
"gpu_mem_limit": "15000000000", # 15GB
|
||||
"arena_extend_strategy": "kSameAsRequested",
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def gpu_options(self):
|
||||
options = ort.SessionOptions()
|
||||
options.enable_mem_pattern = False
|
||||
return options
|
||||
|
||||
def test_inference(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
)
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/red_cat_sitting_on_a_park_bench_onnx.npy"
|
||||
)
|
||||
|
||||
# using the PNDM scheduler by default
|
||||
pipe = OnnxStableDiffusionInpaintPipelineLegacy.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
revision="onnx",
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
provider=self.gpu_provider,
|
||||
sess_options=self.gpu_options,
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A red cat sitting on a park bench"
|
||||
|
||||
generator = np.random.RandomState(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=15,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
Reference in New Issue
Block a user