|
|
|
@@ -18,6 +18,7 @@ import argparse
|
|
|
|
|
import logging
|
|
|
|
|
import math
|
|
|
|
|
import os
|
|
|
|
|
import random
|
|
|
|
|
import shutil
|
|
|
|
|
import warnings
|
|
|
|
|
from pathlib import Path
|
|
|
|
@@ -35,11 +36,12 @@ import transformers
|
|
|
|
|
from accelerate import Accelerator
|
|
|
|
|
from accelerate.logging import get_logger
|
|
|
|
|
from accelerate.utils import ProjectConfiguration, set_seed
|
|
|
|
|
from datasets import load_dataset
|
|
|
|
|
from datasets import concatenate_datasets, load_dataset
|
|
|
|
|
from huggingface_hub import create_repo, upload_folder
|
|
|
|
|
from packaging import version
|
|
|
|
|
from PIL import Image
|
|
|
|
|
from torchvision import transforms
|
|
|
|
|
from torchvision.transforms.functional import crop
|
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
|
from transformers import AutoTokenizer, PretrainedConfig
|
|
|
|
|
|
|
|
|
@@ -54,6 +56,10 @@ from diffusers.utils import check_min_version, deprecate, is_wandb_available, lo
|
|
|
|
|
from diffusers.utils.import_utils import is_xformers_available
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_wandb_available():
|
|
|
|
|
import wandb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
|
|
|
|
check_min_version("0.24.0.dev0")
|
|
|
|
|
|
|
|
|
@@ -62,7 +68,7 @@ logger = get_logger(__name__, log_level="INFO")
|
|
|
|
|
DATASET_NAME_MAPPING = {
|
|
|
|
|
"fusing/instructpix2pix-1000-samples": ("file_name", "edited_image", "edit_prompt"),
|
|
|
|
|
}
|
|
|
|
|
WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
|
|
|
|
|
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
|
|
|
|
|
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -86,6 +92,133 @@ def import_model_class_from_model_name_or_path(
|
|
|
|
|
raise ValueError(f"{model_class} is not supported.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tokenize_prompt(tokenizer, prompt):
|
|
|
|
|
text_inputs = tokenizer(
|
|
|
|
|
prompt,
|
|
|
|
|
padding="max_length",
|
|
|
|
|
max_length=tokenizer.model_max_length,
|
|
|
|
|
truncation=True,
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
)
|
|
|
|
|
text_input_ids = text_inputs.input_ids
|
|
|
|
|
return text_input_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
|
|
|
|
|
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
|
|
|
|
|
prompt_embeds_list = []
|
|
|
|
|
|
|
|
|
|
for i, text_encoder in enumerate(text_encoders):
|
|
|
|
|
if tokenizers is not None:
|
|
|
|
|
tokenizer = tokenizers[i]
|
|
|
|
|
text_input_ids = tokenize_prompt(tokenizer, prompt)
|
|
|
|
|
else:
|
|
|
|
|
assert text_input_ids_list is not None
|
|
|
|
|
text_input_ids = text_input_ids_list[i]
|
|
|
|
|
|
|
|
|
|
prompt_embeds = text_encoder(
|
|
|
|
|
text_input_ids.to(text_encoder.device),
|
|
|
|
|
output_hidden_states=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# We are only ALWAYS interested in the pooled output of the final text encoder
|
|
|
|
|
pooled_prompt_embeds = prompt_embeds[0]
|
|
|
|
|
prompt_embeds = prompt_embeds.hidden_states[-2]
|
|
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
|
|
|
|
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
|
|
|
|
prompt_embeds_list.append(prompt_embeds)
|
|
|
|
|
|
|
|
|
|
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
|
|
|
|
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
|
|
|
|
return prompt_embeds, pooled_prompt_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_validation(
|
|
|
|
|
vae,
|
|
|
|
|
unet,
|
|
|
|
|
text_encoder_1,
|
|
|
|
|
text_encoder_2,
|
|
|
|
|
tokenizer_1,
|
|
|
|
|
tokenizer_2,
|
|
|
|
|
args,
|
|
|
|
|
accelerator,
|
|
|
|
|
weight_dtype,
|
|
|
|
|
global_step,
|
|
|
|
|
):
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
|
|
|
|
f" {args.validation_prompt}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# The models need unwrapping because for compatibility in distributed training mode.
|
|
|
|
|
pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
|
|
|
|
|
args.pretrained_model_name_or_path,
|
|
|
|
|
unet=accelerator.unwrap_model(unet),
|
|
|
|
|
text_encoder=text_encoder_1,
|
|
|
|
|
text_encoder_2=text_encoder_2,
|
|
|
|
|
tokenizer=tokenizer_1,
|
|
|
|
|
tokenizer_2=tokenizer_2,
|
|
|
|
|
vae=vae,
|
|
|
|
|
revision=args.revision,
|
|
|
|
|
torch_dtype=weight_dtype,
|
|
|
|
|
)
|
|
|
|
|
pipeline = pipeline.to(accelerator.device)
|
|
|
|
|
pipeline.set_progress_bar_config(disable=True)
|
|
|
|
|
|
|
|
|
|
if args.enable_xformers_memory_efficient_attention:
|
|
|
|
|
pipeline.enable_xformers_memory_efficient_attention()
|
|
|
|
|
|
|
|
|
|
if args.seed is None:
|
|
|
|
|
generator = None
|
|
|
|
|
else:
|
|
|
|
|
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
|
|
|
|
|
|
|
|
|
# run inference
|
|
|
|
|
# Save validation images
|
|
|
|
|
val_save_dir = os.path.join(args.output_dir, "validation_images")
|
|
|
|
|
if not os.path.exists(val_save_dir):
|
|
|
|
|
os.makedirs(val_save_dir)
|
|
|
|
|
|
|
|
|
|
original_image = (
|
|
|
|
|
lambda image_url_or_path: load_image(image_url_or_path)
|
|
|
|
|
if urlparse(image_url_or_path).scheme
|
|
|
|
|
else Image.open(image_url_or_path).convert("RGB")
|
|
|
|
|
)(args.val_image_url_or_path)
|
|
|
|
|
original_image = original_image.resize((args.resolution, args.resolution))
|
|
|
|
|
|
|
|
|
|
with torch.autocast("cuda"):
|
|
|
|
|
edited_images = []
|
|
|
|
|
for val_img_idx in range(args.num_validation_images):
|
|
|
|
|
a_val_img = pipeline(
|
|
|
|
|
args.validation_prompt,
|
|
|
|
|
height=args.resolution,
|
|
|
|
|
width=args.resolution,
|
|
|
|
|
image=original_image,
|
|
|
|
|
num_inference_steps=25,
|
|
|
|
|
image_guidance_scale=1.5,
|
|
|
|
|
guidance_scale=5.0,
|
|
|
|
|
generator=generator,
|
|
|
|
|
).images[0]
|
|
|
|
|
edited_images.append(a_val_img)
|
|
|
|
|
a_val_img.save(
|
|
|
|
|
os.path.join(
|
|
|
|
|
val_save_dir,
|
|
|
|
|
f"step_{global_step}_val_img_{val_img_idx}.png",
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
formatted_images = [wandb.Image(original_image, caption="Original Image")]
|
|
|
|
|
for edited_image in edited_images:
|
|
|
|
|
formatted_images.append(wandb.Image(edited_image, caption=args.validation_prompt))
|
|
|
|
|
|
|
|
|
|
for tracker in accelerator.trackers:
|
|
|
|
|
if tracker.name == "wandb":
|
|
|
|
|
tracker.log({"validation": formatted_images})
|
|
|
|
|
|
|
|
|
|
del pipeline
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
|
parser = argparse.ArgumentParser(description="Script to train Stable Diffusion XL for InstructPix2Pix.")
|
|
|
|
|
parser.add_argument(
|
|
|
|
@@ -177,15 +310,7 @@ def parse_args():
|
|
|
|
|
default=4,
|
|
|
|
|
help="Number of images that should be generated during validation with `validation_prompt`.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--validation_steps",
|
|
|
|
|
type=int,
|
|
|
|
|
default=100,
|
|
|
|
|
help=(
|
|
|
|
|
"Run fine-tuning validation every X steps. The validation process consists of running the prompt"
|
|
|
|
|
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument("--validation_epochs", type=int, default=1, help="Run fine-tuning validation every X epochs.")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--max_train_samples",
|
|
|
|
|
type=int,
|
|
|
|
@@ -198,7 +323,7 @@ def parse_args():
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--output_dir",
|
|
|
|
|
type=str,
|
|
|
|
|
default="instruct-pix2pix-model",
|
|
|
|
|
default="instruct-pix2pix-sdxl",
|
|
|
|
|
help="The output directory where the model predictions and checkpoints will be written.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
@@ -216,18 +341,6 @@ def parse_args():
|
|
|
|
|
"The resolution for input images, all the images in the train/validation dataset will be resized to this resolution."
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--crops_coords_top_left_h",
|
|
|
|
|
type=int,
|
|
|
|
|
default=0,
|
|
|
|
|
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--crops_coords_top_left_w",
|
|
|
|
|
type=int,
|
|
|
|
|
default=0,
|
|
|
|
|
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--center_crop",
|
|
|
|
|
default=False,
|
|
|
|
@@ -443,7 +556,6 @@ def main():
|
|
|
|
|
if args.report_to == "wandb":
|
|
|
|
|
if not is_wandb_available():
|
|
|
|
|
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
|
|
|
|
import wandb
|
|
|
|
|
|
|
|
|
|
# Make one log on every process with the configuration for debugging.
|
|
|
|
|
logging.basicConfig(
|
|
|
|
@@ -605,6 +717,7 @@ def main():
|
|
|
|
|
args.dataset_config_name,
|
|
|
|
|
cache_dir=args.cache_dir,
|
|
|
|
|
)
|
|
|
|
|
dataset = concatenate_datasets([dataset["validation"], dataset["test"]])
|
|
|
|
|
else:
|
|
|
|
|
data_files = {}
|
|
|
|
|
if args.train_data_dir is not None:
|
|
|
|
@@ -619,7 +732,7 @@ def main():
|
|
|
|
|
|
|
|
|
|
# Preprocessing the datasets.
|
|
|
|
|
# We need to tokenize inputs and targets.
|
|
|
|
|
column_names = dataset["train"].column_names
|
|
|
|
|
column_names = dataset.column_names
|
|
|
|
|
|
|
|
|
|
# 6. Get the column names for input/target.
|
|
|
|
|
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
|
|
|
@@ -659,40 +772,6 @@ def main():
|
|
|
|
|
weight_dtype = torch.bfloat16
|
|
|
|
|
warnings.warn(f"weight_dtype {weight_dtype} may cause nan during vae encoding", UserWarning)
|
|
|
|
|
|
|
|
|
|
# Preprocessing the datasets.
|
|
|
|
|
# We need to tokenize input captions and transform the images.
|
|
|
|
|
def tokenize_captions(captions, tokenizer):
|
|
|
|
|
inputs = tokenizer(
|
|
|
|
|
captions,
|
|
|
|
|
max_length=tokenizer.model_max_length,
|
|
|
|
|
padding="max_length",
|
|
|
|
|
truncation=True,
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
)
|
|
|
|
|
return inputs.input_ids
|
|
|
|
|
|
|
|
|
|
# Preprocessing the datasets.
|
|
|
|
|
train_transforms = transforms.Compose(
|
|
|
|
|
[
|
|
|
|
|
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
|
|
|
|
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def preprocess_images(examples):
|
|
|
|
|
original_images = np.concatenate(
|
|
|
|
|
[convert_to_np(image, args.resolution) for image in examples[original_image_column]]
|
|
|
|
|
)
|
|
|
|
|
edited_images = np.concatenate(
|
|
|
|
|
[convert_to_np(image, args.resolution) for image in examples[edited_image_column]]
|
|
|
|
|
)
|
|
|
|
|
# We need to ensure that the original and the edited images undergo the same
|
|
|
|
|
# augmentation transforms.
|
|
|
|
|
images = np.concatenate([original_images, edited_images])
|
|
|
|
|
images = torch.tensor(images)
|
|
|
|
|
images = 2 * (images / 255) - 1
|
|
|
|
|
return train_transforms(images)
|
|
|
|
|
|
|
|
|
|
# Load scheduler, tokenizer and models.
|
|
|
|
|
tokenizer_1 = AutoTokenizer.from_pretrained(
|
|
|
|
|
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
|
|
|
|
@@ -729,132 +808,111 @@ def main():
|
|
|
|
|
# Set UNet to trainable.
|
|
|
|
|
unet.train()
|
|
|
|
|
|
|
|
|
|
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
|
|
|
|
|
def encode_prompt(text_encoders, tokenizers, prompt):
|
|
|
|
|
prompt_embeds_list = []
|
|
|
|
|
|
|
|
|
|
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
|
|
|
|
text_inputs = tokenizer(
|
|
|
|
|
prompt,
|
|
|
|
|
padding="max_length",
|
|
|
|
|
max_length=tokenizer.model_max_length,
|
|
|
|
|
truncation=True,
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
)
|
|
|
|
|
text_input_ids = text_inputs.input_ids
|
|
|
|
|
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
|
|
|
|
|
|
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
|
|
|
|
text_input_ids, untruncated_ids
|
|
|
|
|
):
|
|
|
|
|
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
|
|
|
|
logger.warning(
|
|
|
|
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
|
|
|
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
|
|
|
|
# Preprocessing the datasets.
|
|
|
|
|
# We need to tokenize input captions and transform the images.
|
|
|
|
|
# Preprocessing the datasets.
|
|
|
|
|
def tokenize_captions(examples, is_train=True):
|
|
|
|
|
captions = []
|
|
|
|
|
for caption in examples[edit_prompt_column]:
|
|
|
|
|
if isinstance(caption, str):
|
|
|
|
|
captions.append(caption)
|
|
|
|
|
elif isinstance(caption, (list, np.ndarray)):
|
|
|
|
|
# take a random caption if there are multiple
|
|
|
|
|
captions.append(random.choice(caption) if is_train else caption[0])
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Caption column `{edit_prompt_column}` should contain either strings or lists of strings."
|
|
|
|
|
)
|
|
|
|
|
tokens_one = tokenize_prompt(tokenizer_1, captions)
|
|
|
|
|
tokens_two = tokenize_prompt(tokenizer_2, captions)
|
|
|
|
|
return tokens_one, tokens_two
|
|
|
|
|
|
|
|
|
|
prompt_embeds = text_encoder(
|
|
|
|
|
text_input_ids.to(text_encoder.device),
|
|
|
|
|
output_hidden_states=True,
|
|
|
|
|
)
|
|
|
|
|
train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
|
|
|
|
|
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
|
|
|
|
|
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
|
|
|
|
normalize = transforms.Normalize([0.5], [0.5])
|
|
|
|
|
|
|
|
|
|
# We are only ALWAYS interested in the pooled output of the final text encoder
|
|
|
|
|
pooled_prompt_embeds = prompt_embeds[0]
|
|
|
|
|
prompt_embeds = prompt_embeds.hidden_states[-2]
|
|
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
|
|
|
|
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
|
|
|
|
prompt_embeds_list.append(prompt_embeds)
|
|
|
|
|
def preprocess_train(samples):
|
|
|
|
|
orig_images = [image.convert("RGB") for image in samples[original_image_column]]
|
|
|
|
|
edited_images = [image.convert("RGB") for image in samples[edited_image_column]]
|
|
|
|
|
resized_edited_images = []
|
|
|
|
|
|
|
|
|
|
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
|
|
|
|
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
|
|
|
|
return prompt_embeds, pooled_prompt_embeds
|
|
|
|
|
# Resize edited images if necessary.
|
|
|
|
|
for edited_image, orig_image in zip(edited_images, orig_images):
|
|
|
|
|
if edited_image.size != orig_image.size:
|
|
|
|
|
edited_image = edited_image.resize(orig_image.size)
|
|
|
|
|
resized_edited_images.append(edited_image)
|
|
|
|
|
else:
|
|
|
|
|
resized_edited_images.append(edited_image)
|
|
|
|
|
|
|
|
|
|
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
|
|
|
|
|
def encode_prompts(text_encoders, tokenizers, prompts):
|
|
|
|
|
prompt_embeds_all = []
|
|
|
|
|
pooled_prompt_embeds_all = []
|
|
|
|
|
# Main image processing.
|
|
|
|
|
final_original_images = []
|
|
|
|
|
final_edited_images = []
|
|
|
|
|
original_sizes = []
|
|
|
|
|
crop_top_lefts = []
|
|
|
|
|
for edited_image, orig_image in zip(resized_edited_images, orig_images):
|
|
|
|
|
original_sizes.append((orig_image.height, orig_image.width))
|
|
|
|
|
|
|
|
|
|
for prompt in prompts:
|
|
|
|
|
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
|
|
|
|
|
prompt_embeds_all.append(prompt_embeds)
|
|
|
|
|
pooled_prompt_embeds_all.append(pooled_prompt_embeds)
|
|
|
|
|
images = torch.stack([transforms.ToTensor()(orig_image), transforms.ToTensor()(edited_image)])
|
|
|
|
|
images = train_resize(images)
|
|
|
|
|
if args.center_crop:
|
|
|
|
|
y1 = max(0, int(round((orig_image.height - args.resolution) / 2.0)))
|
|
|
|
|
x1 = max(0, int(round((orig_image.width - args.resolution) / 2.0)))
|
|
|
|
|
images = train_crop(images)
|
|
|
|
|
else:
|
|
|
|
|
y1, x1, h, w = train_crop.get_params(images, (args.resolution, args.resolution))
|
|
|
|
|
images = crop(images, y1, x1, h, w)
|
|
|
|
|
|
|
|
|
|
return torch.stack(prompt_embeds_all), torch.stack(pooled_prompt_embeds_all)
|
|
|
|
|
if args.random_flip and random.random() < 0.5:
|
|
|
|
|
# flip
|
|
|
|
|
x1 = orig_image.width - x1
|
|
|
|
|
images = train_flip(images)
|
|
|
|
|
crop_top_left = (y1, x1)
|
|
|
|
|
crop_top_lefts.append(crop_top_left)
|
|
|
|
|
|
|
|
|
|
# Adapted from examples.dreambooth.train_dreambooth_lora_sdxl
|
|
|
|
|
# Here, we compute not just the text embeddings but also the additional embeddings
|
|
|
|
|
# needed for the SD XL UNet to operate.
|
|
|
|
|
def compute_embeddings_for_prompts(prompts, text_encoders, tokenizers):
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
prompt_embeds_all, pooled_prompt_embeds_all = encode_prompts(text_encoders, tokenizers, prompts)
|
|
|
|
|
add_text_embeds_all = pooled_prompt_embeds_all
|
|
|
|
|
transformed_images = normalize(images)
|
|
|
|
|
|
|
|
|
|
prompt_embeds_all = prompt_embeds_all.to(accelerator.device)
|
|
|
|
|
add_text_embeds_all = add_text_embeds_all.to(accelerator.device)
|
|
|
|
|
return prompt_embeds_all, add_text_embeds_all
|
|
|
|
|
# Separate the original and edited images and the edit prompt.
|
|
|
|
|
original_image, edited_image = transformed_images.chunk(2)
|
|
|
|
|
original_image = original_image.squeeze(0)
|
|
|
|
|
edited_image = edited_image.squeeze(0)
|
|
|
|
|
final_original_images.append(original_image)
|
|
|
|
|
final_edited_images.append(edited_image)
|
|
|
|
|
|
|
|
|
|
# Get null conditioning
|
|
|
|
|
def compute_null_conditioning():
|
|
|
|
|
null_conditioning_list = []
|
|
|
|
|
for a_tokenizer, a_text_encoder in zip(tokenizers, text_encoders):
|
|
|
|
|
null_conditioning_list.append(
|
|
|
|
|
a_text_encoder(
|
|
|
|
|
tokenize_captions([""], tokenizer=a_tokenizer).to(accelerator.device),
|
|
|
|
|
output_hidden_states=True,
|
|
|
|
|
).hidden_states[-2]
|
|
|
|
|
)
|
|
|
|
|
return torch.concat(null_conditioning_list, dim=-1)
|
|
|
|
|
|
|
|
|
|
null_conditioning = compute_null_conditioning()
|
|
|
|
|
|
|
|
|
|
def compute_time_ids():
|
|
|
|
|
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
|
|
|
|
|
original_size = target_size = (args.resolution, args.resolution)
|
|
|
|
|
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
|
|
|
|
add_time_ids = torch.tensor([add_time_ids], dtype=weight_dtype)
|
|
|
|
|
return add_time_ids.to(accelerator.device).repeat(args.train_batch_size, 1)
|
|
|
|
|
|
|
|
|
|
add_time_ids = compute_time_ids()
|
|
|
|
|
|
|
|
|
|
def preprocess_train(examples):
|
|
|
|
|
# Preprocess images.
|
|
|
|
|
preprocessed_images = preprocess_images(examples)
|
|
|
|
|
# Since the original and edited images were concatenated before
|
|
|
|
|
# applying the transformations, we need to separate them and reshape
|
|
|
|
|
# them accordingly.
|
|
|
|
|
original_images, edited_images = preprocessed_images.chunk(2)
|
|
|
|
|
original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
|
|
|
|
|
edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
|
|
|
|
|
|
|
|
|
|
# Collate the preprocessed images into the `examples`.
|
|
|
|
|
examples["original_pixel_values"] = original_images
|
|
|
|
|
examples["edited_pixel_values"] = edited_images
|
|
|
|
|
|
|
|
|
|
# Preprocess the captions.
|
|
|
|
|
captions = list(examples[edit_prompt_column])
|
|
|
|
|
prompt_embeds_all, add_text_embeds_all = compute_embeddings_for_prompts(captions, text_encoders, tokenizers)
|
|
|
|
|
examples["prompt_embeds"] = prompt_embeds_all
|
|
|
|
|
examples["add_text_embeds"] = add_text_embeds_all
|
|
|
|
|
return examples
|
|
|
|
|
# Pack the values.
|
|
|
|
|
samples["original_sizes"] = original_sizes
|
|
|
|
|
samples["crop_top_lefts"] = crop_top_lefts
|
|
|
|
|
samples["original_pixel_values"] = final_original_images
|
|
|
|
|
samples["edited_pixel_values"] = final_original_images
|
|
|
|
|
tokens_one, tokens_two = tokenize_captions(samples)
|
|
|
|
|
samples["input_ids_one"] = tokens_one
|
|
|
|
|
samples["input_ids_two"] = tokens_two
|
|
|
|
|
return samples
|
|
|
|
|
|
|
|
|
|
with accelerator.main_process_first():
|
|
|
|
|
if args.max_train_samples is not None:
|
|
|
|
|
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
|
|
|
|
dataset = dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
|
|
|
|
|
# Set the training transforms
|
|
|
|
|
train_dataset = dataset["train"].with_transform(preprocess_train)
|
|
|
|
|
train_dataset = dataset.with_transform(preprocess_train)
|
|
|
|
|
|
|
|
|
|
def collate_fn(examples):
|
|
|
|
|
original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples])
|
|
|
|
|
original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()
|
|
|
|
|
|
|
|
|
|
edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples])
|
|
|
|
|
edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()
|
|
|
|
|
prompt_embeds = torch.concat([example["prompt_embeds"] for example in examples], dim=0)
|
|
|
|
|
add_text_embeds = torch.concat([example["add_text_embeds"] for example in examples], dim=0)
|
|
|
|
|
|
|
|
|
|
original_sizes = [example["original_sizes"] for example in examples]
|
|
|
|
|
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
|
|
|
|
|
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
|
|
|
|
|
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
|
|
|
|
|
return {
|
|
|
|
|
"original_pixel_values": original_pixel_values,
|
|
|
|
|
"edited_pixel_values": edited_pixel_values,
|
|
|
|
|
"prompt_embeds": prompt_embeds,
|
|
|
|
|
"add_text_embeds": add_text_embeds,
|
|
|
|
|
"input_ids_one": input_ids_one,
|
|
|
|
|
"input_ids_two": input_ids_two,
|
|
|
|
|
"original_sizes": original_sizes,
|
|
|
|
|
"crop_top_lefts": crop_top_lefts,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# DataLoaders creation:
|
|
|
|
@@ -947,6 +1005,12 @@ def main():
|
|
|
|
|
else:
|
|
|
|
|
initial_global_step = 0
|
|
|
|
|
|
|
|
|
|
# Get null conditioning.
|
|
|
|
|
# Remains fixed throughout training.
|
|
|
|
|
null_conditioning_prompt_embeds, null_conditioning_pooled_prompt_embeds = encode_prompt(
|
|
|
|
|
text_encoders, tokenizers, [""]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
progress_bar = tqdm(
|
|
|
|
|
range(0, args.max_train_steps),
|
|
|
|
|
initial=initial_global_step,
|
|
|
|
@@ -982,9 +1046,13 @@ def main():
|
|
|
|
|
# (this is the forward diffusion process)
|
|
|
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
|
|
|
|
|
|
|
|
|
# SDXL additional inputs
|
|
|
|
|
encoder_hidden_states = batch["prompt_embeds"]
|
|
|
|
|
add_text_embeds = batch["add_text_embeds"]
|
|
|
|
|
# Encode prompts.
|
|
|
|
|
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
|
|
|
|
text_encoders=[text_encoder_1, text_encoder_2],
|
|
|
|
|
tokenizers=None,
|
|
|
|
|
prompt=None,
|
|
|
|
|
text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"]],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Get the additional image embedding for conditioning.
|
|
|
|
|
# Instead of getting a diagonal Gaussian here, we simply take the mode.
|
|
|
|
@@ -992,7 +1060,7 @@ def main():
|
|
|
|
|
original_pixel_values = batch["original_pixel_values"].to(dtype=weight_dtype)
|
|
|
|
|
else:
|
|
|
|
|
original_pixel_values = batch["original_pixel_values"]
|
|
|
|
|
original_image_embeds = vae.encode(original_pixel_values).latent_dist.sample()
|
|
|
|
|
original_image_embeds = vae.encode(original_pixel_values).latent_dist.mode()
|
|
|
|
|
if args.pretrained_vae_model_name_or_path is None:
|
|
|
|
|
original_image_embeds = original_image_embeds.to(weight_dtype)
|
|
|
|
|
|
|
|
|
@@ -1003,8 +1071,13 @@ def main():
|
|
|
|
|
# Sample masks for the edit prompts.
|
|
|
|
|
prompt_mask = random_p < 2 * args.conditioning_dropout_prob
|
|
|
|
|
prompt_mask = prompt_mask.reshape(bsz, 1, 1)
|
|
|
|
|
pooled_prompt_mask = prompt_mask.reshape(bsz, 1)
|
|
|
|
|
|
|
|
|
|
# Final text conditioning.
|
|
|
|
|
encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
|
|
|
|
|
prompt_embeds = torch.where(prompt_mask, null_conditioning_prompt_embeds, prompt_embeds)
|
|
|
|
|
pooled_prompt_embeds = torch.where(
|
|
|
|
|
pooled_prompt_mask, null_conditioning_pooled_prompt_embeds, pooled_prompt_embeds
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Sample masks for the original images.
|
|
|
|
|
image_mask_dtype = original_image_embeds.dtype
|
|
|
|
@@ -1027,11 +1100,24 @@ def main():
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
|
|
|
|
|
|
|
|
|
# Predict the noise residual and compute loss
|
|
|
|
|
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
|
|
|
|
# Compute additional embedding inputs.
|
|
|
|
|
# time ids
|
|
|
|
|
def compute_time_ids(original_size, crops_coords_top_left):
|
|
|
|
|
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
|
|
|
|
target_size = (args.resolution, args.resolution)
|
|
|
|
|
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
|
|
|
|
add_time_ids = torch.tensor([add_time_ids])
|
|
|
|
|
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
|
|
|
|
return add_time_ids
|
|
|
|
|
|
|
|
|
|
add_time_ids = torch.cat(
|
|
|
|
|
[compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
|
|
|
|
|
)
|
|
|
|
|
unet_added_conditions = {"time_ids": add_time_ids, "text_embeds": pooled_prompt_embeds}
|
|
|
|
|
|
|
|
|
|
# Predict the noise residual and compute loss
|
|
|
|
|
model_pred = unet(
|
|
|
|
|
concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
|
|
|
|
concatenated_noisy_latents, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
|
|
|
|
|
).sample
|
|
|
|
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
|
|
|
|
|
|
|
|
@@ -1056,8 +1142,8 @@ def main():
|
|
|
|
|
accelerator.log({"train_loss": train_loss}, step=global_step)
|
|
|
|
|
train_loss = 0.0
|
|
|
|
|
|
|
|
|
|
if global_step % args.checkpointing_steps == 0:
|
|
|
|
|
if accelerator.is_main_process:
|
|
|
|
|
if accelerator.is_main_process:
|
|
|
|
|
if global_step % args.checkpointing_steps == 0:
|
|
|
|
|
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
|
|
|
|
if args.checkpoints_total_limit is not None:
|
|
|
|
|
checkpoints = os.listdir(args.output_dir)
|
|
|
|
@@ -1085,81 +1171,37 @@ def main():
|
|
|
|
|
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
|
|
|
|
progress_bar.set_postfix(**logs)
|
|
|
|
|
|
|
|
|
|
### BEGIN: Perform validation every `validation_epochs` steps
|
|
|
|
|
if global_step % args.validation_steps == 0 or global_step == 1:
|
|
|
|
|
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
|
|
|
|
f" {args.validation_prompt}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# create pipeline
|
|
|
|
|
if args.use_ema:
|
|
|
|
|
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
|
|
|
|
|
ema_unet.store(unet.parameters())
|
|
|
|
|
ema_unet.copy_to(unet.parameters())
|
|
|
|
|
|
|
|
|
|
# The models need unwrapping because for compatibility in distributed training mode.
|
|
|
|
|
pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
|
|
|
|
|
args.pretrained_model_name_or_path,
|
|
|
|
|
unet=accelerator.unwrap_model(unet),
|
|
|
|
|
text_encoder=text_encoder_1,
|
|
|
|
|
text_encoder_2=text_encoder_2,
|
|
|
|
|
tokenizer=tokenizer_1,
|
|
|
|
|
tokenizer_2=tokenizer_2,
|
|
|
|
|
vae=vae,
|
|
|
|
|
revision=args.revision,
|
|
|
|
|
torch_dtype=weight_dtype,
|
|
|
|
|
)
|
|
|
|
|
pipeline = pipeline.to(accelerator.device)
|
|
|
|
|
pipeline.set_progress_bar_config(disable=True)
|
|
|
|
|
|
|
|
|
|
# run inference
|
|
|
|
|
# Save validation images
|
|
|
|
|
val_save_dir = os.path.join(args.output_dir, "validation_images")
|
|
|
|
|
if not os.path.exists(val_save_dir):
|
|
|
|
|
os.makedirs(val_save_dir)
|
|
|
|
|
|
|
|
|
|
original_image = (
|
|
|
|
|
lambda image_url_or_path: load_image(image_url_or_path)
|
|
|
|
|
if urlparse(image_url_or_path).scheme
|
|
|
|
|
else Image.open(image_url_or_path).convert("RGB")
|
|
|
|
|
)(args.val_image_url_or_path)
|
|
|
|
|
with torch.autocast(
|
|
|
|
|
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
|
|
|
|
|
):
|
|
|
|
|
edited_images = []
|
|
|
|
|
for val_img_idx in range(args.num_validation_images):
|
|
|
|
|
a_val_img = pipeline(
|
|
|
|
|
args.validation_prompt,
|
|
|
|
|
image=original_image,
|
|
|
|
|
num_inference_steps=20,
|
|
|
|
|
image_guidance_scale=1.5,
|
|
|
|
|
guidance_scale=7,
|
|
|
|
|
generator=generator,
|
|
|
|
|
).images[0]
|
|
|
|
|
edited_images.append(a_val_img)
|
|
|
|
|
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
|
|
|
|
|
|
|
|
|
|
for tracker in accelerator.trackers:
|
|
|
|
|
if tracker.name == "wandb":
|
|
|
|
|
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
|
|
|
|
|
for edited_image in edited_images:
|
|
|
|
|
wandb_table.add_data(
|
|
|
|
|
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
|
|
|
|
|
)
|
|
|
|
|
tracker.log({"validation": wandb_table})
|
|
|
|
|
if args.use_ema:
|
|
|
|
|
# Switch back to the original UNet parameters.
|
|
|
|
|
ema_unet.restore(unet.parameters())
|
|
|
|
|
|
|
|
|
|
del pipeline
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
### END: Perform validation every `validation_epochs` steps
|
|
|
|
|
|
|
|
|
|
if global_step >= args.max_train_steps:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if accelerator.is_main_process:
|
|
|
|
|
if (
|
|
|
|
|
(args.val_image_url_or_path is not None)
|
|
|
|
|
and (args.validation_prompt is not None)
|
|
|
|
|
and (epoch % args.validation_epochs == 0)
|
|
|
|
|
):
|
|
|
|
|
if args.use_ema:
|
|
|
|
|
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
|
|
|
|
|
ema_unet.store(unet.parameters())
|
|
|
|
|
ema_unet.copy_to(unet.parameters())
|
|
|
|
|
|
|
|
|
|
log_validation(
|
|
|
|
|
vae=vae,
|
|
|
|
|
unet=unet,
|
|
|
|
|
text_encoder_1=text_encoder_1,
|
|
|
|
|
text_encoder_2=text_encoder_2,
|
|
|
|
|
tokenizer_1=tokenizer_1,
|
|
|
|
|
tokenizer_2=tokenizer_2,
|
|
|
|
|
args=args,
|
|
|
|
|
accelerator=accelerator,
|
|
|
|
|
weight_dtype=weight_dtype,
|
|
|
|
|
global_step=global_step,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if args.use_ema:
|
|
|
|
|
# Switch back to the original UNet parameters.
|
|
|
|
|
ema_unet.restore(unet.parameters())
|
|
|
|
|
|
|
|
|
|
# Create the pipeline using the trained modules and save it.
|
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
|
if accelerator.is_main_process:
|
|
|
|
@@ -1189,8 +1231,15 @@ def main():
|
|
|
|
|
|
|
|
|
|
if args.validation_prompt is not None:
|
|
|
|
|
edited_images = []
|
|
|
|
|
original_image = (
|
|
|
|
|
lambda image_url_or_path: load_image(image_url_or_path)
|
|
|
|
|
if urlparse(image_url_or_path).scheme
|
|
|
|
|
else Image.open(image_url_or_path).convert("RGB")
|
|
|
|
|
)(args.val_image_url_or_path)
|
|
|
|
|
original_image = original_image.resize((args.resolution, args.resolution))
|
|
|
|
|
|
|
|
|
|
pipeline = pipeline.to(accelerator.device)
|
|
|
|
|
with torch.autocast(str(accelerator.device).replace(":0", "")):
|
|
|
|
|
with torch.autocast(str(accelerator.device)):
|
|
|
|
|
for _ in range(args.num_validation_images):
|
|
|
|
|
edited_images.append(
|
|
|
|
|
pipeline(
|
|
|
|
@@ -1198,7 +1247,7 @@ def main():
|
|
|
|
|
image=original_image,
|
|
|
|
|
num_inference_steps=20,
|
|
|
|
|
image_guidance_scale=1.5,
|
|
|
|
|
guidance_scale=7,
|
|
|
|
|
guidance_scale=5.0,
|
|
|
|
|
generator=generator,
|
|
|
|
|
).images[0]
|
|
|
|
|
)
|
|
|
|
|