Compare commits

..

6 Commits

Author SHA1 Message Date
Dhruv Nair a30871a0c5 update 2024-02-02 05:12:27 +00:00
Dhruv Nair 9237ea5787 update 2024-02-02 05:07:12 +00:00
Dhruv Nair f915b558d4 Update src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-02-02 10:27:34 +05:30
Dhruv Nair e2827f819a update 2024-02-02 04:30:44 +00:00
Dhruv Nair 3cf7b068c3 update 2024-02-01 08:02:27 +00:00
Dhruv Nair c7652d3d60 update 2024-02-01 07:58:59 +00:00
61 changed files with 701 additions and 1418 deletions
@@ -1,244 +0,0 @@
# Advanced diffusion training examples
## Train Dreambooth LoRA with Stable Diffusion XL
> [!TIP]
> 💡 This example follows the techniques and recommended practices covered in the blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). Make sure to check it out before starting 🤗
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
The `train_dreambooth_lora_sdxl_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_sdxl.py`, with
advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),
[Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️
> [!NOTE]
> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳
> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)
📚 Read more about the advanced features and best practices in this community derived blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script)
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/advanced_diffusion_training` folder and run
```bash
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell e.g. a notebook
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Pivotal Tuning
**Training with text encoder(s)**
Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization
available with `train_dreambooth_lora_sdxl_advanced.py`, in the advanced script **pivotal tuning** is also supported.
[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -
we insert new tokens into the text encoders of the model, instead of reusing existing ones.
We then optimize the newly-inserted token embeddings to represent the new concept.
To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).
Please keep the following points in mind:
* SDXL has two text encoders. So, we fine-tune both using LoRA.
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memoםהקרry.
### 3D icon example
Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./3d_icon"
snapshot_download(
"LinoyTsaban/3d_icon",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
Let's review some of the advanced features we're going to be using for this example:
- **custom captions**:
To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by
```bash
pip install datasets
```
Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt")
```
--dataset_name=./3d_icon
--caption_column=prompt
```
You can also load a dataset straight from by specifying it's name in `dataset_name`.
Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset.
- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer
- **pivotal tuning**
- **min SNR gamma**
**Now, we can launch training:**
```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export DATASET_NAME="./3d_icon"
export OUTPUT_DIR="3d-icon-SDXL-LoRA"
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
accelerate launch train_dreambooth_lora_sdxl_advanced.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--pretrained_vae_model_name_or_path=$VAE_PATH \
--dataset_name=$DATASET_NAME \
--instance_prompt="3d icon in the style of TOK" \
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
--output_dir=$OUTPUT_DIR \
--caption_column="prompt" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=3 \
--repeats=1 \
--report_to="wandb"\
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--learning_rate=1.0 \
--text_encoder_lr=1.0 \
--optimizer="prodigy"\
--train_text_encoder_ti\
--train_text_encoder_ti_frac=0.5\
--snr_gamma=5.0 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--rank=8 \
--max_train_steps=1000 \
--checkpointing_steps=2000 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
Our experiments were conducted on a single 40GB A100 GPU.
### Inference
Once training is done, we can perform inference like so:
1. starting with loading the unet lora weights
```python
import torch
from huggingface_hub import hf_hub_download, upload_file
from diffusers import DiffusionPipeline
from diffusers.models import AutoencoderKL
from safetensors.torch import load_file
username = "linoyts"
repo_id = f"{username}/3d-icon-SDXL-LoRA"
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors")
```
2. now we load the pivotal tuning embeddings
```python
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-SDXL-LoRA_emb.safetensors", repo_type="model")
state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
pipe.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
```
3. let's generate images
```python
instance_token = "<s0><s1>"
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
image.save("llama.png")
```
### Comfy UI / AUTOMATIC1111 Inference
The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!
**AUTOMATIC1111 / SD.Next** \
In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.
- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.
You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls <lora:y2k:0.9>`. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.
**ComfyUI** \
In ComfyUI we will load a LoRA and a textual embedding at the same time.
- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).
-
### Specifying a better VAE
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
### Tips and Tricks
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
## Running on Colab Notebook
Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_advanced_example.ipynb).
to train using the advanced features (including pivotal tuning), and [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb) to train on a free colab, using some of the advanced features (excluding pivotal tuning)
@@ -1,7 +0,0 @@
accelerate>=0.16.0
torchvision
transformers>=4.25.1
ftfy
tensorboard
Jinja2
peft==0.7.0
@@ -119,9 +119,10 @@ def save_model_card(
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
"""
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
"""
webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
- Place it on it on your `embeddings` folder
@@ -388,7 +389,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--resolution",
type=int,
default=512,
default=1024,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
@@ -644,7 +645,6 @@ def parse_args(input_args=None):
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument(
"--rank",
type=int,
@@ -745,11 +745,10 @@ class TokenEmbeddingsHandler:
idx += 1
# copied from train_dreambooth_lora_sdxl_advanced.py
def save_embeddings(self, file_path: str):
assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
tensors = {}
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
for idx, text_encoder in enumerate(self.text_encoders):
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
@@ -1635,11 +1634,6 @@ def main(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn(
(model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
)
bsz = model_input.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
@@ -1794,7 +1788,6 @@ def main(args):
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
tokenizer=tokenizer_one,
text_encoder=accelerator.unwrap_model(text_encoder_one),
unet=accelerator.unwrap_model(unet),
revision=args.revision,
@@ -1867,11 +1860,6 @@ def main(args):
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)
if args.train_text_encoder_ti:
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
embedding_handler.save_embeddings(embeddings_path)
images = []
if args.validation_prompt and args.num_validation_images > 0:
# Final inference
@@ -1907,18 +1895,6 @@ def main(args):
# load attention processors
pipeline.load_lora_weights(args.output_dir)
# load new tokens
if args.train_text_encoder_ti:
state_dict = load_file(embeddings_path)
all_new_tokens = []
for key, value in token_abstraction_dict.items():
all_new_tokens.extend(value)
pipeline.load_textual_inversion(
state_dict["clip_l"],
token=all_new_tokens,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
)
# run inference
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
@@ -1941,6 +1917,11 @@ def main(args):
}
)
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
)
# Conver to WebUI format
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
@@ -20,7 +20,6 @@ import itertools
import logging
import math
import os
import random
import re
import shutil
import warnings
@@ -46,7 +45,6 @@ from PIL.ImageOps import exif_transpose
from safetensors.torch import load_file, save_file
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
@@ -123,7 +121,7 @@ def save_model_card(
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
"""
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
@@ -399,6 +397,18 @@ def parse_args(input_args=None):
" resolution"
),
)
parser.add_argument(
"--crops_coords_top_left_h",
type=int,
default=0,
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
)
parser.add_argument(
"--crops_coords_top_left_w",
type=int,
default=0,
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
)
parser.add_argument(
"--center_crop",
default=False,
@@ -408,11 +418,6 @@ def parse_args(input_args=None):
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_text_encoder",
action="store_true",
@@ -654,7 +659,6 @@ def parse_args(input_args=None):
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument(
"--rank",
type=int,
@@ -897,41 +901,6 @@ class DreamBoothDataset(Dataset):
self.instance_images = []
for img in instance_images:
self.instance_images.extend(itertools.repeat(img, repeats))
# image processing to prepare for using SD-XL micro-conditioning
self.original_sizes = []
self.crop_top_lefts = []
self.pixel_values = []
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
for image in self.instance_images:
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
self.original_sizes.append((image.height, image.width))
image = train_resize(image)
if args.random_flip and random.random() < 0.5:
# flip
image = train_flip(image)
if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
image = train_crop(image)
else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w)
crop_top_left = (y1, x1)
self.crop_top_lefts.append(crop_top_left)
image = train_transforms(image)
self.pixel_values.append(image)
self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images
@@ -961,12 +930,12 @@ class DreamBoothDataset(Dataset):
def __getitem__(self, index):
example = {}
instance_image = self.pixel_values[index % self.num_instance_images]
original_size = self.original_sizes[index % self.num_instance_images]
crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
example["instance_images"] = instance_image
example["original_size"] = original_size
example["crop_top_left"] = crop_top_left
instance_image = self.instance_images[index % self.num_instance_images]
instance_image = exif_transpose(instance_image)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
@@ -997,8 +966,6 @@ class DreamBoothDataset(Dataset):
def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples]
prompts = [example["instance_prompt"] for example in examples]
original_sizes = [example["original_size"] for example in examples]
crop_top_lefts = [example["crop_top_left"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
@@ -1009,12 +976,7 @@ def collate_fn(examples, with_prior_preservation=False):
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {
"pixel_values": pixel_values,
"prompts": prompts,
"original_sizes": original_sizes,
"crop_top_lefts": crop_top_lefts,
}
batch = {"pixel_values": pixel_values, "prompts": prompts}
return batch
@@ -1236,9 +1198,7 @@ def main(args):
args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement))
if args.with_prior_preservation:
args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement))
if args.validation_prompt:
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
print("validation prompt:", args.validation_prompt)
# initialize the new tokens for textual inversion
embedding_handler = TokenEmbeddingsHandler(
[text_encoder_one, text_encoder_two], [tokenizer_one, tokenizer_two]
@@ -1579,11 +1539,11 @@ def main(args):
# pooled text embeddings
# time ids
def compute_time_ids(crops_coords_top_left, original_size=None):
def compute_time_ids():
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
if original_size is None:
original_size = (args.resolution, args.resolution)
original_size = (args.resolution, args.resolution)
target_size = (args.resolution, args.resolution)
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
@@ -1600,6 +1560,9 @@ def main(args):
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds
# Handle instance prompt.
instance_time_ids = compute_time_ids()
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
@@ -1610,6 +1573,7 @@ def main(args):
# Handle class prompt for prior-preservation.
if args.with_prior_preservation:
class_time_ids = compute_time_ids()
if freeze_text_encoder:
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
args.class_prompt, text_encoders, tokenizers
@@ -1624,6 +1588,9 @@ def main(args):
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
add_time_ids = instance_time_ids
if args.with_prior_preservation:
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
# if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion
add_special_tokens = True if args.train_text_encoder_ti else False
@@ -1646,6 +1613,12 @@ def main(args):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
if args.train_text_encoder_ti and args.validation_prompt:
# replace instances of --token_abstraction in validation prompt with the new tokens: "<si><si+1>" etc.
for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
print("validation prompt:", args.validation_prompt)
if args.cache_latents:
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
@@ -1805,12 +1778,6 @@ def main(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn(
(model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
)
bsz = model_input.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
@@ -1822,26 +1789,19 @@ def main(args):
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# time ids
add_time_ids = torch.cat(
[
compute_time_ids(original_size=s, crops_coords_top_left=c)
for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
]
)
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
if not train_dataset.custom_instance_prompts:
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
else:
elems_to_repeat_text_embeds = 1
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
# Predict the noise residual
if freeze_text_encoder:
unet_added_conditions = {
"time_ids": add_time_ids,
# "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
}
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
@@ -1852,7 +1812,7 @@ def main(args):
added_cond_kwargs=unet_added_conditions,
).sample
else:
unet_added_conditions = {"time_ids": add_time_ids}
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None,
@@ -1994,8 +1954,6 @@ def main(args):
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
tokenizer=tokenizer_one,
tokenizer_2=tokenizer_two,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet),
@@ -2075,11 +2033,6 @@ def main(args):
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
)
if args.train_text_encoder_ti:
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
embedding_handler.save_embeddings(embeddings_path)
images = []
if args.validation_prompt and args.num_validation_images > 0:
# Final inference
@@ -2115,25 +2068,6 @@ def main(args):
# load attention processors
pipeline.load_lora_weights(args.output_dir)
# load new tokens
if args.train_text_encoder_ti:
state_dict = load_file(embeddings_path)
all_new_tokens = []
for key, value in token_abstraction_dict.items():
all_new_tokens.extend(value)
pipeline.load_textual_inversion(
state_dict["clip_l"],
token=all_new_tokens,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
)
pipeline.load_textual_inversion(
state_dict["clip_g"],
token=all_new_tokens,
text_encoder=pipeline.text_encoder_2,
tokenizer=pipeline.tokenizer_2,
)
# run inference
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
@@ -2156,6 +2090,11 @@ def main(args):
}
)
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
)
# Conver to WebUI format
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
+44 -37
View File
@@ -104,22 +104,6 @@ class LoRAIPAdapterAttnProcessor(nn.Module):
):
residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -141,8 +125,15 @@ class LoRAIPAdapterAttnProcessor(nn.Module):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
@@ -242,22 +233,6 @@ class LoRAIPAdapterAttnProcessor2_0(nn.Module):
):
residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -284,8 +259,15 @@ class LoRAIPAdapterAttnProcessor2_0(nn.Module):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
@@ -969,6 +951,30 @@ class IPAdapterFaceIDStableDiffusionPipeline(
return prompt_embeds, negative_prompt_embeds
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -1296,6 +1302,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
image_embeds (`torch.FloatTensor`, *optional*):
Pre-generated image embeddings.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -1404,7 +1411,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if image_embeds is not None:
image_embeds = torch.stack([image_embeds] * num_images_per_prompt, dim=0).to(
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0).to(
device=device, dtype=prompt_embeds.dtype
)
negative_image_embeds = torch.zeros_like(image_embeds)
@@ -538,7 +538,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
return hidden_states, output_states
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
eps = 1e-6
output_states = ()
@@ -634,9 +634,7 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
return hidden_states
def hacked_UpBlock2D_forward(
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
):
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
eps = 1e-6
for i, resnet in enumerate(self.resnets):
# pop res hidden states
@@ -507,7 +507,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
return hidden_states, output_states
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
eps = 1e-6
output_states = ()
@@ -603,9 +603,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
return hidden_states
def hacked_UpBlock2D_forward(
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
):
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
eps = 1e-6
for i, resnet in enumerate(self.resnets):
# pop res hidden states
@@ -19,7 +19,6 @@ import itertools
import logging
import math
import os
import random
import shutil
import warnings
from pathlib import Path
@@ -41,7 +40,6 @@ from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
@@ -306,6 +304,18 @@ def parse_args(input_args=None):
" resolution"
),
)
parser.add_argument(
"--crops_coords_top_left_h",
type=int,
default=0,
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
)
parser.add_argument(
"--crops_coords_top_left_w",
type=int,
default=0,
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
)
parser.add_argument(
"--center_crop",
default=False,
@@ -315,11 +325,6 @@ def parse_args(input_args=None):
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_text_encoder",
action="store_true",
@@ -664,41 +669,6 @@ class DreamBoothDataset(Dataset):
self.instance_images = []
for img in instance_images:
self.instance_images.extend(itertools.repeat(img, repeats))
# image processing to prepare for using SD-XL micro-conditioning
self.original_sizes = []
self.crop_top_lefts = []
self.pixel_values = []
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
for image in self.instance_images:
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
self.original_sizes.append((image.height, image.width))
image = train_resize(image)
if args.random_flip and random.random() < 0.5:
# flip
image = train_flip(image)
if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
image = train_crop(image)
else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w)
crop_top_left = (y1, x1)
self.crop_top_lefts.append(crop_top_left)
image = train_transforms(image)
self.pixel_values.append(image)
self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images
@@ -728,12 +698,12 @@ class DreamBoothDataset(Dataset):
def __getitem__(self, index):
example = {}
instance_image = self.pixel_values[index % self.num_instance_images]
original_size = self.original_sizes[index % self.num_instance_images]
crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
example["instance_images"] = instance_image
example["original_size"] = original_size
example["crop_top_left"] = crop_top_left
instance_image = self.instance_images[index % self.num_instance_images]
instance_image = exif_transpose(instance_image)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
@@ -760,8 +730,6 @@ class DreamBoothDataset(Dataset):
def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples]
prompts = [example["instance_prompt"] for example in examples]
original_sizes = [example["original_size"] for example in examples]
crop_top_lefts = [example["crop_top_left"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
@@ -772,12 +740,7 @@ def collate_fn(examples, with_prior_preservation=False):
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {
"pixel_values": pixel_values,
"prompts": prompts,
"original_sizes": original_sizes,
"crop_top_lefts": crop_top_lefts,
}
batch = {"pixel_values": pixel_values, "prompts": prompts}
return batch
@@ -1270,9 +1233,11 @@ def main(args):
# pooled text embeddings
# time ids
def compute_time_ids(original_size, crops_coords_top_left):
def compute_time_ids():
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
original_size = (args.resolution, args.resolution)
target_size = (args.resolution, args.resolution)
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
@@ -1289,6 +1254,9 @@ def main(args):
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds
# Handle instance prompt.
instance_time_ids = compute_time_ids()
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
@@ -1299,6 +1267,7 @@ def main(args):
# Handle class prompt for prior-preservation.
if args.with_prior_preservation:
class_time_ids = compute_time_ids()
if not args.train_text_encoder:
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
args.class_prompt, text_encoders, tokenizers
@@ -1313,6 +1282,9 @@ def main(args):
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
add_time_ids = instance_time_ids
if args.with_prior_preservation:
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
if not train_dataset.custom_instance_prompts:
if not args.train_text_encoder:
@@ -1427,8 +1399,8 @@ def main(args):
text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
text_encoder_one.text_model.embeddings.requires_grad_(True)
text_encoder_two.text_model.embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
@@ -1464,24 +1436,18 @@ def main(args):
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# time ids
add_time_ids = torch.cat(
[
compute_time_ids(original_size=s, crops_coords_top_left=c)
for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
]
)
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
if not train_dataset.custom_instance_prompts:
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
else:
elems_to_repeat_text_embeds = 1
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
# Predict the noise residual
if not args.train_text_encoder:
unet_added_conditions = {
"time_ids": add_time_ids,
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
}
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
@@ -1493,7 +1459,7 @@ def main(args):
return_dict=False,
)[0]
else:
unet_added_conditions = {"time_ids": add_time_ids}
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None,
-6
View File
@@ -158,12 +158,6 @@ class BasicTransformerBlock(nn.Module):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+21 -10
View File
@@ -1031,10 +1031,16 @@ class DownBlockMotion(nn.Module):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, scale
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states.requires_grad_(),
temb,
num_frames,
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
output_states = output_states + (hidden_states,)
@@ -1215,10 +1221,10 @@ class CrossAttnDownBlockMotion(nn.Module):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)[0]
# apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None:
@@ -1419,10 +1425,10 @@ class CrossAttnUpBlockMotion(nn.Module):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -1557,10 +1563,15 @@ class UpBlockMotion(nn.Module):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
)
else:
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -792,7 +792,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
@@ -800,9 +799,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds)
image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds]
encoder_hidden_states = (encoder_hidden_states, image_embeds)
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
# 2. pre-process
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
@@ -789,8 +789,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@@ -705,8 +705,6 @@ class StableDiffusionControlNetInpaintPipeline(
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@@ -871,8 +871,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@@ -566,8 +566,6 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@@ -536,8 +536,6 @@ class StableDiffusionInpaintPipelineLegacy(
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@@ -634,8 +634,6 @@ class LatentConsistencyModelImg2ImgPipeline(
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@@ -906,8 +906,6 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@@ -467,8 +467,6 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@@ -659,8 +659,6 @@ class StableDiffusionImg2ImgPipeline(
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@@ -859,8 +859,6 @@ class StableDiffusionInpaintPipeline(
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@@ -754,8 +754,6 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@@ -554,8 +554,6 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@@ -98,9 +98,15 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self.custom_timesteps = False
self.is_scale_input_called = False
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
return indices.item()
@property
def step_index(self):
"""
@@ -108,24 +114,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
@@ -243,7 +231,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Modified _convert_to_karras implementation that takes in ramp as argument
@@ -293,29 +280,23 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
return c_skip, c_out
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
indices = (schedule_timesteps == timestep).nonzero()
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
self._step_index = self._begin_index
step_index = index_candidates[0]
self._step_index = step_index.item()
def step(
self,
@@ -431,11 +412,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
@@ -187,7 +187,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
@@ -197,24 +196,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -274,7 +255,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
@@ -640,12 +620,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError("only support log-rho multistep deis now")
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (schedule_timesteps == timestep).nonzero()
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
@@ -658,20 +637,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
else:
step_index = index_candidates[0].item()
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
self._step_index = step_index
def step(
self,
@@ -770,11 +736,16 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# begin_index is None when the scheduler is used for training
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
@@ -227,7 +227,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
@@ -237,23 +236,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -329,7 +311,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
@@ -811,11 +792,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
)
return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (schedule_timesteps == timestep).nonzero()
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
@@ -828,19 +809,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
else:
step_index = index_candidates[0].item()
return step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
self._step_index = step_index
def step(
self,
@@ -951,11 +920,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# begin_index is None when the scheduler is used for training
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
@@ -767,6 +767,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
)
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -878,6 +879,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
@@ -13,6 +13,7 @@
# limitations under the License.
import math
from collections import defaultdict
from typing import List, Optional, Tuple, Union
import numpy as np
@@ -197,10 +198,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.noise_sampler = None
self.noise_sampler_seed = noise_sampler_seed
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -211,18 +211,31 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
if len(self._index_counter) == 0:
pos = 1 if len(indices) > 1 else 0
else:
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
pos = self._index_counter[timestep_int]
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
self._step_index = self._begin_index
step_index = index_candidates[0]
self._step_index = step_index.item()
@property
def init_noise_sigma(self):
@@ -239,24 +252,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input(
self,
sample: torch.FloatTensor,
@@ -353,10 +348,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.mid_point_sigma = None
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.noise_sampler = None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self._index_counter = defaultdict(int)
def _second_order_timesteps(self, sigmas, log_sigmas):
def sigma_fn(_t):
return np.exp(-_t)
@@ -446,6 +444,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
if self.step_index is None:
self._init_step_index(timestep)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
# Create a noise sampler if it hasn't been created yet
if self.noise_sampler is None:
min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max()
@@ -525,7 +527,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
@@ -542,11 +544,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
@@ -210,7 +210,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.sample = None
self.order_list = self.get_order_list(num_train_timesteps)
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
def get_order_list(self, num_inference_steps: int) -> List[int]:
@@ -254,24 +253,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -334,7 +315,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
@@ -833,12 +813,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
else:
raise ValueError(f"Order must be 1, 2, 3, got {order}")
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (schedule_timesteps == timestep).nonzero()
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
@@ -851,20 +830,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
else:
step_index = index_candidates[0].item()
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
self._step_index = step_index
def step(
self,
@@ -959,11 +925,16 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# begin_index is None when the scheduler is used for training
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
@@ -216,7 +216,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = False
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
@@ -234,24 +233,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
@@ -319,32 +300,25 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
indices = (schedule_timesteps == timestep).nonzero()
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
self._step_index = self._begin_index
step_index = index_candidates[0]
self._step_index = step_index.item()
def step(
self,
@@ -466,11 +440,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
@@ -237,7 +237,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.use_karras_sigmas = use_karras_sigmas
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
@@ -256,24 +255,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
@@ -361,7 +342,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
def _sigma_to_t(self, sigma, log_sigmas):
@@ -413,27 +393,22 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
indices = (schedule_timesteps == timestep).nonzero()
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
self._step_index = self._begin_index
step_index = index_candidates[0]
self._step_index = step_index.item()
def step(
self,
@@ -563,11 +538,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
@@ -13,6 +13,7 @@
# limitations under the License.
import math
from collections import defaultdict
from typing import List, Optional, Tuple, Union
import numpy as np
@@ -147,10 +148,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.use_karras_sigmas = use_karras_sigmas
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -161,7 +160,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
if len(self._index_counter) == 0:
pos = 1 if len(indices) > 1 else 0
else:
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
pos = self._index_counter[timestep_int]
return indices[pos].item()
@@ -180,24 +183,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input(
self,
sample: torch.FloatTensor,
@@ -285,9 +270,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.dt = None
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self._index_counter = defaultdict(int)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
@@ -344,12 +333,21 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
self._step_index = self._begin_index
step_index = index_candidates[0]
self._step_index = step_index.item()
def step(
self,
@@ -380,6 +378,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
if self.step_index is None:
self._init_step_index(timestep)
# (YiYi notes: keep this for now since we are keeping the add_noise method)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
if self.state_in_first_order:
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
@@ -450,7 +453,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
@@ -467,11 +469,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
+10 -36
View File
@@ -56,7 +56,6 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
# running values
self.ets = []
self._step_index = None
self._begin_index = None
@property
def step_index(self):
@@ -65,24 +64,6 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -109,31 +90,24 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
self.ets = []
self._step_index = None
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
indices = (schedule_timesteps == timestep).nonzero()
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
self._step_index = self._begin_index
step_index = index_candidates[0]
self._step_index = step_index.item()
def step(
self,
@@ -13,6 +13,7 @@
# limitations under the License.
import math
from collections import defaultdict
from typing import List, Optional, Tuple, Union
import numpy as np
@@ -139,9 +140,27 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(self._index_counter) == 0:
pos = 1 if len(indices) > 1 else 0
else:
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
pos = self._index_counter[timestep_int]
return indices[pos].item()
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
@@ -157,24 +176,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input(
self,
sample: torch.FloatTensor,
@@ -294,8 +295,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sample = None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self._index_counter = defaultdict(int)
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
@@ -352,29 +356,23 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def state_in_first_order(self):
return self.sample is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
indices = (schedule_timesteps == timestep).nonzero()
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
self._step_index = self._begin_index
step_index = index_candidates[0]
self._step_index = step_index.item()
def step(
self,
@@ -408,6 +406,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
if self.step_index is None:
self._init_step_index(timestep)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
if self.state_in_first_order:
sigma = self.sigmas[self.step_index]
sigma_interpol = self.sigmas_interpol[self.step_index]
@@ -476,7 +478,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
@@ -493,11 +495,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
@@ -13,6 +13,7 @@
# limitations under the License.
import math
from collections import defaultdict
from typing import List, Optional, Tuple, Union
import numpy as np
@@ -139,9 +140,27 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(self._index_counter) == 0:
pos = 1 if len(indices) > 1 else 0
else:
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
pos = self._index_counter[timestep_int]
return indices[pos].item()
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
@@ -157,24 +176,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input(
self,
sample: torch.FloatTensor,
@@ -279,37 +280,34 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sample = None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self._index_counter = defaultdict(int)
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def state_in_first_order(self):
return self.sample is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
indices = (schedule_timesteps == timestep).nonzero()
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
self._step_index = self._begin_index
step_index = index_candidates[0]
self._step_index = step_index.item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
@@ -390,6 +388,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
if self.step_index is None:
self._init_step_index(timestep)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
if self.state_in_first_order:
sigma = self.sigmas[self.step_index]
sigma_interpol = self.sigmas_interpol[self.step_index + 1]
@@ -451,7 +453,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
@@ -468,11 +470,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
+10 -36
View File
@@ -250,54 +250,29 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
self.custom_timesteps = False
self._step_index = None
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
indices = (schedule_timesteps == timestep).nonzero()
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
self._step_index = self._begin_index
step_index = index_candidates[0]
self._step_index = step_index.item()
@property
def step_index(self):
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
@@ -487,7 +462,6 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
self._step_index = None
self._begin_index = None
def get_scalings_for_boundary_condition_discrete(self, timestep):
self.sigma_data = 0.5 # Default: 0.5
@@ -168,7 +168,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = False
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
@@ -186,24 +185,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
@@ -299,34 +280,27 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.derivatives = []
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
indices = (schedule_timesteps == timestep).nonzero()
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
self._step_index = self._begin_index
step_index = index_candidates[0]
self._step_index = step_index.item()
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
@@ -460,11 +434,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
@@ -212,7 +212,6 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
self.lower_order_nums = 0
self.last_sample = None
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
@@ -222,24 +221,6 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -302,7 +283,6 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
@@ -945,12 +925,11 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
x_t = x_t.to(x.dtype)
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (schedule_timesteps == timestep).nonzero()
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
@@ -963,20 +942,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
else:
step_index = index_candidates[0].item()
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
self._step_index = step_index
def step(
self,
@@ -198,7 +198,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.solver_p = solver_p
self.last_sample = None
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
@@ -208,24 +207,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -288,7 +269,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
@@ -718,12 +698,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
x_t = x_t.to(x.dtype)
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (schedule_timesteps == timestep).nonzero()
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
@@ -736,20 +715,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
else:
step_index = index_candidates[0].item()
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
self._step_index = step_index
def step(
self,
@@ -864,11 +830,16 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# begin_index is None when the scheduler is used for training
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
-2
View File
@@ -854,8 +854,6 @@ def _is_torch_fp64_available(device):
import torch
device = torch.device(device)
try:
x = torch.zeros((2, 2), dtype=torch.float64).to(device)
_ = torch.mul(x, x)
@@ -25,7 +25,7 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
@@ -30,7 +30,7 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
@@ -48,7 +48,7 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
@@ -23,7 +23,7 @@ from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
@@ -30,7 +30,7 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
@@ -28,7 +28,7 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
@@ -46,7 +46,7 @@ from diffusers.utils.testing_utils import (
)
from diffusers.utils.torch_utils import randn_tensor
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
@@ -4,7 +4,7 @@ from diffusers import FlaxAutoencoderKL
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
from ..test_modeling_common_flax import FlaxModelTesterMixin
from .test_modeling_common_flax import FlaxModelTesterMixin
if is_flax_available():
@@ -25,7 +25,7 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
View File
+238 -12
View File
@@ -14,16 +14,22 @@
# limitations under the License.
import gc
import random
import unittest
import torch
from diffusers import (
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
IFInpaintingPipeline,
IFInpaintingSuperResolutionPipeline,
IFPipeline,
IFSuperResolutionPipeline,
)
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
@@ -91,18 +97,77 @@ class IFPipelineSlowTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
def test_if_text_to_image(self):
pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload()
def test_all(self):
# if
torch.cuda.reset_max_memory_allocated()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
pipe_1 = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
pipe_2 = IFSuperResolutionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0", variant="fp16", torch_dtype=torch.float16, text_encoder=None, tokenizer=None
)
# pre compute text embeddings and remove T5 to save memory
pipe_1.text_encoder.to("cuda")
prompt_embeds, negative_prompt_embeds = pipe_1.encode_prompt("anime turtle", device="cuda")
del pipe_1.tokenizer
del pipe_1.text_encoder
gc.collect()
pipe_1.tokenizer = None
pipe_1.text_encoder = None
pipe_1.enable_model_cpu_offload()
pipe_2.enable_model_cpu_offload()
pipe_1.unet.set_attn_processor(AttnAddedKVProcessor())
pipe_2.unet.set_attn_processor(AttnAddedKVProcessor())
self._test_if(pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds)
pipe_1.remove_all_hooks()
pipe_2.remove_all_hooks()
# img2img
pipe_1 = IFImg2ImgPipeline(**pipe_1.components)
pipe_2 = IFImg2ImgSuperResolutionPipeline(**pipe_2.components)
pipe_1.enable_model_cpu_offload()
pipe_2.enable_model_cpu_offload()
pipe_1.unet.set_attn_processor(AttnAddedKVProcessor())
pipe_2.unet.set_attn_processor(AttnAddedKVProcessor())
self._test_if_img2img(pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds)
pipe_1.remove_all_hooks()
pipe_2.remove_all_hooks()
# inpainting
pipe_1 = IFInpaintingPipeline(**pipe_1.components)
pipe_2 = IFInpaintingSuperResolutionPipeline(**pipe_2.components)
pipe_1.enable_model_cpu_offload()
pipe_2.enable_model_cpu_offload()
pipe_1.unet.set_attn_processor(AttnAddedKVProcessor())
pipe_2.unet.set_attn_processor(AttnAddedKVProcessor())
self._test_if_inpainting(pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds)
def _test_if(self, pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds):
# pipeline 1
_start_torch_memory_measurement()
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe(
prompt="anime turtle",
output = pipe_1(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
num_inference_steps=2,
generator=generator,
output_type="np",
@@ -110,11 +175,172 @@ class IFPipelineSlowTests(unittest.TestCase):
image = output.images[0]
assert image.shape == (64, 64, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 12 * 10**9
assert mem_bytes < 13 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if.npy"
)
assert_mean_pixel_difference(image, expected_image)
pipe.remove_all_hooks()
# pipeline 2
_start_torch_memory_measurement()
generator = torch.Generator(device="cpu").manual_seed(0)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
output = pipe_2(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image=image,
generator=generator,
num_inference_steps=2,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 4 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_superresolution_stage_II.npy"
)
assert_mean_pixel_difference(image, expected_image)
def _test_if_img2img(self, pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds):
# pipeline 1
_start_torch_memory_measurement()
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe_1(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image=image,
num_inference_steps=2,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (64, 64, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 10 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_img2img.npy"
)
assert_mean_pixel_difference(image, expected_image)
# pipeline 2
_start_torch_memory_measurement()
generator = torch.Generator(device="cpu").manual_seed(0)
original_image = floats_tensor((1, 3, 256, 256), rng=random.Random(0)).to(torch_device)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
output = pipe_2(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image=image,
original_image=original_image,
generator=generator,
num_inference_steps=2,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 4 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_img2img_superresolution_stage_II.npy"
)
assert_mean_pixel_difference(image, expected_image)
def _test_if_inpainting(self, pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds):
# pipeline 1
_start_torch_memory_measurement()
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
mask_image = floats_tensor((1, 3, 64, 64), rng=random.Random(1)).to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe_1(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image=image,
mask_image=mask_image,
num_inference_steps=2,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (64, 64, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 10 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_inpainting.npy"
)
assert_mean_pixel_difference(image, expected_image)
# pipeline 2
_start_torch_memory_measurement()
generator = torch.Generator(device="cpu").manual_seed(0)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
original_image = floats_tensor((1, 3, 256, 256), rng=random.Random(0)).to(torch_device)
mask_image = floats_tensor((1, 3, 256, 256), rng=random.Random(1)).to(torch_device)
output = pipe_2(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image=image,
mask_image=mask_image,
original_image=original_image,
generator=generator,
num_inference_steps=2,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 4 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_inpainting_superresolution_stage_II.npy"
)
assert_mean_pixel_difference(image, expected_image)
def _start_torch_memory_measurement():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
@@ -13,22 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import torch
from diffusers import IFImg2ImgPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from diffusers.utils.testing_utils import floats_tensor, skip_mps, torch_device
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
)
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from ..test_pipelines_common import PipelineTesterMixin
from . import IFPipelineTesterMixin
@@ -89,43 +87,3 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni
self._test_inference_batch_single_identical(
expected_max_diff=1e-2,
)
@slow
@require_torch_gpu
class IFImg2ImgPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_if_img2img(self):
pipe = IFImg2ImgPipeline.from_pretrained(
"DeepFloyd/IF-I-L-v1.0",
variant="fp16",
torch_dtype=torch.float16,
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload()
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe(
prompt="anime turtle",
image=image,
num_inference_steps=2,
generator=generator,
output_type="np",
)
image = output.images[0]
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_img2img.npy"
)
assert_mean_pixel_difference(image, expected_image)
pipe.remove_all_hooks()
@@ -13,22 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import torch
from diffusers import IFImg2ImgSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from diffusers.utils.testing_utils import floats_tensor, skip_mps, torch_device
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
)
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from . import IFPipelineTesterMixin
@@ -87,50 +82,3 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT
self._test_inference_batch_single_identical(
expected_max_diff=1e-2,
)
@slow
@require_torch_gpu
class IFImg2ImgSuperResolutionPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_if_img2img_superresolution(self):
pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0",
variant="fp16",
torch_dtype=torch.float16,
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
original_image = floats_tensor((1, 3, 256, 256), rng=random.Random(0)).to(torch_device)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
output = pipe(
prompt="anime turtle",
image=image,
original_image=original_image,
generator=generator,
num_inference_steps=2,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_img2img_superresolution_stage_II.npy"
)
assert_mean_pixel_difference(image, expected_image)
pipe.remove_all_hooks()
@@ -13,22 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import torch
from diffusers import IFInpaintingPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from diffusers.utils.testing_utils import floats_tensor, skip_mps, torch_device
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
)
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from ..test_pipelines_common import PipelineTesterMixin
from . import IFPipelineTesterMixin
@@ -87,48 +85,3 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin,
self._test_inference_batch_single_identical(
expected_max_diff=1e-2,
)
@slow
@require_torch_gpu
class IFInpaintingPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_if_inpainting(self):
pipe = IFInpaintingPipeline.from_pretrained(
"DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload()
# Super resolution test
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
mask_image = floats_tensor((1, 3, 64, 64), rng=random.Random(1)).to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe(
prompt="anime prompts",
image=image,
mask_image=mask_image,
num_inference_steps=2,
generator=generator,
output_type="np",
)
image = output.images[0]
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_inpainting.npy"
)
assert_mean_pixel_difference(image, expected_image)
pipe.remove_all_hooks()
@@ -13,22 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import torch
from diffusers import IFInpaintingSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from diffusers.utils.testing_utils import floats_tensor, skip_mps, torch_device
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
)
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from ..test_pipelines_common import PipelineTesterMixin
from . import IFPipelineTesterMixin
@@ -89,55 +87,3 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli
self._test_inference_batch_single_identical(
expected_max_diff=1e-2,
)
@slow
@require_torch_gpu
class IFInpaintingSuperResolutionPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_if_inpainting_superresolution(self):
pipe = IFInpaintingSuperResolutionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0", variant="fp16", torch_dtype=torch.float16
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload()
# Super resolution test
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
generator = torch.Generator(device="cpu").manual_seed(0)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
original_image = floats_tensor((1, 3, 256, 256), rng=random.Random(0)).to(torch_device)
mask_image = floats_tensor((1, 3, 256, 256), rng=random.Random(1)).to(torch_device)
output = pipe(
prompt="anime turtle",
image=image,
original_image=original_image,
mask_image=mask_image,
generator=generator,
num_inference_steps=2,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_inpainting_superresolution_stage_II.npy"
)
assert_mean_pixel_difference(image, expected_image)
pipe.remove_all_hooks()
@@ -13,19 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import torch
from diffusers import IFSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from diffusers.utils.testing_utils import floats_tensor, skip_mps, torch_device
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from ..test_pipelines_common import PipelineTesterMixin
from . import IFPipelineTesterMixin
@@ -82,49 +80,3 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi
self._test_inference_batch_single_identical(
expected_max_diff=1e-2,
)
@slow
@require_torch_gpu
class IFSuperResolutionPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_if_superresolution(self):
pipe = IFSuperResolutionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0", variant="fp16", torch_dtype=torch.float16
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload()
# Super resolution test
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe(
prompt="anime turtle",
image=image,
generator=generator,
num_inference_steps=2,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_superresolution_stage_II.npy"
)
assert_mean_pixel_difference(image, expected_image)
pipe.remove_all_hooks()
+1 -1
View File
@@ -38,7 +38,7 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..models.autoencoders.test_models_vae import (
from ..models.test_models_vae import (
get_asym_autoencoder_kl_config,
get_autoencoder_kl_config,
get_autoencoder_tiny_config,