Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b69fd990ad | |||
| 26b694bc6a | |||
| 84bc0e48b8 | |||
| 5584e1cb8d | |||
| d7634cca87 | |||
| 1ab57b68c2 | |||
| cfa7c0a93d | |||
| 4974b84564 | |||
| 83062fb872 | |||
| b6d7e31d10 | |||
| 53e9aacc10 | |||
| 41424466e3 | |||
| 95de1981c9 | |||
| 0b45b58867 |
@@ -65,6 +65,7 @@ jobs:
|
|||||||
python -m uv pip install -e [quality,test]
|
python -m uv pip install -e [quality,test]
|
||||||
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers
|
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers
|
||||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
|
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
|
||||||
|
python -m uv pip install pytest-reportlog
|
||||||
|
|
||||||
- name: Environment
|
- name: Environment
|
||||||
run: |
|
run: |
|
||||||
@@ -150,6 +151,7 @@ jobs:
|
|||||||
${CONDA_RUN} python -m uv pip install -e [quality,test]
|
${CONDA_RUN} python -m uv pip install -e [quality,test]
|
||||||
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
|
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
|
||||||
|
${CONDA_RUN} python -m uv pip install pytest-reportlog
|
||||||
|
|
||||||
- name: Environment
|
- name: Environment
|
||||||
shell: arch -arch arm64 bash {0}
|
shell: arch -arch arm64 bash {0}
|
||||||
|
|||||||
@@ -404,6 +404,10 @@
|
|||||||
title: EulerAncestralDiscreteScheduler
|
title: EulerAncestralDiscreteScheduler
|
||||||
- local: api/schedulers/euler
|
- local: api/schedulers/euler
|
||||||
title: EulerDiscreteScheduler
|
title: EulerDiscreteScheduler
|
||||||
|
- local: api/schedulers/edm_euler
|
||||||
|
title: EDMEulerScheduler
|
||||||
|
- local: api/schedulers/edm_multistep_dpm_solver
|
||||||
|
title: EDMDPMSolverMultistepScheduler
|
||||||
- local: api/schedulers/heun
|
- local: api/schedulers/heun
|
||||||
title: HeunDiscreteScheduler
|
title: HeunDiscreteScheduler
|
||||||
- local: api/schedulers/ipndm
|
- local: api/schedulers/ipndm
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# EDMEulerScheduler
|
||||||
|
|
||||||
|
The Karras formulation of the Euler scheduler (Algorithm 2) from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper by Karras et al. This is a fast scheduler which can often generate good outputs in 20-30 steps. The scheduler is based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by [Katherine Crowson](https://github.com/crowsonkb/).
|
||||||
|
|
||||||
|
|
||||||
|
## EDMEulerScheduler
|
||||||
|
[[autodoc]] EDMEulerScheduler
|
||||||
|
|
||||||
|
## EDMEulerSchedulerOutput
|
||||||
|
[[autodoc]] schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# EDMDPMSolverMultistepScheduler
|
||||||
|
|
||||||
|
`EDMDPMSolverMultistepScheduler` is a [Karras formulation](https://huggingface.co/papers/2206.00364) of `DPMSolverMultistep`, a multistep scheduler from [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://huggingface.co/papers/2211.01095) by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu.
|
||||||
|
|
||||||
|
DPMSolver (and the improved version DPMSolver++) is a fast dedicated high-order solver for diffusion ODEs with convergence order guarantee. Empirically, DPMSolver sampling with only 20 steps can generate high-quality
|
||||||
|
samples, and it can generate quite good samples even in 10 steps.
|
||||||
|
|
||||||
|
## EDMDPMSolverMultistepScheduler
|
||||||
|
[[autodoc]] EDMDPMSolverMultistepScheduler
|
||||||
|
|
||||||
|
## SchedulerOutput
|
||||||
|
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
|
||||||
@@ -259,6 +259,50 @@ pip install git+https://github.com/huggingface/peft.git
|
|||||||
**Inference**
|
**Inference**
|
||||||
The inference is the same as if you train a regular LoRA 🤗
|
The inference is the same as if you train a regular LoRA 🤗
|
||||||
|
|
||||||
|
## Conducting EDM-style training
|
||||||
|
|
||||||
|
It's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364).
|
||||||
|
|
||||||
|
simply set:
|
||||||
|
|
||||||
|
```diff
|
||||||
|
+ --do_edm_style_training \
|
||||||
|
```
|
||||||
|
|
||||||
|
Other SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate launch train_dreambooth_lora_sdxl_advanced.py \
|
||||||
|
--pretrained_model_name_or_path="playgroundai/playground-v2.5-1024px-aesthetic" \
|
||||||
|
--dataset_name="linoyts/3d_icon" \
|
||||||
|
--instance_prompt="3d icon in the style of TOK" \
|
||||||
|
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
|
||||||
|
--output_dir="3d-icon-SDXL-LoRA" \
|
||||||
|
--do_edm_style_training \
|
||||||
|
--caption_column="prompt" \
|
||||||
|
--mixed_precision="bf16" \
|
||||||
|
--resolution=1024 \
|
||||||
|
--train_batch_size=3 \
|
||||||
|
--repeats=1 \
|
||||||
|
--report_to="wandb"\
|
||||||
|
--gradient_accumulation_steps=1 \
|
||||||
|
--gradient_checkpointing \
|
||||||
|
--learning_rate=1.0 \
|
||||||
|
--text_encoder_lr=1.0 \
|
||||||
|
--optimizer="prodigy"\
|
||||||
|
--train_text_encoder_ti\
|
||||||
|
--train_text_encoder_ti_frac=0.5\
|
||||||
|
--lr_scheduler="constant" \
|
||||||
|
--lr_warmup_steps=0 \
|
||||||
|
--rank=8 \
|
||||||
|
--max_train_steps=1000 \
|
||||||
|
--checkpointing_steps=2000 \
|
||||||
|
--seed="0" \
|
||||||
|
--push_to_hub
|
||||||
|
```
|
||||||
|
|
||||||
|
> [!CAUTION]
|
||||||
|
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
|
||||||
|
|
||||||
### Tips and Tricks
|
### Tips and Tricks
|
||||||
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
|
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -14,9 +14,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import contextlib
|
||||||
import gc
|
import gc
|
||||||
import hashlib
|
import hashlib
|
||||||
import itertools
|
import itertools
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -37,7 +39,7 @@ import transformers
|
|||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||||
from huggingface_hub import create_repo, upload_folder
|
from huggingface_hub import create_repo, hf_hub_download, upload_folder
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from peft import LoraConfig, set_peft_model_state_dict
|
from peft import LoraConfig, set_peft_model_state_dict
|
||||||
from peft.utils import get_peft_model_state_dict
|
from peft.utils import get_peft_model_state_dict
|
||||||
@@ -55,6 +57,8 @@ from diffusers import (
|
|||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
|
EDMEulerScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
StableDiffusionXLPipeline,
|
StableDiffusionXLPipeline,
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
)
|
)
|
||||||
@@ -74,11 +78,25 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def determine_scheduler_type(pretrained_model_name_or_path, revision):
|
||||||
|
model_index_filename = "model_index.json"
|
||||||
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
|
model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)
|
||||||
|
else:
|
||||||
|
model_index = hf_hub_download(
|
||||||
|
repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(model_index, "r") as f:
|
||||||
|
scheduler_type = json.load(f)["scheduler"][1]
|
||||||
|
return scheduler_type
|
||||||
|
|
||||||
|
|
||||||
def save_model_card(
|
def save_model_card(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
use_dora: bool,
|
use_dora: bool,
|
||||||
@@ -370,6 +388,11 @@ def parse_args(input_args=None):
|
|||||||
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--do_edm_style_training",
|
||||||
|
action="store_true",
|
||||||
|
help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--with_prior_preservation",
|
"--with_prior_preservation",
|
||||||
default=False,
|
default=False,
|
||||||
@@ -1117,6 +1140,8 @@ def main(args):
|
|||||||
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
||||||
" Please use `huggingface-cli login` to authenticate with the Hub."
|
" Please use `huggingface-cli login` to authenticate with the Hub."
|
||||||
)
|
)
|
||||||
|
if args.do_edm_style_training and args.snr_gamma is not None:
|
||||||
|
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
|
||||||
|
|
||||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||||
|
|
||||||
@@ -1234,7 +1259,19 @@ def main(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Load scheduler and models
|
# Load scheduler and models
|
||||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)
|
||||||
|
if "EDM" in scheduler_type:
|
||||||
|
args.do_edm_style_training = True
|
||||||
|
noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||||
|
logger.info("Performing EDM-style training!")
|
||||||
|
elif args.do_edm_style_training:
|
||||||
|
noise_scheduler = EulerDiscreteScheduler.from_pretrained(
|
||||||
|
args.pretrained_model_name_or_path, subfolder="scheduler"
|
||||||
|
)
|
||||||
|
logger.info("Performing EDM-style training!")
|
||||||
|
else:
|
||||||
|
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||||
|
|
||||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||||
)
|
)
|
||||||
@@ -1252,7 +1289,12 @@ def main(args):
|
|||||||
revision=args.revision,
|
revision=args.revision,
|
||||||
variant=args.variant,
|
variant=args.variant,
|
||||||
)
|
)
|
||||||
vae_scaling_factor = vae.config.scaling_factor
|
latents_mean = latents_std = None
|
||||||
|
if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
|
||||||
|
latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
|
||||||
|
if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
|
||||||
|
latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
|
||||||
|
|
||||||
unet = UNet2DConditionModel.from_pretrained(
|
unet = UNet2DConditionModel.from_pretrained(
|
||||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
||||||
)
|
)
|
||||||
@@ -1790,6 +1832,19 @@ def main(args):
|
|||||||
disable=not accelerator.is_local_main_process,
|
disable=not accelerator.is_local_main_process,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
||||||
|
# TODO: revisit other sampling algorithms
|
||||||
|
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
|
||||||
|
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
|
||||||
|
timesteps = timesteps.to(accelerator.device)
|
||||||
|
|
||||||
|
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||||
|
|
||||||
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
while len(sigma.shape) < n_dim:
|
||||||
|
sigma = sigma.unsqueeze(-1)
|
||||||
|
return sigma
|
||||||
|
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
|
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
|
||||||
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
|
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
|
||||||
@@ -1841,9 +1896,15 @@ def main(args):
|
|||||||
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
||||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||||
|
|
||||||
model_input = model_input * vae_scaling_factor
|
if latents_mean is None and latents_std is None:
|
||||||
if args.pretrained_vae_model_name_or_path is None:
|
model_input = model_input * vae.config.scaling_factor
|
||||||
model_input = model_input.to(weight_dtype)
|
if args.pretrained_vae_model_name_or_path is None:
|
||||||
|
model_input = model_input.to(weight_dtype)
|
||||||
|
else:
|
||||||
|
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
|
||||||
|
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
|
||||||
|
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
|
||||||
|
model_input = model_input.to(dtype=weight_dtype)
|
||||||
|
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
noise = torch.randn_like(model_input)
|
noise = torch.randn_like(model_input)
|
||||||
@@ -1854,15 +1915,32 @@ def main(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
bsz = model_input.shape[0]
|
bsz = model_input.shape[0]
|
||||||
|
|
||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
timesteps = torch.randint(
|
if not args.do_edm_style_training:
|
||||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
|
timesteps = torch.randint(
|
||||||
)
|
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
|
||||||
timesteps = timesteps.long()
|
)
|
||||||
|
timesteps = timesteps.long()
|
||||||
|
else:
|
||||||
|
# in EDM formulation, the model is conditioned on the pre-conditioned noise levels
|
||||||
|
# instead of discrete timesteps, so here we sample indices to get the noise levels
|
||||||
|
# from `scheduler.timesteps`
|
||||||
|
indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
|
||||||
|
timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
|
||||||
|
|
||||||
# Add noise to the model input according to the noise magnitude at each timestep
|
# Add noise to the model input according to the noise magnitude at each timestep
|
||||||
# (this is the forward diffusion process)
|
# (this is the forward diffusion process)
|
||||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||||
|
# For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
|
||||||
|
# We then precondition the final model inputs based on these sigmas instead of the timesteps.
|
||||||
|
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||||
|
if args.do_edm_style_training:
|
||||||
|
sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
|
||||||
|
if "EDM" in scheduler_type:
|
||||||
|
inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
|
||||||
|
else:
|
||||||
|
inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)
|
||||||
|
|
||||||
# time ids
|
# time ids
|
||||||
add_time_ids = torch.cat(
|
add_time_ids = torch.cat(
|
||||||
@@ -1888,7 +1966,7 @@ def main(args):
|
|||||||
}
|
}
|
||||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||||
model_pred = unet(
|
model_pred = unet(
|
||||||
noisy_model_input,
|
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
|
||||||
timesteps,
|
timesteps,
|
||||||
prompt_embeds_input,
|
prompt_embeds_input,
|
||||||
added_cond_kwargs=unet_added_conditions,
|
added_cond_kwargs=unet_added_conditions,
|
||||||
@@ -1906,14 +1984,42 @@ def main(args):
|
|||||||
)
|
)
|
||||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||||
model_pred = unet(
|
model_pred = unet(
|
||||||
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
|
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
|
||||||
|
timesteps,
|
||||||
|
prompt_embeds_input,
|
||||||
|
added_cond_kwargs=unet_added_conditions,
|
||||||
).sample
|
).sample
|
||||||
|
|
||||||
|
weighting = None
|
||||||
|
if args.do_edm_style_training:
|
||||||
|
# Similar to the input preconditioning, the model predictions are also preconditioned
|
||||||
|
# on noised model inputs (before preconditioning) and the sigmas.
|
||||||
|
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||||
|
if "EDM" in scheduler_type:
|
||||||
|
model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
|
||||||
|
else:
|
||||||
|
if noise_scheduler.config.prediction_type == "epsilon":
|
||||||
|
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||||
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||||
|
model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (
|
||||||
|
noisy_model_input / (sigmas**2 + 1)
|
||||||
|
)
|
||||||
|
# We are not doing weighting here because it tends result in numerical problems.
|
||||||
|
# See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||||
|
# There might be other alternatives for weighting as well:
|
||||||
|
# https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
|
||||||
|
if "EDM" not in scheduler_type:
|
||||||
|
weighting = (sigmas**-2.0).float()
|
||||||
|
|
||||||
# Get the target for loss depending on the prediction type
|
# Get the target for loss depending on the prediction type
|
||||||
if noise_scheduler.config.prediction_type == "epsilon":
|
if noise_scheduler.config.prediction_type == "epsilon":
|
||||||
target = noise
|
target = model_input if args.do_edm_style_training else noise
|
||||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||||
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
|
target = (
|
||||||
|
model_input
|
||||||
|
if args.do_edm_style_training
|
||||||
|
else noise_scheduler.get_velocity(model_input, noise, timesteps)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||||
|
|
||||||
@@ -1923,10 +2029,28 @@ def main(args):
|
|||||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||||
|
|
||||||
# Compute prior loss
|
# Compute prior loss
|
||||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
if weighting is not None:
|
||||||
|
prior_loss = torch.mean(
|
||||||
|
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||||
|
target_prior.shape[0], -1
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
prior_loss = prior_loss.mean()
|
||||||
|
else:
|
||||||
|
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||||
|
|
||||||
if args.snr_gamma is None:
|
if args.snr_gamma is None:
|
||||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
if weighting is not None:
|
||||||
|
loss = torch.mean(
|
||||||
|
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(
|
||||||
|
target.shape[0], -1
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
loss = loss.mean()
|
||||||
|
else:
|
||||||
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||||
else:
|
else:
|
||||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||||
@@ -2049,17 +2173,18 @@ def main(args):
|
|||||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||||
scheduler_args = {}
|
scheduler_args = {}
|
||||||
|
|
||||||
if "variance_type" in pipeline.scheduler.config:
|
if not args.do_edm_style_training:
|
||||||
variance_type = pipeline.scheduler.config.variance_type
|
if "variance_type" in pipeline.scheduler.config:
|
||||||
|
variance_type = pipeline.scheduler.config.variance_type
|
||||||
|
|
||||||
if variance_type in ["learned", "learned_range"]:
|
if variance_type in ["learned", "learned_range"]:
|
||||||
variance_type = "fixed_small"
|
variance_type = "fixed_small"
|
||||||
|
|
||||||
scheduler_args["variance_type"] = variance_type
|
scheduler_args["variance_type"] = variance_type
|
||||||
|
|
||||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||||
pipeline.scheduler.config, **scheduler_args
|
pipeline.scheduler.config, **scheduler_args
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = pipeline.to(accelerator.device)
|
pipeline = pipeline.to(accelerator.device)
|
||||||
pipeline.set_progress_bar_config(disable=True)
|
pipeline.set_progress_bar_config(disable=True)
|
||||||
@@ -2067,8 +2192,13 @@ def main(args):
|
|||||||
# run inference
|
# run inference
|
||||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||||
pipeline_args = {"prompt": args.validation_prompt}
|
pipeline_args = {"prompt": args.validation_prompt}
|
||||||
|
inference_ctx = (
|
||||||
|
contextlib.nullcontext()
|
||||||
|
if "playground" in args.pretrained_model_name_or_path
|
||||||
|
else torch.cuda.amp.autocast()
|
||||||
|
)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast():
|
with inference_ctx:
|
||||||
images = [
|
images = [
|
||||||
pipeline(**pipeline_args, generator=generator).images[0]
|
pipeline(**pipeline_args, generator=generator).images[0]
|
||||||
for _ in range(args.num_validation_images)
|
for _ in range(args.num_validation_images)
|
||||||
@@ -2144,15 +2274,18 @@ def main(args):
|
|||||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||||
scheduler_args = {}
|
scheduler_args = {}
|
||||||
|
|
||||||
if "variance_type" in pipeline.scheduler.config:
|
if not args.do_edm_style_training:
|
||||||
variance_type = pipeline.scheduler.config.variance_type
|
if "variance_type" in pipeline.scheduler.config:
|
||||||
|
variance_type = pipeline.scheduler.config.variance_type
|
||||||
|
|
||||||
if variance_type in ["learned", "learned_range"]:
|
if variance_type in ["learned", "learned_range"]:
|
||||||
variance_type = "fixed_small"
|
variance_type = "fixed_small"
|
||||||
|
|
||||||
scheduler_args["variance_type"] = variance_type
|
scheduler_args["variance_type"] = variance_type
|
||||||
|
|
||||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||||
|
pipeline.scheduler.config, **scheduler_args
|
||||||
|
)
|
||||||
|
|
||||||
# load attention processors
|
# load attention processors
|
||||||
pipeline.load_lora_weights(args.output_dir)
|
pipeline.load_lora_weights(args.output_dir)
|
||||||
|
|||||||
@@ -40,8 +40,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
|
|
||||||
class MarigoldDepthOutput(BaseOutput):
|
class MarigoldDepthOutput(BaseOutput):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
# Cache compiled models across invocations of this script.
|
# Cache compiled models across invocations of this script.
|
||||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -425,6 +425,11 @@ def parse_args(input_args=None):
|
|||||||
default=4,
|
default=4,
|
||||||
help=("The dimension of the LoRA update matrices."),
|
help=("The dimension of the LoRA update matrices."),
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug_loss",
|
||||||
|
action="store_true",
|
||||||
|
help="debug loss for each image, if filenames are awailable in the dataset",
|
||||||
|
)
|
||||||
|
|
||||||
if input_args is not None:
|
if input_args is not None:
|
||||||
args = parser.parse_args(input_args)
|
args = parser.parse_args(input_args)
|
||||||
@@ -603,6 +608,7 @@ def main(args):
|
|||||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||||
# The VAE is in float32 to avoid NaN losses.
|
# The VAE is in float32 to avoid NaN losses.
|
||||||
unet.to(accelerator.device, dtype=weight_dtype)
|
unet.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
if args.pretrained_vae_model_name_or_path is None:
|
if args.pretrained_vae_model_name_or_path is None:
|
||||||
vae.to(accelerator.device, dtype=torch.float32)
|
vae.to(accelerator.device, dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
@@ -890,13 +896,17 @@ def main(args):
|
|||||||
tokens_one, tokens_two = tokenize_captions(examples)
|
tokens_one, tokens_two = tokenize_captions(examples)
|
||||||
examples["input_ids_one"] = tokens_one
|
examples["input_ids_one"] = tokens_one
|
||||||
examples["input_ids_two"] = tokens_two
|
examples["input_ids_two"] = tokens_two
|
||||||
|
if args.debug_loss:
|
||||||
|
fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename]
|
||||||
|
if fnames:
|
||||||
|
examples["filenames"] = fnames
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
with accelerator.main_process_first():
|
with accelerator.main_process_first():
|
||||||
if args.max_train_samples is not None:
|
if args.max_train_samples is not None:
|
||||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||||
# Set the training transforms
|
# Set the training transforms
|
||||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||||
@@ -905,7 +915,7 @@ def main(args):
|
|||||||
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
|
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
|
||||||
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
|
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
|
||||||
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
|
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
|
||||||
return {
|
result = {
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"input_ids_one": input_ids_one,
|
"input_ids_one": input_ids_one,
|
||||||
"input_ids_two": input_ids_two,
|
"input_ids_two": input_ids_two,
|
||||||
@@ -913,6 +923,11 @@ def main(args):
|
|||||||
"crop_top_lefts": crop_top_lefts,
|
"crop_top_lefts": crop_top_lefts,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
filenames = [example["filenames"] for example in examples if "filenames" in example]
|
||||||
|
if filenames:
|
||||||
|
result["filenames"] = filenames
|
||||||
|
return result
|
||||||
|
|
||||||
# DataLoaders creation:
|
# DataLoaders creation:
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
@@ -1105,7 +1120,9 @@ def main(args):
|
|||||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
if args.debug_loss and "filenames" in batch:
|
||||||
|
for fname in batch["filenames"]:
|
||||||
|
accelerator.log({"loss_for_" + fname: loss}, step=global_step)
|
||||||
# Gather the losses across all processes for logging (if we use distributed training).
|
# Gather the losses across all processes for logging (if we use distributed training).
|
||||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
||||||
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ else:
|
|||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.27.0.dev0")
|
check_min_version("0.27.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -249,7 +249,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="diffusers",
|
name="diffusers",
|
||||||
version="0.27.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
version="0.27.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
__version__ = "0.27.0.dev0"
|
__version__ = "0.27.2"
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|||||||
@@ -1081,6 +1081,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
|||||||
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
||||||
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
||||||
A tensor that if specified is added to the residual of the middle unet block.
|
A tensor that if specified is added to the residual of the middle unet block.
|
||||||
|
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||||
|
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
||||||
encoder_attention_mask (`torch.Tensor`):
|
encoder_attention_mask (`torch.Tensor`):
|
||||||
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||||
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||||
@@ -1088,18 +1090,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
|||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||||
tuple.
|
tuple.
|
||||||
cross_attention_kwargs (`dict`, *optional*):
|
|
||||||
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
|
||||||
added_cond_kwargs: (`dict`, *optional*):
|
|
||||||
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
|
||||||
are passed along to the UNet blocks.
|
|
||||||
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
|
||||||
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
|
||||||
example from ControlNet side model(s)
|
|
||||||
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
|
||||||
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
|
||||||
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
|
||||||
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||||
@@ -1185,7 +1175,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
|||||||
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
||||||
|
|
||||||
# 3. down
|
# 3. down
|
||||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
|
||||||
|
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
|
||||||
|
if cross_attention_kwargs is not None:
|
||||||
|
cross_attention_kwargs = cross_attention_kwargs.copy()
|
||||||
|
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
|
||||||
|
else:
|
||||||
|
lora_scale = 1.0
|
||||||
|
|
||||||
if USE_PEFT_BACKEND:
|
if USE_PEFT_BACKEND:
|
||||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||||
scale_lora_layers(self, lora_scale)
|
scale_lora_layers(self, lora_scale)
|
||||||
|
|||||||
-3
@@ -528,15 +528,12 @@ class StableDiffusionInpaintPipelineLegacy(
|
|||||||
f" {negative_prompt_embeds.shape}."
|
f" {negative_prompt_embeds.shape}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
|
||||||
def get_timesteps(self, num_inference_steps, strength, device):
|
def get_timesteps(self, num_inference_steps, strength, device):
|
||||||
# get the original timestep using init_timestep
|
# get the original timestep using init_timestep
|
||||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||||
|
|
||||||
t_start = max(num_inference_steps - init_timestep, 0)
|
t_start = max(num_inference_steps - init_timestep, 0)
|
||||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||||
if hasattr(self.scheduler, "set_begin_index"):
|
|
||||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
|
||||||
|
|
||||||
return timesteps, num_inference_steps - t_start
|
return timesteps, num_inference_steps - t_start
|
||||||
|
|
||||||
|
|||||||
@@ -100,8 +100,10 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
self.register_to_config(latent_dim_scale=latent_dim_scale)
|
self.register_to_config(latent_dim_scale=latent_dim_scale)
|
||||||
|
|
||||||
def prepare_latents(self, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler):
|
def prepare_latents(
|
||||||
batch_size, channels, height, width = image_embeddings.shape
|
self, batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler
|
||||||
|
):
|
||||||
|
_, channels, height, width = image_embeddings.shape
|
||||||
latents_shape = (
|
latents_shape = (
|
||||||
batch_size * num_images_per_prompt,
|
batch_size * num_images_per_prompt,
|
||||||
4,
|
4,
|
||||||
@@ -383,7 +385,19 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
if isinstance(image_embeddings, list):
|
if isinstance(image_embeddings, list):
|
||||||
image_embeddings = torch.cat(image_embeddings, dim=0)
|
image_embeddings = torch.cat(image_embeddings, dim=0)
|
||||||
batch_size = image_embeddings.shape[0]
|
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
# Compute the effective number of images per prompt
|
||||||
|
# We must account for the fact that the image embeddings from the prior can be generated with num_images_per_prompt > 1
|
||||||
|
# This results in a case where a single prompt is associated with multiple image embeddings
|
||||||
|
# Divide the number of image embeddings by the batch size to determine if this is the case.
|
||||||
|
num_images_per_prompt = num_images_per_prompt * (image_embeddings.shape[0] // batch_size)
|
||||||
|
|
||||||
# 2. Encode caption
|
# 2. Encode caption
|
||||||
if prompt_embeds is None and negative_prompt_embeds is None:
|
if prompt_embeds is None and negative_prompt_embeds is None:
|
||||||
@@ -417,7 +431,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
# 5. Prepare latents
|
# 5. Prepare latents
|
||||||
latents = self.prepare_latents(
|
latents = self.prepare_latents(
|
||||||
image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. Run denoising loop
|
# 6. Run denoising loop
|
||||||
|
|||||||
-3
@@ -716,15 +716,12 @@ class StableDiffusionDiffEditPipeline(
|
|||||||
f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}."
|
f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
|
||||||
def get_timesteps(self, num_inference_steps, strength, device):
|
def get_timesteps(self, num_inference_steps, strength, device):
|
||||||
# get the original timestep using init_timestep
|
# get the original timestep using init_timestep
|
||||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||||
|
|
||||||
t_start = max(num_inference_steps - init_timestep, 0)
|
t_start = max(num_inference_steps - init_timestep, 0)
|
||||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||||
if hasattr(self.scheduler, "set_begin_index"):
|
|
||||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
|
||||||
|
|
||||||
return timesteps, num_inference_steps - t_start
|
return timesteps, num_inference_steps - t_start
|
||||||
|
|
||||||
|
|||||||
@@ -434,7 +434,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timesteps.shape[0]
|
||||||
else:
|
else:
|
||||||
|
# add noise is called bevore first denoising step to create inital latent(img2img)
|
||||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
|||||||
@@ -768,10 +768,14 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||||
timesteps = timesteps.to(original_samples.device)
|
timesteps = timesteps.to(original_samples.device)
|
||||||
|
|
||||||
# begin_index is None when the scheduler is used for training
|
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timesteps.shape[0]
|
||||||
else:
|
else:
|
||||||
|
# add noise is called bevore first denoising step to create inital latent(img2img)
|
||||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
|||||||
@@ -1011,10 +1011,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||||
timesteps = timesteps.to(original_samples.device)
|
timesteps = timesteps.to(original_samples.device)
|
||||||
|
|
||||||
# begin_index is None when the scheduler is used for training
|
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timesteps.shape[0]
|
||||||
else:
|
else:
|
||||||
|
# add noise is called bevore first denoising step to create inital latent(img2img)
|
||||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
|||||||
@@ -543,7 +543,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timesteps.shape[0]
|
||||||
else:
|
else:
|
||||||
|
# add noise is called bevore first denoising step to create inital latent(img2img)
|
||||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
|||||||
@@ -223,6 +223,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
"""
|
"""
|
||||||
steps = num_inference_steps
|
steps = num_inference_steps
|
||||||
order = self.config.solver_order
|
order = self.config.solver_order
|
||||||
|
if order > 3:
|
||||||
|
raise ValueError("Order > 3 is not supported by this scheduler")
|
||||||
if self.config.lower_order_final:
|
if self.config.lower_order_final:
|
||||||
if order == 3:
|
if order == 3:
|
||||||
if steps % 3 == 0:
|
if steps % 3 == 0:
|
||||||
@@ -959,10 +961,14 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||||
timesteps = timesteps.to(original_samples.device)
|
timesteps = timesteps.to(original_samples.device)
|
||||||
|
|
||||||
# begin_index is None when the scheduler is used for training
|
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timesteps.shape[0]
|
||||||
else:
|
else:
|
||||||
|
# add noise is called bevore first denoising step to create inital latent(img2img)
|
||||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
|||||||
@@ -669,7 +669,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timesteps.shape[0]
|
||||||
else:
|
else:
|
||||||
|
# add noise is called bevore first denoising step to create inital latent(img2img)
|
||||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
|||||||
@@ -367,7 +367,11 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timesteps.shape[0]
|
||||||
else:
|
else:
|
||||||
|
# add noise is called bevore first denoising step to create inital latent(img2img)
|
||||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
|||||||
@@ -467,7 +467,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timesteps.shape[0]
|
||||||
else:
|
else:
|
||||||
|
# add noise is called bevore first denoising step to create inital latent(img2img)
|
||||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
|||||||
@@ -562,7 +562,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timesteps.shape[0]
|
||||||
else:
|
else:
|
||||||
|
# add noise is called bevore first denoising step to create inital latent(img2img)
|
||||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
|||||||
@@ -468,7 +468,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timesteps.shape[0]
|
||||||
else:
|
else:
|
||||||
|
# add noise is called bevore first denoising step to create inital latent(img2img)
|
||||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
|||||||
@@ -494,7 +494,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timesteps.shape[0]
|
||||||
else:
|
else:
|
||||||
|
# add noise is called bevore first denoising step to create inital latent(img2img)
|
||||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
|||||||
@@ -469,7 +469,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timesteps.shape[0]
|
||||||
else:
|
else:
|
||||||
|
# add noise is called bevore first denoising step to create inital latent(img2img)
|
||||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
|||||||
@@ -461,7 +461,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timesteps.shape[0]
|
||||||
else:
|
else:
|
||||||
|
# add noise is called bevore first denoising step to create inital latent(img2img)
|
||||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
|||||||
@@ -862,10 +862,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||||
timesteps = timesteps.to(original_samples.device)
|
timesteps = timesteps.to(original_samples.device)
|
||||||
|
|
||||||
# begin_index is None when the scheduler is used for training
|
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
||||||
if self.begin_index is None:
|
if self.begin_index is None:
|
||||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||||
|
elif self.step_index is not None:
|
||||||
|
# add_noise is called after first denoising step (for inpainting)
|
||||||
|
step_indices = [self.step_index] * timesteps.shape[0]
|
||||||
else:
|
else:
|
||||||
|
# add noise is called bevore first denoising step to create inital latent(img2img)
|
||||||
step_indices = [self.begin_index] * timesteps.shape[0]
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ class PeftLoraLoaderMixinTests:
|
|||||||
|
|
||||||
pipeline_inputs = {
|
pipeline_inputs = {
|
||||||
"prompt": "A painting of a squirrel eating a burger",
|
"prompt": "A painting of a squirrel eating a burger",
|
||||||
"num_inference_steps": 2,
|
"num_inference_steps": 5,
|
||||||
"guidance_scale": 6.0,
|
"guidance_scale": 6.0,
|
||||||
"output_type": "np",
|
"output_type": "np",
|
||||||
}
|
}
|
||||||
@@ -589,7 +589,7 @@ class PeftLoraLoaderMixinTests:
|
|||||||
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
|
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
|
||||||
).images
|
).images
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
|
not np.allclose(output_lora, output_lora_scale, atol=1e-4, rtol=1e-4),
|
||||||
"Lora + scale should change the output",
|
"Lora + scale should change the output",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1300,6 +1300,11 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
|||||||
pipe.load_lora_weights(lora_id)
|
pipe.load_lora_weights(lora_id)
|
||||||
pipe = pipe.to("cuda")
|
pipe = pipe.to("cuda")
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
self.check_if_lora_correctly_set(pipe.unet),
|
||||||
|
"Lora not correctly set in UNet",
|
||||||
|
)
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
self.check_if_lora_correctly_set(pipe.text_encoder),
|
self.check_if_lora_correctly_set(pipe.text_encoder),
|
||||||
"Lora not correctly set in text encoder 2",
|
"Lora not correctly set in text encoder 2",
|
||||||
|
|||||||
@@ -829,7 +829,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
|||||||
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors",
|
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert vae_default.config.scaling_factor == 0.18125
|
assert vae_default.config.scaling_factor == 0.18215
|
||||||
assert vae_default.config.sample_size == 512
|
assert vae_default.config.sample_size == 512
|
||||||
assert vae_default.dtype == torch.float32
|
assert vae_default.dtype == torch.float32
|
||||||
|
|
||||||
|
|||||||
@@ -50,9 +50,7 @@ class StableCascadeUNetModelSlowTests(unittest.TestCase):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
unet = StableCascadeUNet.from_pretrained(
|
unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade-prior", subfolder="prior", variant="bf16")
|
||||||
"stabilityai/stable-cascade-prior", subfolder="prior", revision="refs/pr/2", variant="bf16"
|
|
||||||
)
|
|
||||||
unet_config = unet.config
|
unet_config = unet.config
|
||||||
del unet
|
del unet
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@@ -74,9 +72,7 @@ class StableCascadeUNetModelSlowTests(unittest.TestCase):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
unet = StableCascadeUNet.from_pretrained(
|
unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade", subfolder="decoder", variant="bf16")
|
||||||
"stabilityai/stable-cascade", subfolder="decoder", revision="refs/pr/44", variant="bf16"
|
|
||||||
)
|
|
||||||
unet_config = unet.config
|
unet_config = unet.config
|
||||||
del unet
|
del unet
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|||||||
@@ -21,18 +21,19 @@ import torch
|
|||||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
|
||||||
from diffusers import DDPMWuerstchenScheduler, StableCascadeDecoderPipeline
|
from diffusers import DDPMWuerstchenScheduler, StableCascadeDecoderPipeline
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
|
||||||
from diffusers.models import StableCascadeUNet
|
from diffusers.models import StableCascadeUNet
|
||||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
enable_full_determinism,
|
enable_full_determinism,
|
||||||
load_image,
|
load_numpy,
|
||||||
load_pt,
|
load_pt,
|
||||||
|
numpy_cosine_similarity_distance,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
skip_mps,
|
skip_mps,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
|
|
||||||
from ..test_pipelines_common import PipelineTesterMixin
|
from ..test_pipelines_common import PipelineTesterMixin
|
||||||
|
|
||||||
@@ -246,6 +247,66 @@ class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
|||||||
|
|
||||||
assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
|
assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
|
||||||
|
|
||||||
|
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings(self):
|
||||||
|
device = "cpu"
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
|
||||||
|
pipe = StableCascadeDecoderPipeline(**components)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prior_num_images_per_prompt = 2
|
||||||
|
decoder_num_images_per_prompt = 2
|
||||||
|
prompt = ["a cat"]
|
||||||
|
batch_size = len(prompt)
|
||||||
|
|
||||||
|
generator = torch.Generator(device)
|
||||||
|
image_embeddings = randn_tensor(
|
||||||
|
(batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0)
|
||||||
|
)
|
||||||
|
decoder_output = pipe(
|
||||||
|
image_embeddings=image_embeddings,
|
||||||
|
prompt=prompt,
|
||||||
|
num_inference_steps=1,
|
||||||
|
output_type="np",
|
||||||
|
guidance_scale=0.0,
|
||||||
|
generator=generator.manual_seed(0),
|
||||||
|
num_images_per_prompt=decoder_num_images_per_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert decoder_output.images.shape[0] == (
|
||||||
|
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings_with_guidance(self):
|
||||||
|
device = "cpu"
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
|
||||||
|
pipe = StableCascadeDecoderPipeline(**components)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prior_num_images_per_prompt = 2
|
||||||
|
decoder_num_images_per_prompt = 2
|
||||||
|
prompt = ["a cat"]
|
||||||
|
batch_size = len(prompt)
|
||||||
|
|
||||||
|
generator = torch.Generator(device)
|
||||||
|
image_embeddings = randn_tensor(
|
||||||
|
(batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0)
|
||||||
|
)
|
||||||
|
decoder_output = pipe(
|
||||||
|
image_embeddings=image_embeddings,
|
||||||
|
prompt=prompt,
|
||||||
|
num_inference_steps=1,
|
||||||
|
output_type="np",
|
||||||
|
guidance_scale=2.0,
|
||||||
|
generator=generator.manual_seed(0),
|
||||||
|
num_images_per_prompt=decoder_num_images_per_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert decoder_output.images.shape[0] == (
|
||||||
|
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@@ -258,7 +319,7 @@ class StableCascadeDecoderPipelineIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
def test_stable_cascade_decoder(self):
|
def test_stable_cascade_decoder(self):
|
||||||
pipe = StableCascadeDecoderPipeline.from_pretrained(
|
pipe = StableCascadeDecoderPipeline.from_pretrained(
|
||||||
"diffusers/StableCascade-decoder", torch_dtype=torch.bfloat16
|
"stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16
|
||||||
)
|
)
|
||||||
pipe.enable_model_cpu_offload()
|
pipe.enable_model_cpu_offload()
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
@@ -271,18 +332,16 @@ class StableCascadeDecoderPipelineIntegrationTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
image = pipe(
|
image = pipe(
|
||||||
prompt=prompt, image_embeddings=image_embedding, num_inference_steps=10, generator=generator
|
prompt=prompt,
|
||||||
|
image_embeddings=image_embedding,
|
||||||
|
output_type="np",
|
||||||
|
num_inference_steps=2,
|
||||||
|
generator=generator,
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
assert image.size == (1024, 1024)
|
assert image.shape == (1024, 1024, 3)
|
||||||
|
expected_image = load_numpy(
|
||||||
expected_image = load_image(
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/stable_cascade_decoder_image.npy"
|
||||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/t2i.png"
|
|
||||||
)
|
)
|
||||||
|
max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
|
||||||
image_processor = VaeImageProcessor()
|
assert max_diff < 1e-4
|
||||||
|
|
||||||
image_np = image_processor.pil_to_numpy(image)
|
|
||||||
expected_image_np = image_processor.pil_to_numpy(expected_image)
|
|
||||||
|
|
||||||
self.assertTrue(np.allclose(image_np, expected_image_np, atol=53e-2))
|
|
||||||
|
|||||||
@@ -29,7 +29,8 @@ from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProc
|
|||||||
from diffusers.utils.import_utils import is_peft_available
|
from diffusers.utils.import_utils import is_peft_available
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
enable_full_determinism,
|
enable_full_determinism,
|
||||||
load_pt,
|
load_numpy,
|
||||||
|
numpy_cosine_similarity_distance,
|
||||||
require_peft_backend,
|
require_peft_backend,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
skip_mps,
|
skip_mps,
|
||||||
@@ -319,7 +320,9 @@ class StableCascadePriorPipelineIntegrationTests(unittest.TestCase):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def test_stable_cascade_prior(self):
|
def test_stable_cascade_prior(self):
|
||||||
pipe = StableCascadePriorPipeline.from_pretrained("diffusers/StableCascade-prior", torch_dtype=torch.bfloat16)
|
pipe = StableCascadePriorPipeline.from_pretrained(
|
||||||
|
"stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
pipe.enable_model_cpu_offload()
|
pipe.enable_model_cpu_offload()
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
@@ -327,17 +330,12 @@ class StableCascadePriorPipelineIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||||
|
|
||||||
output = pipe(prompt, num_inference_steps=10, generator=generator)
|
output = pipe(prompt, num_inference_steps=2, output_type="np", generator=generator)
|
||||||
image_embedding = output.image_embeddings
|
image_embedding = output.image_embeddings
|
||||||
|
expected_image_embedding = load_numpy(
|
||||||
expected_image_embedding = load_pt(
|
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/stable_cascade_prior_image_embeddings.npy"
|
||||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert image_embedding.shape == (1, 16, 24, 24)
|
assert image_embedding.shape == (1, 16, 24, 24)
|
||||||
|
|
||||||
self.assertTrue(
|
max_diff = numpy_cosine_similarity_distance(image_embedding.flatten(), expected_image_embedding.flatten())
|
||||||
np.allclose(
|
assert max_diff < 1e-4
|
||||||
image_embedding.cpu().float().numpy(), expected_image_embedding.cpu().float().numpy(), atol=5e-2
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from diffusers import (
|
|||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
LCMScheduler,
|
LCMScheduler,
|
||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
@@ -557,6 +558,29 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli
|
|||||||
image_slice2 = images[1, -3:, -3:, -1]
|
image_slice2 = images[1, -3:, -3:, -1]
|
||||||
assert np.abs(image_slice1.flatten() - image_slice2.flatten()).max() > 1e-2
|
assert np.abs(image_slice1.flatten() - image_slice2.flatten()).max() > 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_inpaint_euler(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
components = self.get_dummy_components(time_cond_proj_dim=256)
|
||||||
|
sd_pipe = StableDiffusionInpaintPipeline(**components)
|
||||||
|
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
|
||||||
|
sd_pipe = sd_pipe.to(device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device, output_pil=False)
|
||||||
|
half_dim = inputs["image"].shape[2] // 2
|
||||||
|
inputs["mask_image"][0, 0, :half_dim, :half_dim] = 0
|
||||||
|
|
||||||
|
inputs["num_inference_steps"] = 4
|
||||||
|
image = sd_pipe(**inputs).images
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 64, 64, 3)
|
||||||
|
|
||||||
|
expected_slice = np.array(
|
||||||
|
[[0.6387283, 0.5564158, 0.58631873, 0.5539942, 0.5494673, 0.6461868, 0.5251618, 0.5497595, 0.5508756]]
|
||||||
|
)
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
|||||||
Reference in New Issue
Block a user