Compare commits

..

2 Commits

Author SHA1 Message Date
Sayak Paul b08fc2d9d5 Merge branch 'main' into examples-test-fix 2023-12-14 21:49:52 +05:30
Dhruv Nair a031abdc89 add peft to training deps 2023-12-12 14:07:45 +00:00
58 changed files with 302 additions and 906 deletions
-2
View File
@@ -198,8 +198,6 @@
title: Outputs
title: Main Classes
- sections:
- local: api/loaders/ip_adapter
title: IP-Adapter
- local: api/loaders/lora
title: LoRA
- local: api/loaders/single_file
-25
View File
@@ -1,25 +0,0 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# IP-Adapter
[IP-Adapter](https://hf.co/papers/2308.06721) is a lightweight adapter that enables prompting a diffusion model with an image. This method decouples the cross-attention layers of the image and text features. The image features are generated from an image encoder. Files generated from IP-Adapter are only ~100MBs.
<Tip>
Learn how to load an IP-Adapter checkpoint and image in the [IP-Adapter](../../using-diffusers/loading_adapters#ip-adapter) loading guide.
</Tip>
## IPAdapterMixin
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
+2 -2
View File
@@ -179,7 +179,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--dataloader_num_workers=8 \
--resolution=512 \
--resolution=512
--center_crop \
--random_flip \
--train_batch_size=1 \
@@ -214,4 +214,4 @@ image = pipeline("A pokemon with blue eyes").images[0]
Congratulations on training a new model with LoRA! To learn more about how to use your new model, the following guides may be helpful:
- Learn how to [load different LoRA formats](../using-diffusers/loading_adapters#LoRA) trained using community trainers like Kohya and TheLastBen.
- Learn how to use and [combine multiple LoRA's](../tutorials/using_peft_for_inference) with PEFT for inference.
- Learn how to use and [combine multiple LoRA's](../tutorials/using_peft_for_inference) with PEFT for inference.
@@ -112,7 +112,7 @@ def save_model_card(
repo_folder=None,
vae_path=None,
):
img_str = "widget:\n"
img_str = "widget:\n" if images else ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"""
@@ -121,10 +121,6 @@ def save_model_card(
url:
"image_{i}.png"
"""
if not images:
img_str += f"""
- text: '{instance_prompt}'
"""
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
diffusers_imports_pivotal = ""
@@ -161,6 +157,8 @@ tags:
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
widget:
- text: '{validation_prompt if validation_prompt else instance_prompt}'
---
"""
@@ -2012,42 +2010,43 @@ def main(args):
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
)
# Final inference
# Load previous pipeline
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
# load attention processors
pipeline.load_lora_weights(args.output_dir)
# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
# Final inference
# Load previous pipeline
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
# load attention processors
pipeline.load_lora_weights(args.output_dir)
# run inference
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
images = [
@@ -156,7 +156,7 @@ class WebdatasetFilter:
return False
class SDText2ImageDataset:
class Text2ImageDataset:
def __init__(
self,
train_shards_path_or_url: Union[str, List[str]],
@@ -359,43 +359,19 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
# Compare LCMScheduler.step, Step 4
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "sample":
pred_x_0 = model_output
elif prediction_type == "v_prediction":
pred_x_0 = alphas * sample - sigmas * model_output
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
return pred_x_0
# Based on step 4 in DDIMScheduler.step
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_epsilon = model_output
elif prediction_type == "sample":
pred_epsilon = (sample - alphas * model_output) / sigmas
elif prediction_type == "v_prediction":
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_epsilon
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
@@ -859,35 +835,34 @@ def main(args):
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
)
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
# The scheduler calculates the alpha and sigma schedule for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
# Initialize the DDIM ODE solver for distillation.
solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps,
ddim_timesteps=args.num_ddim_timesteps,
)
# 2. Load tokenizers from SD 1.X/2.X checkpoint.
# 2. Load tokenizers from SD-XL checkpoint.
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
)
# 3. Load text encoders from SD 1.X/2.X checkpoint.
# 3. Load text encoders from SD-1.5 checkpoint.
# import correct text encoder classes
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
)
# 4. Load VAE from SD 1.X/2.X checkpoint
# 4. Load VAE from SD-XL checkpoint (or more stable VAE)
vae = AutoencoderKL.from_pretrained(
args.pretrained_teacher_model,
subfolder="vae",
revision=args.teacher_revision,
)
# 5. Load teacher U-Net from SD 1.X/2.X checkpoint
# 5. Load teacher U-Net from SD-XL checkpoint
teacher_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
)
@@ -897,7 +872,7 @@ def main(args):
text_encoder.requires_grad_(False)
teacher_unet.requires_grad_(False)
# 7. Create online student U-Net.
# 7. Create online (`unet`) student U-Nets.
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
)
@@ -960,7 +935,6 @@ def main(args):
# Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device)
sigma_schedule = sigma_schedule.to(accelerator.device)
# Move the ODE solver to accelerator.device.
solver = solver.to(accelerator.device)
# 10. Handle saving and loading of checkpoints
@@ -1037,14 +1011,13 @@ def main(args):
eps=args.adam_epsilon,
)
# 13. Dataset creation and data processing
# Here, we compute not just the text embeddings but also the additional embeddings
# needed for the SD XL UNet to operate.
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
return {"prompt_embeds": prompt_embeds}
dataset = SDText2ImageDataset(
dataset = Text2ImageDataset(
train_shards_path_or_url=args.train_shards_path_or_url,
num_train_examples=args.max_train_samples,
per_gpu_batch_size=args.train_batch_size,
@@ -1064,7 +1037,6 @@ def main(args):
tokenizer=tokenizer,
)
# 14. LR Scheduler creation
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
@@ -1079,7 +1051,6 @@ def main(args):
num_training_steps=args.max_train_steps,
)
# 15. Prepare for training
# Prepare everything with our `accelerator`.
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
@@ -1101,7 +1072,7 @@ def main(args):
).input_ids.to(accelerator.device)
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
# 16. Train!
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
@@ -1152,7 +1123,6 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# 1. Load and process the image and text conditioning
image, text = batch
image = image.to(accelerator.device, non_blocking=True)
@@ -1170,37 +1140,37 @@ def main(args):
latents = latents * vae.config.scaling_factor
latents = latents.to(weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noise = torch.randn_like(latents)
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
# 5. Sample a random guidance scale w from U[w_min, w_max]
# Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype)
# 6. Prepare prompt embeds and unet_added_conditions
# 20.4.8. Prepare prompt embeds and unet_added_conditions
prompt_embeds = encoded_text.pop("prompt_embeds")
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
noise_pred = unet(
noisy_model_input,
start_timesteps,
@@ -1209,7 +1179,7 @@ def main(args):
added_cond_kwargs=encoded_text,
).sample
pred_x_0 = get_predicted_original_sample(
pred_x_0 = predicted_origin(
noise_pred,
start_timesteps,
noisy_model_input,
@@ -1220,27 +1190,17 @@ def main(args):
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
# Get teacher model prediction on noisy_latents and conditional embedding
with torch.no_grad():
with torch.autocast("cuda"):
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
start_timesteps,
encoder_hidden_states=prompt_embeds.to(weight_dtype),
).sample
cond_pred_x0 = get_predicted_original_sample(
cond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
cond_pred_noise = get_predicted_noise(
cond_pred_x0 = predicted_origin(
cond_teacher_output,
start_timesteps,
noisy_model_input,
@@ -1249,21 +1209,13 @@ def main(args):
sigma_schedule,
)
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
# Get teacher model prediction on noisy_latents and unconditional embedding
uncond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
start_timesteps,
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
).sample
uncond_pred_x0 = get_predicted_original_sample(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
uncond_pred_noise = get_predicted_noise(
uncond_pred_x0 = predicted_origin(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
@@ -1272,17 +1224,12 @@ def main(args):
sigma_schedule,
)
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
# augmented PF-ODE trajectory (solving backward in time)
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
target_noise_pred = unet(
@@ -1291,7 +1238,7 @@ def main(args):
timestep_cond=None,
encoder_hidden_states=prompt_embeds.float(),
).sample
pred_x_0 = get_predicted_original_sample(
pred_x_0 = predicted_origin(
target_noise_pred,
timesteps,
x_prev,
@@ -1301,7 +1248,7 @@ def main(args):
)
target = c_skip * x_prev + c_out * pred_x_0
# 10. Calculate loss
# 20.4.13. Calculate loss
if args.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.loss_type == "huber":
@@ -1309,7 +1256,7 @@ def main(args):
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
)
# 11. Backpropagate on the online student model (`unet`)
# 20.4.14. Backpropagate on the online student model (`unet`)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
@@ -162,7 +162,7 @@ class WebdatasetFilter:
return False
class SDXLText2ImageDataset:
class Text2ImageDataset:
def __init__(
self,
train_shards_path_or_url: Union[str, List[str]],
@@ -346,43 +346,19 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
# Compare LCMScheduler.step, Step 4
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "sample":
pred_x_0 = model_output
elif prediction_type == "v_prediction":
pred_x_0 = alphas * sample - sigmas * model_output
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
return pred_x_0
# Based on step 4 in DDIMScheduler.step
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_epsilon = model_output
elif prediction_type == "sample":
pred_epsilon = (sample - alphas * model_output) / sigmas
elif prediction_type == "v_prediction":
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_epsilon
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
@@ -854,10 +830,9 @@ def main(args):
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
)
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
# The scheduler calculates the alpha and sigma schedule for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
# Initialize the DDIM ODE solver for distillation.
solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps,
@@ -911,7 +886,7 @@ def main(args):
text_encoder_two.requires_grad_(False)
teacher_unet.requires_grad_(False)
# 7. Create online student U-Net.
# 7. Create online (`unet`) student U-Nets.
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
)
@@ -975,7 +950,6 @@ def main(args):
# Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device)
sigma_schedule = sigma_schedule.to(accelerator.device)
# Move the ODE solver to accelerator.device.
solver = solver.to(accelerator.device)
# 10. Handle saving and loading of checkpoints
@@ -1083,7 +1057,7 @@ def main(args):
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
dataset = SDXLText2ImageDataset(
dataset = Text2ImageDataset(
train_shards_path_or_url=args.train_shards_path_or_url,
num_train_examples=args.max_train_samples,
per_gpu_batch_size=args.train_batch_size,
@@ -1201,7 +1175,6 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates)
image, text, orig_size, crop_coords = batch
image = image.to(accelerator.device, non_blocking=True)
@@ -1223,37 +1196,37 @@ def main(args):
latents = latents * vae.config.scaling_factor
if args.pretrained_vae_model_name_or_path is None:
latents = latents.to(weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noise = torch.randn_like(latents)
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
# 5. Sample a random guidance scale w from U[w_min, w_max]
# Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype)
# 6. Prepare prompt embeds and unet_added_conditions
# 20.4.8. Prepare prompt embeds and unet_added_conditions
prompt_embeds = encoded_text.pop("prompt_embeds")
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
noise_pred = unet(
noisy_model_input,
start_timesteps,
@@ -1262,7 +1235,7 @@ def main(args):
added_cond_kwargs=encoded_text,
).sample
pred_x_0 = get_predicted_original_sample(
pred_x_0 = predicted_origin(
noise_pred,
start_timesteps,
noisy_model_input,
@@ -1273,28 +1246,18 @@ def main(args):
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
# Get teacher model prediction on noisy_latents and conditional embedding
with torch.no_grad():
with torch.autocast("cuda"):
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
start_timesteps,
encoder_hidden_states=prompt_embeds.to(weight_dtype),
added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
).sample
cond_pred_x0 = get_predicted_original_sample(
cond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
cond_pred_noise = get_predicted_noise(
cond_pred_x0 = predicted_origin(
cond_teacher_output,
start_timesteps,
noisy_model_input,
@@ -1303,7 +1266,7 @@ def main(args):
sigma_schedule,
)
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
# Get teacher model prediction on noisy_latents and unconditional embedding
uncond_added_conditions = copy.deepcopy(encoded_text)
uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
uncond_teacher_output = teacher_unet(
@@ -1312,15 +1275,7 @@ def main(args):
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
).sample
uncond_pred_x0 = get_predicted_original_sample(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
uncond_pred_noise = get_predicted_noise(
uncond_pred_x0 = predicted_origin(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
@@ -1329,17 +1284,12 @@ def main(args):
sigma_schedule,
)
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
# augmented PF-ODE trajectory (solving backward in time)
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
with torch.no_grad():
with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
target_noise_pred = unet(
@@ -1349,7 +1299,7 @@ def main(args):
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample
pred_x_0 = get_predicted_original_sample(
pred_x_0 = predicted_origin(
target_noise_pred,
timesteps,
x_prev,
@@ -1359,7 +1309,7 @@ def main(args):
)
target = c_skip * x_prev + c_out * pred_x_0
# 10. Calculate loss
# 20.4.13. Calculate loss
if args.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.loss_type == "huber":
@@ -1367,7 +1317,7 @@ def main(args):
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
)
# 11. Backpropagate on the online student model (`unet`)
# 20.4.14. Backpropagate on the online student model (`unet`)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
@@ -138,7 +138,7 @@ class WebdatasetFilter:
return False
class SDText2ImageDataset:
class Text2ImageDataset:
def __init__(
self,
train_shards_path_or_url: Union[str, List[str]],
@@ -336,43 +336,19 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
# Compare LCMScheduler.step, Step 4
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "sample":
pred_x_0 = model_output
elif prediction_type == "v_prediction":
pred_x_0 = alphas * sample - sigmas * model_output
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
return pred_x_0
# Based on step 4 in DDIMScheduler.step
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_epsilon = model_output
elif prediction_type == "sample":
pred_epsilon = (sample - alphas * model_output) / sigmas
elif prediction_type == "v_prediction":
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_epsilon
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
@@ -847,35 +823,34 @@ def main(args):
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
)
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
# The scheduler calculates the alpha and sigma schedule for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
# Initialize the DDIM ODE solver for distillation.
solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps,
ddim_timesteps=args.num_ddim_timesteps,
)
# 2. Load tokenizers from SD 1.X/2.X checkpoint.
# 2. Load tokenizers from SD-XL checkpoint.
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
)
# 3. Load text encoders from SD 1.X/2.X checkpoint.
# 3. Load text encoders from SD-1.5 checkpoint.
# import correct text encoder classes
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
)
# 4. Load VAE from SD 1.X/2.X checkpoint
# 4. Load VAE from SD-XL checkpoint (or more stable VAE)
vae = AutoencoderKL.from_pretrained(
args.pretrained_teacher_model,
subfolder="vae",
revision=args.teacher_revision,
)
# 5. Load teacher U-Net from SD 1.X/2.X checkpoint
# 5. Load teacher U-Net from SD-XL checkpoint
teacher_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
)
@@ -885,7 +860,7 @@ def main(args):
text_encoder.requires_grad_(False)
teacher_unet.requires_grad_(False)
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
# 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
if teacher_unet.config.time_cond_proj_dim is None:
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
@@ -894,8 +869,8 @@ def main(args):
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
unet.train()
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
# Initialize from (online) unet
# 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging).
# Initialize from unet
target_unet = UNet2DConditionModel(**teacher_unet.config)
target_unet.load_state_dict(unet.state_dict())
target_unet.train()
@@ -912,7 +887,7 @@ def main(args):
f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
)
# 9. Handle mixed precision and device placement
# 10. Handle mixed precision and device placement
# For mixed precision training we cast all non-trainable weigths to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
@@ -939,7 +914,7 @@ def main(args):
sigma_schedule = sigma_schedule.to(accelerator.device)
solver = solver.to(accelerator.device)
# 10. Handle saving and loading of checkpoints
# 11. Handle saving and loading of checkpoints
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -973,7 +948,7 @@ def main(args):
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# 11. Enable optimizations
# 12. Enable optimizations
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
@@ -1019,14 +994,13 @@ def main(args):
eps=args.adam_epsilon,
)
# 13. Dataset creation and data processing
# Here, we compute not just the text embeddings but also the additional embeddings
# needed for the SD XL UNet to operate.
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
return {"prompt_embeds": prompt_embeds}
dataset = SDText2ImageDataset(
dataset = Text2ImageDataset(
train_shards_path_or_url=args.train_shards_path_or_url,
num_train_examples=args.max_train_samples,
per_gpu_batch_size=args.train_batch_size,
@@ -1046,7 +1020,6 @@ def main(args):
tokenizer=tokenizer,
)
# 14. LR Scheduler creation
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
@@ -1061,7 +1034,6 @@ def main(args):
num_training_steps=args.max_train_steps,
)
# 15. Prepare for training
# Prepare everything with our `accelerator`.
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
@@ -1083,7 +1055,7 @@ def main(args):
).input_ids.to(accelerator.device)
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
# 16. Train!
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
@@ -1134,7 +1106,6 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# 1. Load and process the image and text conditioning
image, text = batch
image = image.to(accelerator.device, non_blocking=True)
@@ -1152,28 +1123,29 @@ def main(args):
latents = latents * vae.config.scaling_factor
latents = latents.to(weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noise = torch.randn_like(latents)
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
# 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
@@ -1181,10 +1153,10 @@ def main(args):
w = w.to(device=latents.device, dtype=latents.dtype)
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
# 6. Prepare prompt embeds and unet_added_conditions
# 20.4.8. Prepare prompt embeds and unet_added_conditions
prompt_embeds = encoded_text.pop("prompt_embeds")
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
noise_pred = unet(
noisy_model_input,
start_timesteps,
@@ -1193,7 +1165,7 @@ def main(args):
added_cond_kwargs=encoded_text,
).sample
pred_x_0 = get_predicted_original_sample(
pred_x_0 = predicted_origin(
noise_pred,
start_timesteps,
noisy_model_input,
@@ -1204,27 +1176,17 @@ def main(args):
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
# Get teacher model prediction on noisy_latents and conditional embedding
with torch.no_grad():
with torch.autocast("cuda"):
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
start_timesteps,
encoder_hidden_states=prompt_embeds.to(weight_dtype),
).sample
cond_pred_x0 = get_predicted_original_sample(
cond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
cond_pred_noise = get_predicted_noise(
cond_pred_x0 = predicted_origin(
cond_teacher_output,
start_timesteps,
noisy_model_input,
@@ -1233,21 +1195,13 @@ def main(args):
sigma_schedule,
)
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
# Get teacher model prediction on noisy_latents and unconditional embedding
uncond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
start_timesteps,
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
).sample
uncond_pred_x0 = get_predicted_original_sample(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
uncond_pred_noise = get_predicted_noise(
uncond_pred_x0 = predicted_origin(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
@@ -1256,16 +1210,12 @@ def main(args):
sigma_schedule,
)
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
# augmented PF-ODE trajectory (solving backward in time)
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
target_noise_pred = target_unet(
@@ -1274,7 +1224,7 @@ def main(args):
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(),
).sample
pred_x_0 = get_predicted_original_sample(
pred_x_0 = predicted_origin(
target_noise_pred,
timesteps,
x_prev,
@@ -1284,7 +1234,7 @@ def main(args):
)
target = c_skip * x_prev + c_out * pred_x_0
# 10. Calculate loss
# 20.4.13. Calculate loss
if args.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.loss_type == "huber":
@@ -1292,7 +1242,7 @@ def main(args):
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
)
# 11. Backpropagate on the online student model (`unet`)
# 20.4.14. Backpropagate on the online student model (`unet`)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
@@ -1302,7 +1252,7 @@ def main(args):
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
# 12. Make EMA update to target student model parameters (`target_unet`)
# 20.4.15. Make EMA update to target student model parameters
update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
progress_bar.update(1)
global_step += 1
@@ -144,7 +144,7 @@ class WebdatasetFilter:
return False
class SDXLText2ImageDataset:
class Text2ImageDataset:
def __init__(
self,
train_shards_path_or_url: Union[str, List[str]],
@@ -324,43 +324,19 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
# Compare LCMScheduler.step, Step 4
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "sample":
pred_x_0 = model_output
elif prediction_type == "v_prediction":
pred_x_0 = alphas * sample - sigmas * model_output
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
return pred_x_0
# Based on step 4 in DDIMScheduler.step
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_epsilon = model_output
elif prediction_type == "sample":
pred_epsilon = (sample - alphas * model_output) / sigmas
elif prediction_type == "v_prediction":
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_epsilon
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
@@ -887,10 +863,9 @@ def main(args):
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
)
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
# The scheduler calculates the alpha and sigma schedule for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
# Initialize the DDIM ODE solver for distillation.
solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps,
@@ -944,7 +919,7 @@ def main(args):
text_encoder_two.requires_grad_(False)
teacher_unet.requires_grad_(False)
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
# 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
if teacher_unet.config.time_cond_proj_dim is None:
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
@@ -953,8 +928,8 @@ def main(args):
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
unet.train()
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
# Initialize from (online) unet
# 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging).
# Initialize from unet
target_unet = UNet2DConditionModel(**teacher_unet.config)
target_unet.load_state_dict(unet.state_dict())
target_unet.train()
@@ -996,7 +971,6 @@ def main(args):
# Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device)
sigma_schedule = sigma_schedule.to(accelerator.device)
# Move the ODE solver to accelerator.device.
solver = solver.to(accelerator.device)
# 10. Handle saving and loading of checkpoints
@@ -1110,7 +1084,7 @@ def main(args):
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
dataset = SDXLText2ImageDataset(
dataset = Text2ImageDataset(
train_shards_path_or_url=args.train_shards_path_or_url,
num_train_examples=args.max_train_samples,
per_gpu_batch_size=args.train_batch_size,
@@ -1228,7 +1202,6 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates)
image, text, orig_size, crop_coords = batch
image = image.to(accelerator.device, non_blocking=True)
@@ -1250,39 +1223,38 @@ def main(args):
latents = latents * vae.config.scaling_factor
if args.pretrained_vae_model_name_or_path is None:
latents = latents.to(weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noise = torch.randn_like(latents)
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
# 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype)
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
# 6. Prepare prompt embeds and unet_added_conditions
# 20.4.8. Prepare prompt embeds and unet_added_conditions
prompt_embeds = encoded_text.pop("prompt_embeds")
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
noise_pred = unet(
noisy_model_input,
start_timesteps,
@@ -1291,7 +1263,7 @@ def main(args):
added_cond_kwargs=encoded_text,
).sample
pred_x_0 = get_predicted_original_sample(
pred_x_0 = predicted_origin(
noise_pred,
start_timesteps,
noisy_model_input,
@@ -1302,28 +1274,18 @@ def main(args):
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
# Get teacher model prediction on noisy_latents and conditional embedding
with torch.no_grad():
with torch.autocast("cuda"):
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
start_timesteps,
encoder_hidden_states=prompt_embeds.to(weight_dtype),
added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
).sample
cond_pred_x0 = get_predicted_original_sample(
cond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
cond_pred_noise = get_predicted_noise(
cond_pred_x0 = predicted_origin(
cond_teacher_output,
start_timesteps,
noisy_model_input,
@@ -1332,7 +1294,7 @@ def main(args):
sigma_schedule,
)
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
# Get teacher model prediction on noisy_latents and unconditional embedding
uncond_added_conditions = copy.deepcopy(encoded_text)
uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
uncond_teacher_output = teacher_unet(
@@ -1341,15 +1303,7 @@ def main(args):
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
).sample
uncond_pred_x0 = get_predicted_original_sample(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
uncond_pred_noise = get_predicted_noise(
uncond_pred_x0 = predicted_origin(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
@@ -1358,16 +1312,12 @@ def main(args):
sigma_schedule,
)
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
# augmented PF-ODE trajectory (solving backward in time)
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
target_noise_pred = target_unet(
@@ -1377,7 +1327,7 @@ def main(args):
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample
pred_x_0 = get_predicted_original_sample(
pred_x_0 = predicted_origin(
target_noise_pred,
timesteps,
x_prev,
@@ -1387,7 +1337,7 @@ def main(args):
)
target = c_skip * x_prev + c_out * pred_x_0
# 10. Calculate loss
# 20.4.13. Calculate loss
if args.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.loss_type == "huber":
@@ -1395,7 +1345,7 @@ def main(args):
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
)
# 11. Backpropagate on the online student model (`unet`)
# 20.4.14. Backpropagate on the online student model (`unet`)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
@@ -1405,7 +1355,7 @@ def main(args):
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
# 12. Make EMA update to target student model parameters (`target_unet`)
# 20.4.15. Make EMA update to target student model parameters
update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
progress_bar.update(1)
global_step += 1
+3 -17
View File
@@ -880,16 +880,11 @@ def main(args):
unet_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None
unet_lora_config = None
text_encoder_lora_config = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_lora_config = model.peft_config["default"]
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -900,8 +895,6 @@ def main(args):
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_lora_config,
)
def load_model_hook(models, input_dir):
@@ -918,12 +911,10 @@ def main(args):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
)
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_, config=metadata
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
)
accelerator.register_save_state_pre_hook(save_model_hook)
@@ -1324,22 +1315,17 @@ def main(args):
unet = unet.to(torch.float32)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
if args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
text_encoder_lora_config = text_encoder.peft_config["default"]
else:
text_encoder_state_dict = None
text_encoder_lora_config = None
LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_lora_config,
)
# Final inference
@@ -1033,20 +1033,13 @@ def main(args):
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
unet_lora_config = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_one_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_two_lora_config = model.peft_config["default"]
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1058,9 +1051,6 @@ def main(args):
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)
def load_model_hook(models, input_dir):
@@ -1080,19 +1070,17 @@ def main(args):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
)
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)
accelerator.register_save_state_pre_hook(save_model_hook)
@@ -1628,29 +1616,21 @@ def main(args):
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
text_encoder_one_lora_config = text_encoder_one.peft_config["default"]
text_encoder_two_lora_config = text_encoder_two.peft_config["default"]
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None
StableDiffusionXLPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)
# Final inference
@@ -833,12 +833,10 @@ def main():
accelerator.save_state(save_path)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_state_dict,
unet_lora_config=unet_lora_config,
safe_serialization=True,
)
@@ -900,12 +898,10 @@ def main():
unet = unet.to(torch.float32)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
unet_lora_config=unet_lora_config,
)
if args.push_to_hub:
@@ -682,20 +682,13 @@ def main(args):
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
unet_lora_config = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_one_lora_config = model.peft_config["default"]
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_two_lora_config = model.peft_config["default"]
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -707,9 +700,6 @@ def main(args):
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)
def load_model_hook(models, input_dir):
@@ -729,19 +719,17 @@ def main(args):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
)
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)
accelerator.register_save_state_pre_hook(save_model_hook)
@@ -1206,7 +1194,6 @@ def main(args):
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_config = unet.peft_config["default"]
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
@@ -1214,23 +1201,15 @@ def main(args):
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
text_encoder_one_lora_config = text_encoder_one.peft_config["default"]
text_encoder_two_lora_config = text_encoder_two.peft_config["default"]
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
text_encoder_one_lora_config = None
text_encoder_two_lora_config = None
StableDiffusionXLPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_one_lora_config,
text_encoder_2_lora_config=text_encoder_two_lora_config,
)
del unet
+27 -144
View File
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Union
@@ -19,7 +18,6 @@ from typing import Callable, Dict, List, Optional, Union
import safetensors
import torch
from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from huggingface_hub.utils import validate_hf_hub_args
from packaging import version
from torch import nn
@@ -104,7 +102,7 @@ class LoraLoaderMixin:
`default_{i}` where i is the total number of adapters being loaded.
"""
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
@@ -115,7 +113,6 @@ class LoraLoaderMixin:
self.load_lora_into_unet(
state_dict,
network_alphas=network_alphas,
config=metadata,
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
@@ -127,7 +124,6 @@ class LoraLoaderMixin:
text_encoder=getattr(self, self.text_encoder_name)
if not hasattr(self, "text_encoder")
else self.text_encoder,
config=metadata,
lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
@@ -222,7 +218,6 @@ class LoraLoaderMixin:
}
model_file = None
metadata = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
@@ -234,9 +229,7 @@ class LoraLoaderMixin:
# determine `weight_name`.
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict,
file_extension=".safetensors",
local_files_only=local_files_only,
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
@@ -252,8 +245,6 @@ class LoraLoaderMixin:
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
with safetensors.safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata()
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
@@ -264,7 +255,7 @@ class LoraLoaderMixin:
if model_file is None:
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
pretrained_model_name_or_path_or_dict, file_extension=".bin"
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
@@ -300,15 +291,10 @@ class LoraLoaderMixin:
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
return state_dict, network_alphas, metadata
return state_dict, network_alphas
@classmethod
def _best_guess_weight_name(
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
if local_files_only or HF_HUB_OFFLINE:
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
targeted_files = []
if os.path.isfile(pretrained_model_name_or_path_or_dict):
@@ -376,7 +362,7 @@ class LoraLoaderMixin:
@classmethod
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, config=None, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -390,8 +376,6 @@ class LoraLoaderMixin:
See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
config: (`Dict`):
LoRA configuration parsed from the state dict.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
@@ -451,9 +435,7 @@ class LoraLoaderMixin:
if "lora_B" in key:
rank[key] = val.shape[1]
if config is not None and isinstance(config, dict) and len(config) > 0:
config = json.loads(config["unet"])
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True)
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
@@ -494,7 +476,6 @@ class LoraLoaderMixin:
network_alphas,
text_encoder,
prefix=None,
config=None,
lora_scale=1.0,
low_cpu_mem_usage=None,
adapter_name=None,
@@ -513,8 +494,6 @@ class LoraLoaderMixin:
The text encoder model to load the LoRA layers into.
prefix (`str`):
Expected prefix of the `text_encoder` in the `state_dict`.
config (`Dict`):
LoRA configuration parsed from state dict.
lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
@@ -588,11 +567,10 @@ class LoraLoaderMixin:
if USE_PEFT_BACKEND:
from peft import LoraConfig
if config is not None and len(config) > 0:
config = json.loads(config[prefix])
lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
)
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
@@ -800,8 +778,6 @@ class LoraLoaderMixin:
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
unet_lora_config=None,
text_encoder_lora_config=None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -829,54 +805,21 @@ class LoraLoaderMixin:
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
if not USE_PEFT_BACKEND and not safe_serialization:
if unet_lora_config or text_encoder_lora_config:
raise ValueError(
"Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` is not possible. Please install `peft`."
)
elif USE_PEFT_BACKEND and safe_serialization:
from peft import LoraConfig
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
state_dict = {}
metadata = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
def pack_metadata(config, prefix):
local_metadata = {}
if config is not None:
if isinstance(config, LoraConfig):
config = config.to_dict()
for key, value in config.items():
if isinstance(value, set):
config[key] = list(value)
config_as_string = json.dumps(config, indent=2, sort_keys=True)
local_metadata[prefix] = config_as_string
return local_metadata
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
if unet_lora_layers:
prefix = "unet"
unet_state_dict = pack_weights(unet_lora_layers, prefix)
state_dict.update(unet_state_dict)
if unet_lora_config is not None:
unet_metadata = pack_metadata(unet_lora_config, prefix)
metadata.update(unet_metadata)
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers:
prefix = "text_encoder"
text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder")
state_dict.update(text_encoder_state_dict)
if text_encoder_lora_config is not None:
text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix)
metadata.update(text_encoder_metadata)
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
# Save the model
cls.write_lora_layers(
@@ -886,7 +829,6 @@ class LoraLoaderMixin:
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
metadata=metadata,
)
@staticmethod
@@ -897,11 +839,7 @@ class LoraLoaderMixin:
weight_name: str,
save_function: Callable,
safe_serialization: bool,
metadata=None,
):
if not safe_serialization and isinstance(metadata, dict) and len(metadata) > 0:
raise ValueError("Passing `metadata` is not possible when `safe_serialization` is False.")
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
@@ -909,10 +847,8 @@ class LoraLoaderMixin:
if save_function is None:
if safe_serialization:
def save_function(weights, filename, metadata):
if metadata is None:
metadata = {"format": "pt"}
return safetensors.torch.save_file(weights, filename, metadata=metadata)
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
@@ -925,10 +861,7 @@ class LoraLoaderMixin:
else:
weight_name = LORA_WEIGHT_NAME
if save_function != torch.save:
save_function(state_dict, os.path.join(save_directory, weight_name), metadata)
else:
save_function(state_dict, os.path.join(save_directory, weight_name))
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
def unload_lora_weights(self):
@@ -1360,7 +1293,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
# pipeline.
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas, metadata = self.lora_state_dict(
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
@@ -1370,12 +1303,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(
state_dict,
network_alphas=network_alphas,
unet=self.unet,
config=metadata,
adapter_name=adapter_name,
_pipeline=self,
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
@@ -1383,7 +1311,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
config=metadata,
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
@@ -1396,7 +1323,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
config=metadata,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
@@ -1410,9 +1336,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
unet_lora_config=None,
text_encoder_lora_config=None,
text_encoder_2_lora_config=None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
@@ -1440,63 +1363,24 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
if not USE_PEFT_BACKEND and not safe_serialization:
if unet_lora_config or text_encoder_lora_config or text_encoder_2_lora_config:
raise ValueError(
"Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` or `text_encoder_2_lora_config` is not possible. Please install `peft`."
)
elif USE_PEFT_BACKEND and safe_serialization:
from peft import LoraConfig
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
)
state_dict = {}
metadata = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
def pack_metadata(config, prefix):
local_metadata = {}
if config is not None:
if isinstance(config, LoraConfig):
config = config.to_dict()
for key, value in config.items():
if isinstance(value, set):
config[key] = list(value)
config_as_string = json.dumps(config, indent=2, sort_keys=True)
local_metadata[prefix] = config_as_string
return local_metadata
if unet_lora_layers:
prefix = "unet"
unet_state_dict = pack_weights(unet_lora_layers, prefix)
state_dict.update(unet_state_dict)
if unet_lora_config is not None:
unet_metadata = pack_metadata(unet_lora_config, prefix)
metadata.update(unet_metadata)
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
prefix = "text_encoder"
text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder")
state_dict.update(text_encoder_state_dict)
if text_encoder_lora_config is not None:
text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix)
metadata.update(text_encoder_metadata)
prefix = "text_encoder_2"
text_encoder_2_state_dict = pack_weights(text_encoder_2_lora_layers, prefix)
state_dict.update(text_encoder_2_state_dict)
if text_encoder_2_lora_config is not None:
text_encoder_2_metadata = pack_metadata(text_encoder_2_lora_config, prefix)
metadata.update(text_encoder_2_metadata)
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
cls.write_lora_layers(
state_dict=state_dict,
@@ -1505,7 +1389,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
metadata=metadata,
)
def _remove_text_encoder_monkey_patch(self):
@@ -84,12 +84,6 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -147,9 +147,6 @@ class StableDiffusionControlNetPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
@@ -140,11 +140,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
@@ -251,9 +251,6 @@ class StableDiffusionControlNetInpaintPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
<Tip>
@@ -148,10 +148,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
In addition the pipeline inherits the following loading methods:
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
as well as the following saving methods:
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
@@ -129,10 +129,8 @@ class StableDiffusionXLControlNetPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
- [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
@@ -155,10 +155,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
@@ -98,9 +98,7 @@ class StableDiffusionControlNetXSPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
@@ -102,9 +102,8 @@ class StableDiffusionXLControlNetXSPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
@@ -143,11 +143,6 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
@@ -177,9 +177,6 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
@@ -232,7 +232,6 @@ class StableDiffusionInpaintPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]):
@@ -54,11 +54,6 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
<Tip warning={true}>
This is an experimental pipeline and is likely to change in the future.
@@ -67,9 +67,6 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, FromSingleFileMixi
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
@@ -43,11 +43,6 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
@@ -282,7 +282,7 @@ class Pix2PixZeroAttnProcessor:
class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
r"""
Pipeline for pixel-level image editing using Pix2Pix Zero. Based on Stable Diffusion.
Pipeline for pixel-levl image editing using Pix2Pix Zero. Based on Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
@@ -76,12 +76,6 @@ class StableDiffusionUpscalePipeline(
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
@@ -65,11 +65,6 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
Args:
prior_tokenizer ([`CLIPTokenizer`]):
A [`CLIPTokenizer`].
@@ -76,11 +76,6 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
Args:
feature_extractor ([`CLIPImageProcessor`]):
Feature extractor for image pre-processing before being encoded.
@@ -595,11 +595,10 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, IPAdapterMixin):
```py
import torch
from diffusers import StableDiffusionPipelineSafe
from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
pipeline = StableDiffusionPipelineSafe.from_pretrained(
"AIML-TUDA/stable-diffusion-safe", torch_dtype=torch.float16
).to("cuda")
)
prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker"
image = pipeline(prompt=prompt, **SafetyConfig.MEDIUM).images[0]
```
@@ -159,12 +159,12 @@ class StableDiffusionXLPipeline(
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
In addition the pipeline inherits the following loading methods:
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
as well as the following saving methods:
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
@@ -176,12 +176,12 @@ class StableDiffusionXLImg2ImgPipeline(
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
In addition the pipeline inherits the following loading methods:
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
as well as the following saving methods:
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
@@ -321,12 +321,12 @@ class StableDiffusionXLInpaintPipeline(
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
In addition the pipeline inherits the following loading methods:
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
as well as the following saving methods:
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
@@ -126,11 +126,11 @@ class StableDiffusionXLInstructPix2PixPipeline(
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
In addition the pipeline inherits the following loading methods:
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
as well as the following saving methods:
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
@@ -178,12 +178,6 @@ class StableDiffusionXLAdapterPipeline(
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
Args:
adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a
@@ -83,11 +83,6 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -159,11 +159,6 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
@@ -69,10 +69,6 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
Args:
prior ([`Prior`]):
The canonical unCLIP prior to approximate the image embedding from the text embedding.
@@ -98,7 +98,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self.custom_timesteps = False
self.is_scale_input_called = False
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
@@ -231,7 +230,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Modified _convert_to_karras implementation that takes in ramp as argument
def _convert_to_karras(self, ramp):
@@ -187,7 +187,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def step_index(self):
@@ -255,7 +254,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
@@ -214,7 +214,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def step_index(self):
@@ -291,7 +290,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
@@ -209,7 +209,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.use_karras_sigmas = use_karras_sigmas
@property
@@ -290,7 +289,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
@@ -198,7 +198,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.noise_sampler = None
self.noise_sampler_seed = noise_sampler_seed
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
@@ -348,7 +347,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.mid_point_sigma = None
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.noise_sampler = None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
@@ -197,7 +197,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.sample = None
self.order_list = self.get_order_list(num_train_timesteps)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
def get_order_list(self, num_inference_steps: int) -> List[int]:
"""
@@ -289,7 +288,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
@@ -166,7 +166,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = False
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
@@ -250,7 +249,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
@@ -237,7 +237,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.use_karras_sigmas = use_karras_sigmas
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
@@ -342,7 +341,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
@@ -148,7 +148,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.use_karras_sigmas = use_karras_sigmas
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
@@ -270,7 +269,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.dt = None
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
@@ -140,7 +140,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
@@ -296,7 +295,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._index_counter = defaultdict(int)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
@@ -140,7 +140,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
@@ -285,7 +284,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self._index_counter = defaultdict(int)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def state_in_first_order(self):
@@ -168,7 +168,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = False
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
@@ -280,7 +279,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.derivatives = []
@@ -198,7 +198,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.solver_p = solver_p
self.last_sample = None
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def step_index(self):
@@ -269,7 +268,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+3 -10
View File
@@ -138,17 +138,11 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
module.set_scale(adapter_name, 1.0)
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, is_unet=True):
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
# Try to retrive config.
alpha_retrieved = False
if config is not None:
lora_alpha = config["lora_alpha"]
alpha_retrieved = True
if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0]
@@ -160,8 +154,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None,
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
if len(set(network_alpha_dict.values())) > 1:
# get the alpha occuring the most number of times
if not alpha_retrieved:
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
@@ -172,7 +165,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None,
}
else:
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
elif not alpha_retrieved:
else:
lora_alpha = set(network_alpha_dict.values()).pop()
# layer names without the Diffusers specific
+2 -6
View File
@@ -820,9 +820,7 @@ def _is_torch_fp16_available(device):
try:
x = torch.zeros((2, 2), dtype=torch.float16).to(device)
_ = torch.mul(x, x)
return True
_ = x @ x
except Exception as e:
if device.type == "cuda":
raise ValueError(
@@ -840,9 +838,7 @@ def _is_torch_fp64_available(device):
try:
x = torch.zeros((2, 2), dtype=torch.float64).to(device)
_ = torch.mul(x, x)
return True
_ = x @ x
except Exception as e:
if device.type == "cuda":
raise ValueError(
+7 -22
View File
@@ -343,21 +343,6 @@ class LoraLoaderMixinTests(unittest.TestCase):
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
@unittest.skipIf(not torch.cuda.is_available() or not is_xformers_available(), reason="xformers requires cuda")
def test_stable_diffusion_set_xformers_attn_processors(self):
# disable_full_determinism()
device = "cuda" # ensure determinism for the device-dependent torch.Generator
components, _ = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs()
# run normal sd pipe
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run lora xformers attention
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
attn_processors = {
@@ -622,7 +607,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
orig_image_slice, orig_image_slice_two, atol=1e-3
), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
@unittest.skipIf(torch_device != "cuda" or not is_xformers_available(), "This test is supposed to run on GPU")
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_lora_unet_attn_processors_with_xformers(self):
with tempfile.TemporaryDirectory() as tmpdirname:
self.create_lora_weight_file(tmpdirname)
@@ -659,7 +644,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
if isinstance(module, Attention):
self.assertIsInstance(module.processor, XFormersAttnProcessor)
@unittest.skipIf(torch_device != "cuda" or not is_xformers_available(), "This test is supposed to run on GPU")
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_lora_save_load_with_xformers(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**pipeline_components)
@@ -2285,8 +2270,8 @@ class LoraIntegrationTests(unittest.TestCase):
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.enable_model_cpu_offload()
start_time = time.time()
@@ -2299,13 +2284,13 @@ class LoraIntegrationTests(unittest.TestCase):
del pipe
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
pipe.fuse_lora()
pipe.enable_model_cpu_offload()
generator = torch.Generator().manual_seed(0)
start_time = time.time()
generator = torch.Generator().manual_seed(0)
for _ in range(3):
pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
+26 -92
View File
@@ -46,7 +46,6 @@ from diffusers.utils.testing_utils import (
floats_tensor,
load_image,
nightly,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
slow,
@@ -107,9 +106,8 @@ class PeftLoraLoaderMixinTests:
unet_kwargs = None
vae_kwargs = None
def get_dummy_components(self, scheduler_cls=None, lora_alpha=None):
def get_dummy_components(self, scheduler_cls=None):
scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
lora_alpha = 4 if lora_alpha is None else lora_alpha
torch.manual_seed(0)
unet = UNet2DConditionModel(**self.unet_kwargs)
@@ -124,14 +122,11 @@ class PeftLoraLoaderMixinTests:
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
text_lora_config = LoraConfig(
r=4,
lora_alpha=lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
init_lora_weights=False,
r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False
)
unet_lora_config = LoraConfig(
r=4, lora_alpha=lora_alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
)
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
@@ -718,68 +713,6 @@ class PeftLoraLoaderMixinTests:
"Fused lora should change the output",
)
def test_if_lora_alpha_is_correctly_parsed(self):
lora_alpha = 8
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.unet.add_adapter(unet_lora_config)
pipe.text_encoder.add_adapter(text_lora_config)
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
# Inference works?
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname:
unet_state_dict = get_peft_model_state_dict(pipe.unet)
unet_lora_config = pipe.unet.peft_config["default"]
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
text_encoder_lora_config = pipe.text_encoder.peft_config["default"]
if self.has_two_text_encoders:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
text_encoder_2_lora_config = pipe.text_encoder_2.peft_config["default"]
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=unet_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_lora_config,
text_encoder_2_lora_config=text_encoder_2_lora_config,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=unet_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
unet_lora_config=unet_lora_config,
text_encoder_lora_config=text_encoder_lora_config,
)
loaded_pipe = self.pipeline_class(**components)
loaded_pipe.load_lora_weights(tmpdirname)
# Inference works?
_ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images
assert (
loaded_pipe.unet.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for UNet."
assert (
loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for text encoder."
if self.has_two_text_encoders:
assert (
loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha
), "LoRA alpha not correctly loaded for text encoder 2."
def test_simple_inference_with_text_unet_lora_unfused(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
@@ -1780,7 +1713,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
release_memory(pipe)
def test_sdxl_1_0_lora(self):
generator = torch.Generator("cpu").manual_seed(0)
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.enable_model_cpu_offload()
@@ -1803,7 +1736,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
generator = torch.Generator("cpu").manual_seed(0)
generator = torch.Generator().manual_seed(0)
lora_model_id = "latent-consistency/lcm-lora-sdxl"
@@ -1820,8 +1753,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
image_np = pipe.image_processor.pil_to_numpy(image)
expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
max_diff = numpy_cosine_similarity_distance(image_np.flatten(), expected_image_np.flatten())
assert max_diff < 1e-4
self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
pipe.unload_lora_weights()
@@ -1832,7 +1764,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
pipe.to("cuda")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
generator = torch.Generator("cpu").manual_seed(0)
generator = torch.Generator().manual_seed(0)
lora_model_id = "latent-consistency/lcm-lora-sdv1-5"
pipe.load_lora_weights(lora_model_id)
@@ -1848,8 +1780,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
image_np = pipe.image_processor.pil_to_numpy(image)
expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
max_diff = numpy_cosine_similarity_distance(image_np.flatten(), expected_image_np.flatten())
assert max_diff < 1e-4
self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
pipe.unload_lora_weights()
@@ -1864,7 +1795,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape.png"
)
generator = torch.Generator("cpu").manual_seed(0)
generator = torch.Generator().manual_seed(0)
lora_model_id = "latent-consistency/lcm-lora-sdv1-5"
pipe.load_lora_weights(lora_model_id)
@@ -1885,8 +1816,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
image_np = pipe.image_processor.pil_to_numpy(image)
expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
max_diff = numpy_cosine_similarity_distance(image_np.flatten(), expected_image_np.flatten())
assert max_diff < 1e-4
self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
pipe.unload_lora_weights()
@@ -1919,7 +1849,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
release_memory(pipe)
def test_sdxl_1_0_lora_unfusion(self):
generator = torch.Generator("cpu").manual_seed(0)
generator = torch.Generator().manual_seed(0)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
@@ -1930,16 +1860,16 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
pipe.enable_model_cpu_offload()
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=3
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images_with_fusion = images.flatten()
images_with_fusion = images[0, -3:, -3:, -1].flatten()
pipe.unfuse_lora()
generator = torch.Generator("cpu").manual_seed(0)
generator = torch.Generator().manual_seed(0)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=3
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
images_without_fusion = images.flatten()
images_without_fusion = images[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(images_with_fusion, images_without_fusion, atol=1e-3))
release_memory(pipe)
@@ -1983,8 +1913,10 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
)
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
start_time = time.time()
@@ -1997,17 +1929,19 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
del pipe
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
)
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.bfloat16)
pipe.fuse_lora()
# We need to unload the lora weights since in the previous API `fuse_lora` led to lora weights being
# silently deleted - otherwise this will CPU OOM
pipe.unload_lora_weights()
pipe.enable_model_cpu_offload()
generator = torch.Generator().manual_seed(0)
start_time = time.time()
generator = torch.Generator().manual_seed(0)
for _ in range(3):
pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2