Compare commits
12 Commits
lora-tests
...
sd-tests
| Author | SHA1 | Date | |
|---|---|---|---|
| f8c53ee022 | |||
| d1272550d6 | |||
| 75001f620e | |||
| fee93c81eb | |||
| 5308cce994 | |||
| 318556b20e | |||
| 6620eda357 | |||
| 1f0705adcf | |||
| 5e96333cb2 | |||
| da95a28ff6 | |||
| d66d554dc2 | |||
| c7df846dec |
@@ -77,7 +77,7 @@ Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggi
|
||||
|
||||
## Quickstart
|
||||
|
||||
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 16000+ checkpoints):
|
||||
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 19000+ checkpoints):
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -219,7 +219,7 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
|
||||
- https://github.com/deep-floyd/IF
|
||||
- https://github.com/bentoml/BentoML
|
||||
- https://github.com/bmaltais/kohya_ss
|
||||
- +7000 other amazing GitHub repositories 💪
|
||||
- +8000 other amazing GitHub repositories 💪
|
||||
|
||||
Thank you for using us ❤️.
|
||||
|
||||
|
||||
@@ -30,8 +30,8 @@ To learn more about how to load single file weights, see the [Load different Sta
|
||||
|
||||
## FromOriginalVAEMixin
|
||||
|
||||
[[autodoc]] loaders.single_file.FromOriginalVAEMixin
|
||||
[[autodoc]] loaders.autoencoder.FromOriginalVAEMixin
|
||||
|
||||
## FromOriginalControlnetMixin
|
||||
|
||||
[[autodoc]] loaders.single_file.FromOriginalControlnetMixin
|
||||
[[autodoc]] loaders.controlnet.FromOriginalControlNetMixin
|
||||
@@ -22,4 +22,4 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNetMotionModel
|
||||
|
||||
## UNet3DConditionOutput
|
||||
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
|
||||
[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput
|
||||
|
||||
@@ -22,4 +22,4 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNet1DModel
|
||||
|
||||
## UNet1DOutput
|
||||
[[autodoc]] models.unet_1d.UNet1DOutput
|
||||
[[autodoc]] models.unets.unet_1d.UNet1DOutput
|
||||
|
||||
@@ -22,10 +22,10 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNet2DConditionModel
|
||||
|
||||
## UNet2DConditionOutput
|
||||
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
|
||||
[[autodoc]] models.unets.unet_2d_condition.UNet2DConditionOutput
|
||||
|
||||
## FlaxUNet2DConditionModel
|
||||
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionModel
|
||||
[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionModel
|
||||
|
||||
## FlaxUNet2DConditionOutput
|
||||
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
|
||||
[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput
|
||||
|
||||
@@ -22,4 +22,4 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNet2DModel
|
||||
|
||||
## UNet2DOutput
|
||||
[[autodoc]] models.unet_2d.UNet2DOutput
|
||||
[[autodoc]] models.unets.unet_2d.UNet2DOutput
|
||||
|
||||
@@ -22,4 +22,4 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNet3DConditionModel
|
||||
|
||||
## UNet3DConditionOutput
|
||||
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
|
||||
[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput
|
||||
|
||||
@@ -26,7 +26,7 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.models.unet_motion_model import MotionAdapter
|
||||
from diffusers.models.unets.unet_motion_model import MotionAdapter
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import (
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
from diffusers import StableDiffusionControlNetPipeline
|
||||
from diffusers.models import ControlNetModel
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.utils import logging
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
|
||||
from diffusers.utils import PIL_INTERPOLATION, logging
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
from diffusers.models.unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
|
||||
@@ -26,7 +26,7 @@ from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProc
|
||||
from diffusers.models.autoencoders import AutoencoderKL
|
||||
from diffusers.models.lora import LoRACompatibleConv
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
from diffusers.models.unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
@@ -36,7 +36,7 @@ from diffusers.models.unet_2d_blocks import (
|
||||
UpBlock2D,
|
||||
Upsample2D,
|
||||
)
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
|
||||
|
||||
|
||||
@@ -740,6 +740,10 @@ def main(args):
|
||||
# Resize.
|
||||
combined_im = train_resize(combined_im)
|
||||
|
||||
# Flipping.
|
||||
if not args.no_flip and random.random() < 0.5:
|
||||
combined_im = train_flip(combined_im)
|
||||
|
||||
# Cropping.
|
||||
if not args.random_crop:
|
||||
y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))
|
||||
@@ -749,11 +753,6 @@ def main(args):
|
||||
y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))
|
||||
combined_im = crop(combined_im, y1, x1, h, w)
|
||||
|
||||
# Flipping.
|
||||
if random.random() < 0.5:
|
||||
x1 = combined_im.shape[2] - x1
|
||||
combined_im = train_flip(combined_im)
|
||||
|
||||
crop_top_left = (y1, x1)
|
||||
crop_top_lefts.append(crop_top_left)
|
||||
combined_im = normalize(combined_im)
|
||||
|
||||
@@ -10,7 +10,7 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import VQModel
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.models.uvit_2d import UVit2DModel
|
||||
from diffusers.models.unets.uvit_2d import UVit2DModel
|
||||
from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline
|
||||
from diffusers.schedulers import AmusedScheduler
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from tqdm import tqdm
|
||||
from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
|
||||
from diffusers.models.autoencoders.vae import Encoder
|
||||
from diffusers.models.embeddings import TimestepEmbedding
|
||||
from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
|
||||
from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
|
||||
|
||||
|
||||
args = ArgumentParser()
|
||||
|
||||
@@ -153,6 +153,7 @@ else:
|
||||
"LCMScheduler",
|
||||
"PNDMScheduler",
|
||||
"RePaintScheduler",
|
||||
"SASolverScheduler",
|
||||
"SchedulerMixin",
|
||||
"ScoreSdeVeScheduler",
|
||||
"UnCLIPScheduler",
|
||||
@@ -381,7 +382,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
|
||||
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
|
||||
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
|
||||
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
|
||||
_import_structure["schedulers"].extend(
|
||||
@@ -530,6 +531,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LCMScheduler,
|
||||
PNDMScheduler,
|
||||
RePaintScheduler,
|
||||
SASolverScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
UnCLIPScheduler,
|
||||
@@ -709,7 +711,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .models.controlnet_flax import FlaxControlNetModel
|
||||
from .models.modeling_flax_utils import FlaxModelMixin
|
||||
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .models.vae_flax import FlaxAutoencoderKL
|
||||
from .pipelines import FlaxDiffusionPipeline
|
||||
from .schedulers import (
|
||||
|
||||
@@ -16,7 +16,7 @@ import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from ...models.unet_1d import UNet1DModel
|
||||
from ...models.unets.unet_1d import UNet1DModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...utils.dummy_pt_objects import DDPMScheduler
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
|
||||
@@ -54,12 +54,13 @@ if is_transformers_available():
|
||||
_import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"]
|
||||
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]
|
||||
|
||||
_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
|
||||
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
||||
_import_structure["utils"] = ["AttnProcsLayers"]
|
||||
|
||||
if is_transformers_available():
|
||||
_import_structure["single_file"].extend(["FromSingleFileMixin"])
|
||||
_import_structure["single_file"] = ["FromSingleFileMixin"]
|
||||
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
|
||||
@@ -69,7 +70,8 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin
|
||||
from .autoencoder import FromOriginalVAEMixin
|
||||
from .controlnet import FromOriginalControlNetMixin
|
||||
from .unet import UNet2DConditionLoadersMixin
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from .single_file_utils import (
|
||||
create_diffusers_vae_model_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
class FromOriginalVAEMixin:
|
||||
"""
|
||||
Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`AutoencoderKL`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
|
||||
a VAE from SDXL or a Stable Diffusion v2 model or higher.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
|
||||
model = AutoencoderKL.from_single_file(url)
|
||||
```
|
||||
"""
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
class_name = cls.__name__
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size)
|
||||
vae = component["vae"]
|
||||
if torch_dtype is not None:
|
||||
vae = vae.to(torch_dtype)
|
||||
|
||||
return vae
|
||||
@@ -0,0 +1,127 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from .single_file_utils import (
|
||||
create_diffusers_controlnet_model_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
class FromOriginalControlNetMixin:
|
||||
"""
|
||||
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
|
||||
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
|
||||
model = ControlNetModel.from_single_file(url)
|
||||
|
||||
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
|
||||
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
|
||||
```
|
||||
"""
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
class_name = cls.__name__
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
upcast_attention = kwargs.pop("upcast_attention", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
|
||||
component = create_diffusers_controlnet_model_from_ldm(
|
||||
class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size
|
||||
)
|
||||
controlnet = component["controlnet"]
|
||||
if torch_dtype is not None:
|
||||
controlnet = controlnet.to(torch_dtype)
|
||||
|
||||
return controlnet
|
||||
@@ -11,39 +11,132 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..utils import deprecate, is_accelerate_available, is_transformers_available, logging
|
||||
from ..utils import is_transformers_available, logging
|
||||
from .single_file_utils import (
|
||||
create_diffusers_unet_model_from_ldm,
|
||||
create_diffusers_vae_model_from_ldm,
|
||||
create_scheduler_from_ldm,
|
||||
create_text_encoders_and_tokenizers_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
infer_model_type,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
pass
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Pipelines that support the SDXL Refiner checkpoint
|
||||
REFINER_PIPELINES = [
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
|
||||
def build_sub_model_components(
|
||||
pipeline_components,
|
||||
pipeline_class_name,
|
||||
component_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
local_files_only=False,
|
||||
load_safety_checker=False,
|
||||
model_type=None,
|
||||
image_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
if component_name in pipeline_components:
|
||||
return {}
|
||||
|
||||
if component_name == "unet":
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
unet_components = create_diffusers_unet_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, num_in_channels=num_in_channels, image_size=image_size
|
||||
)
|
||||
return unet_components
|
||||
|
||||
if component_name == "vae":
|
||||
vae_components = create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, image_size
|
||||
)
|
||||
return vae_components
|
||||
|
||||
if component_name == "scheduler":
|
||||
scheduler_type = kwargs.get("scheduler_type", "ddim")
|
||||
prediction_type = kwargs.get("prediction_type", None)
|
||||
|
||||
scheduler_components = create_scheduler_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
scheduler_type=scheduler_type,
|
||||
prediction_type=prediction_type,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
return scheduler_components
|
||||
|
||||
if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]:
|
||||
text_encoder_components = create_text_encoders_and_tokenizers_from_ldm(
|
||||
original_config,
|
||||
checkpoint,
|
||||
model_type=model_type,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return text_encoder_components
|
||||
|
||||
if component_name == "safety_checker":
|
||||
if load_safety_checker:
|
||||
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
||||
)
|
||||
else:
|
||||
safety_checker = None
|
||||
return {"safety_checker": safety_checker}
|
||||
|
||||
if component_name == "feature_extractor":
|
||||
if load_safety_checker:
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
||||
)
|
||||
else:
|
||||
feature_extractor = None
|
||||
return {"feature_extractor": feature_extractor}
|
||||
|
||||
return
|
||||
|
||||
|
||||
def set_additional_components(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
model_type=None,
|
||||
):
|
||||
components = {}
|
||||
if pipeline_class_name in REFINER_PIPELINES:
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
is_refiner = model_type == "SDXL-Refiner"
|
||||
components.update(
|
||||
{
|
||||
"requires_aesthetics_score": is_refiner,
|
||||
"force_zeros_for_empty_prompt": False if is_refiner else True,
|
||||
}
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_ckpt(cls, *args, **kwargs):
|
||||
deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
|
||||
deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
|
||||
return cls.from_single_file(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
@@ -58,8 +151,7 @@ class FromSingleFileMixin:
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
@@ -85,42 +177,6 @@ class FromSingleFileMixin:
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
extract_ema (`bool`, *optional*, defaults to `False`):
|
||||
Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
|
||||
higher quality images for inference. Non-EMA weights are usually better for continuing finetuning.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
prediction_type (`str`, *optional*):
|
||||
The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
|
||||
the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
|
||||
num_in_channels (`int`, *optional*, defaults to `None`):
|
||||
The number of input channels. If `None`, it is automatically inferred.
|
||||
scheduler_type (`str`, *optional*, defaults to `"pndm"`):
|
||||
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
||||
"ddim"]`.
|
||||
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
||||
Whether to load the safety checker or not.
|
||||
text_encoder ([`~transformers.CLIPTextModel`], *optional*, defaults to `None`):
|
||||
An instance of `CLIPTextModel` to use, specifically the
|
||||
[clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this
|
||||
parameter is `None`, the function loads a new instance of `CLIPTextModel` by itself if needed.
|
||||
vae (`AutoencoderKL`, *optional*, defaults to `None`):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
|
||||
this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
|
||||
tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
|
||||
An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
|
||||
of `CLIPTokenizer` by itself if needed.
|
||||
original_config_file (`str`):
|
||||
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be
|
||||
automatically inferred by looking for a key that only exists in SD2.0 models.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
@@ -143,484 +199,80 @@ class FromSingleFileMixin:
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
"""
|
||||
# import here to avoid circular dependency
|
||||
from ..pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config_files = kwargs.pop("config_files", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
scheduler_type = kwargs.pop("scheduler_type", "pndm")
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
||||
prediction_type = kwargs.pop("prediction_type", None)
|
||||
text_encoder = kwargs.pop("text_encoder", None)
|
||||
text_encoder_2 = kwargs.pop("text_encoder_2", None)
|
||||
vae = kwargs.pop("vae", None)
|
||||
controlnet = kwargs.pop("controlnet", None)
|
||||
adapter = kwargs.pop("adapter", None)
|
||||
tokenizer = kwargs.pop("tokenizer", None)
|
||||
tokenizer_2 = kwargs.pop("tokenizer_2", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
class_name = cls.__name__
|
||||
|
||||
pipeline_name = cls.__name__
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# TODO: For now we only support stable diffusion
|
||||
stable_unclip = None
|
||||
model_type = None
|
||||
|
||||
if pipeline_name in [
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
]:
|
||||
from ..models.controlnet import ControlNetModel
|
||||
from ..pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
|
||||
# list/tuple or a single instance of ControlNetModel or MultiControlNetModel
|
||||
if not (
|
||||
isinstance(controlnet, (ControlNetModel, MultiControlNetModel))
|
||||
or isinstance(controlnet, (list, tuple))
|
||||
and isinstance(controlnet[0], ControlNetModel)
|
||||
):
|
||||
raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.")
|
||||
elif "StableDiffusion" in pipeline_name:
|
||||
# Model type will be inferred from the checkpoint.
|
||||
pass
|
||||
elif pipeline_name == "StableUnCLIPPipeline":
|
||||
model_type = "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "txt2img"
|
||||
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
|
||||
model_type = "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "img2img"
|
||||
elif pipeline_name == "PaintByExamplePipeline":
|
||||
model_type = "PaintByExample"
|
||||
elif pipeline_name == "LDMTextToImagePipeline":
|
||||
model_type = "LDMTextToImage"
|
||||
else:
|
||||
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
||||
|
||||
# remove huggingface url
|
||||
has_valid_url_prefix = False
|
||||
valid_url_prefixes = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
|
||||
for prefix in valid_url_prefixes:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
has_valid_url_prefix = True
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
if not has_valid_url_prefix:
|
||||
raise ValueError(
|
||||
f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(valid_url_prefixes)}"
|
||||
)
|
||||
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
pipeline_class=cls,
|
||||
model_type=model_type,
|
||||
stable_unclip=stable_unclip,
|
||||
controlnet=controlnet,
|
||||
adapter=adapter,
|
||||
from_safetensors=from_safetensors,
|
||||
extract_ema=extract_ema,
|
||||
image_size=image_size,
|
||||
scheduler_type=scheduler_type,
|
||||
num_in_channels=num_in_channels,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=load_safety_checker,
|
||||
prediction_type=prediction_type,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
vae=vae,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
config_files=config_files,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
from ..pipelines.pipeline_utils import _get_pipeline_class
|
||||
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config=None,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
model_type = kwargs.pop("model_type", None)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
load_safety_checker = (kwargs.pop("load_safety_checker", False)) or (
|
||||
passed_class_obj.get("safety_checker", None) is not None
|
||||
)
|
||||
|
||||
init_kwargs = {}
|
||||
for name in expected_modules:
|
||||
if name in passed_class_obj:
|
||||
init_kwargs[name] = passed_class_obj[name]
|
||||
else:
|
||||
components = build_sub_model_components(
|
||||
init_kwargs,
|
||||
class_name,
|
||||
name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
model_type=model_type,
|
||||
image_size=image_size,
|
||||
load_safety_checker=load_safety_checker,
|
||||
local_files_only=local_files_only,
|
||||
**kwargs,
|
||||
)
|
||||
if not components:
|
||||
continue
|
||||
init_kwargs.update(components)
|
||||
|
||||
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
|
||||
if additional_components:
|
||||
init_kwargs.update(additional_components)
|
||||
|
||||
init_kwargs.update(passed_pipe_kwargs)
|
||||
pipe = pipeline_class(**init_kwargs)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(dtype=torch_dtype)
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
class FromOriginalVAEMixin:
|
||||
"""
|
||||
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into an [`AutoencoderKL`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
|
||||
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
|
||||
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
|
||||
a VAE from SDXL or a Stable Diffusion v2 model or higher.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
|
||||
model = AutoencoderKL.from_single_file(url)
|
||||
```
|
||||
"""
|
||||
from ..models import AutoencoderKL
|
||||
|
||||
# import here to avoid circular dependency
|
||||
from ..pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
convert_ldm_vae_checkpoint,
|
||||
create_vae_diffusers_config,
|
||||
)
|
||||
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
scaling_factor = kwargs.pop("scaling_factor", None)
|
||||
kwargs.pop("upcast_attention", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
if from_safetensors:
|
||||
from safetensors import safe_open
|
||||
|
||||
checkpoint = {}
|
||||
with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
checkpoint[key] = f.get_tensor(key)
|
||||
else:
|
||||
checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
if config_file is None:
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
original_config = yaml.safe_load(config_file)
|
||||
|
||||
# default to sd-v1-5
|
||||
image_size = image_size or 512
|
||||
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
if scaling_factor is None:
|
||||
if (
|
||||
"model" in original_config
|
||||
and "params" in original_config["model"]
|
||||
and "scale_factor" in original_config["model"]["params"]
|
||||
):
|
||||
vae_scaling_factor = original_config["model"]["params"]["scale_factor"]
|
||||
else:
|
||||
vae_scaling_factor = 0.18215 # default SD scaling factor
|
||||
|
||||
vae_config["scaling_factor"] = vae_scaling_factor
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu")
|
||||
else:
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
if torch_dtype is not None:
|
||||
vae.to(dtype=torch_dtype)
|
||||
|
||||
return vae
|
||||
|
||||
|
||||
class FromOriginalControlnetMixin:
|
||||
"""
|
||||
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
|
||||
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
|
||||
model = ControlNetModel.from_single_file(url)
|
||||
|
||||
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
|
||||
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
|
||||
```
|
||||
"""
|
||||
# import here to avoid circular dependency
|
||||
from ..pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
|
||||
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
use_linear_projection = kwargs.pop("use_linear_projection", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
if config_file is None:
|
||||
config_url = "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml"
|
||||
config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
image_size = image_size or 512
|
||||
|
||||
controlnet = download_controlnet_from_original_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
original_config_file=config_file,
|
||||
image_size=image_size,
|
||||
extract_ema=extract_ema,
|
||||
num_in_channels=num_in_channels,
|
||||
upcast_attention=upcast_attention,
|
||||
from_safetensors=from_safetensors,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
controlnet.to(dtype=torch_dtype)
|
||||
|
||||
return controlnet
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -39,19 +39,19 @@ if is_torch_available():
|
||||
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
|
||||
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
|
||||
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
||||
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||
_import_structure["uvit_2d"] = ["UVit2DModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
_import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"]
|
||||
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
|
||||
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
||||
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
|
||||
_import_structure["vq_model"] = ["VQModel"]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
|
||||
_import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
|
||||
|
||||
|
||||
@@ -73,19 +73,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_condition import UNet3DConditionModel
|
||||
from .unet_kandinsky3 import Kandinsky3UNet
|
||||
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
||||
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
||||
from .uvit_2d import UVit2DModel
|
||||
from .unets import (
|
||||
Kandinsky3UNet,
|
||||
MotionAdapter,
|
||||
UNet1DModel,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
UNet3DConditionModel,
|
||||
UNetMotionModel,
|
||||
UNetSpatioTemporalConditionModel,
|
||||
UVit2DModel,
|
||||
)
|
||||
from .vq_model import VQModel
|
||||
|
||||
if is_flax_available():
|
||||
from .controlnet_flax import FlaxControlNetModel
|
||||
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .unets import FlaxUNet2DConditionModel
|
||||
from .vae_flax import FlaxAutoencoderKL
|
||||
|
||||
else:
|
||||
|
||||
@@ -157,7 +157,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
self.use_slicing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -181,7 +181,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -216,7 +216,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -448,7 +448,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
@@ -472,7 +472,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
|
||||
@@ -17,13 +17,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalVAEMixin
|
||||
from ...utils import is_torch_version
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
|
||||
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
@@ -162,7 +161,7 @@ class TemporalDecoder(nn.Module):
|
||||
return sample
|
||||
|
||||
|
||||
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
@@ -242,7 +241,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -266,7 +265,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -31,7 +31,7 @@ from ..attention_processor import (
|
||||
AttnProcessor,
|
||||
)
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unet_2d import UNet2DModel
|
||||
from ..unets.unet_2d import UNet2DModel
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
self.use_slicing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -211,7 +211,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -246,7 +246,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...utils import BaseOutput, is_torch_version
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import SpatialNorm
|
||||
from ..unet_2d_blocks import (
|
||||
from ..unets.unet_2d_blocks import (
|
||||
AutoencoderTinyBlock,
|
||||
UNetMidBlock2D,
|
||||
get_down_block,
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalControlnetMixin
|
||||
from ..loaders import FromOriginalControlNetMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -30,8 +30,14 @@ from .attention_processor import (
|
||||
)
|
||||
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D,
|
||||
UNetMidBlock2D,
|
||||
UNetMidBlock2DCrossAttn,
|
||||
get_down_block,
|
||||
)
|
||||
from .unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -102,7 +108,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
return embedding
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
@@ -509,7 +515,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
return controlnet
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -533,7 +539,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -568,7 +574,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -584,7 +590,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
@@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from .modeling_flax_utils import FlaxModelMixin
|
||||
from .unet_2d_blocks_flax import (
|
||||
from .unets.unet_2d_blocks_flax import (
|
||||
FlaxCrossAttnDownBlock2D,
|
||||
FlaxDownBlock2D,
|
||||
FlaxUNetMidBlock2DCrossAttn,
|
||||
@@ -329,14 +329,14 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
|
||||
conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
plain tuple.
|
||||
train (`bool`, *optional*, defaults to `False`):
|
||||
Use deterministic functions and disable dropout when not training.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
|
||||
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
channel_order = self.controlnet_conditioning_channel_order
|
||||
|
||||
@@ -120,7 +120,7 @@ class DualTransformer2DModel(nn.Module):
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
||||
|
||||
@@ -32,6 +32,7 @@ from .. import __version__
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
SAFETENSORS_FILE_EXTENSION,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
_add_variant,
|
||||
@@ -102,10 +103,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
try:
|
||||
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
else:
|
||||
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
||||
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
else:
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(checkpoint_file) as f:
|
||||
|
||||
@@ -167,7 +167,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -191,7 +191,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -226,7 +226,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
|
||||
@@ -286,7 +286,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
||||
above. This bias will be added to the cross-attention scores.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -149,7 +149,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -12,244 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
||||
from ..utils import deprecate
|
||||
from .unets.unet_1d import UNet1DModel, UNet1DOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet1DOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`UNet1DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
|
||||
The hidden states output from the last layer of the model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
class UNet1DOutput(UNet1DOutput):
|
||||
deprecation_message = "Importing `UNet1DOutput` from `diffusers.models.unet_1d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d import UNet1DOutput`, instead."
|
||||
deprecate("UNet1DOutput", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
|
||||
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
|
||||
extra_in_channels (`int`, *optional*, defaults to 0):
|
||||
Number of additional channels to be added to the input of the first down block. Useful for cases where the
|
||||
input data has more channels than what the model was initially designed for.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
|
||||
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip sin to cos for Fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
|
||||
Tuple of block output channels.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
|
||||
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
|
||||
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
|
||||
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
|
||||
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
|
||||
downsample_each_block (`int`, *optional*, defaults to `False`):
|
||||
Experimental feature for using a UNet without upsampling.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 65536,
|
||||
sample_rate: Optional[int] = None,
|
||||
in_channels: int = 2,
|
||||
out_channels: int = 2,
|
||||
extra_in_channels: int = 0,
|
||||
time_embedding_type: str = "fourier",
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
freq_shift: float = 0.0,
|
||||
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
mid_block_type: Tuple[str] = "UNetMidBlock1D",
|
||||
out_block_type: str = None,
|
||||
block_out_channels: Tuple[int] = (32, 32, 64),
|
||||
act_fn: str = None,
|
||||
norm_num_groups: int = 8,
|
||||
layers_per_block: int = 1,
|
||||
downsample_each_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.sample_size = sample_size
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
||||
)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(
|
||||
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
|
||||
)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
if use_timestep_embedding:
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=timestep_input_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
out_dim=block_out_channels[0],
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
self.out_block = None
|
||||
|
||||
# down
|
||||
output_channel = in_channels
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
|
||||
if i == 0:
|
||||
input_channel += extra_in_channels
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_downsample=not is_final_block or downsample_each_block,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = get_mid_block(
|
||||
mid_block_type,
|
||||
in_channels=block_out_channels[-1],
|
||||
mid_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
embed_dim=block_out_channels[0],
|
||||
num_layers=layers_per_block,
|
||||
add_downsample=downsample_each_block,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
if out_block_type is None:
|
||||
final_upsample_channels = out_channels
|
||||
else:
|
||||
final_upsample_channels = block_out_channels[0]
|
||||
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = (
|
||||
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
|
||||
)
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_upsample=not is_final_block,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.out_block = get_out_block(
|
||||
out_block_type=out_block_type,
|
||||
num_groups_out=num_groups_out,
|
||||
embed_dim=block_out_channels[0],
|
||||
out_channels=out_channels,
|
||||
act_fn=act_fn,
|
||||
fc_dim=block_out_channels[-1] // 4,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet1DOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet1DModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
timestep_embed = self.time_proj(timesteps)
|
||||
if self.config.use_timestep_embedding:
|
||||
timestep_embed = self.time_mlp(timestep_embed)
|
||||
else:
|
||||
timestep_embed = timestep_embed[..., None]
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
||||
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
|
||||
|
||||
# 2. down
|
||||
down_block_res_samples = ()
|
||||
for downsample_block in self.down_blocks:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 3. mid
|
||||
if self.mid_block:
|
||||
sample = self.mid_block(sample, timestep_embed)
|
||||
|
||||
# 4. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
res_samples = down_block_res_samples[-1:]
|
||||
down_block_res_samples = down_block_res_samples[:-1]
|
||||
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
|
||||
|
||||
# 5. post-process
|
||||
if self.out_block:
|
||||
sample = self.out_block(sample, timestep_embed)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet1DOutput(sample=sample)
|
||||
class UNet1DModel(UNet1DModel):
|
||||
deprecation_message = "Importing `UNet1DModel` from `diffusers.models.unet_1d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d import UNet1DModel`, instead."
|
||||
deprecate("UNet1DModel", "0.29", deprecation_message)
|
||||
|
||||
@@ -11,616 +11,112 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .activations import get_activation
|
||||
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
|
||||
|
||||
|
||||
class DownResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
conv_shortcut: bool = False,
|
||||
temb_channels: int = 32,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
non_linearity: Optional[str] = None,
|
||||
time_embedding_norm: str = "default",
|
||||
output_scale_factor: float = 1.0,
|
||||
add_downsample: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_downsample = add_downsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
output_states = ()
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.downsample is not None:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class UpResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
temb_channels: int = 32,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
non_linearity: Optional[str] = None,
|
||||
time_embedding_norm: str = "default",
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_upsample = add_upsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if res_hidden_states_tuple is not None:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.upsample is not None:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ValueFunctionMidBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
|
||||
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
|
||||
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
|
||||
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
|
||||
|
||||
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
x = self.res1(x, temb)
|
||||
x = self.down1(x)
|
||||
x = self.res2(x, temb)
|
||||
x = self.down2(x)
|
||||
return x
|
||||
|
||||
|
||||
class MidResTemporalBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
embed_dim: int,
|
||||
num_layers: int = 1,
|
||||
add_downsample: bool = False,
|
||||
add_upsample: bool = False,
|
||||
non_linearity: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.add_downsample = add_downsample
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
if self.upsample and self.downsample:
|
||||
raise ValueError("Block cannot downsample and upsample")
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
if self.downsample:
|
||||
self.downsample = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutConv1DBlock(nn.Module):
|
||||
def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
|
||||
super().__init__()
|
||||
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
|
||||
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
|
||||
self.final_conv1d_act = get_activation(act_fn)
|
||||
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.final_conv1d_1(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_gn(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_act(hidden_states)
|
||||
hidden_states = self.final_conv1d_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutValueFunctionBlock(nn.Module):
|
||||
def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
|
||||
super().__init__()
|
||||
self.final_block = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
|
||||
get_activation(act_fn),
|
||||
nn.Linear(fc_dim // 2, 1),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
|
||||
hidden_states = torch.cat((hidden_states, temb), dim=-1)
|
||||
for layer in self.final_block:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
_kernels = {
|
||||
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
||||
"lanczos3": [
|
||||
0.003689131001010537,
|
||||
0.015056144446134567,
|
||||
-0.03399861603975296,
|
||||
-0.066637322306633,
|
||||
0.13550527393817902,
|
||||
0.44638532400131226,
|
||||
0.44638532400131226,
|
||||
0.13550527393817902,
|
||||
-0.066637322306633,
|
||||
-0.03399861603975296,
|
||||
0.015056144446134567,
|
||||
0.003689131001010537,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel])
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv1d(hidden_states, weight, stride=2)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
|
||||
|
||||
class SelfAttention1d(nn.Module):
|
||||
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
|
||||
super().__init__()
|
||||
self.channels = in_channels
|
||||
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
|
||||
self.num_heads = n_head
|
||||
|
||||
self.query = nn.Linear(self.channels, self.channels)
|
||||
self.key = nn.Linear(self.channels, self.channels)
|
||||
self.value = nn.Linear(self.channels, self.channels)
|
||||
|
||||
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
|
||||
|
||||
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
||||
|
||||
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
||||
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
||||
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
||||
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
||||
return new_projection
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
batch, channel_dim, seq = hidden_states.shape
|
||||
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
query_states = self.transpose_for_scores(query_proj)
|
||||
key_states = self.transpose_for_scores(key_proj)
|
||||
value_states = self.transpose_for_scores(value_proj)
|
||||
|
||||
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
|
||||
|
||||
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
|
||||
attention_probs = torch.softmax(attention_scores, dim=-1)
|
||||
|
||||
# compute attention output
|
||||
hidden_states = torch.matmul(attention_probs, value_states)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
||||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
||||
hidden_states = hidden_states.view(new_hidden_states_shape)
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResConvBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
|
||||
super().__init__()
|
||||
self.is_last = is_last
|
||||
self.has_conv_skip = in_channels != out_channels
|
||||
|
||||
if self.has_conv_skip:
|
||||
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
|
||||
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
|
||||
self.gelu_1 = nn.GELU()
|
||||
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
|
||||
|
||||
if not self.is_last:
|
||||
self.group_norm_2 = nn.GroupNorm(1, out_channels)
|
||||
self.gelu_2 = nn.GELU()
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
|
||||
|
||||
hidden_states = self.conv_1(hidden_states)
|
||||
hidden_states = self.group_norm_1(hidden_states)
|
||||
hidden_states = self.gelu_1(hidden_states)
|
||||
hidden_states = self.conv_2(hidden_states)
|
||||
from ..utils import deprecate
|
||||
from .unets.unet_1d_blocks import (
|
||||
AttnDownBlock1D,
|
||||
AttnUpBlock1D,
|
||||
DownBlock1D,
|
||||
DownBlock1DNoSkip,
|
||||
DownResnetBlock1D,
|
||||
Downsample1d,
|
||||
MidResTemporalBlock1D,
|
||||
OutConv1DBlock,
|
||||
OutValueFunctionBlock,
|
||||
ResConvBlock,
|
||||
SelfAttention1d,
|
||||
UNetMidBlock1D,
|
||||
UpBlock1D,
|
||||
UpBlock1DNoSkip,
|
||||
UpResnetBlock1D,
|
||||
Upsample1d,
|
||||
ValueFunctionMidBlock1D,
|
||||
)
|
||||
|
||||
if not self.is_last:
|
||||
hidden_states = self.group_norm_2(hidden_states)
|
||||
hidden_states = self.gelu_2(hidden_states)
|
||||
|
||||
output = hidden_states + residual
|
||||
return output
|
||||
class DownResnetBlock1D(DownResnetBlock1D):
|
||||
deprecation_message = "Importing `DownResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownResnetBlock1D`, instead."
|
||||
deprecate("DownResnetBlock1D", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class UNetMidBlock1D(nn.Module):
|
||||
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
class UpResnetBlock1D(UpResnetBlock1D):
|
||||
deprecation_message = "Importing `UpResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpResnetBlock1D`, instead."
|
||||
deprecate("UpResnetBlock1D", "0.29", deprecation_message)
|
||||
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
|
||||
# there is always at least one resnet
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
class ValueFunctionMidBlock1D(ValueFunctionMidBlock1D):
|
||||
deprecation_message = "Importing `ValueFunctionMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ValueFunctionMidBlock1D`, instead."
|
||||
deprecate("ValueFunctionMidBlock1D", "0.29", deprecation_message)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
for attn, resnet in zip(self.attentions, self.resnets):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
class OutConv1DBlock(OutConv1DBlock):
|
||||
deprecation_message = "Importing `OutConv1DBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutConv1DBlock`, instead."
|
||||
deprecate("OutConv1DBlock", "0.29", deprecation_message)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
class OutValueFunctionBlock(OutValueFunctionBlock):
|
||||
deprecation_message = "Importing `OutValueFunctionBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutValueFunctionBlock`, instead."
|
||||
deprecate("OutValueFunctionBlock", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class AttnDownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
class Downsample1d(Downsample1d):
|
||||
deprecation_message = "Importing `Downsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Downsample1d`, instead."
|
||||
deprecate("Downsample1d", "0.29", deprecation_message)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
class Upsample1d(Upsample1d):
|
||||
deprecation_message = "Importing `Upsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Upsample1d`, instead."
|
||||
deprecate("Upsample1d", "0.29", deprecation_message)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class DownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class DownBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = torch.cat([hidden_states, temb], dim=1)
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class AttnUpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
class SelfAttention1d(SelfAttention1d):
|
||||
deprecation_message = "Importing `SelfAttention1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import SelfAttention1d`, instead."
|
||||
deprecate("SelfAttention1d", "0.29", deprecation_message)
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
class ResConvBlock(ResConvBlock):
|
||||
deprecation_message = "Importing `ResConvBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ResConvBlock`, instead."
|
||||
deprecate("ResConvBlock", "0.29", deprecation_message)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
class UNetMidBlock1D(UNetMidBlock1D):
|
||||
deprecation_message = "Importing `UNetMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UNetMidBlock1D`, instead."
|
||||
deprecate("UNetMidBlock1D", "0.29", deprecation_message)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
class AttnDownBlock1D(AttnDownBlock1D):
|
||||
deprecation_message = "Importing `AttnDownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnDownBlock1D`, instead."
|
||||
deprecate("AttnDownBlock1D", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class UpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
class DownBlock1D(DownBlock1D):
|
||||
deprecation_message = "Importing `DownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1D`, instead."
|
||||
deprecate("DownBlock1D", "0.29", deprecation_message)
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
class DownBlock1DNoSkip(DownBlock1DNoSkip):
|
||||
deprecation_message = "Importing `DownBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1DNoSkip`, instead."
|
||||
deprecate("DownBlock1DNoSkip", "0.29", deprecation_message)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
class AttnUpBlock1D(AttnUpBlock1D):
|
||||
deprecation_message = "Importing `AttnUpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnUpBlock1D`, instead."
|
||||
deprecate("AttnUpBlock1D", "0.29", deprecation_message)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
class UpBlock1D(UpBlock1D):
|
||||
deprecation_message = "Importing `UpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1D`, instead."
|
||||
deprecate("UpBlock1D", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class UpBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
class UpBlock1DNoSkip(UpBlock1DNoSkip):
|
||||
deprecation_message = "Importing `UpBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1DNoSkip`, instead."
|
||||
deprecate("UpBlock1DNoSkip", "0.29", deprecation_message)
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
|
||||
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
|
||||
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
|
||||
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
|
||||
class MidResTemporalBlock1D(MidResTemporalBlock1D):
|
||||
deprecation_message = "Importing `MidResTemporalBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import MidResTemporalBlock1D`, instead."
|
||||
deprecate("MidResTemporalBlock1D", "0.29", deprecation_message)
|
||||
|
||||
|
||||
def get_down_block(
|
||||
@@ -630,42 +126,38 @@ def get_down_block(
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
add_downsample: bool,
|
||||
) -> DownBlockType:
|
||||
if down_block_type == "DownResnetBlock1D":
|
||||
return DownResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif down_block_type == "DownBlock1D":
|
||||
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "AttnDownBlock1D":
|
||||
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "DownBlock1DNoSkip":
|
||||
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
):
|
||||
deprecation_message = "Importing `get_down_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_down_block`, instead."
|
||||
deprecate("get_down_block", "0.29", deprecation_message)
|
||||
|
||||
from .unets.unet_1d_blocks import get_down_block
|
||||
|
||||
return get_down_block(
|
||||
down_block_type=down_block_type,
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
|
||||
|
||||
def get_up_block(
|
||||
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
|
||||
) -> UpBlockType:
|
||||
if up_block_type == "UpResnetBlock1D":
|
||||
return UpResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
)
|
||||
elif up_block_type == "UpBlock1D":
|
||||
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "AttnUpBlock1D":
|
||||
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "UpBlock1DNoSkip":
|
||||
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
):
|
||||
deprecation_message = "Importing `get_up_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_up_block`, instead."
|
||||
deprecate("get_up_block", "0.29", deprecation_message)
|
||||
|
||||
from .unets.unet_1d_blocks import get_up_block
|
||||
|
||||
return get_up_block(
|
||||
up_block_type=up_block_type,
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
)
|
||||
|
||||
|
||||
def get_mid_block(
|
||||
@@ -676,27 +168,36 @@ def get_mid_block(
|
||||
out_channels: int,
|
||||
embed_dim: int,
|
||||
add_downsample: bool,
|
||||
) -> MidBlockType:
|
||||
if mid_block_type == "MidResTemporalBlock1D":
|
||||
return MidResTemporalBlock1D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
embed_dim=embed_dim,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif mid_block_type == "ValueFunctionMidBlock1D":
|
||||
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
|
||||
elif mid_block_type == "UNetMidBlock1D":
|
||||
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{mid_block_type} does not exist.")
|
||||
):
|
||||
deprecation_message = "Importing `get_mid_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_mid_block`, instead."
|
||||
deprecate("get_mid_block", "0.29", deprecation_message)
|
||||
|
||||
from .unets.unet_1d_blocks import get_mid_block
|
||||
|
||||
return get_mid_block(
|
||||
mid_block_type=mid_block_type,
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
mid_channels=mid_channels,
|
||||
out_channels=out_channels,
|
||||
embed_dim=embed_dim,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
|
||||
|
||||
def get_out_block(
|
||||
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
|
||||
) -> Optional[OutBlockType]:
|
||||
if out_block_type == "OutConv1DBlock":
|
||||
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
|
||||
elif out_block_type == "ValueFunction":
|
||||
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
|
||||
return None
|
||||
):
|
||||
deprecation_message = "Importing `get_out_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_out_block`, instead."
|
||||
deprecate("get_out_block", "0.29", deprecation_message)
|
||||
|
||||
from .unets.unet_1d_blocks import get_out_block
|
||||
|
||||
return get_out_block(
|
||||
out_block_type=out_block_type,
|
||||
num_groups_out=num_groups_out,
|
||||
embed_dim=embed_dim,
|
||||
out_channels=out_channels,
|
||||
act_fn=act_fn,
|
||||
fc_dim=fc_dim,
|
||||
)
|
||||
|
||||
@@ -11,336 +11,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`UNet2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
The hidden states output from the last layer of the model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
from ..utils import deprecate
|
||||
from .unets.unet_2d import UNet2DModel, UNet2DOutput
|
||||
|
||||
|
||||
class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
||||
class UNet2DOutput(UNet2DOutput):
|
||||
deprecation_message = "Importing `UNet2DOutput` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DOutput`, instead."
|
||||
deprecate("UNet2DOutput", "0.29", deprecation_message)
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
||||
1)`.
|
||||
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
|
||||
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
Whether to flip sin to cos for Fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
|
||||
Tuple of downsample block types.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
|
||||
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
|
||||
Tuple of block output channels.
|
||||
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
|
||||
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
|
||||
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
|
||||
downsample_type (`str`, *optional*, defaults to `conv`):
|
||||
The downsample type for downsampling layers. Choose between "conv" and "resnet"
|
||||
upsample_type (`str`, *optional*, defaults to `conv`):
|
||||
The upsample type for upsampling layers. Choose between "conv" and "resnet"
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
|
||||
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
|
||||
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
|
||||
given number of groups. If left as `None`, the group norm layer will only be created if
|
||||
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
|
||||
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, or `"identity"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to `None`):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
|
||||
conditioning with `class_embed_type` equal to `None`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
center_input_sample: bool = False,
|
||||
time_embedding_type: str = "positional",
|
||||
freq_shift: int = 0,
|
||||
flip_sin_to_cos: bool = True,
|
||||
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
||||
block_out_channels: Tuple[int] = (224, 448, 672, 896),
|
||||
layers_per_block: int = 2,
|
||||
mid_block_scale_factor: float = 1,
|
||||
downsample_padding: int = 1,
|
||||
downsample_type: str = "conv",
|
||||
upsample_type: str = "conv",
|
||||
dropout: float = 0.0,
|
||||
act_fn: str = "silu",
|
||||
attention_head_dim: Optional[int] = 8,
|
||||
norm_num_groups: int = 32,
|
||||
attn_norm_num_groups: Optional[int] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
add_attention: bool = True,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
num_train_timesteps: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
)
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
elif time_embedding_type == "learned":
|
||||
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
downsample_type=downsample_type,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_groups=attn_norm_num_groups,
|
||||
add_attention=add_attention,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
upsample_type=upsample_type,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when doing class conditioning")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
elif self.class_embedding is None and class_labels is not None:
|
||||
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
|
||||
|
||||
# 2. pre-process
|
||||
skip_sample = sample
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "skip_conv"):
|
||||
sample, res_samples, skip_sample = downsample_block(
|
||||
hidden_states=sample, temb=emb, skip_sample=skip_sample
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. up
|
||||
skip_sample = None
|
||||
for upsample_block in self.up_blocks:
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
if hasattr(upsample_block, "skip_conv"):
|
||||
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
|
||||
else:
|
||||
sample = upsample_block(sample, res_samples, emb)
|
||||
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if skip_sample is not None:
|
||||
sample += skip_sample
|
||||
|
||||
if self.config.time_embedding_type == "fourier":
|
||||
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
|
||||
sample = sample / timesteps
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet2DOutput(sample=sample)
|
||||
class UNet2DModel(UNet2DModel):
|
||||
deprecation_message = "Importing `UNet2DModel` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DModel`, instead."
|
||||
deprecate("UNet2DModel", "0.29", deprecation_message)
|
||||
|
||||
+170
-3461
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,16 @@
|
||||
from ...utils import is_flax_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_condition import UNet3DConditionModel
|
||||
from .unet_kandinsky3 import Kandinsky3UNet
|
||||
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
||||
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
||||
from .uvit_2d import UVit2DModel
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
@@ -0,0 +1,255 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet1DOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`UNet1DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
|
||||
The hidden states output from the last layer of the model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
|
||||
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
|
||||
extra_in_channels (`int`, *optional*, defaults to 0):
|
||||
Number of additional channels to be added to the input of the first down block. Useful for cases where the
|
||||
input data has more channels than what the model was initially designed for.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
|
||||
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip sin to cos for Fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
|
||||
Tuple of block output channels.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
|
||||
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
|
||||
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
|
||||
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
|
||||
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
|
||||
downsample_each_block (`int`, *optional*, defaults to `False`):
|
||||
Experimental feature for using a UNet without upsampling.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 65536,
|
||||
sample_rate: Optional[int] = None,
|
||||
in_channels: int = 2,
|
||||
out_channels: int = 2,
|
||||
extra_in_channels: int = 0,
|
||||
time_embedding_type: str = "fourier",
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
freq_shift: float = 0.0,
|
||||
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
mid_block_type: Tuple[str] = "UNetMidBlock1D",
|
||||
out_block_type: str = None,
|
||||
block_out_channels: Tuple[int] = (32, 32, 64),
|
||||
act_fn: str = None,
|
||||
norm_num_groups: int = 8,
|
||||
layers_per_block: int = 1,
|
||||
downsample_each_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.sample_size = sample_size
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
||||
)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(
|
||||
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
|
||||
)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
if use_timestep_embedding:
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=timestep_input_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
out_dim=block_out_channels[0],
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
self.out_block = None
|
||||
|
||||
# down
|
||||
output_channel = in_channels
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
|
||||
if i == 0:
|
||||
input_channel += extra_in_channels
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_downsample=not is_final_block or downsample_each_block,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = get_mid_block(
|
||||
mid_block_type,
|
||||
in_channels=block_out_channels[-1],
|
||||
mid_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
embed_dim=block_out_channels[0],
|
||||
num_layers=layers_per_block,
|
||||
add_downsample=downsample_each_block,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
if out_block_type is None:
|
||||
final_upsample_channels = out_channels
|
||||
else:
|
||||
final_upsample_channels = block_out_channels[0]
|
||||
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = (
|
||||
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
|
||||
)
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_upsample=not is_final_block,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.out_block = get_out_block(
|
||||
out_block_type=out_block_type,
|
||||
num_groups_out=num_groups_out,
|
||||
embed_dim=block_out_channels[0],
|
||||
out_channels=out_channels,
|
||||
act_fn=act_fn,
|
||||
fc_dim=block_out_channels[-1] // 4,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet1DOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet1DModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
timestep_embed = self.time_proj(timesteps)
|
||||
if self.config.use_timestep_embedding:
|
||||
timestep_embed = self.time_mlp(timestep_embed)
|
||||
else:
|
||||
timestep_embed = timestep_embed[..., None]
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
||||
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
|
||||
|
||||
# 2. down
|
||||
down_block_res_samples = ()
|
||||
for downsample_block in self.down_blocks:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 3. mid
|
||||
if self.mid_block:
|
||||
sample = self.mid_block(sample, timestep_embed)
|
||||
|
||||
# 4. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
res_samples = down_block_res_samples[-1:]
|
||||
down_block_res_samples = down_block_res_samples[:-1]
|
||||
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
|
||||
|
||||
# 5. post-process
|
||||
if self.out_block:
|
||||
sample = self.out_block(sample, timestep_embed)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet1DOutput(sample=sample)
|
||||
@@ -0,0 +1,702 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..activations import get_activation
|
||||
from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
|
||||
|
||||
|
||||
class DownResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
conv_shortcut: bool = False,
|
||||
temb_channels: int = 32,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
non_linearity: Optional[str] = None,
|
||||
time_embedding_norm: str = "default",
|
||||
output_scale_factor: float = 1.0,
|
||||
add_downsample: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_downsample = add_downsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
output_states = ()
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.downsample is not None:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class UpResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
temb_channels: int = 32,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
non_linearity: Optional[str] = None,
|
||||
time_embedding_norm: str = "default",
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_upsample = add_upsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if res_hidden_states_tuple is not None:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.upsample is not None:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ValueFunctionMidBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
|
||||
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
|
||||
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
|
||||
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
|
||||
|
||||
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
x = self.res1(x, temb)
|
||||
x = self.down1(x)
|
||||
x = self.res2(x, temb)
|
||||
x = self.down2(x)
|
||||
return x
|
||||
|
||||
|
||||
class MidResTemporalBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
embed_dim: int,
|
||||
num_layers: int = 1,
|
||||
add_downsample: bool = False,
|
||||
add_upsample: bool = False,
|
||||
non_linearity: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.add_downsample = add_downsample
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
if self.upsample and self.downsample:
|
||||
raise ValueError("Block cannot downsample and upsample")
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
if self.downsample:
|
||||
self.downsample = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutConv1DBlock(nn.Module):
|
||||
def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
|
||||
super().__init__()
|
||||
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
|
||||
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
|
||||
self.final_conv1d_act = get_activation(act_fn)
|
||||
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.final_conv1d_1(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_gn(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_act(hidden_states)
|
||||
hidden_states = self.final_conv1d_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutValueFunctionBlock(nn.Module):
|
||||
def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
|
||||
super().__init__()
|
||||
self.final_block = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
|
||||
get_activation(act_fn),
|
||||
nn.Linear(fc_dim // 2, 1),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
|
||||
hidden_states = torch.cat((hidden_states, temb), dim=-1)
|
||||
for layer in self.final_block:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
_kernels = {
|
||||
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
||||
"lanczos3": [
|
||||
0.003689131001010537,
|
||||
0.015056144446134567,
|
||||
-0.03399861603975296,
|
||||
-0.066637322306633,
|
||||
0.13550527393817902,
|
||||
0.44638532400131226,
|
||||
0.44638532400131226,
|
||||
0.13550527393817902,
|
||||
-0.066637322306633,
|
||||
-0.03399861603975296,
|
||||
0.015056144446134567,
|
||||
0.003689131001010537,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel])
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv1d(hidden_states, weight, stride=2)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
|
||||
|
||||
class SelfAttention1d(nn.Module):
|
||||
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
|
||||
super().__init__()
|
||||
self.channels = in_channels
|
||||
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
|
||||
self.num_heads = n_head
|
||||
|
||||
self.query = nn.Linear(self.channels, self.channels)
|
||||
self.key = nn.Linear(self.channels, self.channels)
|
||||
self.value = nn.Linear(self.channels, self.channels)
|
||||
|
||||
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
|
||||
|
||||
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
||||
|
||||
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
||||
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
||||
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
||||
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
||||
return new_projection
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
batch, channel_dim, seq = hidden_states.shape
|
||||
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
query_states = self.transpose_for_scores(query_proj)
|
||||
key_states = self.transpose_for_scores(key_proj)
|
||||
value_states = self.transpose_for_scores(value_proj)
|
||||
|
||||
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
|
||||
|
||||
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
|
||||
attention_probs = torch.softmax(attention_scores, dim=-1)
|
||||
|
||||
# compute attention output
|
||||
hidden_states = torch.matmul(attention_probs, value_states)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
||||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
||||
hidden_states = hidden_states.view(new_hidden_states_shape)
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResConvBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
|
||||
super().__init__()
|
||||
self.is_last = is_last
|
||||
self.has_conv_skip = in_channels != out_channels
|
||||
|
||||
if self.has_conv_skip:
|
||||
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
|
||||
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
|
||||
self.gelu_1 = nn.GELU()
|
||||
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
|
||||
|
||||
if not self.is_last:
|
||||
self.group_norm_2 = nn.GroupNorm(1, out_channels)
|
||||
self.gelu_2 = nn.GELU()
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
|
||||
|
||||
hidden_states = self.conv_1(hidden_states)
|
||||
hidden_states = self.group_norm_1(hidden_states)
|
||||
hidden_states = self.gelu_1(hidden_states)
|
||||
hidden_states = self.conv_2(hidden_states)
|
||||
|
||||
if not self.is_last:
|
||||
hidden_states = self.group_norm_2(hidden_states)
|
||||
hidden_states = self.gelu_2(hidden_states)
|
||||
|
||||
output = hidden_states + residual
|
||||
return output
|
||||
|
||||
|
||||
class UNetMidBlock1D(nn.Module):
|
||||
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
|
||||
# there is always at least one resnet
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
for attn, resnet in zip(self.attentions, self.resnets):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnDownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class DownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class DownBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = torch.cat([hidden_states, temb], dim=1)
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class AttnUpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
|
||||
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
|
||||
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
|
||||
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
|
||||
|
||||
|
||||
def get_down_block(
|
||||
down_block_type: str,
|
||||
num_layers: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
add_downsample: bool,
|
||||
) -> DownBlockType:
|
||||
if down_block_type == "DownResnetBlock1D":
|
||||
return DownResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif down_block_type == "DownBlock1D":
|
||||
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "AttnDownBlock1D":
|
||||
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "DownBlock1DNoSkip":
|
||||
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(
|
||||
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
|
||||
) -> UpBlockType:
|
||||
if up_block_type == "UpResnetBlock1D":
|
||||
return UpResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
)
|
||||
elif up_block_type == "UpBlock1D":
|
||||
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "AttnUpBlock1D":
|
||||
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "UpBlock1DNoSkip":
|
||||
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_mid_block(
|
||||
mid_block_type: str,
|
||||
num_layers: int,
|
||||
in_channels: int,
|
||||
mid_channels: int,
|
||||
out_channels: int,
|
||||
embed_dim: int,
|
||||
add_downsample: bool,
|
||||
) -> MidBlockType:
|
||||
if mid_block_type == "MidResTemporalBlock1D":
|
||||
return MidResTemporalBlock1D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
embed_dim=embed_dim,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif mid_block_type == "ValueFunctionMidBlock1D":
|
||||
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
|
||||
elif mid_block_type == "UNetMidBlock1D":
|
||||
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{mid_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_out_block(
|
||||
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
|
||||
) -> Optional[OutBlockType]:
|
||||
if out_block_type == "OutConv1DBlock":
|
||||
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
|
||||
elif out_block_type == "ValueFunction":
|
||||
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
|
||||
return None
|
||||
@@ -0,0 +1,346 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`UNet2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
The hidden states output from the last layer of the model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
||||
1)`.
|
||||
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
|
||||
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
Whether to flip sin to cos for Fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
|
||||
Tuple of downsample block types.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
|
||||
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
|
||||
Tuple of block output channels.
|
||||
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
|
||||
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
|
||||
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
|
||||
downsample_type (`str`, *optional*, defaults to `conv`):
|
||||
The downsample type for downsampling layers. Choose between "conv" and "resnet"
|
||||
upsample_type (`str`, *optional*, defaults to `conv`):
|
||||
The upsample type for upsampling layers. Choose between "conv" and "resnet"
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
|
||||
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
|
||||
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
|
||||
given number of groups. If left as `None`, the group norm layer will only be created if
|
||||
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
|
||||
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, or `"identity"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to `None`):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
|
||||
conditioning with `class_embed_type` equal to `None`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
center_input_sample: bool = False,
|
||||
time_embedding_type: str = "positional",
|
||||
freq_shift: int = 0,
|
||||
flip_sin_to_cos: bool = True,
|
||||
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
||||
block_out_channels: Tuple[int] = (224, 448, 672, 896),
|
||||
layers_per_block: int = 2,
|
||||
mid_block_scale_factor: float = 1,
|
||||
downsample_padding: int = 1,
|
||||
downsample_type: str = "conv",
|
||||
upsample_type: str = "conv",
|
||||
dropout: float = 0.0,
|
||||
act_fn: str = "silu",
|
||||
attention_head_dim: Optional[int] = 8,
|
||||
norm_num_groups: int = 32,
|
||||
attn_norm_num_groups: Optional[int] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
add_attention: bool = True,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
num_train_timesteps: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
)
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
elif time_embedding_type == "learned":
|
||||
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
downsample_type=downsample_type,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_groups=attn_norm_num_groups,
|
||||
add_attention=add_attention,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
upsample_type=upsample_type,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when doing class conditioning")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
elif self.class_embedding is None and class_labels is not None:
|
||||
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
|
||||
|
||||
# 2. pre-process
|
||||
skip_sample = sample
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "skip_conv"):
|
||||
sample, res_samples, skip_sample = downsample_block(
|
||||
hidden_states=sample, temb=emb, skip_sample=skip_sample
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. up
|
||||
skip_sample = None
|
||||
for upsample_block in self.up_blocks:
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
if hasattr(upsample_block, "skip_conv"):
|
||||
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
|
||||
else:
|
||||
sample = upsample_block(sample, res_samples, emb)
|
||||
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if skip_sample is not None:
|
||||
sample += skip_sample
|
||||
|
||||
if self.config.time_embedding_type == "fourier":
|
||||
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
|
||||
sample = sample / timesteps
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet2DOutput(sample=sample)
|
||||
File diff suppressed because it is too large
Load Diff
+2
-2
@@ -15,8 +15,8 @@
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
from .attention_flax import FlaxTransformer2DModel
|
||||
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
||||
from ..attention_flax import FlaxTransformer2DModel
|
||||
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
||||
|
||||
|
||||
class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
File diff suppressed because it is too large
Load Diff
+7
-7
@@ -19,10 +19,10 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from .modeling_flax_utils import FlaxModelMixin
|
||||
from ...configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from ..modeling_flax_utils import FlaxModelMixin
|
||||
from .unet_2d_blocks_flax import (
|
||||
FlaxCrossAttnDownBlock2D,
|
||||
FlaxCrossAttnUpBlock2D,
|
||||
@@ -342,14 +342,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
||||
A tensor that if specified is added to the residual of the middle unet block.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
plain tuple.
|
||||
train (`bool`, *optional*, defaults to `False`):
|
||||
Use deterministic functions and disable dropout when not training.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# 1. time
|
||||
+7
-7
@@ -17,19 +17,19 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..utils import is_torch_version
|
||||
from ..utils.torch_utils import apply_freeu
|
||||
from .attention import Attention
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .resnet import (
|
||||
from ...utils import is_torch_version
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention import Attention
|
||||
from ..dual_transformer_2d import DualTransformer2DModel
|
||||
from ..resnet import (
|
||||
Downsample2D,
|
||||
ResnetBlock2D,
|
||||
SpatioTemporalResBlock,
|
||||
TemporalConvLayer,
|
||||
Upsample2D,
|
||||
)
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import (
|
||||
from ..transformer_2d import Transformer2DModel
|
||||
from ..transformer_temporal import (
|
||||
TransformerSpatioTemporalModel,
|
||||
TransformerTemporalModel,
|
||||
)
|
||||
+15
-15
@@ -20,20 +20,20 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, deprecate, logging
|
||||
from .activations import get_activation
|
||||
from .attention_processor import (
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput, deprecate, logging
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformer_temporal import TransformerTemporalModel
|
||||
from .unet_3d_blocks import (
|
||||
CrossAttnDownBlock3D,
|
||||
CrossAttnUpBlock3D,
|
||||
@@ -284,7 +284,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -308,7 +308,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
@@ -374,7 +374,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -449,7 +449,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, None, 0)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -469,7 +469,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
@@ -494,7 +494,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
setattr(upsample_block, "b1", b1)
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
@@ -503,7 +503,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unload_lora
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora
|
||||
def unload_lora(self):
|
||||
"""Unloads LoRA weights."""
|
||||
deprecate(
|
||||
+5
-5
@@ -19,11 +19,11 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import Attention, AttentionProcessor, AttnProcessor
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
+14
-14
@@ -17,19 +17,19 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import logging
|
||||
from .attention_processor import (
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import UNet2DConditionLoadersMixin
|
||||
from ...utils import logging
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformer_temporal import TransformerTemporalModel
|
||||
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_blocks import (
|
||||
@@ -524,7 +524,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -548,7 +548,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -583,7 +583,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
@@ -613,7 +613,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
||||
def disable_forward_chunking(self) -> None:
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
@@ -625,7 +625,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, None, 0)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self) -> None:
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -645,7 +645,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
@@ -670,7 +670,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
setattr(upsample_block, "b1", b1)
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
def disable_freeu(self) -> None:
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
+7
-7
@@ -4,12 +4,12 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
|
||||
|
||||
|
||||
@@ -323,7 +323,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
@@ -20,20 +20,20 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import PeftAdapterMixin
|
||||
from .attention import BasicTransformerBlock, SkipFFTransformerBlock
|
||||
from .attention_processor import (
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ..attention import BasicTransformerBlock, SkipFFTransformerBlock
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, get_timestep_embedding
|
||||
from .modeling_utils import ModelMixin
|
||||
from .normalization import GlobalResponseNorm, RMSNorm
|
||||
from .resnet import Downsample2D, Upsample2D
|
||||
from ..embeddings import TimestepEmbedding, get_timestep_embedding
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import GlobalResponseNorm, RMSNorm
|
||||
from ..resnet import Downsample2D, Upsample2D
|
||||
|
||||
|
||||
class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
@@ -213,7 +213,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
return logits
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -237,7 +237,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -272,7 +272,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -26,7 +26,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unet_motion_model import MotionAdapter
|
||||
from ...models.unets.unet_motion_model import MotionAdapter
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
@@ -67,10 +67,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -79,6 +76,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -805,11 +811,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
@@ -36,8 +36,8 @@ from ...models.embeddings import (
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from ...models.transformer_2d import Transformer2DModel
|
||||
from ...models.unet_2d_blocks import DownBlock2D, UpBlock2D
|
||||
from ...models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
|
||||
from ...models.unets.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...utils import BaseOutput, is_torch_version, logging
|
||||
|
||||
|
||||
@@ -513,7 +513,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -537,7 +537,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -572,7 +572,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -588,7 +588,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
@@ -654,7 +654,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
@@ -687,7 +687,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||
@@ -700,8 +700,8 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||
a `tuple` is returned where the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
|
||||
@@ -33,7 +33,7 @@ from ....models.embeddings import (
|
||||
)
|
||||
from ....models.resnet import ResnetBlockCondNorm2D
|
||||
from ....models.transformer_2d import Transformer2DModel
|
||||
from ....models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ....models.unets.unet_2d_condition import UNet2DConditionOutput
|
||||
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ....utils.torch_utils import apply_freeu
|
||||
|
||||
@@ -268,6 +268,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
|
||||
return objs
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
|
||||
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
||||
@@ -1095,7 +1096,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||
@@ -1111,8 +1112,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||
a `tuple` is returned where the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
@@ -1785,7 +1786,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
|
||||
class UpBlockFlat(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1896,7 +1897,7 @@ class UpBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
|
||||
class CrossAttnUpBlockFlat(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2070,7 +2071,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlat(nn.Module):
|
||||
"""
|
||||
A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks.
|
||||
@@ -2226,7 +2227,7 @@ class UNetMidBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2373,7 +2374,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -351,7 +351,7 @@ def get_class_obj_and_candidates(
|
||||
|
||||
def _get_pipeline_class(
|
||||
class_obj,
|
||||
config,
|
||||
config=None,
|
||||
load_connected_pipeline=False,
|
||||
custom_pipeline=None,
|
||||
repo_id=None,
|
||||
@@ -389,7 +389,12 @@ def _get_pipeline_class(
|
||||
return class_obj
|
||||
|
||||
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
|
||||
class_name = config["_class_name"]
|
||||
class_name = class_name or config["_class_name"]
|
||||
if not class_name:
|
||||
raise ValueError(
|
||||
"The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`."
|
||||
)
|
||||
|
||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||
|
||||
pipeline_cls = getattr(diffusers_module, class_name)
|
||||
|
||||
@@ -40,10 +40,8 @@ def _append_dims(x, target_dims):
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -53,7 +51,13 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
return np.stack(outputs)
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet3DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -58,22 +59,26 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
||||
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
# reshape to ncfhw
|
||||
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
# unnormalize back to [0,1]
|
||||
video = video.mul_(std).add_(mean)
|
||||
video.clamp_(0, 1)
|
||||
# prepare the final outputs
|
||||
i, c, f, h, w = video.shape
|
||||
images = video.permute(2, 3, 0, 4, 1).reshape(
|
||||
f, h, i * w, c
|
||||
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
|
||||
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
|
||||
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
|
||||
return images
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
@@ -122,6 +127,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
@@ -717,11 +723,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
return TextToVideoSDPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
+26
-21
@@ -20,6 +20,7 @@ import PIL.Image
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet3DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -93,22 +94,26 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
||||
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
# reshape to ncfhw
|
||||
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
# unnormalize back to [0,1]
|
||||
video = video.mul_(std).add_(mean)
|
||||
video.clamp_(0, 1)
|
||||
# prepare the final outputs
|
||||
i, c, f, h, w = video.shape
|
||||
images = video.permute(2, 3, 0, 4, 1).reshape(
|
||||
f, h, i * w, c
|
||||
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
|
||||
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
|
||||
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
|
||||
return images
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def preprocess_video(video):
|
||||
@@ -198,6 +203,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
@@ -812,12 +818,11 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
if output_type == "latent":
|
||||
return TextToVideoSDPipelineOutput(frames=latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor)
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
@@ -752,7 +752,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
cross_attention_kwargs (*optional*):
|
||||
Keyword arguments to supply to the cross attention layers, if used.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
hidden_states_is_embedding (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will
|
||||
ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the
|
||||
|
||||
@@ -66,7 +66,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
self.set_default_attn_processor()
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -90,7 +90,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -125,7 +125,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
|
||||
@@ -61,6 +61,7 @@ else:
|
||||
_import_structure["scheduling_lcm"] = ["LCMScheduler"]
|
||||
_import_structure["scheduling_pndm"] = ["PNDMScheduler"]
|
||||
_import_structure["scheduling_repaint"] = ["RePaintScheduler"]
|
||||
_import_structure["scheduling_sasolver"] = ["SASolverScheduler"]
|
||||
_import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
|
||||
_import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
|
||||
_import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
|
||||
@@ -152,6 +153,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .scheduling_lcm import LCMScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_repaint import RePaintScheduler
|
||||
from .scheduling_sasolver import SASolverScheduler
|
||||
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
||||
from .scheduling_unclip import UnCLIPScheduler
|
||||
from .scheduling_unipc_multistep import UniPCMultistepScheduler
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -28,6 +28,7 @@ from .constants import (
|
||||
MIN_PEFT_VERSION,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
SAFETENSORS_FILE_EXTENSION,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
USE_PEFT_BACKEND,
|
||||
WEIGHTS_NAME,
|
||||
|
||||
@@ -31,6 +31,7 @@ WEIGHTS_NAME = "diffusion_pytorch_model.bin"
|
||||
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
||||
SAFETENSORS_FILE_EXTENSION = "safetensors"
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
|
||||
@@ -990,6 +990,21 @@ class RePaintScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SASolverScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SchedulerMixin(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -244,15 +244,15 @@ def _get_model_file(
|
||||
pretrained_model_name_or_path: Union[str, Path],
|
||||
*,
|
||||
weights_name: str,
|
||||
subfolder: Optional[str],
|
||||
cache_dir: Optional[str],
|
||||
force_download: bool,
|
||||
proxies: Optional[Dict],
|
||||
resume_download: bool,
|
||||
local_files_only: bool,
|
||||
token: Optional[str],
|
||||
user_agent: Union[Dict, str, None],
|
||||
revision: Optional[str],
|
||||
subfolder: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
force_download: bool = False,
|
||||
proxies: Optional[Dict] = None,
|
||||
resume_download: bool = False,
|
||||
local_files_only: bool = False,
|
||||
token: Optional[str] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
commit_hash: Optional[str] = None,
|
||||
):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
@@ -1662,6 +1663,11 @@ class UNet3DConditionLoRAModelTests(unittest.TestCase):
|
||||
@deprecate_after_peft_backend
|
||||
@require_torch_gpu
|
||||
class LoraIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_dreambooth_old_format(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import gc
|
||||
import importlib
|
||||
import os
|
||||
import tempfile
|
||||
@@ -1205,6 +1206,11 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
"latent_channels": 4,
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_integration_move_lora_cpu(self):
|
||||
@@ -1434,6 +1440,11 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
"sample_size": 128,
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@@ -1468,11 +1479,9 @@ class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
import gc
|
||||
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def test_dreambooth_old_format(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
@@ -1757,11 +1766,9 @@ class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
import gc
|
||||
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def test_sdxl_0_9_lora_one(self):
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
from diffusers.models.unet_2d_blocks import * # noqa F403
|
||||
from diffusers.models.unets.unet_2d_blocks import * # noqa F403
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
from .test_unet_blocks_common import UNetBlockTesterMixin
|
||||
|
||||
@@ -262,7 +262,7 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
|
||||
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
|
||||
self.assertGreater(
|
||||
sum_enabled, 1e2, "Enabling of FreeInit should lead to results different from the default pipeline results"
|
||||
sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
|
||||
)
|
||||
self.assertLess(
|
||||
max_diff_disabled,
|
||||
|
||||
@@ -37,6 +37,7 @@ from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_image,
|
||||
load_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_python39_or_higher,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
@@ -1022,39 +1023,49 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
def test_load_local(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe_1 = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
controlnet = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
pipe_2 = StableDiffusionControlNetPipeline.from_single_file(
|
||||
pipe_sf = StableDiffusionControlNetPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
scheduler_type="pndm",
|
||||
)
|
||||
pipes = [pipe_1, pipe_2]
|
||||
images = []
|
||||
pipe_sf.unet.set_default_attn_processor()
|
||||
pipe_sf.enable_model_cpu_offload()
|
||||
|
||||
for pipe in pipes:
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
prompt = "bird"
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipe(
|
||||
prompt,
|
||||
image=control_image,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
).images[0]
|
||||
|
||||
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
|
||||
images.append(output.images[0])
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output_sf = pipe_sf(
|
||||
prompt,
|
||||
image=control_image,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
).images[0]
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).max() < 1e-3
|
||||
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -39,6 +39,7 @@ from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -421,46 +422,53 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
def test_load_local(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe_1 = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
||||
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
controlnet = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
pipe_2 = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
|
||||
pipe_sf = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
scheduler_type="pndm",
|
||||
)
|
||||
pipe_sf.unet.set_default_attn_processor()
|
||||
pipe_sf.enable_model_cpu_offload()
|
||||
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
image = load_image(
|
||||
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
|
||||
).resize((512, 512))
|
||||
prompt = "bird"
|
||||
|
||||
pipes = [pipe_1, pipe_2]
|
||||
images = []
|
||||
for pipe in pipes:
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipe(
|
||||
prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
strength=0.9,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
).images[0]
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
output = pipe(
|
||||
prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
strength=0.9,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
)
|
||||
images.append(output.images[0])
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output_sf = pipe_sf(
|
||||
prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
strength=0.9,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
).images[0]
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).max() < 1e-3
|
||||
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@@ -569,6 +569,7 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
scheduler_type="pndm",
|
||||
)
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
@@ -605,4 +606,5 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).max() < 1e-3
|
||||
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), images[1].flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@@ -28,10 +28,17 @@ from diffusers import (
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.models.unet_2d_blocks import UNetMidBlock2D
|
||||
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D
|
||||
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_image,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..pipeline_params import (
|
||||
@@ -819,6 +826,41 @@ class ControlNetSDXLPipelineSlowTests(unittest.TestCase):
|
||||
expected_image = np.array([0.4399, 0.5112, 0.5478, 0.4314, 0.472, 0.4823, 0.4647, 0.4957, 0.4853])
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16)
|
||||
single_file_url = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
)
|
||||
pipe_single_file = StableDiffusionXLControlNetPipeline.from_single_file(
|
||||
single_file_url, controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe_single_file.unet.set_default_attn_processor()
|
||||
pipe_single_file.enable_model_cpu_offload()
|
||||
pipe_single_file.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "Stormtrooper's lecture"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
|
||||
)
|
||||
single_file_images = pipe_single_file(
|
||||
prompt, image=image, generator=generator, output_type="np", num_inference_steps=2
|
||||
).images
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=2).images
|
||||
|
||||
assert images[0].shape == (512, 512, 3)
|
||||
assert single_file_images[0].shape == (512, 512, 3)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), single_file_images[0].flatten())
|
||||
assert max_diff < 5e-2
|
||||
|
||||
|
||||
class StableDiffusionSSD1BControlNetPipelineFastTests(StableDiffusionXLControlNetPipelineFastTests):
|
||||
def test_controlnet_sdxl_guess(self):
|
||||
|
||||
@@ -1,283 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, CycleDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = CycleDiffusionPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {
|
||||
"negative_prompt",
|
||||
"height",
|
||||
"width",
|
||||
"negative_prompt_embeds",
|
||||
}
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"})
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "An astronaut riding an elephant",
|
||||
"source_prompt": "An astronaut riding a horse",
|
||||
"image": image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"eta": 0.1,
|
||||
"strength": 0.8,
|
||||
"guidance_scale": 3,
|
||||
"source_guidance_scale": 1,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_cycle(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = CycleDiffusionPipeline(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = pipe(**inputs)
|
||||
images = output.images
|
||||
|
||||
image_slice = images[0, -3:, -3:, -1]
|
||||
|
||||
assert images.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4459, 0.4943, 0.4544, 0.6643, 0.5474, 0.4327, 0.5701, 0.5959, 0.5179])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||
def test_stable_diffusion_cycle_fp16(self):
|
||||
components = self.get_dummy_components()
|
||||
for name, module in components.items():
|
||||
if hasattr(module, "half"):
|
||||
components[name] = module.half()
|
||||
pipe = CycleDiffusionPipeline(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)
|
||||
images = output.images
|
||||
|
||||
image_slice = images[0, -3:, -3:, -1]
|
||||
|
||||
assert images.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.3506, 0.4543, 0.446, 0.4575, 0.5195, 0.4155, 0.5273, 0.518, 0.4116])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@skip_mps
|
||||
def test_save_load_local(self):
|
||||
return super().test_save_load_local()
|
||||
|
||||
@unittest.skip("non-deterministic pipeline")
|
||||
def test_inference_batch_single_identical(self):
|
||||
return super().test_inference_batch_single_identical()
|
||||
|
||||
@skip_mps
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
return super().test_dict_tuple_outputs_equivalent()
|
||||
|
||||
@skip_mps
|
||||
def test_save_load_optional_components(self):
|
||||
return super().test_save_load_optional_components()
|
||||
|
||||
@skip_mps
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return super().test_attention_slicing_forward_pass()
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_cycle_diffusion_pipeline_fp16(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/cycle-diffusion/black_colored_car.png"
|
||||
)
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/cycle-diffusion/blue_colored_car_fp16.npy"
|
||||
)
|
||||
init_image = init_image.resize((512, 512))
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(
|
||||
model_id, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, revision="fp16"
|
||||
)
|
||||
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
source_prompt = "A black colored car"
|
||||
prompt = "A blue colored car"
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
source_prompt=source_prompt,
|
||||
image=init_image,
|
||||
num_inference_steps=100,
|
||||
eta=0.1,
|
||||
strength=0.85,
|
||||
guidance_scale=3,
|
||||
source_guidance_scale=1,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images
|
||||
|
||||
# the values aren't exactly equal, but the images look the same visually
|
||||
assert np.abs(image - expected_image).max() < 5e-1
|
||||
|
||||
def test_cycle_diffusion_pipeline(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/cycle-diffusion/black_colored_car.png"
|
||||
)
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/cycle-diffusion/blue_colored_car.npy"
|
||||
)
|
||||
init_image = init_image.resize((512, 512))
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None)
|
||||
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
source_prompt = "A black colored car"
|
||||
prompt = "A blue colored car"
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
source_prompt=source_prompt,
|
||||
image=init_image,
|
||||
num_inference_steps=100,
|
||||
eta=0.1,
|
||||
strength=0.85,
|
||||
guidance_scale=3,
|
||||
source_guidance_scale=1,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images
|
||||
|
||||
assert np.abs(image - expected_image).max() < 2e-2
|
||||
@@ -836,7 +836,10 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
def test_stable_diffusion_dpm(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
|
||||
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
sd_pipe.scheduler.config,
|
||||
final_sigmas_type="sigma_min",
|
||||
)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -1243,9 +1246,12 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase):
|
||||
assert image_out.shape == (512, 512, 3)
|
||||
|
||||
def test_download_local(self):
|
||||
filename = hf_hub_download("runwayml/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.ckpt")
|
||||
ckpt_filename = hf_hub_download("runwayml/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.ckpt")
|
||||
config_filename = hf_hub_download("runwayml/stable-diffusion-v1-5", filename="v1-inference.yaml")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_single_file(filename, torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionPipeline.from_single_file(
|
||||
ckpt_filename, config_files={"v1": config_filename}, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
|
||||
@@ -1256,13 +1262,13 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase):
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt"
|
||||
|
||||
pipe = StableDiffusionPipeline.from_single_file(ckpt_path)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_attn_processor(AttnProcessor())
|
||||
pipe.to("cuda")
|
||||
sf_pipe = StableDiffusionPipeline.from_single_file(ckpt_path)
|
||||
sf_pipe.scheduler = DDIMScheduler.from_config(sf_pipe.scheduler.config)
|
||||
sf_pipe.unet.set_attn_processor(AttnProcessor())
|
||||
sf_pipe.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_ckpt = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
|
||||
image_single_file = sf_pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
@@ -1272,7 +1278,7 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase):
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten())
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ from diffusers.utils.testing_utils import (
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_python39_or_higher,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
@@ -771,7 +772,9 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||
inputs["num_inference_steps"] = 5
|
||||
image = pipe(**inputs).images[0]
|
||||
|
||||
assert np.max(np.abs(image - image_ckpt)) < 5e-4
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten())
|
||||
|
||||
assert max_diff < 1e-4
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -1,630 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
VQModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
preprocess_image,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def dummy_image(self):
|
||||
batch_size = 1
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
|
||||
return image
|
||||
|
||||
@property
|
||||
def dummy_uncond_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_cond_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_cond_unet_inpaint(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=9,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vq_model(self):
|
||||
torch.manual_seed(0)
|
||||
model = VQModel(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=3,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vae(self):
|
||||
torch.manual_seed(0)
|
||||
model = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModel(config)
|
||||
|
||||
@property
|
||||
def dummy_extractor(self):
|
||||
def extract(*args, **kwargs):
|
||||
class Out:
|
||||
def __init__(self):
|
||||
self.pixel_values = torch.ones([0])
|
||||
|
||||
def to(self, device):
|
||||
self.pixel_values.to(device)
|
||||
return self
|
||||
|
||||
return Out()
|
||||
|
||||
return extract
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
)
|
||||
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4941, 0.5396, 0.4689, 0.6338, 0.5392, 0.4094, 0.5477, 0.5904, 0.5165])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_batched(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
init_images_tens = preprocess_image(init_image, batch_size=2)
|
||||
init_masks_tens = init_images_tens + 4
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
images = sd_pipe(
|
||||
[prompt] * 2,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_images_tens,
|
||||
mask_image=init_masks_tens,
|
||||
).images
|
||||
|
||||
assert images.shape == (2, 32, 32, 3)
|
||||
|
||||
image_slice_0 = images[0, -3:, -3:, -1].flatten()
|
||||
image_slice_1 = images[1, -3:, -3:, -1].flatten()
|
||||
|
||||
expected_slice_0 = np.array([0.4697, 0.3770, 0.4096, 0.4653, 0.4497, 0.4183, 0.3950, 0.4668, 0.4672])
|
||||
expected_slice_1 = np.array([0.4105, 0.4987, 0.5771, 0.4921, 0.4237, 0.5684, 0.5496, 0.4645, 0.5272])
|
||||
|
||||
assert np.abs(expected_slice_0 - image_slice_0).max() < 1e-2
|
||||
assert np.abs(expected_slice_1 - image_slice_1).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
negative_prompt = "french fries"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
)
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4941, 0.5396, 0.4689, 0.6338, 0.5392, 0.4094, 0.5477, 0.5904, 0.5165])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_num_images_per_prompt(self):
|
||||
device = "cpu"
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
|
||||
# test num_images_per_prompt=1 (default)
|
||||
images = sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
).images
|
||||
|
||||
assert images.shape == (1, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt=1 (default) for batch of prompts
|
||||
batch_size = 2
|
||||
images = sd_pipe(
|
||||
[prompt] * batch_size,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
).images
|
||||
|
||||
assert images.shape == (batch_size, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt for single prompt
|
||||
num_images_per_prompt = 2
|
||||
images = sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
).images
|
||||
|
||||
assert images.shape == (num_images_per_prompt, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
images = sd_pipe(
|
||||
[prompt] * batch_size,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
).images
|
||||
|
||||
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, generator_device="cpu", seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_image.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
)
|
||||
inputs = {
|
||||
"prompt": "A red cat sitting on a park bench",
|
||||
"image": init_image,
|
||||
"mask_image": mask_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_pndm(self):
|
||||
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.5665, 0.6117, 0.6430, 0.4057, 0.4594, 0.5658, 0.1596, 0.3106, 0.4305])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 3e-3
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_batched(self):
|
||||
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
inputs["prompt"] = [inputs["prompt"]] * 2
|
||||
inputs["image"] = preprocess_image(inputs["image"], batch_size=2)
|
||||
|
||||
mask = inputs["mask_image"].convert("L")
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
mask = torch.from_numpy(1 - mask)
|
||||
masks = torch.vstack([mask[None][None]] * 2)
|
||||
inputs["mask_image"] = masks
|
||||
|
||||
image = pipe(**inputs).images
|
||||
assert image.shape == (2, 512, 512, 3)
|
||||
|
||||
image_slice_0 = image[0, 253:256, 253:256, -1].flatten()
|
||||
image_slice_1 = image[1, 253:256, 253:256, -1].flatten()
|
||||
|
||||
expected_slice_0 = np.array(
|
||||
[0.52093095, 0.4176447, 0.32752383, 0.6175223, 0.50563973, 0.36470804, 0.65460044, 0.5775188, 0.44332123]
|
||||
)
|
||||
expected_slice_1 = np.array(
|
||||
[0.3592432, 0.4233033, 0.3914635, 0.31014425, 0.3702293, 0.39412856, 0.17526966, 0.2642669, 0.37480092]
|
||||
)
|
||||
|
||||
assert np.abs(expected_slice_0 - image_slice_0).max() < 3e-3
|
||||
assert np.abs(expected_slice_1 - image_slice_1).max() < 3e-3
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_k_lms(self):
|
||||
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None
|
||||
)
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.4534, 0.4467, 0.4329, 0.4329, 0.4339, 0.4220, 0.4244, 0.4332, 0.4426])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 3e-3
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_intermediate_state(self):
|
||||
number_of_steps = 0
|
||||
|
||||
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
|
||||
callback_fn.has_been_called = True
|
||||
nonlocal number_of_steps
|
||||
number_of_steps += 1
|
||||
if step == 1:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.5977, 1.5449, 1.0586, -0.3250, 0.7383, -0.0862, 0.4631, -0.2571, -1.1289])
|
||||
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
elif step == 2:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.5190, 1.1621, 0.6885, 0.2424, 0.3337, -0.1617, 0.6914, -0.1957, -0.5474])
|
||||
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
callback_fn.has_been_called = False
|
||||
|
||||
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
pipe(**inputs, callback=callback_fn, callback_steps=1)
|
||||
assert callback_fn.has_been_called
|
||||
assert number_of_steps == 2
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class StableDiffusionInpaintLegacyPipelineNightlyTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_image.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
)
|
||||
inputs = {
|
||||
"prompt": "A red cat sitting on a park bench",
|
||||
"image": init_image,
|
||||
"mask_image": mask_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 50,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inpaint_pndm(self):
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = sd_pipe(**inputs).images[0]
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_pndm.npy"
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_inpaint_ddim(self):
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = sd_pipe(**inputs).images[0]
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_ddim.npy"
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_inpaint_lms(self):
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = sd_pipe(**inputs).images[0]
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_lms.npy"
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_inpaint_dpm(self):
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 30
|
||||
image = sd_pipe(**inputs).images[0]
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_dpm_multi.npy"
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
@@ -1,255 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionModelEditingPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, skip_mps, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@skip_mps
|
||||
class StableDiffusionModelEditingPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionModelEditingPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
scheduler = DDIMScheduler()
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
generator = torch.manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A field of roses",
|
||||
"generator": generator,
|
||||
# Setting height and width to None to prevent OOMs on CPU.
|
||||
"height": None,
|
||||
"width": None,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_model_editing_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionModelEditingPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.4755, 0.5132, 0.4976, 0.3904, 0.3554, 0.4765, 0.5139, 0.5158, 0.4889])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_model_editing_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionModelEditingPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
negative_prompt = "french fries"
|
||||
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.4992, 0.5101, 0.5004, 0.3949, 0.3604, 0.4735, 0.5216, 0.5204, 0.4913])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_model_editing_euler(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = EulerAncestralDiscreteScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||
)
|
||||
sd_pipe = StableDiffusionModelEditingPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.4747, 0.5372, 0.4779, 0.4982, 0.5543, 0.4816, 0.5238, 0.4904, 0.5027])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_model_editing_pndm(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = PNDMScheduler()
|
||||
sd_pipe = StableDiffusionModelEditingPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
# the pipeline does not expect pndm so test if it raises error.
|
||||
with self.assertRaises(ValueError):
|
||||
_ = sd_pipe(**inputs).images
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=5e-3)
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
super().test_attention_slicing_forward_pass(expected_max_diff=5e-3)
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class StableDiffusionModelEditingSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, seed=0):
|
||||
generator = torch.manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A field of roses",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_model_editing_default(self):
|
||||
model_ckpt = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionModelEditingPipeline.from_pretrained(model_ckpt, safety_checker=None)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.6749496, 0.6386453, 0.51443267, 0.66094905, 0.61921215, 0.5491332, 0.5744417, 0.58075106, 0.5174658]
|
||||
)
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 1e-2
|
||||
|
||||
# make sure image changes after editing
|
||||
pipe.edit_model("A pack of roses", "A pack of blue roses")
|
||||
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() > 1e-1
|
||||
|
||||
def test_stable_diffusion_model_editing_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
model_ckpt = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
|
||||
pipe = StableDiffusionModelEditingPipeline.from_pretrained(
|
||||
model_ckpt, scheduler=scheduler, safety_checker=None
|
||||
)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing(1)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
_ = pipe(**inputs)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 4.4 GB is allocated
|
||||
assert mem_bytes < 4.4 * 10**9
|
||||
@@ -1,228 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMParallelScheduler,
|
||||
DDPMParallelScheduler,
|
||||
StableDiffusionParadigmsPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
nightly,
|
||||
require_torch_gpu,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionParadigmsPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionParadigmsPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
)
|
||||
scheduler = DDIMParallelScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=512,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "a photograph of an astronaut riding a horse",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 10,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"parallel": 3,
|
||||
"debug": True,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_paradigms_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionParadigmsPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.4773, 0.5417, 0.4723, 0.4925, 0.5631, 0.4752, 0.5240, 0.4935, 0.5023])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_paradigms_default_case_ddpm(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
torch.manual_seed(0)
|
||||
components["scheduler"] = DDPMParallelScheduler()
|
||||
torch.manual_seed(0)
|
||||
sd_pipe = StableDiffusionParadigmsPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.3573, 0.4420, 0.4960, 0.4799, 0.3796, 0.3879, 0.4819, 0.4365, 0.4468])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
# override to speed the overall test timing up.
|
||||
def test_inference_batch_consistent(self):
|
||||
super().test_inference_batch_consistent(batch_sizes=[1, 2])
|
||||
|
||||
# override to speed the overall test timing up.
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=3e-3)
|
||||
|
||||
def test_stable_diffusion_paradigms_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionParadigmsPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
negative_prompt = "french fries"
|
||||
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.4771, 0.5420, 0.4683, 0.4918, 0.5636, 0.4725, 0.5230, 0.4923, 0.5015])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class StableDiffusionParadigmsPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, seed=0):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "a photograph of an astronaut riding a horse",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 10,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
"parallel": 3,
|
||||
"debug": True,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_paradigms_default(self):
|
||||
model_ckpt = "stabilityai/stable-diffusion-2-base"
|
||||
scheduler = DDIMParallelScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
|
||||
pipe = StableDiffusionParadigmsPipeline.from_pretrained(model_ckpt, scheduler=scheduler, safety_checker=None)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
expected_slice = np.array([0.9622, 0.9602, 0.9748, 0.9591, 0.9630, 0.9691, 0.9661, 0.9631, 0.9741])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 1e-2
|
||||
@@ -1,590 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMInverseScheduler,
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
load_numpy,
|
||||
load_pt,
|
||||
nightly,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
assert_mean_pixel_difference,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@skip_mps
|
||||
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionPix2PixZeroPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"image"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.source_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/src_emb_0.pt"
|
||||
)
|
||||
|
||||
cls.target_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/tgt_emb_0.pt"
|
||||
)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
scheduler = DDIMScheduler()
|
||||
inverse_scheduler = DDIMInverseScheduler()
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"inverse_scheduler": inverse_scheduler,
|
||||
"caption_generator": None,
|
||||
"caption_processor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"cross_attention_guidance_amount": 0.15,
|
||||
"source_embeds": self.source_embeds,
|
||||
"target_embeds": self.target_embeds,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def get_dummy_inversion_inputs(self, device, seed=0):
|
||||
dummy_image = floats_tensor((2, 3, 32, 32), rng=random.Random(seed)).to(torch_device)
|
||||
dummy_image = dummy_image / 2 + 0.5
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": [
|
||||
"A painting of a squirrel eating a burger",
|
||||
"A painting of a burger eating a squirrel",
|
||||
],
|
||||
"image": dummy_image.cpu(),
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"generator": generator,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def get_dummy_inversion_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"):
|
||||
inputs = self.get_dummy_inversion_inputs(device, seed)
|
||||
|
||||
if input_image_type == "pt":
|
||||
image = inputs["image"]
|
||||
elif input_image_type == "np":
|
||||
image = VaeImageProcessor.pt_to_numpy(inputs["image"])
|
||||
elif input_image_type == "pil":
|
||||
image = VaeImageProcessor.pt_to_numpy(inputs["image"])
|
||||
image = VaeImageProcessor.numpy_to_pil(image)
|
||||
else:
|
||||
raise ValueError(f"unsupported input_image_type {input_image_type}")
|
||||
|
||||
inputs["image"] = image
|
||||
inputs["output_type"] = output_type
|
||||
|
||||
return inputs
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
if not hasattr(self.pipeline_class, "_optional_components"):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# set all optional components to None and update pipeline config accordingly
|
||||
for optional_component in pipe._optional_components:
|
||||
setattr(pipe, optional_component, None)
|
||||
pipe.register_modules(**{optional_component: None for optional_component in pipe._optional_components})
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
for optional_component in pipe._optional_components:
|
||||
self.assertTrue(
|
||||
getattr(pipe_loaded, optional_component) is None,
|
||||
f"`{optional_component}` did not stay set to None after loading.",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output - output_loaded).max()
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_inversion(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inversion_inputs(device)
|
||||
inputs["image"] = inputs["image"][:1]
|
||||
inputs["prompt"] = inputs["prompt"][:1]
|
||||
image = sd_pipe.invert(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4732, 0.4630, 0.5722, 0.5103, 0.5140, 0.5622, 0.5104, 0.5390, 0.5020])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_inversion_batch(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inversion_inputs(device)
|
||||
image = sd_pipe.invert(**inputs).images
|
||||
image_slice = image[1, -3:, -3:, -1]
|
||||
assert image.shape == (2, 32, 32, 3)
|
||||
expected_slice = np.array([0.6046, 0.5400, 0.4902, 0.4448, 0.4694, 0.5498, 0.4857, 0.5073, 0.5089])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.4863, 0.5053, 0.5033, 0.4007, 0.3571, 0.4768, 0.5176, 0.5277, 0.4940])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
negative_prompt = "french fries"
|
||||
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.5177, 0.5097, 0.5047, 0.4076, 0.3667, 0.4767, 0.5238, 0.5307, 0.4958])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_euler(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = EulerAncestralDiscreteScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||
)
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.5421, 0.5525, 0.6085, 0.5279, 0.4658, 0.5317, 0.4418, 0.4815, 0.5132])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_ddpm(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = DDPMScheduler()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.4861, 0.5053, 0.5038, 0.3994, 0.3562, 0.4768, 0.5172, 0.5280, 0.4938])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_inversion_pt_np_pil_outputs_equivalent(self):
|
||||
device = torch_device
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output_pt = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="pt")).images
|
||||
output_np = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="np")).images
|
||||
output_pil = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="pil")).images
|
||||
|
||||
max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max()
|
||||
self.assertLess(max_diff, 1e-4, "`output_type=='pt'` generate different results from `output_type=='np'`")
|
||||
|
||||
max_diff = np.abs(np.array(output_pil[0]) - (output_np[0] * 255).round()).max()
|
||||
self.assertLess(max_diff, 2.0, "`output_type=='pil'` generate different results from `output_type=='np'`")
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_inversion_pt_np_pil_inputs_equivalent(self):
|
||||
device = torch_device
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
out_input_pt = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, input_image_type="pt")).images
|
||||
out_input_np = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, input_image_type="np")).images
|
||||
out_input_pil = sd_pipe.invert(
|
||||
**self.get_dummy_inversion_inputs_by_type(device, input_image_type="pil")
|
||||
).images
|
||||
|
||||
max_diff = np.abs(out_input_pt - out_input_np).max()
|
||||
self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`")
|
||||
|
||||
assert_mean_pixel_difference(out_input_pil, out_input_np, expected_max_diff=1)
|
||||
|
||||
# Non-determinism caused by the scheduler optimizing the latent inputs during inference
|
||||
@unittest.skip("non-deterministic pipeline")
|
||||
def test_inference_batch_single_identical(self):
|
||||
return super().test_inference_batch_single_identical()
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class StableDiffusionPix2PixZeroPipelineNightlyTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.source_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat.pt"
|
||||
)
|
||||
|
||||
cls.target_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.pt"
|
||||
)
|
||||
|
||||
def get_inputs(self, seed=0):
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "turn him into a cyborg",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 7.5,
|
||||
"cross_attention_guidance_amount": 0.15,
|
||||
"source_embeds": self.source_embeds,
|
||||
"target_embeds": self.target_embeds,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_default(self):
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.5742, 0.5757, 0.5747, 0.5781, 0.5688, 0.5713, 0.5742, 0.5664, 0.5747])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 5e-2
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_k_lms(self):
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.6367, 0.5459, 0.5146, 0.5479, 0.4905, 0.4753, 0.4961, 0.4629, 0.4624])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 5e-2
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_intermediate_state(self):
|
||||
number_of_steps = 0
|
||||
|
||||
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
|
||||
callback_fn.has_been_called = True
|
||||
nonlocal number_of_steps
|
||||
number_of_steps += 1
|
||||
if step == 1:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.1345, 0.268, 0.1539, 0.0726, 0.0959, 0.2261, -0.2673, 0.0277, -0.2062])
|
||||
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||
elif step == 2:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.1393, 0.2637, 0.1617, 0.0724, 0.0987, 0.2271, -0.2666, 0.0299, -0.2104])
|
||||
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||
|
||||
callback_fn.has_been_called = False
|
||||
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
pipe(**inputs, callback=callback_fn, callback_steps=1)
|
||||
assert callback_fn.has_been_called
|
||||
assert number_of_steps == 3
|
||||
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing(1)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
_ = pipe(**inputs)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 8.2 GB is allocated
|
||||
assert mem_bytes < 8.2 * 10**9
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class InversionPipelineNightlyTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
raw_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png"
|
||||
)
|
||||
|
||||
raw_image = raw_image.convert("RGB").resize((512, 512))
|
||||
|
||||
cls.raw_image = raw_image
|
||||
|
||||
def test_stable_diffusion_pix2pix_inversion(self):
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
caption = "a photography of a cat with flowers"
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe.invert(caption, image=self.raw_image, generator=generator, num_inference_steps=10)
|
||||
inv_latents = output[0]
|
||||
|
||||
image_slice = inv_latents[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert inv_latents.shape == (1, 4, 64, 64)
|
||||
expected_slice = np.array([0.8447, -0.0730, 0.7588, -1.2070, -0.4678, 0.1511, -0.8555, 1.1816, -0.7666])
|
||||
|
||||
assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 5e-2
|
||||
|
||||
def test_stable_diffusion_2_pix2pix_inversion(self):
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
caption = "a photography of a cat with flowers"
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe.invert(caption, image=self.raw_image, generator=generator, num_inference_steps=10)
|
||||
inv_latents = output[0]
|
||||
|
||||
image_slice = inv_latents[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert inv_latents.shape == (1, 4, 64, 64)
|
||||
expected_slice = np.array([0.8970, -0.1611, 0.4766, -1.1162, -0.5923, 0.1050, -0.9678, 1.0537, -0.6050])
|
||||
|
||||
assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 5e-2
|
||||
|
||||
def test_stable_diffusion_2_pix2pix_full(self):
|
||||
# numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog_2.png
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog_2.npy"
|
||||
)
|
||||
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
caption = "a photography of a cat with flowers"
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe.invert(caption, image=self.raw_image, generator=generator)
|
||||
inv_latents = output[0]
|
||||
|
||||
source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
|
||||
target_prompts = 4 * ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]
|
||||
|
||||
source_embeds = pipe.get_embeds(source_prompts)
|
||||
target_embeds = pipe.get_embeds(target_prompts)
|
||||
|
||||
image = pipe(
|
||||
caption,
|
||||
source_embeds=source_embeds,
|
||||
target_embeds=target_embeds,
|
||||
num_inference_steps=125,
|
||||
cross_attention_guidance_amount=0.015,
|
||||
generator=generator,
|
||||
latents=inv_latents,
|
||||
negative_prompt=caption,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
mean_diff = np.abs(expected_image - image).mean()
|
||||
assert mean_diff < 0.25
|
||||
@@ -627,7 +627,9 @@ class StableDiffusion2PipelineNightlyTests(unittest.TestCase):
|
||||
|
||||
def test_stable_diffusion_dpm(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base").to(torch_device)
|
||||
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
sd_pipe.scheduler.config, final_sigmas_type="sigma_min"
|
||||
)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
|
||||
@@ -323,7 +323,9 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||
TODO: update this test after making DPM compatible with V-prediction!
|
||||
"""
|
||||
scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2", subfolder="scheduler"
|
||||
"stabilityai/stable-diffusion-2",
|
||||
subfolder="scheduler",
|
||||
final_sigmas_type="sigma_min",
|
||||
)
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=scheduler)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user