Compare commits

...

13 Commits

Author SHA1 Message Date
patil-suraj a44e825b18 add code 2024-01-04 16:22:47 +05:30
patil-suraj 9eceb07f8d add second loop 2024-01-04 14:53:33 +05:30
gzguevara 9f283b01d2 changed w&b report link (#6387) 2023-12-29 19:49:11 +05:30
Sayak Paul 203724e9d9 [Docs] add note on fp16 in fast diffusion (#6380)
add note on fp16
2023-12-29 09:38:50 +05:30
gzguevara e7044a4221 multi-subject-dreambooth-inpainting with 🤗 datasets (#6378)
* files added

* fixing code quality

* fixing code quality

* fixing code quality

* fixing code quality

* sorted import block

* seperated import wandb

* ruff on script

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-29 09:33:49 +05:30
Sayak Paul 034b39b8cb [docs] add details concerning diffusers-specific bits. (#6375)
add details concerning diffusers-specific bits.
2023-12-28 23:12:49 +05:30
Sayak Paul 2db73f4a50 remove delete documentation trigger workflows. (#6373) 2023-12-28 18:26:14 +05:30
Adrian Punga 84d7faebe4 Fix support for MPS in KDPM2AncestralDiscreteScheduler (#6365)
Fix support for MPS

MPS doesn't support float64
2023-12-28 10:22:02 +01:00
YiYi Xu 4c483deb90 [refactor embeddings] gligen + ip-adapter (#6244)
* refactor ip-adapter-imageproj, gligen

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
2023-12-27 18:48:42 -10:00
Sayak Paul 1ac07d8a8d [Training examples] Follow up of #6306 (#6346)
* add to dreambooth lora.

* add: t2i lora.

* add: sdxl t2i lora.

* style

* lcm lora sdxl.

* unwrap

* fix: enable_adapters().
2023-12-28 07:37:50 +05:30
apolinário 1fff527702 Fix keys for lora format on advanced training scripts (#6361)
fix keys for lora format on advanced training scripts
2023-12-27 11:38:03 -06:00
apolinário 645a62bf3b Add PEFT to advanced training script (#6294)
* Fix ProdigyOPT in SDXL Dreambooth script

* style

* style

* Add PEFT to Advanced Training Script

* style

* style

*  style 

* change order for logic operation

* add lora alpha

* style

* Align PEFT to new format

* Update train_dreambooth_lora_sdxl_advanced.py

Apply #6355 fix

---------

Co-authored-by: multimodalart <joaopaulo.passos+multimodal@gmail.com>
2023-12-27 10:00:32 -03:00
Dhruv Nair 6414d4e4f9 Fix chunking in SVD (#6350)
fix
2023-12-27 13:07:41 +01:00
20 changed files with 1034 additions and 164 deletions
-14
View File
@@ -1,14 +0,0 @@
name: Delete doc comment
on:
workflow_run:
workflows: ["Delete doc comment trigger"]
types:
- completed
jobs:
delete:
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
secrets:
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
@@ -1,12 +0,0 @@
name: Delete doc comment trigger
on:
pull_request:
types: [ closed ]
jobs:
delete:
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
with:
pr_number: ${{ github.event.number }}
+25 -1
View File
@@ -96,6 +96,8 @@ bfloat16 reduces the latency from 7.36 seconds to 4.63 seconds:
</div>
_(We later ran the experiments in float16 and found out that the recent versions of torchao do not incur numerical problems from float16.)_
**Why bfloat16?**
* Using a reduced numerical precision (such as float16, bfloat16) to run inference doesnt affect the generation quality but significantly improves latency.
@@ -315,4 +317,26 @@ Applying dynamic quantization improves the latency from 2.52 seconds to 2.43 sec
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_5.png" width=500>
</div>
</div>
## Misc
### No graph breaks during torch.compile
Ensuring that the underlying model/method can be fully compiled is crucial for performance (torch.compile with fullgraph=True). This means having no graph breaks. We did this for the UNet and VAE by changing how we access the returning variables. Consider the following example:
```diff
- latents = unet(
- latents, timestep=timestep, encoder_hidden_states=prompt_embeds
-).sample
+ latents = unet(
+ latents, timestep=timestep, encoder_hidden_states=prompt_embeds, return_dict=False
+)[0]
```
### Getting rid of GPU syncs after compilation
During the iterative reverse diffusion process, we [call](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1228) `step()` on the scheduler each time after the denoiser predicts the less noisy latent embeddings. Inside `step()`, the `sigmas` variable is [indexed](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/schedulers/scheduling_euler_discrete.py#L476). If the `sigmas` array is placed on the GPU, indexing causes a communication sync between the CPU and GPU. This causes a latency, and it becomes more evident when the denoiser has already been compiled.
But if the `sigmas` array always stays on the CPU (refer to [this line](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240)), this sync doesnt take place, hence improved latency. In general, any CPU <-> GPU communication sync should be none or be kept to a bare minimum as it can impact inference latency.
@@ -37,6 +37,8 @@ from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
from safetensors.torch import save_file
@@ -54,10 +56,9 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr, unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -67,39 +68,6 @@ check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str,
images=None,
@@ -161,8 +129,6 @@ tags:
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
widget:
- text: '{validation_prompt if validation_prompt else instance_prompt}'
---
"""
@@ -1264,54 +1230,25 @@ def main(args):
text_encoder_two.gradient_checkpointing_enable()
# now we will add new LoRA weights to the attention layers
# Set correct lora layers
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
unet.add_adapter(unet_lora_config)
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
text_encoder_one, dtype=torch.float32, rank=args.rank
)
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
text_encoder_two, dtype=torch.float32, rank=args.rank
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
# if we use textual inversion, we freeze all parameters except for the token embeddings
# in text encoder
@@ -1335,6 +1272,17 @@ def main(args):
else:
param.requires_grad = False
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
@@ -1346,11 +1294,15 @@ def main(args):
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_lora_state_dict(model)
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1407,6 +1359,12 @@ def main(args):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
if args.train_text_encoder:
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
# If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
@@ -1997,13 +1955,17 @@ def main(args):
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = unet_lora_state_dict(unet)
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
text_encoder_lora_layers = convert_state_dict_to_diffusers(
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
)
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
)
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
@@ -51,7 +51,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -113,7 +113,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
if unet is None:
raise ValueError("Must provide a `unet` when doing intermediate validation.")
unet = accelerator.unwrap_model(unet)
state_dict = get_peft_model_state_dict(unet)
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
to_load = state_dict
else:
to_load = args.output_dir
@@ -819,7 +819,7 @@ def main(args):
unet_ = accelerator.unwrap_model(unet)
# also save the checkpoints in native `diffusers` format so that it can be easily
# be independently loaded via `load_lora_weights()`.
state_dict = get_peft_model_state_dict(unet_)
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict)
for _, model in enumerate(models):
@@ -1184,7 +1184,7 @@ def main(args):
# solver timestep.
# With the adapters disabled, the `unet` is the regular teacher model.
unet.disable_adapters()
accelerator.unwrap_model(unet).disable_adapters()
with torch.no_grad():
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = unet(
@@ -1248,7 +1248,7 @@ def main(args):
x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype)
# re-enable unet adapters to turn the `unet` into a student unet.
unet.enable_adapters()
accelerator.unwrap_model(unet).enable_adapters()
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
@@ -1332,7 +1332,7 @@ def main(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)
if args.push_to_hub:
+7 -5
View File
@@ -54,7 +54,7 @@ from diffusers import (
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -853,9 +853,11 @@ def main(args):
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1285,11 +1287,11 @@ def main(args):
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
if args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
else:
text_encoder_state_dict = None
@@ -0,0 +1,93 @@
# Multi Subject Dreambooth for Inpainting Models
Please note that this project is not actively maintained. However, you can open an issue and tag @gzguevara.
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. This project consists of **two parts**. Training Stable Diffusion for inpainting requieres prompt-image-mask pairs. The Unet of inpainiting models have 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself).
**The first part**, the `multi_inpaint_dataset.ipynb` notebook, demonstrates how make a 🤗 dataset of prompt-image-mask pairs. You can, however, skip the first part and move straight to the second part with the example datasets in this project. ([cat toy dataset masked](https://huggingface.co/datasets/gzguevara/cat_toy_masked), [mr. potato head dataset masked](https://huggingface.co/datasets/gzguevara/mr_potato_head_masked))
**The second part**, the `train_multi_subject_inpainting.py` training script, demonstrates how to implement a training procedure for one or more subjects and adapt it for stable diffusion for inpainting.
## 1. Data Collection: Make Prompt-Image-Mask Pairs
Earlier training scripts have provided approaches like random masking for the training images. This project provides a notebook for more precise mask setting.
The notebook can be found here: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JNEASI_B7pLW1srxhgln6nM0HoGAQT32?usp=sharing)
The `multi_inpaint_dataset.ipynb` notebook, takes training & validation images, on which the user draws masks and provides prompts to make a prompt-image-mask pairs. This ensures that during training, the loss is computed on the area masking the object of interest, rather than on random areas. Moreover, the `multi_inpaint_dataset.ipynb` notebook allows you to build a validation dataset with corresponding masks for monitoring the training process. Example below:
![train_val_pairs](https://drive.google.com/uc?id=1PzwH8E3icl_ubVmA19G0HZGLImFX3x5I)
You can build multiple datasets for every subject and upload them to the 🤗 hub. Later, when launching the training script you can indicate the paths of the datasets, on which you would like to finetune Stable Diffusion for inpaining.
## 2. Train Multi Subject Dreambooth for Inpainting
### 2.1. Setting The Training Configuration
Before launching the training script, make sure to select the inpainting the target model, the output directory and the 🤗 datasets.
```bash
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
export OUTPUT_DIR="path-to-save-model"
export DATASET_1="gzguevara/mr_potato_head_masked"
export DATASET_2="gzguevara/cat_toy_masked"
... # Further paths to 🤗 datasets
```
### 2.2. Launching The Training Script
```bash
accelerate launch train_multi_subject_dreambooth_inpaint.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir $DATASET_1 $DATASET_2 \
--output_dir=$OUTPUT_DIR \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=2 \
--learning_rate=3e-6 \
--max_train_steps=500 \
--report_to_wandb
```
### 2.3. Fine-tune text encoder with the UNet.
The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
```bash
accelerate launch train_multi_subject_dreambooth_inpaint.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir $DATASET_1 $DATASET_2 \
--output_dir=$OUTPUT_DIR \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=2 \
--learning_rate=2e-6 \
--max_train_steps=500 \
--report_to_wandb \
--train_text_encoder
```
## 3. Results
A [![Weights & Biases](https://img.shields.io/badge/Weights%20&%20Biases-Report-blue)](https://wandb.ai/gzguevara/uncategorized/reports/Multi-Subject-Dreambooth-for-Inpainting--Vmlldzo2MzY5NDQ4?accessToken=y0nya2d7baguhbryxaikbfr1203amvn1jsmyl07vk122mrs7tnph037u1nqgse8t) is provided showing the training progress by every 50 steps. Note, the reported weights & baises run was performed on a A100 GPU with the following stetting:
```bash
accelerate launch train_multi_subject_dreambooth_inpaint.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir $DATASET_1 $DATASET_2 \
--output_dir=$OUTPUT_DIR \
--resolution=512 \
--train_batch_size=10 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-6 \
--max_train_steps=500 \
--report_to_wandb \
--train_text_encoder
```
Here you can see the target objects on my desk and next to my plant:
![Results](https://drive.google.com/uc?id=1kQisOiiF5cj4rOYjdq8SCZenNsUP2aK0)
@@ -0,0 +1,8 @@
accelerate>=0.16.0
torchvision
transformers>=4.25.1
datasets>=2.16.0
wandb>=0.16.1
ftfy
tensorboard
Jinja2
@@ -0,0 +1,661 @@
import argparse
import copy
import itertools
import logging
import math
import os
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import concatenate_datasets, load_dataset
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionInpaintPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.13.0.dev0")
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument("--instance_data_dir", nargs="+", help="Instance data directories")
parser.add_argument(
"--output_dir",
type=str,
default="text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--train_text_encoder", default=False, action="store_true", help="Whether to train the text encoder"
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-6,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=1000,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
" using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpointing_from",
type=int,
default=1000,
help=("Start to checkpoint from step"),
)
parser.add_argument(
"--validation_steps",
type=int,
default=50,
help=(
"Run validation every X steps. Validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`"
" and logging the images."
),
)
parser.add_argument(
"--validation_from",
type=int,
default=0,
help=("Start to validate from step"),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--validation_project_name",
type=str,
default=None,
help="The w&b name.",
)
parser.add_argument(
"--report_to_wandb", default=False, action="store_true", help="Whether to report to weights and biases"
)
args = parser.parse_args()
return args
def prepare_mask_and_masked_image(image, mask):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image
class DreamBoothDataset(Dataset):
def __init__(
self,
tokenizer,
datasets_paths,
):
self.tokenizer = tokenizer
self.datasets_paths = (datasets_paths,)
self.datasets = [load_dataset(dataset_path) for dataset_path in self.datasets_paths[0]]
self.train_data = concatenate_datasets([dataset["train"] for dataset in self.datasets])
self.test_data = concatenate_datasets([dataset["test"] for dataset in self.datasets])
self.image_normalize = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def set_image(self, img, switch):
if img.mode not in ["RGB", "L"]:
img = img.convert("RGB")
if switch:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img = img.resize((512, 512), Image.BILINEAR)
return img
def __len__(self):
return len(self.train_data)
def __getitem__(self, index):
# Lettings
example = {}
img_idx = index % len(self.train_data)
switch = random.choice([True, False])
# Load image
image = self.set_image(self.train_data[img_idx]["image"], switch)
# Normalize image
image_norm = self.image_normalize(image)
# Tokenise prompt
tokenized_prompt = self.tokenizer(
self.train_data[img_idx]["prompt"],
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
# Load masks for image
masks = [
self.set_image(self.train_data[img_idx][key], switch) for key in self.train_data[img_idx] if "mask" in key
]
# Build example
example["PIL_image"] = image
example["instance_image"] = image_norm
example["instance_prompt_id"] = tokenized_prompt
example["instance_masks"] = masks
return example
def weighted_mask(masks):
# Convert each mask to a NumPy array and ensure it's binary
mask_arrays = [np.array(mask) / 255 for mask in masks] # Normalizing to 0-1 range
# Generate random weights and apply them to each mask
weights = [random.random() for _ in masks]
weights = [weight / sum(weights) for weight in weights]
weighted_masks = [mask * weight for mask, weight in zip(mask_arrays, weights)]
# Sum the weighted masks
summed_mask = np.sum(weighted_masks, axis=0)
# Apply a threshold to create the final mask
threshold = 0.5 # This threshold can be adjusted
result_mask = summed_mask >= threshold
# Convert the result back to a PIL image
return Image.fromarray(result_mask.astype(np.uint8) * 255)
def collate_fn(examples, tokenizer):
input_ids = [example["instance_prompt_id"] for example in examples]
pixel_values = [example["instance_image"] for example in examples]
masks, masked_images = [], []
for example in examples:
# generate a random mask
mask = weighted_mask(example["instance_masks"])
# prepare mask and masked image
mask, masked_image = prepare_mask_and_masked_image(example["PIL_image"], mask)
masks.append(mask)
masked_images.append(masked_image)
pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
masks = torch.stack(masks)
masked_images = torch.stack(masked_images)
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
return batch
def log_validation(pipeline, text_encoder, unet, val_pairs, accelerator):
# update pipeline (note: unet and vae are loaded again in float32)
pipeline.text_encoder = accelerator.unwrap_model(text_encoder)
pipeline.unet = accelerator.unwrap_model(unet)
with torch.autocast("cuda"):
val_results = [{"data_or_path": pipeline(**pair).images[0], "caption": pair["prompt"]} for pair in val_pairs]
torch.cuda.empty_cache()
wandb.log({"validation": [wandb.Image(**val_result) for val_result in val_results]})
def checkpoint(args, global_step, accelerator):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
def main():
args = parse_args()
project_config = ProjectConfiguration(
total_limit=args.checkpoints_total_limit,
project_dir=args.output_dir,
logging_dir=Path(args.output_dir, args.logging_dir),
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
project_config=project_config,
log_with="wandb" if args.report_to_wandb else None,
)
if args.report_to_wandb and not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
if args.seed is not None:
set_seed(args.seed)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
# Load the tokenizer & models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder"
).requires_grad_(args.train_text_encoder)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae").requires_grad_(False)
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
optimizer = torch.optim.AdamW(
params=itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
train_dataset = DreamBoothDataset(
tokenizer=tokenizer,
datasets_paths=args.instance_data_dir,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, tokenizer),
)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
accelerator.register_for_checkpointing(lr_scheduler)
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
else:
weight_dtype = torch.float32
# Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
vae.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Afterwards we calculate our number of training epochs
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_config = vars(copy.deepcopy(args))
accelerator.init_trackers(args.validation_project_name, config=tracker_config)
# create validation pipeline (note: unet and vae are loaded again in float32)
val_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
args.pretrained_model_name_or_path,
tokenizer=tokenizer,
text_encoder=text_encoder,
unet=unet,
vae=vae,
torch_dtype=weight_dtype,
safety_checker=None,
)
val_pipeline.set_progress_bar_config(disable=True)
# prepare validation dataset
val_pairs = [
{
"image": example["image"],
"mask_image": mask,
"prompt": example["prompt"],
}
for example in train_dataset.test_data
for mask in [example[key] for key in example if "mask" in key]
]
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models:
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
accelerator.register_save_state_pre_hook(save_model_hook)
print()
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Convert masked images to latent space
masked_latents = vae.encode(
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
).latent_dist.sample()
masked_latents = masked_latents * vae.config.scaling_factor
masks = batch["masks"]
# resize the mask to latents shape as we concatenate the mask to the latents
mask = torch.stack(
[
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
for mask in masks
]
)
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# concatenate the noised latents with the mask and the masked latents
latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if (
global_step % args.validation_steps == 0
and global_step >= args.validation_from
and args.report_to_wandb
):
log_validation(
val_pipeline,
text_encoder,
unet,
val_pairs,
accelerator,
)
if global_step % args.checkpointing_steps == 0 and global_step >= args.checkpointing_from:
checkpoint(
args,
global_step,
accelerator,
)
# Step logging
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
accelerator.wait_for_everyone()
# Terminate training
accelerator.end_training()
if __name__ == "__main__":
main()
@@ -44,7 +44,7 @@ import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -809,7 +809,9 @@ def main():
accelerator.save_state(save_path)
unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet)
)
StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
@@ -876,7 +878,7 @@ def main():
unet = unet.to(torch.float32)
unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
@@ -52,7 +52,7 @@ from diffusers import (
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -651,11 +651,15 @@ def main(args):
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1160,14 +1164,14 @@ def main(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
+4 -4
View File
@@ -24,7 +24,7 @@ import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn
from ..models.embeddings import ImageProjection, MLPProjection, Resampler
from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
USE_PEFT_BACKEND,
@@ -712,7 +712,7 @@ class UNet2DConditionLoadersMixin:
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
image_projection = MLPProjection(
image_projection = IPAdapterFullImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
)
@@ -730,7 +730,7 @@ class UNet2DConditionLoadersMixin:
hidden_dims = state_dict["latents"].shape[2]
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
image_projection = Resampler(
image_projection = IPAdapterPlusImageProjection(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
@@ -780,7 +780,7 @@ class UNet2DConditionLoadersMixin:
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
# Set encoder_hid_proj after loading ip_adapter weights,
# because `Resampler` also has `attn_processors`.
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None
# set ip-adapter cross-attention processors & load state_dict
+1 -1
View File
@@ -498,7 +498,7 @@ class TemporalBasicTransformerBlock(nn.Module):
hidden_states = self.norm_in(hidden_states)
if self._chunk_size is not None:
hidden_states = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size)
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
else:
hidden_states = self.ff_in(hidden_states)
+21 -16
View File
@@ -462,7 +462,7 @@ class ImageProjection(nn.Module):
return image_embeds
class MLPProjection(nn.Module):
class IPAdapterFullImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
super().__init__()
from .attention import FeedForward
@@ -621,29 +621,34 @@ class AttentionPooling(nn.Module):
return a[:, 0, :] # cls_token
class FourierEmbedder(nn.Module):
def __init__(self, num_freqs=64, temperature=100):
super().__init__()
def get_fourier_embeds_from_boundingbox(embed_dim, box):
"""
Args:
embed_dim: int
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
Returns:
[B x N x embed_dim] tensor of positional embeddings
"""
self.num_freqs = num_freqs
self.temperature = temperature
batch_size, num_boxes = box.shape[:2]
freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
freq_bands = freq_bands[None, None, None]
self.register_buffer("freq_bands", freq_bands, persistent=False)
emb = 100 ** (torch.arange(embed_dim) / embed_dim)
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
emb = emb * box.unsqueeze(-1)
def __call__(self, x):
x = self.freq_bands * x.unsqueeze(-1)
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
return emb
class PositionNet(nn.Module):
class GLIGENTextBoundingboxProjection(nn.Module):
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
super().__init__()
self.positive_len = positive_len
self.out_dim = out_dim
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
self.fourier_embedder_dim = fourier_freqs
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
if isinstance(out_dim, tuple):
@@ -692,7 +697,7 @@ class PositionNet(nn.Module):
masks = masks.unsqueeze(-1)
# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C
# learnable null embedding
xyxy_null = self.null_position_feature.view(1, 1, -1)
@@ -787,7 +792,7 @@ class PixArtAlphaTextProjection(nn.Module):
return hidden_states
class Resampler(nn.Module):
class IPAdapterPlusImageProjection(nn.Module):
"""Resampler of IP-Adapter Plus.
Args:
+2 -2
View File
@@ -32,10 +32,10 @@ from .attention_processor import (
)
from .embeddings import (
GaussianFourierProjection,
GLIGENTextBoundingboxProjection,
ImageHintTimeEmbedding,
ImageProjection,
ImageTimeEmbedding,
PositionNet,
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
@@ -615,7 +615,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
positive_len = cross_attention_dim[0]
feature_type = "text-only" if attention_type == "gated" else "text-image"
self.position_net = PositionNet(
self.position_net = GLIGENTextBoundingboxProjection(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
)
@@ -187,7 +187,7 @@ class FourierEmbedder(nn.Module):
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
class PositionNet(nn.Module):
class GLIGENTextBoundingboxProjection(nn.Module):
def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8):
super().__init__()
self.positive_len = positive_len
@@ -820,7 +820,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
positive_len = cross_attention_dim[0]
feature_type = "text-only" if attention_type == "gated" else "text-image"
self.position_net = PositionNet(
self.position_net = GLIGENTextBoundingboxProjection(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
)
@@ -730,7 +730,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
)
gligen_phrases = gligen_phrases[:max_objs]
gligen_boxes = gligen_boxes[:max_objs]
# prepare batched input to the PositionNet (boxes, phrases, mask)
# prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask)
# Get tokens for phrases from pre-trained CLIPTokenizer
tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(device)
# For the token, we use the same pre-trained text encoder
@@ -309,6 +309,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
num_inference_steps: int = 25,
min_guidance_scale: float = 1.0,
max_guidance_scale: float = 3.0,
rec_guidance_scale: float = 1.0,
fps: int = 7,
motion_bucket_id: int = 127,
noise_aug_strength: int = 0.02,
@@ -534,6 +535,136 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
else:
frames = latents
# 3. Encode input image
next_image = frames[0][-1]
next_image_embeddings = self._encode_image(next_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
# 4. Encode input image using VAE
next_image = self.image_processor.preprocess(next_image, height=height, width=width)
noise = randn_tensor(next_image.shape, generator=generator, device=next_image.device, dtype=next_image.dtype)
next_image = next_image + noise_aug_strength * noise
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.vae.to(dtype=torch.float32)
next_image_latents = self._encode_vae_image(next_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
next_image_latents = next_image_latents.to(next_image_embeddings.dtype)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
# Repeat the image latents for each frame so we can concatenate them with the noise
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
next_image_latents = next_image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
next_latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_frames,
num_channels_latents,
height,
width,
next_image_embeddings.dtype,
device,
generator,
None,
)
image_embeddings = image_embeddings.chunk(2)[1]
image_latents = image_latents.chunk(2)[1]
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
next_latent_model_input = torch.cat([next_latents] * 2) if self.do_classifier_free_guidance else next_latents
next_latent_model_input = self.scheduler.scale_model_input(next_latent_model_input, t)
# Concatenate image_latents over channels dimention
next_latent_model_input = torch.cat([next_latent_model_input, next_image_latents], dim=2)
# predict the noise residual
noise_pred = self.unet(
next_latent_model_input,
t,
encoder_hidden_states=next_image_embeddings,
added_time_ids=added_time_ids,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
with torch.enable_grad():
self.unet.train()
self.unet.enable_gradient_checkpointing()
self.unet.requires_grad_(True)
latents.requires_grad_(True)
# Add noise to the latents
noise = torch.rand_like(latents.flatten(0, 1))
timestep = torch.ones(noise.shape[0]).to(noise.device) * t
prev_noised_latents = self.scheduler.add_noise(latents.flatten(0, 1), noise, timestep)
# [batch*frames, channels, height, width] -> [batch, frames, channels, height, width]
prev_noised_latents = prev_noised_latents.reshape(-1, num_frames, *prev_noised_latents.shape[1:])
scaled_prev_noised_latents = self.scheduler.scale_model_input(prev_noised_latents, t)
scaled_prev_noised_latents = torch.cat([scaled_prev_noised_latents, image_latents], dim=2)
rec_noise_pred = self.unet(
scaled_prev_noised_latents,
t,
encoder_hidden_states=image_embeddings,
added_time_ids=added_time_ids.chunk(2)[1],
return_dict=False,
)[0]
sigma = self.scheduler.sigmas[self.scheduler.step_index]
rec_prev_latents = rec_noise_pred * (-sigma / (sigma**2 + 1) ** 0.5) + (prev_noised_latents / (sigma**2 + 1))
loss = torch.nn.functional.mse_loss(rec_prev_latents, latents)
# compute grads
grads = torch.autograd.grad(loss, latents)[0]
self.unet.eval()
self.unet.requires_grad_(False)
latents.requires_grad_(False)
noise_pred = noise_pred - grads * rec_guidance_scale
# compute the previous noisy sample x_t -> x_t-1
next_latents = self.scheduler.step(noise_pred, t, next_latents).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
next_latents = callback_outputs.pop("latents", next_latents)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
next_frames = self.decode_latents(next_latents, num_frames, decode_chunk_size)
next_frames = tensor2vid(next_frames, self.image_processor, output_type=output_type)
frames = frames + next_frames
self.maybe_free_model_hooks()
@@ -277,7 +277,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
timesteps = torch.from_numpy(timesteps).to(device)
if str(device).startswith("mps"):
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
timesteps = torch.from_numpy(timesteps).to(device)
sigmas_interpol = sigmas_interpol.cpu()
log_sigmas = self.log_sigmas.cpu()
timesteps_interpol = np.array(
@@ -26,7 +26,7 @@ from pytest import mark
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
from diffusers.models.embeddings import ImageProjection, Resampler
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
@@ -133,7 +133,7 @@ def create_ip_adapter_plus_state_dict(model):
# "image_proj" (ImageProjection layer weights)
cross_attention_dim = model.config["cross_attention_dim"]
image_projection = Resampler(
image_projection = IPAdapterPlusImageProjection(
embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4
)