Compare commits

..

1 Commits

Author SHA1 Message Date
apolinário aa39fd7cb6 fix widget format for advanced training 2023-12-27 04:32:44 -06:00
23 changed files with 230 additions and 1119 deletions
+14
View File
@@ -0,0 +1,14 @@
name: Delete doc comment
on:
workflow_run:
workflows: ["Delete doc comment trigger"]
types:
- completed
jobs:
delete:
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
secrets:
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
@@ -0,0 +1,12 @@
name: Delete doc comment trigger
on:
pull_request:
types: [ closed ]
jobs:
delete:
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
with:
pr_number: ${{ github.event.number }}
+1 -25
View File
@@ -96,8 +96,6 @@ bfloat16 reduces the latency from 7.36 seconds to 4.63 seconds:
</div>
_(We later ran the experiments in float16 and found out that the recent versions of torchao do not incur numerical problems from float16.)_
**Why bfloat16?**
* Using a reduced numerical precision (such as float16, bfloat16) to run inference doesnt affect the generation quality but significantly improves latency.
@@ -317,26 +315,4 @@ Applying dynamic quantization improves the latency from 2.52 seconds to 2.43 sec
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_5.png" width=500>
</div>
## Misc
### No graph breaks during torch.compile
Ensuring that the underlying model/method can be fully compiled is crucial for performance (torch.compile with fullgraph=True). This means having no graph breaks. We did this for the UNet and VAE by changing how we access the returning variables. Consider the following example:
```diff
- latents = unet(
- latents, timestep=timestep, encoder_hidden_states=prompt_embeds
-).sample
+ latents = unet(
+ latents, timestep=timestep, encoder_hidden_states=prompt_embeds, return_dict=False
+)[0]
```
### Getting rid of GPU syncs after compilation
During the iterative reverse diffusion process, we [call](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1228) `step()` on the scheduler each time after the denoiser predicts the less noisy latent embeddings. Inside `step()`, the `sigmas` variable is [indexed](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/schedulers/scheduling_euler_discrete.py#L476). If the `sigmas` array is placed on the GPU, indexing causes a communication sync between the CPU and GPU. This causes a latency, and it becomes more evident when the denoiser has already been compiled.
But if the `sigmas` array always stays on the CPU (refer to [this line](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240)), this sync doesnt take place, hence improved latency. In general, any CPU <-> GPU communication sync should be none or be kept to a bare minimum as it can impact inference latency.
</div>
@@ -20,7 +20,6 @@ import itertools
import logging
import math
import os
import re
import shutil
import warnings
from pathlib import Path
@@ -38,11 +37,9 @@ from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
from safetensors.torch import load_file, save_file
from safetensors.torch import save_file
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
@@ -57,15 +54,10 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import (
check_min_version,
convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
is_wandb_available,
)
from diffusers.training_utils import compute_snr, unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -75,6 +67,39 @@ check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str,
images=None,
@@ -100,17 +125,10 @@ def save_model_card(
img_str += f"""
- text: '{instance_prompt}'
"""
embeddings_filename = f"{repo_folder}_emb"
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
if instance_prompt_webui != embeddings_filename:
instance_prompt_sentence = f"For example, `{instance_prompt_webui}`"
else:
instance_prompt_sentence = ""
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
diffusers_imports_pivotal = ""
diffusers_example_pivotal = ""
webui_example_pivotal = ""
if train_text_encoder_ti:
trigger_str = (
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
@@ -119,16 +137,11 @@ def save_model_card(
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
"""
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename="embeddings.safetensors", repo_type="model")
state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
"""
webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
- Place it on it on your `embeddings` folder
- Use it by adding `{embeddings_filename}` to your prompt. {instance_prompt_sentence}
(you need both the LoRA and the embeddings as they were trained together for this LoRA)
"""
if token_abstraction_dict:
for key, value in token_abstraction_dict.items():
tokens = "".join(value)
@@ -160,14 +173,9 @@ license: openrail++
### These are {repo_id} LoRA adaption weights for {base_model}.
## Download model
## Trigger words
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
- **LoRA**: download **[`{repo_folder}.safetensors` here 💾](/{repo_id}/blob/main/{repo_folder}.safetensors)**.
- Place it on your `models/Lora` folder.
- On AUTOMATIC1111, load the LoRA by adding `<lora:{repo_folder}:1>` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).
{webui_example_pivotal}
{trigger_str}
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
@@ -183,12 +191,16 @@ image = pipeline('{validation_prompt if validation_prompt else instance_prompt}'
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
## Trigger words
## Download model
{trigger_str}
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
- Download the LoRA *.safetensors [here](/{repo_id}/blob/main/pytorch_lora_weights.safetensors). Rename it and place it on your Lora folder.
- Download the text embeddings *.safetensors [here](/{repo_id}/blob/main/embeddings.safetensors). Rename it and place it on it on your embeddings folder.
All [Files & versions](/{repo_id}/tree/main).
## Details
All [Files & versions](/{repo_id}/tree/main).
The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py).
@@ -1250,25 +1262,54 @@ def main(args):
text_encoder_two.gradient_checkpointing_enable()
# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
unet.add_adapter(unet_lora_config)
# Set correct lora layers
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
text_encoder_one, dtype=torch.float32, rank=args.rank
)
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
text_encoder_two, dtype=torch.float32, rank=args.rank
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
# if we use textual inversion, we freeze all parameters except for the token embeddings
# in text encoder
@@ -1292,17 +1333,6 @@ def main(args):
else:
param.requires_grad = False
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
@@ -1314,15 +1344,11 @@ def main(args):
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
unet_lora_layers_to_save = unet_lora_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1379,12 +1405,6 @@ def main(args):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
if args.train_text_encoder:
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
# If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
@@ -1975,17 +1995,13 @@ def main(args):
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
unet_lora_layers = unet_lora_state_dict(unet)
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = convert_state_dict_to_diffusers(
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
)
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
@@ -2055,15 +2071,8 @@ def main(args):
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
f"{args.output_dir}/embeddings.safetensors",
)
# Conver to WebUI format
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
save_model_card(
model_id if not args.push_to_hub else repo_id,
images=images,
@@ -51,7 +51,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -113,7 +113,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
if unet is None:
raise ValueError("Must provide a `unet` when doing intermediate validation.")
unet = accelerator.unwrap_model(unet)
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
state_dict = get_peft_model_state_dict(unet)
to_load = state_dict
else:
to_load = args.output_dir
@@ -819,7 +819,7 @@ def main(args):
unet_ = accelerator.unwrap_model(unet)
# also save the checkpoints in native `diffusers` format so that it can be easily
# be independently loaded via `load_lora_weights()`.
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
state_dict = get_peft_model_state_dict(unet_)
StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict)
for _, model in enumerate(models):
@@ -1184,7 +1184,7 @@ def main(args):
# solver timestep.
# With the adapters disabled, the `unet` is the regular teacher model.
accelerator.unwrap_model(unet).disable_adapters()
unet.disable_adapters()
with torch.no_grad():
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = unet(
@@ -1248,7 +1248,7 @@ def main(args):
x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype)
# re-enable unet adapters to turn the `unet` into a student unet.
accelerator.unwrap_model(unet).enable_adapters()
unet.enable_adapters()
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
@@ -1332,7 +1332,7 @@ def main(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
unet_lora_state_dict = get_peft_model_state_dict(unet)
StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)
if args.push_to_hub:
+5 -7
View File
@@ -54,7 +54,7 @@ from diffusers import (
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -853,11 +853,9 @@ def main(args):
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
unet_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1287,11 +1285,11 @@ def main(args):
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
unet_lora_state_dict = get_peft_model_state_dict(unet)
if args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
else:
text_encoder_state_dict = None
@@ -1,93 +0,0 @@
# Multi Subject Dreambooth for Inpainting Models
Please note that this project is not actively maintained. However, you can open an issue and tag @gzguevara.
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. This project consists of **two parts**. Training Stable Diffusion for inpainting requieres prompt-image-mask pairs. The Unet of inpainiting models have 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself).
**The first part**, the `multi_inpaint_dataset.ipynb` notebook, demonstrates how make a 🤗 dataset of prompt-image-mask pairs. You can, however, skip the first part and move straight to the second part with the example datasets in this project. ([cat toy dataset masked](https://huggingface.co/datasets/gzguevara/cat_toy_masked), [mr. potato head dataset masked](https://huggingface.co/datasets/gzguevara/mr_potato_head_masked))
**The second part**, the `train_multi_subject_inpainting.py` training script, demonstrates how to implement a training procedure for one or more subjects and adapt it for stable diffusion for inpainting.
## 1. Data Collection: Make Prompt-Image-Mask Pairs
Earlier training scripts have provided approaches like random masking for the training images. This project provides a notebook for more precise mask setting.
The notebook can be found here: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JNEASI_B7pLW1srxhgln6nM0HoGAQT32?usp=sharing)
The `multi_inpaint_dataset.ipynb` notebook, takes training & validation images, on which the user draws masks and provides prompts to make a prompt-image-mask pairs. This ensures that during training, the loss is computed on the area masking the object of interest, rather than on random areas. Moreover, the `multi_inpaint_dataset.ipynb` notebook allows you to build a validation dataset with corresponding masks for monitoring the training process. Example below:
![train_val_pairs](https://drive.google.com/uc?id=1PzwH8E3icl_ubVmA19G0HZGLImFX3x5I)
You can build multiple datasets for every subject and upload them to the 🤗 hub. Later, when launching the training script you can indicate the paths of the datasets, on which you would like to finetune Stable Diffusion for inpaining.
## 2. Train Multi Subject Dreambooth for Inpainting
### 2.1. Setting The Training Configuration
Before launching the training script, make sure to select the inpainting the target model, the output directory and the 🤗 datasets.
```bash
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
export OUTPUT_DIR="path-to-save-model"
export DATASET_1="gzguevara/mr_potato_head_masked"
export DATASET_2="gzguevara/cat_toy_masked"
... # Further paths to 🤗 datasets
```
### 2.2. Launching The Training Script
```bash
accelerate launch train_multi_subject_dreambooth_inpaint.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir $DATASET_1 $DATASET_2 \
--output_dir=$OUTPUT_DIR \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=2 \
--learning_rate=3e-6 \
--max_train_steps=500 \
--report_to_wandb
```
### 2.3. Fine-tune text encoder with the UNet.
The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
```bash
accelerate launch train_multi_subject_dreambooth_inpaint.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir $DATASET_1 $DATASET_2 \
--output_dir=$OUTPUT_DIR \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=2 \
--learning_rate=2e-6 \
--max_train_steps=500 \
--report_to_wandb \
--train_text_encoder
```
## 3. Results
A [![Weights & Biases](https://img.shields.io/badge/Weights%20&%20Biases-Report-blue)](https://wandb.ai/gzguevara/uncategorized/reports/Multi-Subject-Dreambooth-for-Inpainting--Vmlldzo2MzY5NDQ4?accessToken=y0nya2d7baguhbryxaikbfr1203amvn1jsmyl07vk122mrs7tnph037u1nqgse8t) is provided showing the training progress by every 50 steps. Note, the reported weights & baises run was performed on a A100 GPU with the following stetting:
```bash
accelerate launch train_multi_subject_dreambooth_inpaint.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir $DATASET_1 $DATASET_2 \
--output_dir=$OUTPUT_DIR \
--resolution=512 \
--train_batch_size=10 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-6 \
--max_train_steps=500 \
--report_to_wandb \
--train_text_encoder
```
Here you can see the target objects on my desk and next to my plant:
![Results](https://drive.google.com/uc?id=1kQisOiiF5cj4rOYjdq8SCZenNsUP2aK0)
@@ -1,8 +0,0 @@
accelerate>=0.16.0
torchvision
transformers>=4.25.1
datasets>=2.16.0
wandb>=0.16.1
ftfy
tensorboard
Jinja2
@@ -1,661 +0,0 @@
import argparse
import copy
import itertools
import logging
import math
import os
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import concatenate_datasets, load_dataset
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionInpaintPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.13.0.dev0")
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument("--instance_data_dir", nargs="+", help="Instance data directories")
parser.add_argument(
"--output_dir",
type=str,
default="text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--train_text_encoder", default=False, action="store_true", help="Whether to train the text encoder"
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-6,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=1000,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
" using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpointing_from",
type=int,
default=1000,
help=("Start to checkpoint from step"),
)
parser.add_argument(
"--validation_steps",
type=int,
default=50,
help=(
"Run validation every X steps. Validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`"
" and logging the images."
),
)
parser.add_argument(
"--validation_from",
type=int,
default=0,
help=("Start to validate from step"),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more docs"
),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--validation_project_name",
type=str,
default=None,
help="The w&b name.",
)
parser.add_argument(
"--report_to_wandb", default=False, action="store_true", help="Whether to report to weights and biases"
)
args = parser.parse_args()
return args
def prepare_mask_and_masked_image(image, mask):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image
class DreamBoothDataset(Dataset):
def __init__(
self,
tokenizer,
datasets_paths,
):
self.tokenizer = tokenizer
self.datasets_paths = (datasets_paths,)
self.datasets = [load_dataset(dataset_path) for dataset_path in self.datasets_paths[0]]
self.train_data = concatenate_datasets([dataset["train"] for dataset in self.datasets])
self.test_data = concatenate_datasets([dataset["test"] for dataset in self.datasets])
self.image_normalize = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def set_image(self, img, switch):
if img.mode not in ["RGB", "L"]:
img = img.convert("RGB")
if switch:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img = img.resize((512, 512), Image.BILINEAR)
return img
def __len__(self):
return len(self.train_data)
def __getitem__(self, index):
# Lettings
example = {}
img_idx = index % len(self.train_data)
switch = random.choice([True, False])
# Load image
image = self.set_image(self.train_data[img_idx]["image"], switch)
# Normalize image
image_norm = self.image_normalize(image)
# Tokenise prompt
tokenized_prompt = self.tokenizer(
self.train_data[img_idx]["prompt"],
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
# Load masks for image
masks = [
self.set_image(self.train_data[img_idx][key], switch) for key in self.train_data[img_idx] if "mask" in key
]
# Build example
example["PIL_image"] = image
example["instance_image"] = image_norm
example["instance_prompt_id"] = tokenized_prompt
example["instance_masks"] = masks
return example
def weighted_mask(masks):
# Convert each mask to a NumPy array and ensure it's binary
mask_arrays = [np.array(mask) / 255 for mask in masks] # Normalizing to 0-1 range
# Generate random weights and apply them to each mask
weights = [random.random() for _ in masks]
weights = [weight / sum(weights) for weight in weights]
weighted_masks = [mask * weight for mask, weight in zip(mask_arrays, weights)]
# Sum the weighted masks
summed_mask = np.sum(weighted_masks, axis=0)
# Apply a threshold to create the final mask
threshold = 0.5 # This threshold can be adjusted
result_mask = summed_mask >= threshold
# Convert the result back to a PIL image
return Image.fromarray(result_mask.astype(np.uint8) * 255)
def collate_fn(examples, tokenizer):
input_ids = [example["instance_prompt_id"] for example in examples]
pixel_values = [example["instance_image"] for example in examples]
masks, masked_images = [], []
for example in examples:
# generate a random mask
mask = weighted_mask(example["instance_masks"])
# prepare mask and masked image
mask, masked_image = prepare_mask_and_masked_image(example["PIL_image"], mask)
masks.append(mask)
masked_images.append(masked_image)
pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
masks = torch.stack(masks)
masked_images = torch.stack(masked_images)
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
return batch
def log_validation(pipeline, text_encoder, unet, val_pairs, accelerator):
# update pipeline (note: unet and vae are loaded again in float32)
pipeline.text_encoder = accelerator.unwrap_model(text_encoder)
pipeline.unet = accelerator.unwrap_model(unet)
with torch.autocast("cuda"):
val_results = [{"data_or_path": pipeline(**pair).images[0], "caption": pair["prompt"]} for pair in val_pairs]
torch.cuda.empty_cache()
wandb.log({"validation": [wandb.Image(**val_result) for val_result in val_results]})
def checkpoint(args, global_step, accelerator):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
def main():
args = parse_args()
project_config = ProjectConfiguration(
total_limit=args.checkpoints_total_limit,
project_dir=args.output_dir,
logging_dir=Path(args.output_dir, args.logging_dir),
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
project_config=project_config,
log_with="wandb" if args.report_to_wandb else None,
)
if args.report_to_wandb and not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
if args.seed is not None:
set_seed(args.seed)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
# Load the tokenizer & models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder"
).requires_grad_(args.train_text_encoder)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae").requires_grad_(False)
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
optimizer = torch.optim.AdamW(
params=itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
train_dataset = DreamBoothDataset(
tokenizer=tokenizer,
datasets_paths=args.instance_data_dir,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, tokenizer),
)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
accelerator.register_for_checkpointing(lr_scheduler)
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
else:
weight_dtype = torch.float32
# Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
vae.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Afterwards we calculate our number of training epochs
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_config = vars(copy.deepcopy(args))
accelerator.init_trackers(args.validation_project_name, config=tracker_config)
# create validation pipeline (note: unet and vae are loaded again in float32)
val_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
args.pretrained_model_name_or_path,
tokenizer=tokenizer,
text_encoder=text_encoder,
unet=unet,
vae=vae,
torch_dtype=weight_dtype,
safety_checker=None,
)
val_pipeline.set_progress_bar_config(disable=True)
# prepare validation dataset
val_pairs = [
{
"image": example["image"],
"mask_image": mask,
"prompt": example["prompt"],
}
for example in train_dataset.test_data
for mask in [example[key] for key in example if "mask" in key]
]
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models:
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
accelerator.register_save_state_pre_hook(save_model_hook)
print()
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Convert masked images to latent space
masked_latents = vae.encode(
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
).latent_dist.sample()
masked_latents = masked_latents * vae.config.scaling_factor
masks = batch["masks"]
# resize the mask to latents shape as we concatenate the mask to the latents
mask = torch.stack(
[
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
for mask in masks
]
)
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# concatenate the noised latents with the mask and the masked latents
latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if (
global_step % args.validation_steps == 0
and global_step >= args.validation_from
and args.report_to_wandb
):
log_validation(
val_pipeline,
text_encoder,
unet,
val_pairs,
accelerator,
)
if global_step % args.checkpointing_steps == 0 and global_step >= args.checkpointing_from:
checkpoint(
args,
global_step,
accelerator,
)
# Step logging
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
accelerator.wait_for_everyone()
# Terminate training
accelerator.end_training()
if __name__ == "__main__":
main()
@@ -44,7 +44,7 @@ import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -809,9 +809,7 @@ def main():
accelerator.save_state(save_path)
unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet)
)
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
@@ -878,7 +876,7 @@ def main():
unet = unet.to(torch.float32)
unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
@@ -52,7 +52,7 @@ from diffusers import (
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -651,15 +651,11 @@ def main(args):
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
unet_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1164,14 +1160,14 @@ def main(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
unet_lora_state_dict = get_peft_model_state_dict(unet)
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
@@ -1,55 +0,0 @@
# Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
# This means that you can input your diffusers-trained LoRAs and
# Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
# To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
# https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
# and run the script:
# python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
# now you can use corgy.safetensors in your WebUI of choice!
# To train your own, here are some diffusers training scripts and utils that you can use and then convert:
# LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
# Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
# - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
# Canonical diffusers training scripts:
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
# - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
import argparse
import os
from safetensors.torch import load_file, save_file
from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
def convert_and_save(input_lora, output_lora=None):
if output_lora is None:
base_name = os.path.splitext(input_lora)[0]
output_lora = f"{base_name}_webui.safetensors"
diffusers_state_dict = load_file(input_lora)
peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
save_file(kohya_state_dict, output_lora)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format.")
parser.add_argument(
"input_lora",
type=str,
help="Path to the input LoRA model file in the diffusers format.",
)
parser.add_argument(
"output_lora",
type=str,
nargs="?",
help="Path for the converted LoRA (safetensors format for AUTOMATIC1111, ComfyUI, etc.). Optional, defaults to input name with a _webui suffix.",
)
args = parser.parse_args()
convert_and_save(args.input_lora, args.output_lora)
+4 -4
View File
@@ -24,7 +24,7 @@ import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn
from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection
from ..models.embeddings import ImageProjection, MLPProjection, Resampler
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
USE_PEFT_BACKEND,
@@ -712,7 +712,7 @@ class UNet2DConditionLoadersMixin:
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
image_projection = IPAdapterFullImageProjection(
image_projection = MLPProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
)
@@ -730,7 +730,7 @@ class UNet2DConditionLoadersMixin:
hidden_dims = state_dict["latents"].shape[2]
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
image_projection = IPAdapterPlusImageProjection(
image_projection = Resampler(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
@@ -780,7 +780,7 @@ class UNet2DConditionLoadersMixin:
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
# Set encoder_hid_proj after loading ip_adapter weights,
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
# because `Resampler` also has `attn_processors`.
self.encoder_hid_proj = None
# set ip-adapter cross-attention processors & load state_dict
+1 -1
View File
@@ -498,7 +498,7 @@ class TemporalBasicTransformerBlock(nn.Module):
hidden_states = self.norm_in(hidden_states)
if self._chunk_size is not None:
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
hidden_states = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size)
else:
hidden_states = self.ff_in(hidden_states)
+16 -21
View File
@@ -462,7 +462,7 @@ class ImageProjection(nn.Module):
return image_embeds
class IPAdapterFullImageProjection(nn.Module):
class MLPProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
super().__init__()
from .attention import FeedForward
@@ -621,34 +621,29 @@ class AttentionPooling(nn.Module):
return a[:, 0, :] # cls_token
def get_fourier_embeds_from_boundingbox(embed_dim, box):
"""
Args:
embed_dim: int
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
Returns:
[B x N x embed_dim] tensor of positional embeddings
"""
class FourierEmbedder(nn.Module):
def __init__(self, num_freqs=64, temperature=100):
super().__init__()
batch_size, num_boxes = box.shape[:2]
self.num_freqs = num_freqs
self.temperature = temperature
emb = 100 ** (torch.arange(embed_dim) / embed_dim)
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
emb = emb * box.unsqueeze(-1)
freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
freq_bands = freq_bands[None, None, None]
self.register_buffer("freq_bands", freq_bands, persistent=False)
emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
return emb
def __call__(self, x):
x = self.freq_bands * x.unsqueeze(-1)
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
class GLIGENTextBoundingboxProjection(nn.Module):
class PositionNet(nn.Module):
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
super().__init__()
self.positive_len = positive_len
self.out_dim = out_dim
self.fourier_embedder_dim = fourier_freqs
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
if isinstance(out_dim, tuple):
@@ -697,7 +692,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
masks = masks.unsqueeze(-1)
# embedding position (it may includes padding as placeholder)
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
# learnable null embedding
xyxy_null = self.null_position_feature.view(1, 1, -1)
@@ -792,7 +787,7 @@ class PixArtAlphaTextProjection(nn.Module):
return hidden_states
class IPAdapterPlusImageProjection(nn.Module):
class Resampler(nn.Module):
"""Resampler of IP-Adapter Plus.
Args:
+2 -2
View File
@@ -32,10 +32,10 @@ from .attention_processor import (
)
from .embeddings import (
GaussianFourierProjection,
GLIGENTextBoundingboxProjection,
ImageHintTimeEmbedding,
ImageProjection,
ImageTimeEmbedding,
PositionNet,
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
@@ -615,7 +615,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
positive_len = cross_attention_dim[0]
feature_type = "text-only" if attention_type == "gated" else "text-image"
self.position_net = GLIGENTextBoundingboxProjection(
self.position_net = PositionNet(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
)
@@ -187,7 +187,7 @@ class FourierEmbedder(nn.Module):
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
class GLIGENTextBoundingboxProjection(nn.Module):
class PositionNet(nn.Module):
def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8):
super().__init__()
self.positive_len = positive_len
@@ -820,7 +820,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
positive_len = cross_attention_dim[0]
feature_type = "text-only" if attention_type == "gated" else "text-image"
self.position_net = GLIGENTextBoundingboxProjection(
self.position_net = PositionNet(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
)
@@ -730,7 +730,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
)
gligen_phrases = gligen_phrases[:max_objs]
gligen_boxes = gligen_boxes[:max_objs]
# prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask)
# prepare batched input to the PositionNet (boxes, phrases, mask)
# Get tokens for phrases from pre-trained CLIPTokenizer
tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(device)
# For the token, we use the same pre-trained text encoder
@@ -277,11 +277,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
if str(device).startswith("mps"):
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
timesteps = torch.from_numpy(timesteps).to(device)
timesteps = torch.from_numpy(timesteps).to(device)
sigmas_interpol = sigmas_interpol.cpu()
log_sigmas = self.log_sigmas.cpu()
timesteps_interpol = np.array(
-2
View File
@@ -98,9 +98,7 @@ from .peft_utils import (
)
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .state_dict_utils import (
convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
)
+1 -95
View File
@@ -16,11 +16,6 @@ State dict utilities: utility methods for converting state dicts easily
"""
import enum
from .logging import get_logger
logger = get_logger(__name__)
class StateDictType(enum.Enum):
"""
@@ -28,7 +23,7 @@ class StateDictType(enum.Enum):
"""
DIFFUSERS_OLD = "diffusers_old"
KOHYA_SS = "kohya_ss"
# KOHYA_SS = "kohya_ss" # TODO: implement this
PEFT = "peft"
DIFFUSERS = "diffusers"
@@ -105,14 +100,6 @@ DIFFUSERS_OLD_TO_DIFFUSERS = {
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
}
PEFT_TO_KOHYA_SS = {
"lora_A": "lora_down",
"lora_B": "lora_up",
# This is not a comprehensive dict as kohya format requires replacing `.` with `_` in keys,
# adding prefixes and adding alpha values
# Check `convert_state_dict_to_kohya` for more
}
PEFT_STATE_DICT_MAPPINGS = {
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT,
StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT,
@@ -123,8 +110,6 @@ DIFFUSERS_STATE_DICT_MAPPINGS = {
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
}
KOHYA_STATE_DICT_MAPPINGS = {StateDictType.PEFT: PEFT_TO_KOHYA_SS}
KEYS_TO_ALWAYS_REPLACE = {
".processor.": ".",
}
@@ -243,82 +228,3 @@ def convert_unet_state_dict_to_peft(state_dict):
"""
mapping = UNET_TO_DIFFUSERS
return convert_state_dict(state_dict, mapping)
def convert_all_state_dict_to_peft(state_dict):
r"""
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer`
for a valid `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
"""
try:
peft_dict = convert_state_dict_to_peft(state_dict)
except Exception as e:
if str(e) == "Could not automatically infer state dict type":
peft_dict = convert_unet_state_dict_to_peft(state_dict)
else:
raise
if not any("lora_A" in key or "lora_B" in key for key in peft_dict.keys()):
raise ValueError("Your LoRA was not converted to PEFT")
return peft_dict
def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
r"""
Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc.
The method only supports the conversion from PEFT to Kohya for now.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
original_type (`StateDictType`, *optional*):
The original type of the state dict, if not provided, the method will try to infer it automatically.
kwargs (`dict`, *args*):
Additional arguments to pass to the method.
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
`get_peft_model_state_dict` method:
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
but we add it here in case we don't want to rely on that method.
"""
try:
import torch
except ImportError:
logger.error("Converting PEFT state dicts to Kohya requires torch to be installed.")
raise
peft_adapter_name = kwargs.pop("adapter_name", None)
if peft_adapter_name is not None:
peft_adapter_name = "." + peft_adapter_name
else:
peft_adapter_name = ""
if original_type is None:
if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
original_type = StateDictType.PEFT
if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys():
raise ValueError(f"Original type {original_type} is not supported")
# Use the convert_state_dict function with the appropriate mapping
kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT])
kohya_ss_state_dict = {}
# Additional logic for replacing header, alpha parameters `.` with `_` in all keys
for kohya_key, weight in kohya_ss_partial_state_dict.items():
if "text_encoder_2." in kohya_key:
kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.")
elif "text_encoder." in kohya_key:
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
elif "unet" in kohya_key:
kohya_key = kohya_key.replace("unet", "lora_unet")
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
kohya_ss_state_dict[kohya_key] = weight
if "lora_down" in kohya_key:
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
return kohya_ss_state_dict
+52 -22
View File
@@ -22,6 +22,7 @@ import unittest
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from huggingface_hub.repocard import RepoCard
from packaging import version
@@ -40,6 +41,8 @@ from diffusers import (
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from diffusers.utils.import_utils import is_accelerate_available, is_peft_available
from diffusers.utils.testing_utils import (
floats_tensor,
@@ -75,6 +78,28 @@ def state_dicts_almost_equal(sd1, sd2):
return models_are_equal
def create_unet_lora_layers(unet: nn.Module):
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
return lora_attn_procs, unet_lora_layers
@require_peft_backend
class PeftLoraLoaderMixinTests:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -115,6 +140,8 @@ class PeftLoraLoaderMixinTests:
r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
)
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
if self.has_two_text_encoders:
pipeline_components = {
"unet": unet,
@@ -138,8 +165,11 @@ class PeftLoraLoaderMixinTests:
"feature_extractor": None,
"image_encoder": None,
}
return pipeline_components, text_lora_config, unet_lora_config
lora_components = {
"unet_lora_layers": unet_lora_layers,
"unet_lora_attn_procs": unet_lora_attn_procs,
}
return pipeline_components, lora_components, text_lora_config, unet_lora_config
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
@@ -186,7 +216,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -201,7 +231,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -232,7 +262,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -279,7 +309,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -321,7 +351,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -364,7 +394,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA.
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -429,7 +459,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -480,7 +510,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -553,7 +583,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -607,7 +637,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected - with unet
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -653,7 +683,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -700,7 +730,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -750,7 +780,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set them
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -818,7 +848,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set/delete them
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -908,7 +938,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set them
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -980,7 +1010,7 @@ class PeftLoraLoaderMixinTests:
def test_lora_fuse_nan(self):
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1018,7 +1048,7 @@ class PeftLoraLoaderMixinTests:
are the expected results
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1045,7 +1075,7 @@ class PeftLoraLoaderMixinTests:
are the expected results
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1083,7 +1113,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected - with unet and multi-adapter case
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1145,7 +1175,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -26,7 +26,7 @@ from pytest import mark
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
from diffusers.models.embeddings import ImageProjection, Resampler
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
@@ -133,7 +133,7 @@ def create_ip_adapter_plus_state_dict(model):
# "image_proj" (ImageProjection layer weights)
cross_attention_dim = model.config["cross_attention_dim"]
image_projection = IPAdapterPlusImageProjection(
image_projection = Resampler(
embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4
)