Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0f9427f0dd | |||
| f6844d3cf6 | |||
| daa75665cf | |||
| 38aece94c4 | |||
| b03aa10375 | |||
| 2bfdcabadc | |||
| a837033105 | |||
| 566aaab423 | |||
| 9b910bdc5c | |||
| 965b40aa17 | |||
| d62076ac5f | |||
| 3338ce0d40 | |||
| 565416c2e7 | |||
| 6338ad5b0b |
@@ -37,6 +37,8 @@ from accelerate.logging import get_logger
|
|||||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||||
from huggingface_hub import create_repo, upload_folder
|
from huggingface_hub import create_repo, upload_folder
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from peft import LoraConfig
|
||||||
|
from peft.utils import get_peft_model_state_dict
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.ImageOps import exif_transpose
|
from PIL.ImageOps import exif_transpose
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
@@ -54,10 +56,9 @@ from diffusers import (
|
|||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
)
|
)
|
||||||
from diffusers.loaders import LoraLoaderMixin
|
from diffusers.loaders import LoraLoaderMixin
|
||||||
from diffusers.models.lora import LoRALinearLayer
|
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.training_utils import compute_snr, unet_lora_state_dict
|
from diffusers.training_utils import compute_snr
|
||||||
from diffusers.utils import check_min_version, is_wandb_available
|
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
|
|
||||||
|
|
||||||
@@ -67,39 +68,6 @@ check_min_version("0.25.0.dev0")
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
|
||||||
def text_encoder_lora_state_dict(text_encoder):
|
|
||||||
state_dict = {}
|
|
||||||
|
|
||||||
def text_encoder_attn_modules(text_encoder):
|
|
||||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
|
||||||
|
|
||||||
attn_modules = []
|
|
||||||
|
|
||||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
|
||||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
|
||||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
|
||||||
mod = layer.self_attn
|
|
||||||
attn_modules.append((name, mod))
|
|
||||||
|
|
||||||
return attn_modules
|
|
||||||
|
|
||||||
for name, module in text_encoder_attn_modules(text_encoder):
|
|
||||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
|
||||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
|
||||||
|
|
||||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
|
||||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
|
||||||
|
|
||||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
|
||||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
|
||||||
|
|
||||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
|
||||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
|
||||||
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def save_model_card(
|
def save_model_card(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
images=None,
|
images=None,
|
||||||
@@ -161,8 +129,6 @@ tags:
|
|||||||
base_model: {base_model}
|
base_model: {base_model}
|
||||||
instance_prompt: {instance_prompt}
|
instance_prompt: {instance_prompt}
|
||||||
license: openrail++
|
license: openrail++
|
||||||
widget:
|
|
||||||
- text: '{validation_prompt if validation_prompt else instance_prompt}'
|
|
||||||
---
|
---
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1264,54 +1230,25 @@ def main(args):
|
|||||||
text_encoder_two.gradient_checkpointing_enable()
|
text_encoder_two.gradient_checkpointing_enable()
|
||||||
|
|
||||||
# now we will add new LoRA weights to the attention layers
|
# now we will add new LoRA weights to the attention layers
|
||||||
# Set correct lora layers
|
unet_lora_config = LoraConfig(
|
||||||
unet_lora_parameters = []
|
r=args.rank,
|
||||||
for attn_processor_name, attn_processor in unet.attn_processors.items():
|
lora_alpha=args.rank,
|
||||||
# Parse the attention module.
|
init_lora_weights="gaussian",
|
||||||
attn_module = unet
|
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||||
for n in attn_processor_name.split(".")[:-1]:
|
)
|
||||||
attn_module = getattr(attn_module, n)
|
unet.add_adapter(unet_lora_config)
|
||||||
|
|
||||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
|
||||||
attn_module.to_q.set_lora_layer(
|
|
||||||
LoRALinearLayer(
|
|
||||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
|
||||||
)
|
|
||||||
)
|
|
||||||
attn_module.to_k.set_lora_layer(
|
|
||||||
LoRALinearLayer(
|
|
||||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
|
||||||
)
|
|
||||||
)
|
|
||||||
attn_module.to_v.set_lora_layer(
|
|
||||||
LoRALinearLayer(
|
|
||||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
|
||||||
)
|
|
||||||
)
|
|
||||||
attn_module.to_out[0].set_lora_layer(
|
|
||||||
LoRALinearLayer(
|
|
||||||
in_features=attn_module.to_out[0].in_features,
|
|
||||||
out_features=attn_module.to_out[0].out_features,
|
|
||||||
rank=args.rank,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Accumulate the LoRA params to optimize.
|
|
||||||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
|
|
||||||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
|
|
||||||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
|
|
||||||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
|
|
||||||
|
|
||||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
text_lora_config = LoraConfig(
|
||||||
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
|
r=args.rank,
|
||||||
text_encoder_one, dtype=torch.float32, rank=args.rank
|
lora_alpha=args.rank,
|
||||||
)
|
init_lora_weights="gaussian",
|
||||||
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
|
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||||
text_encoder_two, dtype=torch.float32, rank=args.rank
|
|
||||||
)
|
)
|
||||||
|
text_encoder_one.add_adapter(text_lora_config)
|
||||||
|
text_encoder_two.add_adapter(text_lora_config)
|
||||||
|
|
||||||
# if we use textual inversion, we freeze all parameters except for the token embeddings
|
# if we use textual inversion, we freeze all parameters except for the token embeddings
|
||||||
# in text encoder
|
# in text encoder
|
||||||
@@ -1335,6 +1272,17 @@ def main(args):
|
|||||||
else:
|
else:
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Make sure the trainable params are in float32.
|
||||||
|
if args.mixed_precision == "fp16":
|
||||||
|
models = [unet]
|
||||||
|
if args.train_text_encoder:
|
||||||
|
models.extend([text_encoder_one, text_encoder_two])
|
||||||
|
for model in models:
|
||||||
|
for param in model.parameters():
|
||||||
|
# only upcast trainable parameters (LoRA) into fp32
|
||||||
|
if param.requires_grad:
|
||||||
|
param.data = param.to(torch.float32)
|
||||||
|
|
||||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||||
def save_model_hook(models, weights, output_dir):
|
def save_model_hook(models, weights, output_dir):
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
@@ -1346,11 +1294,15 @@ def main(args):
|
|||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||||
unet_lora_layers_to_save = unet_lora_state_dict(model)
|
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||||
|
get_peft_model_state_dict(model)
|
||||||
|
)
|
||||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||||
|
get_peft_model_state_dict(model)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||||
|
|
||||||
@@ -1407,6 +1359,12 @@ def main(args):
|
|||||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||||
|
|
||||||
|
if args.train_text_encoder:
|
||||||
|
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
|
||||||
|
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
|
||||||
|
|
||||||
# If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
|
# If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
|
||||||
freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
|
freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
|
||||||
|
|
||||||
@@ -1997,13 +1955,17 @@ def main(args):
|
|||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
unet = accelerator.unwrap_model(unet)
|
unet = accelerator.unwrap_model(unet)
|
||||||
unet = unet.to(torch.float32)
|
unet = unet.to(torch.float32)
|
||||||
unet_lora_layers = unet_lora_state_dict(unet)
|
unet_lora_layers = get_peft_model_state_dict(unet)
|
||||||
|
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||||
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
|
text_encoder_lora_layers = convert_state_dict_to_diffusers(
|
||||||
|
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||||
|
)
|
||||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||||
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
|
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
|
||||||
|
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
text_encoder_lora_layers = None
|
text_encoder_lora_layers = None
|
||||||
text_encoder_2_lora_layers = None
|
text_encoder_2_lora_layers = None
|
||||||
|
|||||||
Reference in New Issue
Block a user