Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| aa39fd7cb6 |
@@ -0,0 +1,14 @@
|
||||
name: Delete doc comment
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Delete doc comment trigger"]
|
||||
types:
|
||||
- completed
|
||||
|
||||
|
||||
jobs:
|
||||
delete:
|
||||
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
|
||||
secrets:
|
||||
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
|
||||
@@ -0,0 +1,12 @@
|
||||
name: Delete doc comment trigger
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [ closed ]
|
||||
|
||||
|
||||
jobs:
|
||||
delete:
|
||||
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
|
||||
with:
|
||||
pr_number: ${{ github.event.number }}
|
||||
@@ -96,8 +96,6 @@ bfloat16 reduces the latency from 7.36 seconds to 4.63 seconds:
|
||||
|
||||
</div>
|
||||
|
||||
_(We later ran the experiments in float16 and found out that the recent versions of torchao do not incur numerical problems from float16.)_
|
||||
|
||||
**Why bfloat16?**
|
||||
|
||||
* Using a reduced numerical precision (such as float16, bfloat16) to run inference doesn’t affect the generation quality but significantly improves latency.
|
||||
@@ -317,26 +315,4 @@ Applying dynamic quantization improves the latency from 2.52 seconds to 2.43 sec
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_5.png" width=500>
|
||||
|
||||
</div>
|
||||
|
||||
## Misc
|
||||
|
||||
### No graph breaks during torch.compile
|
||||
|
||||
Ensuring that the underlying model/method can be fully compiled is crucial for performance (torch.compile with fullgraph=True). This means having no graph breaks. We did this for the UNet and VAE by changing how we access the returning variables. Consider the following example:
|
||||
|
||||
```diff
|
||||
- latents = unet(
|
||||
- latents, timestep=timestep, encoder_hidden_states=prompt_embeds
|
||||
-).sample
|
||||
|
||||
+ latents = unet(
|
||||
+ latents, timestep=timestep, encoder_hidden_states=prompt_embeds, return_dict=False
|
||||
+)[0]
|
||||
```
|
||||
|
||||
### Getting rid of GPU syncs after compilation
|
||||
|
||||
During the iterative reverse diffusion process, we [call](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1228) `step()` on the scheduler each time after the denoiser predicts the less noisy latent embeddings. Inside `step()`, the `sigmas` variable is [indexed](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/schedulers/scheduling_euler_discrete.py#L476). If the `sigmas` array is placed on the GPU, indexing causes a communication sync between the CPU and GPU. This causes a latency, and it becomes more evident when the denoiser has already been compiled.
|
||||
|
||||
But if the `sigmas` array always stays on the CPU (refer to [this line](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240)), this sync doesn’t take place, hence improved latency. In general, any CPU <-> GPU communication sync should be none or be kept to a bare minimum as it can impact inference latency.
|
||||
</div>
|
||||
@@ -20,7 +20,6 @@ import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
@@ -38,11 +37,9 @@ from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from safetensors.torch import load_file, save_file
|
||||
from safetensors.torch import save_file
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
@@ -57,15 +54,10 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.models.lora import LoRALinearLayer
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_all_state_dict_to_peft,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_kohya,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.training_utils import compute_snr, unet_lora_state_dict
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -75,6 +67,39 @@ check_min_version("0.25.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
@@ -100,17 +125,10 @@ def save_model_card(
|
||||
img_str += f"""
|
||||
- text: '{instance_prompt}'
|
||||
"""
|
||||
embeddings_filename = f"{repo_folder}_emb"
|
||||
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
|
||||
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
|
||||
if instance_prompt_webui != embeddings_filename:
|
||||
instance_prompt_sentence = f"For example, `{instance_prompt_webui}`"
|
||||
else:
|
||||
instance_prompt_sentence = ""
|
||||
|
||||
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
|
||||
diffusers_imports_pivotal = ""
|
||||
diffusers_example_pivotal = ""
|
||||
webui_example_pivotal = ""
|
||||
if train_text_encoder_ti:
|
||||
trigger_str = (
|
||||
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
|
||||
@@ -119,16 +137,11 @@ def save_model_card(
|
||||
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
"""
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename="embeddings.safetensors", repo_type="model")
|
||||
state_dict = load_file(embedding_path)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
"""
|
||||
webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
|
||||
- Place it on it on your `embeddings` folder
|
||||
- Use it by adding `{embeddings_filename}` to your prompt. {instance_prompt_sentence}
|
||||
(you need both the LoRA and the embeddings as they were trained together for this LoRA)
|
||||
"""
|
||||
if token_abstraction_dict:
|
||||
for key, value in token_abstraction_dict.items():
|
||||
tokens = "".join(value)
|
||||
@@ -160,14 +173,9 @@ license: openrail++
|
||||
|
||||
### These are {repo_id} LoRA adaption weights for {base_model}.
|
||||
|
||||
## Download model
|
||||
## Trigger words
|
||||
|
||||
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
|
||||
|
||||
- **LoRA**: download **[`{repo_folder}.safetensors` here 💾](/{repo_id}/blob/main/{repo_folder}.safetensors)**.
|
||||
- Place it on your `models/Lora` folder.
|
||||
- On AUTOMATIC1111, load the LoRA by adding `<lora:{repo_folder}:1>` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).
|
||||
{webui_example_pivotal}
|
||||
{trigger_str}
|
||||
|
||||
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
|
||||
|
||||
@@ -183,12 +191,16 @@ image = pipeline('{validation_prompt if validation_prompt else instance_prompt}'
|
||||
|
||||
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
|
||||
|
||||
## Trigger words
|
||||
## Download model
|
||||
|
||||
{trigger_str}
|
||||
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
|
||||
|
||||
- Download the LoRA *.safetensors [here](/{repo_id}/blob/main/pytorch_lora_weights.safetensors). Rename it and place it on your Lora folder.
|
||||
- Download the text embeddings *.safetensors [here](/{repo_id}/blob/main/embeddings.safetensors). Rename it and place it on it on your embeddings folder.
|
||||
|
||||
All [Files & versions](/{repo_id}/tree/main).
|
||||
|
||||
## Details
|
||||
All [Files & versions](/{repo_id}/tree/main).
|
||||
|
||||
The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py).
|
||||
|
||||
@@ -1250,25 +1262,54 @@ def main(args):
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
unet.add_adapter(unet_lora_config)
|
||||
# Set correct lora layers
|
||||
unet_lora_parameters = []
|
||||
for attn_processor_name, attn_processor in unet.attn_processors.items():
|
||||
# Parse the attention module.
|
||||
attn_module = unet
|
||||
for n in attn_processor_name.split(".")[:-1]:
|
||||
attn_module = getattr(attn_module, n)
|
||||
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
|
||||
# Accumulate the LoRA params to optimize.
|
||||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_one, dtype=torch.float32, rank=args.rank
|
||||
)
|
||||
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_two, dtype=torch.float32, rank=args.rank
|
||||
)
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
# if we use textual inversion, we freeze all parameters except for the token embeddings
|
||||
# in text encoder
|
||||
@@ -1292,17 +1333,6 @@ def main(args):
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
for model in models:
|
||||
for param in model.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
@@ -1314,15 +1344,11 @@ def main(args):
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
unet_lora_layers_to_save = unet_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1379,12 +1405,6 @@ def main(args):
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
|
||||
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
|
||||
|
||||
# If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
|
||||
freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
|
||||
|
||||
@@ -1975,17 +1995,13 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
unet_lora_layers = unet_lora_state_dict(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
)
|
||||
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
)
|
||||
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
@@ -2055,15 +2071,8 @@ def main(args):
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
|
||||
f"{args.output_dir}/embeddings.safetensors",
|
||||
)
|
||||
|
||||
# Conver to WebUI format
|
||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
||||
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
|
||||
|
||||
save_model_card(
|
||||
model_id if not args.push_to_hub else repo_id,
|
||||
images=images,
|
||||
|
||||
@@ -51,7 +51,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
|
||||
if unet is None:
|
||||
raise ValueError("Must provide a `unet` when doing intermediate validation.")
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
state_dict = get_peft_model_state_dict(unet)
|
||||
to_load = state_dict
|
||||
else:
|
||||
to_load = args.output_dir
|
||||
@@ -819,7 +819,7 @@ def main(args):
|
||||
unet_ = accelerator.unwrap_model(unet)
|
||||
# also save the checkpoints in native `diffusers` format so that it can be easily
|
||||
# be independently loaded via `load_lora_weights()`.
|
||||
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
|
||||
state_dict = get_peft_model_state_dict(unet_)
|
||||
StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict)
|
||||
|
||||
for _, model in enumerate(models):
|
||||
@@ -1184,7 +1184,7 @@ def main(args):
|
||||
# solver timestep.
|
||||
|
||||
# With the adapters disabled, the `unet` is the regular teacher model.
|
||||
accelerator.unwrap_model(unet).disable_adapters()
|
||||
unet.disable_adapters()
|
||||
with torch.no_grad():
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = unet(
|
||||
@@ -1248,7 +1248,7 @@ def main(args):
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype)
|
||||
|
||||
# re-enable unet adapters to turn the `unet` into a student unet.
|
||||
accelerator.unwrap_model(unet).enable_adapters()
|
||||
unet.enable_adapters()
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# Note that we do not use a separate target network for LCM-LoRA distillation.
|
||||
@@ -1332,7 +1332,7 @@ def main(args):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
@@ -54,7 +54,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -853,11 +853,9 @@ def main(args):
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1287,11 +1285,11 @@ def main(args):
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
|
||||
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
|
||||
else:
|
||||
text_encoder_state_dict = None
|
||||
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
# Multi Subject Dreambooth for Inpainting Models
|
||||
|
||||
Please note that this project is not actively maintained. However, you can open an issue and tag @gzguevara.
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. This project consists of **two parts**. Training Stable Diffusion for inpainting requieres prompt-image-mask pairs. The Unet of inpainiting models have 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself).
|
||||
|
||||
**The first part**, the `multi_inpaint_dataset.ipynb` notebook, demonstrates how make a 🤗 dataset of prompt-image-mask pairs. You can, however, skip the first part and move straight to the second part with the example datasets in this project. ([cat toy dataset masked](https://huggingface.co/datasets/gzguevara/cat_toy_masked), [mr. potato head dataset masked](https://huggingface.co/datasets/gzguevara/mr_potato_head_masked))
|
||||
|
||||
**The second part**, the `train_multi_subject_inpainting.py` training script, demonstrates how to implement a training procedure for one or more subjects and adapt it for stable diffusion for inpainting.
|
||||
|
||||
## 1. Data Collection: Make Prompt-Image-Mask Pairs
|
||||
|
||||
Earlier training scripts have provided approaches like random masking for the training images. This project provides a notebook for more precise mask setting.
|
||||
|
||||
The notebook can be found here: [](https://colab.research.google.com/drive/1JNEASI_B7pLW1srxhgln6nM0HoGAQT32?usp=sharing)
|
||||
|
||||
The `multi_inpaint_dataset.ipynb` notebook, takes training & validation images, on which the user draws masks and provides prompts to make a prompt-image-mask pairs. This ensures that during training, the loss is computed on the area masking the object of interest, rather than on random areas. Moreover, the `multi_inpaint_dataset.ipynb` notebook allows you to build a validation dataset with corresponding masks for monitoring the training process. Example below:
|
||||
|
||||

|
||||
|
||||
You can build multiple datasets for every subject and upload them to the 🤗 hub. Later, when launching the training script you can indicate the paths of the datasets, on which you would like to finetune Stable Diffusion for inpaining.
|
||||
|
||||
## 2. Train Multi Subject Dreambooth for Inpainting
|
||||
|
||||
### 2.1. Setting The Training Configuration
|
||||
|
||||
Before launching the training script, make sure to select the inpainting the target model, the output directory and the 🤗 datasets.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-inpainting"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
export DATASET_1="gzguevara/mr_potato_head_masked"
|
||||
export DATASET_2="gzguevara/cat_toy_masked"
|
||||
... # Further paths to 🤗 datasets
|
||||
```
|
||||
|
||||
### 2.2. Launching The Training Script
|
||||
|
||||
```bash
|
||||
accelerate launch train_multi_subject_dreambooth_inpaint.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir $DATASET_1 $DATASET_2 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--learning_rate=3e-6 \
|
||||
--max_train_steps=500 \
|
||||
--report_to_wandb
|
||||
```
|
||||
|
||||
### 2.3. Fine-tune text encoder with the UNet.
|
||||
|
||||
The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
|
||||
Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
|
||||
|
||||
___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
|
||||
|
||||
```bash
|
||||
accelerate launch train_multi_subject_dreambooth_inpaint.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir $DATASET_1 $DATASET_2 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--learning_rate=2e-6 \
|
||||
--max_train_steps=500 \
|
||||
--report_to_wandb \
|
||||
--train_text_encoder
|
||||
```
|
||||
|
||||
## 3. Results
|
||||
|
||||
A [](https://wandb.ai/gzguevara/uncategorized/reports/Multi-Subject-Dreambooth-for-Inpainting--Vmlldzo2MzY5NDQ4?accessToken=y0nya2d7baguhbryxaikbfr1203amvn1jsmyl07vk122mrs7tnph037u1nqgse8t) is provided showing the training progress by every 50 steps. Note, the reported weights & baises run was performed on a A100 GPU with the following stetting:
|
||||
|
||||
```bash
|
||||
accelerate launch train_multi_subject_dreambooth_inpaint.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir $DATASET_1 $DATASET_2 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--resolution=512 \
|
||||
--train_batch_size=10 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--learning_rate=1e-6 \
|
||||
--max_train_steps=500 \
|
||||
--report_to_wandb \
|
||||
--train_text_encoder
|
||||
```
|
||||
Here you can see the target objects on my desk and next to my plant:
|
||||
|
||||

|
||||
@@ -1,8 +0,0 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
datasets>=2.16.0
|
||||
wandb>=0.16.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
-661
@@ -1,661 +0,0 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionInpaintPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument("--instance_data_dir", nargs="+", help="Instance data directories")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder", default=False, action="store_true", help="Whether to train the text encoder"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-6,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
||||
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
|
||||
" using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_from",
|
||||
type=int,
|
||||
default=1000,
|
||||
help=("Start to checkpoint from step"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help=(
|
||||
"Run validation every X steps. Validation consists of running the prompt"
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
||||
" and logging the images."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_from",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Start to validate from step"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
||||
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
||||
" for more docs"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_project_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The w&b name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to_wandb", default=False, action="store_true", help="Whether to report to weights and biases"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image, mask):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
mask = np.array(mask.convert("L"))
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
datasets_paths,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.datasets_paths = (datasets_paths,)
|
||||
self.datasets = [load_dataset(dataset_path) for dataset_path in self.datasets_paths[0]]
|
||||
self.train_data = concatenate_datasets([dataset["train"] for dataset in self.datasets])
|
||||
self.test_data = concatenate_datasets([dataset["test"] for dataset in self.datasets])
|
||||
|
||||
self.image_normalize = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def set_image(self, img, switch):
|
||||
if img.mode not in ["RGB", "L"]:
|
||||
img = img.convert("RGB")
|
||||
|
||||
if switch:
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
|
||||
img = img.resize((512, 512), Image.BILINEAR)
|
||||
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.train_data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# Lettings
|
||||
example = {}
|
||||
img_idx = index % len(self.train_data)
|
||||
switch = random.choice([True, False])
|
||||
|
||||
# Load image
|
||||
image = self.set_image(self.train_data[img_idx]["image"], switch)
|
||||
|
||||
# Normalize image
|
||||
image_norm = self.image_normalize(image)
|
||||
|
||||
# Tokenise prompt
|
||||
tokenized_prompt = self.tokenizer(
|
||||
self.train_data[img_idx]["prompt"],
|
||||
padding="do_not_pad",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
).input_ids
|
||||
|
||||
# Load masks for image
|
||||
masks = [
|
||||
self.set_image(self.train_data[img_idx][key], switch) for key in self.train_data[img_idx] if "mask" in key
|
||||
]
|
||||
|
||||
# Build example
|
||||
example["PIL_image"] = image
|
||||
example["instance_image"] = image_norm
|
||||
example["instance_prompt_id"] = tokenized_prompt
|
||||
example["instance_masks"] = masks
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def weighted_mask(masks):
|
||||
# Convert each mask to a NumPy array and ensure it's binary
|
||||
mask_arrays = [np.array(mask) / 255 for mask in masks] # Normalizing to 0-1 range
|
||||
|
||||
# Generate random weights and apply them to each mask
|
||||
weights = [random.random() for _ in masks]
|
||||
weights = [weight / sum(weights) for weight in weights]
|
||||
weighted_masks = [mask * weight for mask, weight in zip(mask_arrays, weights)]
|
||||
|
||||
# Sum the weighted masks
|
||||
summed_mask = np.sum(weighted_masks, axis=0)
|
||||
|
||||
# Apply a threshold to create the final mask
|
||||
threshold = 0.5 # This threshold can be adjusted
|
||||
result_mask = summed_mask >= threshold
|
||||
|
||||
# Convert the result back to a PIL image
|
||||
return Image.fromarray(result_mask.astype(np.uint8) * 255)
|
||||
|
||||
|
||||
def collate_fn(examples, tokenizer):
|
||||
input_ids = [example["instance_prompt_id"] for example in examples]
|
||||
pixel_values = [example["instance_image"] for example in examples]
|
||||
|
||||
masks, masked_images = [], []
|
||||
|
||||
for example in examples:
|
||||
# generate a random mask
|
||||
mask = weighted_mask(example["instance_masks"])
|
||||
|
||||
# prepare mask and masked image
|
||||
mask, masked_image = prepare_mask_and_masked_image(example["PIL_image"], mask)
|
||||
|
||||
masks.append(mask)
|
||||
masked_images.append(masked_image)
|
||||
|
||||
pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
|
||||
masks = torch.stack(masks)
|
||||
masked_images = torch.stack(masked_images)
|
||||
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
|
||||
|
||||
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images}
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def log_validation(pipeline, text_encoder, unet, val_pairs, accelerator):
|
||||
# update pipeline (note: unet and vae are loaded again in float32)
|
||||
pipeline.text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
pipeline.unet = accelerator.unwrap_model(unet)
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
val_results = [{"data_or_path": pipeline(**pair).images[0], "caption": pair["prompt"]} for pair in val_pairs]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
wandb.log({"validation": [wandb.Image(**val_result) for val_result in val_results]})
|
||||
|
||||
|
||||
def checkpoint(args, global_step, accelerator):
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
project_config = ProjectConfiguration(
|
||||
total_limit=args.checkpoints_total_limit,
|
||||
project_dir=args.output_dir,
|
||||
logging_dir=Path(args.output_dir, args.logging_dir),
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
project_config=project_config,
|
||||
log_with="wandb" if args.report_to_wandb else None,
|
||||
)
|
||||
|
||||
if args.report_to_wandb and not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
|
||||
# Load the tokenizer & models and create wrapper for stable diffusion
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder"
|
||||
).requires_grad_(args.train_text_encoder)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae").requires_grad_(False)
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
params=itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet.parameters(),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
train_dataset = DreamBoothDataset(
|
||||
tokenizer=tokenizer,
|
||||
datasets_paths=args.instance_data_dir,
|
||||
)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=lambda examples: collate_fn(examples, tokenizer),
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
)
|
||||
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
accelerator.register_for_checkpointing(lr_scheduler)
|
||||
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
else:
|
||||
weight_dtype = torch.float32
|
||||
|
||||
# Move text_encode and vae to gpu.
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
if not args.train_text_encoder:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
||||
# Afterwards we calculate our number of training epochs
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
tracker_config = vars(copy.deepcopy(args))
|
||||
accelerator.init_trackers(args.validation_project_name, config=tracker_config)
|
||||
|
||||
# create validation pipeline (note: unet and vae are loaded again in float32)
|
||||
val_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
vae=vae,
|
||||
torch_dtype=weight_dtype,
|
||||
safety_checker=None,
|
||||
)
|
||||
val_pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# prepare validation dataset
|
||||
val_pairs = [
|
||||
{
|
||||
"image": example["image"],
|
||||
"mask_image": mask,
|
||||
"prompt": example["prompt"],
|
||||
}
|
||||
for example in train_dataset.test_data
|
||||
for mask in [example[key] for key in example if "mask" in key]
|
||||
]
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
for model in models:
|
||||
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
|
||||
print()
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
for epoch in range(first_epoch, num_train_epochs):
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Convert masked images to latent space
|
||||
masked_latents = vae.encode(
|
||||
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
|
||||
).latent_dist.sample()
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
|
||||
masks = batch["masks"]
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
mask = torch.stack(
|
||||
[
|
||||
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
|
||||
for mask in masks
|
||||
]
|
||||
)
|
||||
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# concatenate the noised latents with the mask and the masked latents
|
||||
latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (
|
||||
itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet.parameters()
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if (
|
||||
global_step % args.validation_steps == 0
|
||||
and global_step >= args.validation_from
|
||||
and args.report_to_wandb
|
||||
):
|
||||
log_validation(
|
||||
val_pipeline,
|
||||
text_encoder,
|
||||
unet,
|
||||
val_pairs,
|
||||
accelerator,
|
||||
)
|
||||
|
||||
if global_step % args.checkpointing_steps == 0 and global_step >= args.checkpointing_from:
|
||||
checkpoint(
|
||||
args,
|
||||
global_step,
|
||||
accelerator,
|
||||
)
|
||||
|
||||
# Step logging
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Terminate training
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -44,7 +44,7 @@ import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -809,9 +809,7 @@ def main():
|
||||
accelerator.save_state(save_path)
|
||||
|
||||
unwrapped_unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(unwrapped_unet)
|
||||
)
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
|
||||
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=save_path,
|
||||
@@ -878,7 +876,7 @@ def main():
|
||||
unet = unet.to(torch.float32)
|
||||
|
||||
unwrapped_unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
|
||||
@@ -52,7 +52,7 @@ from diffusers import (
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -651,15 +651,11 @@ def main(args):
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1164,14 +1160,14 @@ def main(args):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
|
||||
text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
|
||||
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
|
||||
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
# Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
|
||||
# This means that you can input your diffusers-trained LoRAs and
|
||||
# Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
|
||||
|
||||
# To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
|
||||
# https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
|
||||
# and run the script:
|
||||
# python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
|
||||
# now you can use corgy.safetensors in your WebUI of choice!
|
||||
|
||||
# To train your own, here are some diffusers training scripts and utils that you can use and then convert:
|
||||
# LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
|
||||
# Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
|
||||
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
|
||||
# - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
|
||||
# Canonical diffusers training scripts:
|
||||
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
# - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
|
||||
|
||||
|
||||
def convert_and_save(input_lora, output_lora=None):
|
||||
if output_lora is None:
|
||||
base_name = os.path.splitext(input_lora)[0]
|
||||
output_lora = f"{base_name}_webui.safetensors"
|
||||
|
||||
diffusers_state_dict = load_file(input_lora)
|
||||
peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
|
||||
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
||||
save_file(kohya_state_dict, output_lora)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format.")
|
||||
parser.add_argument(
|
||||
"input_lora",
|
||||
type=str,
|
||||
help="Path to the input LoRA model file in the diffusers format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_lora",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="Path for the converted LoRA (safetensors format for AUTOMATIC1111, ComfyUI, etc.). Optional, defaults to input name with a _webui suffix.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_and_save(args.input_lora, args.output_lora)
|
||||
@@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from torch import nn
|
||||
|
||||
from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection
|
||||
from ..models.embeddings import ImageProjection, MLPProjection, Resampler
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -712,7 +712,7 @@ class UNet2DConditionLoadersMixin:
|
||||
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
|
||||
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
|
||||
|
||||
image_projection = IPAdapterFullImageProjection(
|
||||
image_projection = MLPProjection(
|
||||
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
|
||||
)
|
||||
|
||||
@@ -730,7 +730,7 @@ class UNet2DConditionLoadersMixin:
|
||||
hidden_dims = state_dict["latents"].shape[2]
|
||||
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
|
||||
|
||||
image_projection = IPAdapterPlusImageProjection(
|
||||
image_projection = Resampler(
|
||||
embed_dims=embed_dims,
|
||||
output_dims=output_dims,
|
||||
hidden_dims=hidden_dims,
|
||||
@@ -780,7 +780,7 @@ class UNet2DConditionLoadersMixin:
|
||||
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
|
||||
|
||||
# Set encoder_hid_proj after loading ip_adapter weights,
|
||||
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
|
||||
# because `Resampler` also has `attn_processors`.
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# set ip-adapter cross-attention processors & load state_dict
|
||||
|
||||
@@ -498,7 +498,7 @@ class TemporalBasicTransformerBlock(nn.Module):
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
if self._chunk_size is not None:
|
||||
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
|
||||
hidden_states = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
hidden_states = self.ff_in(hidden_states)
|
||||
|
||||
|
||||
@@ -462,7 +462,7 @@ class ImageProjection(nn.Module):
|
||||
return image_embeds
|
||||
|
||||
|
||||
class IPAdapterFullImageProjection(nn.Module):
|
||||
class MLPProjection(nn.Module):
|
||||
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
@@ -621,34 +621,29 @@ class AttentionPooling(nn.Module):
|
||||
return a[:, 0, :] # cls_token
|
||||
|
||||
|
||||
def get_fourier_embeds_from_boundingbox(embed_dim, box):
|
||||
"""
|
||||
Args:
|
||||
embed_dim: int
|
||||
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
|
||||
Returns:
|
||||
[B x N x embed_dim] tensor of positional embeddings
|
||||
"""
|
||||
class FourierEmbedder(nn.Module):
|
||||
def __init__(self, num_freqs=64, temperature=100):
|
||||
super().__init__()
|
||||
|
||||
batch_size, num_boxes = box.shape[:2]
|
||||
self.num_freqs = num_freqs
|
||||
self.temperature = temperature
|
||||
|
||||
emb = 100 ** (torch.arange(embed_dim) / embed_dim)
|
||||
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
|
||||
emb = emb * box.unsqueeze(-1)
|
||||
freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
|
||||
freq_bands = freq_bands[None, None, None]
|
||||
self.register_buffer("freq_bands", freq_bands, persistent=False)
|
||||
|
||||
emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
|
||||
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
|
||||
|
||||
return emb
|
||||
def __call__(self, x):
|
||||
x = self.freq_bands * x.unsqueeze(-1)
|
||||
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
|
||||
|
||||
|
||||
class GLIGENTextBoundingboxProjection(nn.Module):
|
||||
class PositionNet(nn.Module):
|
||||
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
|
||||
super().__init__()
|
||||
self.positive_len = positive_len
|
||||
self.out_dim = out_dim
|
||||
|
||||
self.fourier_embedder_dim = fourier_freqs
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
|
||||
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
|
||||
|
||||
if isinstance(out_dim, tuple):
|
||||
@@ -697,7 +692,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
|
||||
masks = masks.unsqueeze(-1)
|
||||
|
||||
# embedding position (it may includes padding as placeholder)
|
||||
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C
|
||||
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
|
||||
|
||||
# learnable null embedding
|
||||
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
||||
@@ -792,7 +787,7 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class IPAdapterPlusImageProjection(nn.Module):
|
||||
class Resampler(nn.Module):
|
||||
"""Resampler of IP-Adapter Plus.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -32,10 +32,10 @@ from .attention_processor import (
|
||||
)
|
||||
from .embeddings import (
|
||||
GaussianFourierProjection,
|
||||
GLIGENTextBoundingboxProjection,
|
||||
ImageHintTimeEmbedding,
|
||||
ImageProjection,
|
||||
ImageTimeEmbedding,
|
||||
PositionNet,
|
||||
TextImageProjection,
|
||||
TextImageTimeEmbedding,
|
||||
TextTimeEmbedding,
|
||||
@@ -615,7 +615,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
positive_len = cross_attention_dim[0]
|
||||
|
||||
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
||||
self.position_net = GLIGENTextBoundingboxProjection(
|
||||
self.position_net = PositionNet(
|
||||
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
|
||||
)
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class FourierEmbedder(nn.Module):
|
||||
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
|
||||
|
||||
|
||||
class GLIGENTextBoundingboxProjection(nn.Module):
|
||||
class PositionNet(nn.Module):
|
||||
def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8):
|
||||
super().__init__()
|
||||
self.positive_len = positive_len
|
||||
@@ -820,7 +820,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
positive_len = cross_attention_dim[0]
|
||||
|
||||
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
||||
self.position_net = GLIGENTextBoundingboxProjection(
|
||||
self.position_net = PositionNet(
|
||||
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
|
||||
)
|
||||
|
||||
|
||||
@@ -730,7 +730,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
|
||||
)
|
||||
gligen_phrases = gligen_phrases[:max_objs]
|
||||
gligen_boxes = gligen_boxes[:max_objs]
|
||||
# prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask)
|
||||
# prepare batched input to the PositionNet (boxes, phrases, mask)
|
||||
# Get tokens for phrases from pre-trained CLIPTokenizer
|
||||
tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(device)
|
||||
# For the token, we use the same pre-trained text encoder
|
||||
|
||||
@@ -277,11 +277,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
|
||||
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
|
||||
else:
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
sigmas_interpol = sigmas_interpol.cpu()
|
||||
log_sigmas = self.log_sigmas.cpu()
|
||||
timesteps_interpol = np.array(
|
||||
|
||||
@@ -98,9 +98,7 @@ from .peft_utils import (
|
||||
)
|
||||
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
|
||||
from .state_dict_utils import (
|
||||
convert_all_state_dict_to_peft,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_kohya,
|
||||
convert_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
)
|
||||
|
||||
@@ -16,11 +16,6 @@ State dict utilities: utility methods for converting state dicts easily
|
||||
"""
|
||||
import enum
|
||||
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StateDictType(enum.Enum):
|
||||
"""
|
||||
@@ -28,7 +23,7 @@ class StateDictType(enum.Enum):
|
||||
"""
|
||||
|
||||
DIFFUSERS_OLD = "diffusers_old"
|
||||
KOHYA_SS = "kohya_ss"
|
||||
# KOHYA_SS = "kohya_ss" # TODO: implement this
|
||||
PEFT = "peft"
|
||||
DIFFUSERS = "diffusers"
|
||||
|
||||
@@ -105,14 +100,6 @@ DIFFUSERS_OLD_TO_DIFFUSERS = {
|
||||
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
|
||||
}
|
||||
|
||||
PEFT_TO_KOHYA_SS = {
|
||||
"lora_A": "lora_down",
|
||||
"lora_B": "lora_up",
|
||||
# This is not a comprehensive dict as kohya format requires replacing `.` with `_` in keys,
|
||||
# adding prefixes and adding alpha values
|
||||
# Check `convert_state_dict_to_kohya` for more
|
||||
}
|
||||
|
||||
PEFT_STATE_DICT_MAPPINGS = {
|
||||
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT,
|
||||
StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT,
|
||||
@@ -123,8 +110,6 @@ DIFFUSERS_STATE_DICT_MAPPINGS = {
|
||||
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
|
||||
}
|
||||
|
||||
KOHYA_STATE_DICT_MAPPINGS = {StateDictType.PEFT: PEFT_TO_KOHYA_SS}
|
||||
|
||||
KEYS_TO_ALWAYS_REPLACE = {
|
||||
".processor.": ".",
|
||||
}
|
||||
@@ -243,82 +228,3 @@ def convert_unet_state_dict_to_peft(state_dict):
|
||||
"""
|
||||
mapping = UNET_TO_DIFFUSERS
|
||||
return convert_state_dict(state_dict, mapping)
|
||||
|
||||
|
||||
def convert_all_state_dict_to_peft(state_dict):
|
||||
r"""
|
||||
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer`
|
||||
for a valid `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
|
||||
"""
|
||||
try:
|
||||
peft_dict = convert_state_dict_to_peft(state_dict)
|
||||
except Exception as e:
|
||||
if str(e) == "Could not automatically infer state dict type":
|
||||
peft_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||
else:
|
||||
raise
|
||||
|
||||
if not any("lora_A" in key or "lora_B" in key for key in peft_dict.keys()):
|
||||
raise ValueError("Your LoRA was not converted to PEFT")
|
||||
|
||||
return peft_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
|
||||
r"""
|
||||
Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc.
|
||||
The method only supports the conversion from PEFT to Kohya for now.
|
||||
|
||||
Args:
|
||||
state_dict (`dict[str, torch.Tensor]`):
|
||||
The state dict to convert.
|
||||
original_type (`StateDictType`, *optional*):
|
||||
The original type of the state dict, if not provided, the method will try to infer it automatically.
|
||||
kwargs (`dict`, *args*):
|
||||
Additional arguments to pass to the method.
|
||||
|
||||
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
|
||||
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
|
||||
`get_peft_model_state_dict` method:
|
||||
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
|
||||
but we add it here in case we don't want to rely on that method.
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
logger.error("Converting PEFT state dicts to Kohya requires torch to be installed.")
|
||||
raise
|
||||
|
||||
peft_adapter_name = kwargs.pop("adapter_name", None)
|
||||
if peft_adapter_name is not None:
|
||||
peft_adapter_name = "." + peft_adapter_name
|
||||
else:
|
||||
peft_adapter_name = ""
|
||||
|
||||
if original_type is None:
|
||||
if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.PEFT
|
||||
|
||||
if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys():
|
||||
raise ValueError(f"Original type {original_type} is not supported")
|
||||
|
||||
# Use the convert_state_dict function with the appropriate mapping
|
||||
kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT])
|
||||
kohya_ss_state_dict = {}
|
||||
|
||||
# Additional logic for replacing header, alpha parameters `.` with `_` in all keys
|
||||
for kohya_key, weight in kohya_ss_partial_state_dict.items():
|
||||
if "text_encoder_2." in kohya_key:
|
||||
kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.")
|
||||
elif "text_encoder." in kohya_key:
|
||||
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
|
||||
elif "unet" in kohya_key:
|
||||
kohya_key = kohya_key.replace("unet", "lora_unet")
|
||||
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
|
||||
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
|
||||
kohya_ss_state_dict[kohya_key] = weight
|
||||
if "lora_down" in kohya_key:
|
||||
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
|
||||
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
|
||||
|
||||
return kohya_ss_state_dict
|
||||
|
||||
@@ -22,6 +22,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from packaging import version
|
||||
@@ -40,6 +41,8 @@ from diffusers import (
|
||||
StableDiffusionXLPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
|
||||
from diffusers.utils.import_utils import is_accelerate_available, is_peft_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
@@ -75,6 +78,28 @@ def state_dicts_almost_equal(sd1, sd2):
|
||||
return models_are_equal
|
||||
|
||||
|
||||
def create_unet_lora_layers(unet: nn.Module):
|
||||
lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
|
||||
return lora_attn_procs, unet_lora_layers
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class PeftLoraLoaderMixinTests:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -115,6 +140,8 @@ class PeftLoraLoaderMixinTests:
|
||||
r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
|
||||
)
|
||||
|
||||
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
|
||||
|
||||
if self.has_two_text_encoders:
|
||||
pipeline_components = {
|
||||
"unet": unet,
|
||||
@@ -138,8 +165,11 @@ class PeftLoraLoaderMixinTests:
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
|
||||
return pipeline_components, text_lora_config, unet_lora_config
|
||||
lora_components = {
|
||||
"unet_lora_layers": unet_lora_layers,
|
||||
"unet_lora_attn_procs": unet_lora_attn_procs,
|
||||
}
|
||||
return pipeline_components, lora_components, text_lora_config, unet_lora_config
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
@@ -186,7 +216,7 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -201,7 +231,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -232,7 +262,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -279,7 +309,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -321,7 +351,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -364,7 +394,7 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple usecase where users could use saving utilities for LoRA.
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -429,7 +459,7 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -480,7 +510,7 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -553,7 +583,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -607,7 +637,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected - with unet
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -653,7 +683,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -700,7 +730,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -750,7 +780,7 @@ class PeftLoraLoaderMixinTests:
|
||||
multiple adapters and set them
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -818,7 +848,7 @@ class PeftLoraLoaderMixinTests:
|
||||
multiple adapters and set/delete them
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -908,7 +938,7 @@ class PeftLoraLoaderMixinTests:
|
||||
multiple adapters and set them
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -980,7 +1010,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
def test_lora_fuse_nan(self):
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -1018,7 +1048,7 @@ class PeftLoraLoaderMixinTests:
|
||||
are the expected results
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -1045,7 +1075,7 @@ class PeftLoraLoaderMixinTests:
|
||||
are the expected results
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -1083,7 +1113,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected - with unet and multi-adapter case
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -1145,7 +1175,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -26,7 +26,7 @@ from pytest import mark
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
|
||||
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
|
||||
from diffusers.models.embeddings import ImageProjection, Resampler
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
@@ -133,7 +133,7 @@ def create_ip_adapter_plus_state_dict(model):
|
||||
|
||||
# "image_proj" (ImageProjection layer weights)
|
||||
cross_attention_dim = model.config["cross_attention_dim"]
|
||||
image_projection = IPAdapterPlusImageProjection(
|
||||
image_projection = Resampler(
|
||||
embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user