Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b08fc2d9d5 | |||
| a031abdc89 |
@@ -198,8 +198,6 @@
|
||||
title: Outputs
|
||||
title: Main Classes
|
||||
- sections:
|
||||
- local: api/loaders/ip_adapter
|
||||
title: IP-Adapter
|
||||
- local: api/loaders/lora
|
||||
title: LoRA
|
||||
- local: api/loaders/single_file
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# IP-Adapter
|
||||
|
||||
[IP-Adapter](https://hf.co/papers/2308.06721) is a lightweight adapter that enables prompting a diffusion model with an image. This method decouples the cross-attention layers of the image and text features. The image features are generated from an image encoder. Files generated from IP-Adapter are only ~100MBs.
|
||||
|
||||
<Tip>
|
||||
|
||||
Learn how to load an IP-Adapter checkpoint and image in the [IP-Adapter](../../using-diffusers/loading_adapters#ip-adapter) loading guide.
|
||||
|
||||
</Tip>
|
||||
|
||||
## IPAdapterMixin
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
|
||||
@@ -179,7 +179,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--dataloader_num_workers=8 \
|
||||
--resolution=512 \
|
||||
--resolution=512
|
||||
--center_crop \
|
||||
--random_flip \
|
||||
--train_batch_size=1 \
|
||||
@@ -214,4 +214,4 @@ image = pipeline("A pokemon with blue eyes").images[0]
|
||||
Congratulations on training a new model with LoRA! To learn more about how to use your new model, the following guides may be helpful:
|
||||
|
||||
- Learn how to [load different LoRA formats](../using-diffusers/loading_adapters#LoRA) trained using community trainers like Kohya and TheLastBen.
|
||||
- Learn how to use and [combine multiple LoRA's](../tutorials/using_peft_for_inference) with PEFT for inference.
|
||||
- Learn how to use and [combine multiple LoRA's](../tutorials/using_peft_for_inference) with PEFT for inference.
|
||||
@@ -112,7 +112,7 @@ def save_model_card(
|
||||
repo_folder=None,
|
||||
vae_path=None,
|
||||
):
|
||||
img_str = "widget:\n"
|
||||
img_str = "widget:\n" if images else ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"""
|
||||
@@ -121,10 +121,6 @@ def save_model_card(
|
||||
url:
|
||||
"image_{i}.png"
|
||||
"""
|
||||
if not images:
|
||||
img_str += f"""
|
||||
- text: '{instance_prompt}'
|
||||
"""
|
||||
|
||||
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
|
||||
diffusers_imports_pivotal = ""
|
||||
@@ -161,6 +157,8 @@ tags:
|
||||
base_model: {base_model}
|
||||
instance_prompt: {instance_prompt}
|
||||
license: openrail++
|
||||
widget:
|
||||
- text: '{validation_prompt if validation_prompt else instance_prompt}'
|
||||
---
|
||||
"""
|
||||
|
||||
@@ -2012,42 +2010,43 @@ def main(args):
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
images = [
|
||||
|
||||
@@ -156,7 +156,7 @@ class WebdatasetFilter:
|
||||
return False
|
||||
|
||||
|
||||
class SDText2ImageDataset:
|
||||
class Text2ImageDataset:
|
||||
def __init__(
|
||||
self,
|
||||
train_shards_path_or_url: Union[str, List[str]],
|
||||
@@ -359,43 +359,19 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
|
||||
|
||||
|
||||
# Compare LCMScheduler.step, Step 4
|
||||
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
if prediction_type == "epsilon":
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
pred_x_0 = (sample - sigmas * model_output) / alphas
|
||||
elif prediction_type == "sample":
|
||||
pred_x_0 = model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_x_0 = alphas * sample - sigmas * model_output
|
||||
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
|
||||
|
||||
return pred_x_0
|
||||
|
||||
|
||||
# Based on step 4 in DDIMScheduler.step
|
||||
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
pred_epsilon = model_output
|
||||
elif prediction_type == "sample":
|
||||
pred_epsilon = (sample - alphas * model_output) / sigmas
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_epsilon = alphas * model_output + sigmas * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_epsilon
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
@@ -859,35 +835,34 @@ def main(args):
|
||||
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
|
||||
# The scheduler calculates the alpha and sigma schedule for us
|
||||
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
|
||||
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
|
||||
# Initialize the DDIM ODE solver for distillation.
|
||||
solver = DDIMSolver(
|
||||
noise_scheduler.alphas_cumprod.numpy(),
|
||||
timesteps=noise_scheduler.config.num_train_timesteps,
|
||||
ddim_timesteps=args.num_ddim_timesteps,
|
||||
)
|
||||
|
||||
# 2. Load tokenizers from SD 1.X/2.X checkpoint.
|
||||
# 2. Load tokenizers from SD-XL checkpoint.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
|
||||
)
|
||||
|
||||
# 3. Load text encoders from SD 1.X/2.X checkpoint.
|
||||
# 3. Load text encoders from SD-1.5 checkpoint.
|
||||
# import correct text encoder classes
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# 4. Load VAE from SD 1.X/2.X checkpoint
|
||||
# 4. Load VAE from SD-XL checkpoint (or more stable VAE)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_teacher_model,
|
||||
subfolder="vae",
|
||||
revision=args.teacher_revision,
|
||||
)
|
||||
|
||||
# 5. Load teacher U-Net from SD 1.X/2.X checkpoint
|
||||
# 5. Load teacher U-Net from SD-XL checkpoint
|
||||
teacher_unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
|
||||
)
|
||||
@@ -897,7 +872,7 @@ def main(args):
|
||||
text_encoder.requires_grad_(False)
|
||||
teacher_unet.requires_grad_(False)
|
||||
|
||||
# 7. Create online student U-Net.
|
||||
# 7. Create online (`unet`) student U-Nets.
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
|
||||
)
|
||||
@@ -960,7 +935,6 @@ def main(args):
|
||||
# Also move the alpha and sigma noise schedules to accelerator.device.
|
||||
alpha_schedule = alpha_schedule.to(accelerator.device)
|
||||
sigma_schedule = sigma_schedule.to(accelerator.device)
|
||||
# Move the ODE solver to accelerator.device.
|
||||
solver = solver.to(accelerator.device)
|
||||
|
||||
# 10. Handle saving and loading of checkpoints
|
||||
@@ -1037,14 +1011,13 @@ def main(args):
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# 13. Dataset creation and data processing
|
||||
# Here, we compute not just the text embeddings but also the additional embeddings
|
||||
# needed for the SD XL UNet to operate.
|
||||
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
|
||||
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
dataset = SDText2ImageDataset(
|
||||
dataset = Text2ImageDataset(
|
||||
train_shards_path_or_url=args.train_shards_path_or_url,
|
||||
num_train_examples=args.max_train_samples,
|
||||
per_gpu_batch_size=args.train_batch_size,
|
||||
@@ -1064,7 +1037,6 @@ def main(args):
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# 14. LR Scheduler creation
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
|
||||
@@ -1079,7 +1051,6 @@ def main(args):
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
# 15. Prepare for training
|
||||
# Prepare everything with our `accelerator`.
|
||||
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
|
||||
|
||||
@@ -1101,7 +1072,7 @@ def main(args):
|
||||
).input_ids.to(accelerator.device)
|
||||
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
|
||||
|
||||
# 16. Train!
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
@@ -1152,7 +1123,6 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# 1. Load and process the image and text conditioning
|
||||
image, text = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
@@ -1170,37 +1140,37 @@ def main(args):
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
latents = latents.to(weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
|
||||
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
|
||||
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
|
||||
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
|
||||
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
|
||||
start_timesteps = solver.ddim_timesteps[index]
|
||||
timesteps = start_timesteps - topk
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noise = torch.randn_like(latents)
|
||||
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
|
||||
|
||||
# 5. Sample a random guidance scale w from U[w_min, w_max]
|
||||
# Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# 6. Prepare prompt embeds and unet_added_conditions
|
||||
# 20.4.8. Prepare prompt embeds and unet_added_conditions
|
||||
prompt_embeds = encoded_text.pop("prompt_embeds")
|
||||
|
||||
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
|
||||
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
|
||||
noise_pred = unet(
|
||||
noisy_model_input,
|
||||
start_timesteps,
|
||||
@@ -1209,7 +1179,7 @@ def main(args):
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
noise_pred,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1220,27 +1190,17 @@ def main(args):
|
||||
|
||||
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
|
||||
|
||||
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
|
||||
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
|
||||
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and conditional embedding
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=prompt_embeds.to(weight_dtype),
|
||||
).sample
|
||||
cond_pred_x0 = get_predicted_original_sample(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
cond_pred_noise = get_predicted_noise(
|
||||
cond_pred_x0 = predicted_origin(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1249,21 +1209,13 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and unconditional embedding
|
||||
uncond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
|
||||
).sample
|
||||
uncond_pred_x0 = get_predicted_original_sample(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
uncond_pred_noise = get_predicted_noise(
|
||||
uncond_pred_x0 = predicted_origin(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1272,17 +1224,12 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
|
||||
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
|
||||
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
|
||||
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
|
||||
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
|
||||
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
|
||||
# augmented PF-ODE trajectory (solving backward in time)
|
||||
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
|
||||
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# Note that we do not use a separate target network for LCM-LoRA distillation.
|
||||
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
target_noise_pred = unet(
|
||||
@@ -1291,7 +1238,7 @@ def main(args):
|
||||
timestep_cond=None,
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
).sample
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
target_noise_pred,
|
||||
timesteps,
|
||||
x_prev,
|
||||
@@ -1301,7 +1248,7 @@ def main(args):
|
||||
)
|
||||
target = c_skip * x_prev + c_out * pred_x_0
|
||||
|
||||
# 10. Calculate loss
|
||||
# 20.4.13. Calculate loss
|
||||
if args.loss_type == "l2":
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
elif args.loss_type == "huber":
|
||||
@@ -1309,7 +1256,7 @@ def main(args):
|
||||
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
|
||||
)
|
||||
|
||||
# 11. Backpropagate on the online student model (`unet`)
|
||||
# 20.4.14. Backpropagate on the online student model (`unet`)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
||||
|
||||
@@ -162,7 +162,7 @@ class WebdatasetFilter:
|
||||
return False
|
||||
|
||||
|
||||
class SDXLText2ImageDataset:
|
||||
class Text2ImageDataset:
|
||||
def __init__(
|
||||
self,
|
||||
train_shards_path_or_url: Union[str, List[str]],
|
||||
@@ -346,43 +346,19 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
|
||||
|
||||
|
||||
# Compare LCMScheduler.step, Step 4
|
||||
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
if prediction_type == "epsilon":
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
pred_x_0 = (sample - sigmas * model_output) / alphas
|
||||
elif prediction_type == "sample":
|
||||
pred_x_0 = model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_x_0 = alphas * sample - sigmas * model_output
|
||||
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
|
||||
|
||||
return pred_x_0
|
||||
|
||||
|
||||
# Based on step 4 in DDIMScheduler.step
|
||||
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
pred_epsilon = model_output
|
||||
elif prediction_type == "sample":
|
||||
pred_epsilon = (sample - alphas * model_output) / sigmas
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_epsilon = alphas * model_output + sigmas * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_epsilon
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
@@ -854,10 +830,9 @@ def main(args):
|
||||
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
|
||||
# The scheduler calculates the alpha and sigma schedule for us
|
||||
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
|
||||
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
|
||||
# Initialize the DDIM ODE solver for distillation.
|
||||
solver = DDIMSolver(
|
||||
noise_scheduler.alphas_cumprod.numpy(),
|
||||
timesteps=noise_scheduler.config.num_train_timesteps,
|
||||
@@ -911,7 +886,7 @@ def main(args):
|
||||
text_encoder_two.requires_grad_(False)
|
||||
teacher_unet.requires_grad_(False)
|
||||
|
||||
# 7. Create online student U-Net.
|
||||
# 7. Create online (`unet`) student U-Nets.
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
|
||||
)
|
||||
@@ -975,7 +950,6 @@ def main(args):
|
||||
# Also move the alpha and sigma noise schedules to accelerator.device.
|
||||
alpha_schedule = alpha_schedule.to(accelerator.device)
|
||||
sigma_schedule = sigma_schedule.to(accelerator.device)
|
||||
# Move the ODE solver to accelerator.device.
|
||||
solver = solver.to(accelerator.device)
|
||||
|
||||
# 10. Handle saving and loading of checkpoints
|
||||
@@ -1083,7 +1057,7 @@ def main(args):
|
||||
|
||||
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
|
||||
|
||||
dataset = SDXLText2ImageDataset(
|
||||
dataset = Text2ImageDataset(
|
||||
train_shards_path_or_url=args.train_shards_path_or_url,
|
||||
num_train_examples=args.max_train_samples,
|
||||
per_gpu_batch_size=args.train_batch_size,
|
||||
@@ -1201,7 +1175,6 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates)
|
||||
image, text, orig_size, crop_coords = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
@@ -1223,37 +1196,37 @@ def main(args):
|
||||
latents = latents * vae.config.scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
latents = latents.to(weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
|
||||
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
|
||||
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
|
||||
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
|
||||
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
|
||||
start_timesteps = solver.ddim_timesteps[index]
|
||||
timesteps = start_timesteps - topk
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noise = torch.randn_like(latents)
|
||||
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
|
||||
|
||||
# 5. Sample a random guidance scale w from U[w_min, w_max]
|
||||
# Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# 6. Prepare prompt embeds and unet_added_conditions
|
||||
# 20.4.8. Prepare prompt embeds and unet_added_conditions
|
||||
prompt_embeds = encoded_text.pop("prompt_embeds")
|
||||
|
||||
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
|
||||
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
|
||||
noise_pred = unet(
|
||||
noisy_model_input,
|
||||
start_timesteps,
|
||||
@@ -1262,7 +1235,7 @@ def main(args):
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
noise_pred,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1273,28 +1246,18 @@ def main(args):
|
||||
|
||||
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
|
||||
|
||||
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
|
||||
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
|
||||
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and conditional embedding
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=prompt_embeds.to(weight_dtype),
|
||||
added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
|
||||
).sample
|
||||
cond_pred_x0 = get_predicted_original_sample(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
cond_pred_noise = get_predicted_noise(
|
||||
cond_pred_x0 = predicted_origin(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1303,7 +1266,7 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and unconditional embedding
|
||||
uncond_added_conditions = copy.deepcopy(encoded_text)
|
||||
uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
|
||||
uncond_teacher_output = teacher_unet(
|
||||
@@ -1312,15 +1275,7 @@ def main(args):
|
||||
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
|
||||
added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
|
||||
).sample
|
||||
uncond_pred_x0 = get_predicted_original_sample(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
uncond_pred_noise = get_predicted_noise(
|
||||
uncond_pred_x0 = predicted_origin(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1329,17 +1284,12 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
|
||||
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
|
||||
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
|
||||
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
|
||||
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
|
||||
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
|
||||
# augmented PF-ODE trajectory (solving backward in time)
|
||||
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
|
||||
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# Note that we do not use a separate target network for LCM-LoRA distillation.
|
||||
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
|
||||
target_noise_pred = unet(
|
||||
@@ -1349,7 +1299,7 @@ def main(args):
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
target_noise_pred,
|
||||
timesteps,
|
||||
x_prev,
|
||||
@@ -1359,7 +1309,7 @@ def main(args):
|
||||
)
|
||||
target = c_skip * x_prev + c_out * pred_x_0
|
||||
|
||||
# 10. Calculate loss
|
||||
# 20.4.13. Calculate loss
|
||||
if args.loss_type == "l2":
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
elif args.loss_type == "huber":
|
||||
@@ -1367,7 +1317,7 @@ def main(args):
|
||||
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
|
||||
)
|
||||
|
||||
# 11. Backpropagate on the online student model (`unet`)
|
||||
# 20.4.14. Backpropagate on the online student model (`unet`)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
||||
|
||||
@@ -138,7 +138,7 @@ class WebdatasetFilter:
|
||||
return False
|
||||
|
||||
|
||||
class SDText2ImageDataset:
|
||||
class Text2ImageDataset:
|
||||
def __init__(
|
||||
self,
|
||||
train_shards_path_or_url: Union[str, List[str]],
|
||||
@@ -336,43 +336,19 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
|
||||
|
||||
|
||||
# Compare LCMScheduler.step, Step 4
|
||||
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
if prediction_type == "epsilon":
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
pred_x_0 = (sample - sigmas * model_output) / alphas
|
||||
elif prediction_type == "sample":
|
||||
pred_x_0 = model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_x_0 = alphas * sample - sigmas * model_output
|
||||
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
|
||||
|
||||
return pred_x_0
|
||||
|
||||
|
||||
# Based on step 4 in DDIMScheduler.step
|
||||
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
pred_epsilon = model_output
|
||||
elif prediction_type == "sample":
|
||||
pred_epsilon = (sample - alphas * model_output) / sigmas
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_epsilon = alphas * model_output + sigmas * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_epsilon
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
@@ -847,35 +823,34 @@ def main(args):
|
||||
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
|
||||
# The scheduler calculates the alpha and sigma schedule for us
|
||||
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
|
||||
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
|
||||
# Initialize the DDIM ODE solver for distillation.
|
||||
solver = DDIMSolver(
|
||||
noise_scheduler.alphas_cumprod.numpy(),
|
||||
timesteps=noise_scheduler.config.num_train_timesteps,
|
||||
ddim_timesteps=args.num_ddim_timesteps,
|
||||
)
|
||||
|
||||
# 2. Load tokenizers from SD 1.X/2.X checkpoint.
|
||||
# 2. Load tokenizers from SD-XL checkpoint.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
|
||||
)
|
||||
|
||||
# 3. Load text encoders from SD 1.X/2.X checkpoint.
|
||||
# 3. Load text encoders from SD-1.5 checkpoint.
|
||||
# import correct text encoder classes
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# 4. Load VAE from SD 1.X/2.X checkpoint
|
||||
# 4. Load VAE from SD-XL checkpoint (or more stable VAE)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_teacher_model,
|
||||
subfolder="vae",
|
||||
revision=args.teacher_revision,
|
||||
)
|
||||
|
||||
# 5. Load teacher U-Net from SD 1.X/2.X checkpoint
|
||||
# 5. Load teacher U-Net from SD-XL checkpoint
|
||||
teacher_unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
|
||||
)
|
||||
@@ -885,7 +860,7 @@ def main(args):
|
||||
text_encoder.requires_grad_(False)
|
||||
teacher_unet.requires_grad_(False)
|
||||
|
||||
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
|
||||
if teacher_unet.config.time_cond_proj_dim is None:
|
||||
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
|
||||
@@ -894,8 +869,8 @@ def main(args):
|
||||
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
|
||||
unet.train()
|
||||
|
||||
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from (online) unet
|
||||
# 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from unet
|
||||
target_unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
target_unet.load_state_dict(unet.state_dict())
|
||||
target_unet.train()
|
||||
@@ -912,7 +887,7 @@ def main(args):
|
||||
f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
|
||||
)
|
||||
|
||||
# 9. Handle mixed precision and device placement
|
||||
# 10. Handle mixed precision and device placement
|
||||
# For mixed precision training we cast all non-trainable weigths to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
@@ -939,7 +914,7 @@ def main(args):
|
||||
sigma_schedule = sigma_schedule.to(accelerator.device)
|
||||
solver = solver.to(accelerator.device)
|
||||
|
||||
# 10. Handle saving and loading of checkpoints
|
||||
# 11. Handle saving and loading of checkpoints
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -973,7 +948,7 @@ def main(args):
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
# 11. Enable optimizations
|
||||
# 12. Enable optimizations
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
@@ -1019,14 +994,13 @@ def main(args):
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# 13. Dataset creation and data processing
|
||||
# Here, we compute not just the text embeddings but also the additional embeddings
|
||||
# needed for the SD XL UNet to operate.
|
||||
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
|
||||
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
dataset = SDText2ImageDataset(
|
||||
dataset = Text2ImageDataset(
|
||||
train_shards_path_or_url=args.train_shards_path_or_url,
|
||||
num_train_examples=args.max_train_samples,
|
||||
per_gpu_batch_size=args.train_batch_size,
|
||||
@@ -1046,7 +1020,6 @@ def main(args):
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# 14. LR Scheduler creation
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
|
||||
@@ -1061,7 +1034,6 @@ def main(args):
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
# 15. Prepare for training
|
||||
# Prepare everything with our `accelerator`.
|
||||
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
|
||||
|
||||
@@ -1083,7 +1055,7 @@ def main(args):
|
||||
).input_ids.to(accelerator.device)
|
||||
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
|
||||
|
||||
# 16. Train!
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
@@ -1134,7 +1106,6 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# 1. Load and process the image and text conditioning
|
||||
image, text = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
@@ -1152,28 +1123,29 @@ def main(args):
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
latents = latents.to(weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
|
||||
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
|
||||
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
|
||||
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
|
||||
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
|
||||
start_timesteps = solver.ddim_timesteps[index]
|
||||
timesteps = start_timesteps - topk
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noise = torch.randn_like(latents)
|
||||
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
|
||||
|
||||
# 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
@@ -1181,10 +1153,10 @@ def main(args):
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# 6. Prepare prompt embeds and unet_added_conditions
|
||||
# 20.4.8. Prepare prompt embeds and unet_added_conditions
|
||||
prompt_embeds = encoded_text.pop("prompt_embeds")
|
||||
|
||||
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
|
||||
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
|
||||
noise_pred = unet(
|
||||
noisy_model_input,
|
||||
start_timesteps,
|
||||
@@ -1193,7 +1165,7 @@ def main(args):
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
noise_pred,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1204,27 +1176,17 @@ def main(args):
|
||||
|
||||
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
|
||||
|
||||
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
|
||||
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
|
||||
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and conditional embedding
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=prompt_embeds.to(weight_dtype),
|
||||
).sample
|
||||
cond_pred_x0 = get_predicted_original_sample(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
cond_pred_noise = get_predicted_noise(
|
||||
cond_pred_x0 = predicted_origin(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1233,21 +1195,13 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and unconditional embedding
|
||||
uncond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
|
||||
).sample
|
||||
uncond_pred_x0 = get_predicted_original_sample(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
uncond_pred_noise = get_predicted_noise(
|
||||
uncond_pred_x0 = predicted_origin(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1256,16 +1210,12 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
|
||||
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
|
||||
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
|
||||
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
|
||||
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
|
||||
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
|
||||
# augmented PF-ODE trajectory (solving backward in time)
|
||||
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
|
||||
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
target_noise_pred = target_unet(
|
||||
@@ -1274,7 +1224,7 @@ def main(args):
|
||||
timestep_cond=w_embedding,
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
).sample
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
target_noise_pred,
|
||||
timesteps,
|
||||
x_prev,
|
||||
@@ -1284,7 +1234,7 @@ def main(args):
|
||||
)
|
||||
target = c_skip * x_prev + c_out * pred_x_0
|
||||
|
||||
# 10. Calculate loss
|
||||
# 20.4.13. Calculate loss
|
||||
if args.loss_type == "l2":
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
elif args.loss_type == "huber":
|
||||
@@ -1292,7 +1242,7 @@ def main(args):
|
||||
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
|
||||
)
|
||||
|
||||
# 11. Backpropagate on the online student model (`unet`)
|
||||
# 20.4.14. Backpropagate on the online student model (`unet`)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
||||
@@ -1302,7 +1252,7 @@ def main(args):
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
# 12. Make EMA update to target student model parameters (`target_unet`)
|
||||
# 20.4.15. Make EMA update to target student model parameters
|
||||
update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
@@ -144,7 +144,7 @@ class WebdatasetFilter:
|
||||
return False
|
||||
|
||||
|
||||
class SDXLText2ImageDataset:
|
||||
class Text2ImageDataset:
|
||||
def __init__(
|
||||
self,
|
||||
train_shards_path_or_url: Union[str, List[str]],
|
||||
@@ -324,43 +324,19 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
|
||||
|
||||
|
||||
# Compare LCMScheduler.step, Step 4
|
||||
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
if prediction_type == "epsilon":
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
pred_x_0 = (sample - sigmas * model_output) / alphas
|
||||
elif prediction_type == "sample":
|
||||
pred_x_0 = model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_x_0 = alphas * sample - sigmas * model_output
|
||||
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
|
||||
|
||||
return pred_x_0
|
||||
|
||||
|
||||
# Based on step 4 in DDIMScheduler.step
|
||||
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
pred_epsilon = model_output
|
||||
elif prediction_type == "sample":
|
||||
pred_epsilon = (sample - alphas * model_output) / sigmas
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_epsilon = alphas * model_output + sigmas * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_epsilon
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
@@ -887,10 +863,9 @@ def main(args):
|
||||
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
|
||||
# The scheduler calculates the alpha and sigma schedule for us
|
||||
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
|
||||
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
|
||||
# Initialize the DDIM ODE solver for distillation.
|
||||
solver = DDIMSolver(
|
||||
noise_scheduler.alphas_cumprod.numpy(),
|
||||
timesteps=noise_scheduler.config.num_train_timesteps,
|
||||
@@ -944,7 +919,7 @@ def main(args):
|
||||
text_encoder_two.requires_grad_(False)
|
||||
teacher_unet.requires_grad_(False)
|
||||
|
||||
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
|
||||
if teacher_unet.config.time_cond_proj_dim is None:
|
||||
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
|
||||
@@ -953,8 +928,8 @@ def main(args):
|
||||
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
|
||||
unet.train()
|
||||
|
||||
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from (online) unet
|
||||
# 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from unet
|
||||
target_unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
target_unet.load_state_dict(unet.state_dict())
|
||||
target_unet.train()
|
||||
@@ -996,7 +971,6 @@ def main(args):
|
||||
# Also move the alpha and sigma noise schedules to accelerator.device.
|
||||
alpha_schedule = alpha_schedule.to(accelerator.device)
|
||||
sigma_schedule = sigma_schedule.to(accelerator.device)
|
||||
# Move the ODE solver to accelerator.device.
|
||||
solver = solver.to(accelerator.device)
|
||||
|
||||
# 10. Handle saving and loading of checkpoints
|
||||
@@ -1110,7 +1084,7 @@ def main(args):
|
||||
|
||||
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
|
||||
|
||||
dataset = SDXLText2ImageDataset(
|
||||
dataset = Text2ImageDataset(
|
||||
train_shards_path_or_url=args.train_shards_path_or_url,
|
||||
num_train_examples=args.max_train_samples,
|
||||
per_gpu_batch_size=args.train_batch_size,
|
||||
@@ -1228,7 +1202,6 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates)
|
||||
image, text, orig_size, crop_coords = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
@@ -1250,39 +1223,38 @@ def main(args):
|
||||
latents = latents * vae.config.scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
latents = latents.to(weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
|
||||
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
|
||||
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
|
||||
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
|
||||
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
|
||||
start_timesteps = solver.ddim_timesteps[index]
|
||||
timesteps = start_timesteps - topk
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noise = torch.randn_like(latents)
|
||||
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
|
||||
|
||||
# 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
# Move to U-Net device and dtype
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# 6. Prepare prompt embeds and unet_added_conditions
|
||||
# 20.4.8. Prepare prompt embeds and unet_added_conditions
|
||||
prompt_embeds = encoded_text.pop("prompt_embeds")
|
||||
|
||||
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
|
||||
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
|
||||
noise_pred = unet(
|
||||
noisy_model_input,
|
||||
start_timesteps,
|
||||
@@ -1291,7 +1263,7 @@ def main(args):
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
noise_pred,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1302,28 +1274,18 @@ def main(args):
|
||||
|
||||
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
|
||||
|
||||
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
|
||||
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
|
||||
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and conditional embedding
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=prompt_embeds.to(weight_dtype),
|
||||
added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
|
||||
).sample
|
||||
cond_pred_x0 = get_predicted_original_sample(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
cond_pred_noise = get_predicted_noise(
|
||||
cond_pred_x0 = predicted_origin(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1332,7 +1294,7 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and unconditional embedding
|
||||
uncond_added_conditions = copy.deepcopy(encoded_text)
|
||||
uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
|
||||
uncond_teacher_output = teacher_unet(
|
||||
@@ -1341,15 +1303,7 @@ def main(args):
|
||||
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
|
||||
added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
|
||||
).sample
|
||||
uncond_pred_x0 = get_predicted_original_sample(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
uncond_pred_noise = get_predicted_noise(
|
||||
uncond_pred_x0 = predicted_origin(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1358,16 +1312,12 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
|
||||
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
|
||||
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
|
||||
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
|
||||
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
|
||||
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
|
||||
# augmented PF-ODE trajectory (solving backward in time)
|
||||
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
|
||||
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
target_noise_pred = target_unet(
|
||||
@@ -1377,7 +1327,7 @@ def main(args):
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
target_noise_pred,
|
||||
timesteps,
|
||||
x_prev,
|
||||
@@ -1387,7 +1337,7 @@ def main(args):
|
||||
)
|
||||
target = c_skip * x_prev + c_out * pred_x_0
|
||||
|
||||
# 10. Calculate loss
|
||||
# 20.4.13. Calculate loss
|
||||
if args.loss_type == "l2":
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
elif args.loss_type == "huber":
|
||||
@@ -1395,7 +1345,7 @@ def main(args):
|
||||
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
|
||||
)
|
||||
|
||||
# 11. Backpropagate on the online student model (`unet`)
|
||||
# 20.4.14. Backpropagate on the online student model (`unet`)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
||||
@@ -1405,7 +1355,7 @@ def main(args):
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
# 12. Make EMA update to target student model parameters (`target_unet`)
|
||||
# 20.4.15. Make EMA update to target student model parameters
|
||||
update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
@@ -880,16 +880,11 @@ def main(args):
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_lora_layers_to_save = None
|
||||
|
||||
unet_lora_config = None
|
||||
text_encoder_lora_config = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
unet_lora_config = model.peft_config["default"]
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
text_encoder_lora_config = model.peft_config["default"]
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -900,8 +895,6 @@ def main(args):
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_lora_config,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
@@ -918,12 +911,10 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(
|
||||
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
|
||||
)
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_, config=metadata
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
@@ -1324,22 +1315,17 @@ def main(args):
|
||||
unet = unet.to(torch.float32)
|
||||
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unet_lora_config = unet.peft_config["default"]
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
|
||||
text_encoder_lora_config = text_encoder.peft_config["default"]
|
||||
else:
|
||||
text_encoder_state_dict = None
|
||||
text_encoder_lora_config = None
|
||||
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
text_encoder_lora_layers=text_encoder_state_dict,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_lora_config,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
|
||||
@@ -1033,20 +1033,13 @@ def main(args):
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
unet_lora_config = None
|
||||
text_encoder_one_lora_config = None
|
||||
text_encoder_two_lora_config = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
unet_lora_config = model.peft_config["default"]
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
text_encoder_one_lora_config = model.peft_config["default"]
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
text_encoder_two_lora_config = model.peft_config["default"]
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1058,9 +1051,6 @@ def main(args):
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_one_lora_config,
|
||||
text_encoder_2_lora_config=text_encoder_two_lora_config,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
@@ -1080,19 +1070,17 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(
|
||||
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
|
||||
)
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
@@ -1628,29 +1616,21 @@ def main(args):
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = get_peft_model_state_dict(unet)
|
||||
unet_lora_config = unet.peft_config["default"]
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
text_encoder_one_lora_config = text_encoder_one.peft_config["default"]
|
||||
text_encoder_two_lora_config = text_encoder_two.peft_config["default"]
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
text_encoder_one_lora_config = None
|
||||
text_encoder_two_lora_config = None
|
||||
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_one_lora_config,
|
||||
text_encoder_2_lora_config=text_encoder_two_lora_config,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
|
||||
@@ -833,12 +833,10 @@ def main():
|
||||
accelerator.save_state(save_path)
|
||||
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unet_lora_config = unet.peft_config["default"]
|
||||
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=save_path,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
unet_lora_config=unet_lora_config,
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
@@ -900,12 +898,10 @@ def main():
|
||||
unet = unet.to(torch.float32)
|
||||
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unet_lora_config = unet.peft_config["default"]
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
safe_serialization=True,
|
||||
unet_lora_config=unet_lora_config,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
|
||||
@@ -682,20 +682,13 @@ def main(args):
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
unet_lora_config = None
|
||||
text_encoder_one_lora_config = None
|
||||
text_encoder_two_lora_config = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
unet_lora_config = model.peft_config["default"]
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
text_encoder_one_lora_config = model.peft_config["default"]
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
text_encoder_two_lora_config = model.peft_config["default"]
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -707,9 +700,6 @@ def main(args):
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_one_lora_config,
|
||||
text_encoder_2_lora_config=text_encoder_two_lora_config,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
@@ -729,19 +719,17 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alphas, metadata = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(
|
||||
lora_state_dict, network_alphas=network_alphas, unet=unet_, config=metadata
|
||||
)
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, config=metadata
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, config=metadata
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
@@ -1206,7 +1194,6 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unet_lora_config = unet.peft_config["default"]
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
@@ -1214,23 +1201,15 @@ def main(args):
|
||||
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
|
||||
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
|
||||
|
||||
text_encoder_one_lora_config = text_encoder_one.peft_config["default"]
|
||||
text_encoder_two_lora_config = text_encoder_two.peft_config["default"]
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
text_encoder_one_lora_config = None
|
||||
text_encoder_two_lora_config = None
|
||||
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_one_lora_config,
|
||||
text_encoder_2_lora_config=text_encoder_two_lora_config,
|
||||
)
|
||||
|
||||
del unet
|
||||
|
||||
+27
-144
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
@@ -19,7 +18,6 @@ from typing import Callable, Dict, List, Optional, Union
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
@@ -104,7 +102,7 @@ class LoraLoaderMixin:
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
@@ -115,7 +113,6 @@ class LoraLoaderMixin:
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
config=metadata,
|
||||
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
@@ -127,7 +124,6 @@ class LoraLoaderMixin:
|
||||
text_encoder=getattr(self, self.text_encoder_name)
|
||||
if not hasattr(self, "text_encoder")
|
||||
else self.text_encoder,
|
||||
config=metadata,
|
||||
lora_scale=self.lora_scale,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
@@ -222,7 +218,6 @@ class LoraLoaderMixin:
|
||||
}
|
||||
|
||||
model_file = None
|
||||
metadata = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
@@ -234,9 +229,7 @@ class LoraLoaderMixin:
|
||||
# determine `weight_name`.
|
||||
if weight_name is None:
|
||||
weight_name = cls._best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
file_extension=".safetensors",
|
||||
local_files_only=local_files_only,
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
@@ -252,8 +245,6 @@ class LoraLoaderMixin:
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
with safetensors.safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
except (IOError, safetensors.SafetensorError) as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
@@ -264,7 +255,7 @@ class LoraLoaderMixin:
|
||||
if model_file is None:
|
||||
if weight_name is None:
|
||||
weight_name = cls._best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin"
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
@@ -300,15 +291,10 @@ class LoraLoaderMixin:
|
||||
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
|
||||
|
||||
return state_dict, network_alphas, metadata
|
||||
return state_dict, network_alphas
|
||||
|
||||
@classmethod
|
||||
def _best_guess_weight_name(
|
||||
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
||||
):
|
||||
if local_files_only or HF_HUB_OFFLINE:
|
||||
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
||||
|
||||
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
|
||||
targeted_files = []
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
||||
@@ -376,7 +362,7 @@ class LoraLoaderMixin:
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(
|
||||
cls, state_dict, network_alphas, unet, config=None, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
||||
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
@@ -390,8 +376,6 @@ class LoraLoaderMixin:
|
||||
See `LoRALinearLayer` for more details.
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
config: (`Dict`):
|
||||
LoRA configuration parsed from the state dict.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
@@ -451,9 +435,7 @@ class LoraLoaderMixin:
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
if config is not None and isinstance(config, dict) and len(config) > 0:
|
||||
config = json.loads(config["unet"])
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, config=config, is_unet=True)
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
@@ -494,7 +476,6 @@ class LoraLoaderMixin:
|
||||
network_alphas,
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
config=None,
|
||||
lora_scale=1.0,
|
||||
low_cpu_mem_usage=None,
|
||||
adapter_name=None,
|
||||
@@ -513,8 +494,6 @@ class LoraLoaderMixin:
|
||||
The text encoder model to load the LoRA layers into.
|
||||
prefix (`str`):
|
||||
Expected prefix of the `text_encoder` in the `state_dict`.
|
||||
config (`Dict`):
|
||||
LoRA configuration parsed from state dict.
|
||||
lora_scale (`float`):
|
||||
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
||||
lora layer.
|
||||
@@ -588,11 +567,10 @@ class LoraLoaderMixin:
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft import LoraConfig
|
||||
|
||||
if config is not None and len(config) > 0:
|
||||
config = json.loads(config[prefix])
|
||||
lora_config_kwargs = get_peft_kwargs(
|
||||
rank, network_alphas, text_encoder_lora_state_dict, config=config, is_unet=False
|
||||
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
|
||||
)
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
@@ -800,8 +778,6 @@ class LoraLoaderMixin:
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
unet_lora_config=None,
|
||||
text_encoder_lora_config=None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -829,54 +805,21 @@ class LoraLoaderMixin:
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND and not safe_serialization:
|
||||
if unet_lora_config or text_encoder_lora_config:
|
||||
raise ValueError(
|
||||
"Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` is not possible. Please install `peft`."
|
||||
)
|
||||
elif USE_PEFT_BACKEND and safe_serialization:
|
||||
from peft import LoraConfig
|
||||
|
||||
if not (unet_lora_layers or text_encoder_lora_layers):
|
||||
raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
state_dict = {}
|
||||
metadata = {}
|
||||
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
|
||||
return layers_state_dict
|
||||
|
||||
def pack_metadata(config, prefix):
|
||||
local_metadata = {}
|
||||
if config is not None:
|
||||
if isinstance(config, LoraConfig):
|
||||
config = config.to_dict()
|
||||
for key, value in config.items():
|
||||
if isinstance(value, set):
|
||||
config[key] = list(value)
|
||||
|
||||
config_as_string = json.dumps(config, indent=2, sort_keys=True)
|
||||
local_metadata[prefix] = config_as_string
|
||||
return local_metadata
|
||||
if not (unet_lora_layers or text_encoder_lora_layers):
|
||||
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
|
||||
|
||||
if unet_lora_layers:
|
||||
prefix = "unet"
|
||||
unet_state_dict = pack_weights(unet_lora_layers, prefix)
|
||||
state_dict.update(unet_state_dict)
|
||||
if unet_lora_config is not None:
|
||||
unet_metadata = pack_metadata(unet_lora_config, prefix)
|
||||
metadata.update(unet_metadata)
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
|
||||
if text_encoder_lora_layers:
|
||||
prefix = "text_encoder"
|
||||
text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder")
|
||||
state_dict.update(text_encoder_state_dict)
|
||||
if text_encoder_lora_config is not None:
|
||||
text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix)
|
||||
metadata.update(text_encoder_metadata)
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
@@ -886,7 +829,6 @@ class LoraLoaderMixin:
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -897,11 +839,7 @@ class LoraLoaderMixin:
|
||||
weight_name: str,
|
||||
save_function: Callable,
|
||||
safe_serialization: bool,
|
||||
metadata=None,
|
||||
):
|
||||
if not safe_serialization and isinstance(metadata, dict) and len(metadata) > 0:
|
||||
raise ValueError("Passing `metadata` is not possible when `safe_serialization` is False.")
|
||||
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
@@ -909,10 +847,8 @@ class LoraLoaderMixin:
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename, metadata):
|
||||
if metadata is None:
|
||||
metadata = {"format": "pt"}
|
||||
return safetensors.torch.save_file(weights, filename, metadata=metadata)
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
@@ -925,10 +861,7 @@ class LoraLoaderMixin:
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
|
||||
if save_function != torch.save:
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name), metadata)
|
||||
else:
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
|
||||
def unload_lora_weights(self):
|
||||
@@ -1360,7 +1293,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
# pipeline.
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas, metadata = self.lora_state_dict(
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
unet_config=self.unet.config,
|
||||
**kwargs,
|
||||
@@ -1370,12 +1303,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=self.unet,
|
||||
config=metadata,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
|
||||
)
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
@@ -1383,7 +1311,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
text_encoder_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
config=metadata,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
@@ -1396,7 +1323,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
text_encoder_2_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
config=metadata,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
@@ -1410,9 +1336,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
unet_lora_config=None,
|
||||
text_encoder_lora_config=None,
|
||||
text_encoder_2_lora_config=None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -1440,63 +1363,24 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND and not safe_serialization:
|
||||
if unet_lora_config or text_encoder_lora_config or text_encoder_2_lora_config:
|
||||
raise ValueError(
|
||||
"Without `peft`, passing `unet_lora_config` or `text_encoder_lora_config` or `text_encoder_2_lora_config` is not possible. Please install `peft`."
|
||||
)
|
||||
elif USE_PEFT_BACKEND and safe_serialization:
|
||||
from peft import LoraConfig
|
||||
state_dict = {}
|
||||
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
||||
raise ValueError(
|
||||
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
||||
)
|
||||
|
||||
state_dict = {}
|
||||
metadata = {}
|
||||
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
|
||||
return layers_state_dict
|
||||
|
||||
def pack_metadata(config, prefix):
|
||||
local_metadata = {}
|
||||
if config is not None:
|
||||
if isinstance(config, LoraConfig):
|
||||
config = config.to_dict()
|
||||
for key, value in config.items():
|
||||
if isinstance(value, set):
|
||||
config[key] = list(value)
|
||||
|
||||
config_as_string = json.dumps(config, indent=2, sort_keys=True)
|
||||
local_metadata[prefix] = config_as_string
|
||||
return local_metadata
|
||||
|
||||
if unet_lora_layers:
|
||||
prefix = "unet"
|
||||
unet_state_dict = pack_weights(unet_lora_layers, prefix)
|
||||
state_dict.update(unet_state_dict)
|
||||
if unet_lora_config is not None:
|
||||
unet_metadata = pack_metadata(unet_lora_config, prefix)
|
||||
metadata.update(unet_metadata)
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
|
||||
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
||||
prefix = "text_encoder"
|
||||
text_encoder_state_dict = pack_weights(text_encoder_lora_layers, "text_encoder")
|
||||
state_dict.update(text_encoder_state_dict)
|
||||
if text_encoder_lora_config is not None:
|
||||
text_encoder_metadata = pack_metadata(text_encoder_lora_config, prefix)
|
||||
metadata.update(text_encoder_metadata)
|
||||
|
||||
prefix = "text_encoder_2"
|
||||
text_encoder_2_state_dict = pack_weights(text_encoder_2_lora_layers, prefix)
|
||||
state_dict.update(text_encoder_2_state_dict)
|
||||
if text_encoder_2_lora_config is not None:
|
||||
text_encoder_2_metadata = pack_metadata(text_encoder_2_lora_config, prefix)
|
||||
metadata.update(text_encoder_2_metadata)
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
@@ -1505,7 +1389,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
|
||||
@@ -84,12 +84,6 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -147,9 +147,6 @@ class StableDiffusionControlNetPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
Args:
|
||||
|
||||
@@ -140,11 +140,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -251,9 +251,6 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -148,10 +148,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -129,10 +129,8 @@ class StableDiffusionXLControlNetPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
- [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -155,10 +155,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -98,9 +98,7 @@ class StableDiffusionControlNetXSPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -102,9 +102,8 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -143,11 +143,6 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
|
||||
-3
@@ -177,9 +177,6 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -232,7 +232,6 @@ class StableDiffusionInpaintPipeline(
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]):
|
||||
|
||||
@@ -54,11 +54,6 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental pipeline and is likely to change in the future.
|
||||
|
||||
@@ -67,9 +67,6 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, FromSingleFileMixi
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -43,11 +43,6 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -282,7 +282,7 @@ class Pix2PixZeroAttnProcessor:
|
||||
|
||||
class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for pixel-level image editing using Pix2Pix Zero. Based on Stable Diffusion.
|
||||
Pipeline for pixel-levl image editing using Pix2Pix Zero. Based on Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
@@ -76,12 +76,6 @@ class StableDiffusionUpscalePipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -65,11 +65,6 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
prior_tokenizer ([`CLIPTokenizer`]):
|
||||
A [`CLIPTokenizer`].
|
||||
|
||||
@@ -76,11 +76,6 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Feature extractor for image pre-processing before being encoded.
|
||||
|
||||
@@ -595,11 +595,10 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, IPAdapterMixin):
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipelineSafe
|
||||
from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
|
||||
|
||||
pipeline = StableDiffusionPipelineSafe.from_pretrained(
|
||||
"AIML-TUDA/stable-diffusion-safe", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
)
|
||||
prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker"
|
||||
image = pipeline(prompt=prompt, **SafetyConfig.MEDIUM).images[0]
|
||||
```
|
||||
|
||||
@@ -159,12 +159,12 @@ class StableDiffusionXLPipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -176,12 +176,12 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -321,12 +321,12 @@ class StableDiffusionXLInpaintPipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
+5
-5
@@ -126,11 +126,11 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -178,12 +178,6 @@ class StableDiffusionXLAdapterPipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
|
||||
Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a
|
||||
|
||||
@@ -83,11 +83,6 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
-5
@@ -159,11 +159,6 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
|
||||
@@ -69,10 +69,6 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
prior ([`Prior`]):
|
||||
The canonical unCLIP prior to approximate the image embedding from the text embedding.
|
||||
|
||||
@@ -98,7 +98,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.custom_timesteps = False
|
||||
self.is_scale_input_called = False
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
@@ -231,7 +230,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Modified _convert_to_karras implementation that takes in ramp as argument
|
||||
def _convert_to_karras(self, ramp):
|
||||
|
||||
@@ -187,7 +187,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
@@ -255,7 +254,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
@@ -214,7 +214,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
@@ -291,7 +290,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
@@ -209,7 +209,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
@property
|
||||
@@ -290,7 +289,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
@@ -198,7 +198,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.noise_sampler = None
|
||||
self.noise_sampler_seed = noise_sampler_seed
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
@@ -348,7 +347,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.mid_point_sigma = None
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.noise_sampler = None
|
||||
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
|
||||
@@ -197,7 +197,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sample = None
|
||||
self.order_list = self.get_order_list(num_train_timesteps)
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def get_order_list(self, num_inference_steps: int) -> List[int]:
|
||||
"""
|
||||
@@ -289,7 +288,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
@@ -166,7 +166,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = False
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
@@ -250,7 +249,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
|
||||
@@ -237,7 +237,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
@@ -342,7 +341,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
|
||||
@@ -148,7 +148,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
@@ -270,7 +269,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.dt = None
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
|
||||
@@ -140,7 +140,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# set all values
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
@@ -296,7 +295,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
|
||||
@@ -140,7 +140,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
@@ -285,7 +284,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
|
||||
@@ -168,7 +168,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = False
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
@@ -280,7 +279,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
self.derivatives = []
|
||||
|
||||
|
||||
@@ -198,7 +198,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.solver_p = solver_p
|
||||
self.last_sample = None
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
@@ -269,7 +268,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
@@ -138,17 +138,11 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
|
||||
module.set_scale(adapter_name, 1.0)
|
||||
|
||||
|
||||
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None, is_unet=True):
|
||||
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
|
||||
rank_pattern = {}
|
||||
alpha_pattern = {}
|
||||
r = lora_alpha = list(rank_dict.values())[0]
|
||||
|
||||
# Try to retrive config.
|
||||
alpha_retrieved = False
|
||||
if config is not None:
|
||||
lora_alpha = config["lora_alpha"]
|
||||
alpha_retrieved = True
|
||||
|
||||
if len(set(rank_dict.values())) > 1:
|
||||
# get the rank occuring the most number of times
|
||||
r = collections.Counter(rank_dict.values()).most_common()[0][0]
|
||||
@@ -160,8 +154,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None,
|
||||
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
|
||||
if len(set(network_alpha_dict.values())) > 1:
|
||||
# get the alpha occuring the most number of times
|
||||
if not alpha_retrieved:
|
||||
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
|
||||
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
|
||||
|
||||
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
|
||||
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
|
||||
@@ -172,7 +165,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, config=None,
|
||||
}
|
||||
else:
|
||||
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
|
||||
elif not alpha_retrieved:
|
||||
else:
|
||||
lora_alpha = set(network_alpha_dict.values()).pop()
|
||||
|
||||
# layer names without the Diffusers specific
|
||||
|
||||
@@ -820,9 +820,7 @@ def _is_torch_fp16_available(device):
|
||||
|
||||
try:
|
||||
x = torch.zeros((2, 2), dtype=torch.float16).to(device)
|
||||
_ = torch.mul(x, x)
|
||||
return True
|
||||
|
||||
_ = x @ x
|
||||
except Exception as e:
|
||||
if device.type == "cuda":
|
||||
raise ValueError(
|
||||
@@ -840,9 +838,7 @@ def _is_torch_fp64_available(device):
|
||||
|
||||
try:
|
||||
x = torch.zeros((2, 2), dtype=torch.float64).to(device)
|
||||
_ = torch.mul(x, x)
|
||||
return True
|
||||
|
||||
_ = x @ x
|
||||
except Exception as e:
|
||||
if device.type == "cuda":
|
||||
raise ValueError(
|
||||
|
||||
@@ -343,21 +343,6 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
image = sd_pipe(**inputs).images
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available() or not is_xformers_available(), reason="xformers requires cuda")
|
||||
def test_stable_diffusion_set_xformers_attn_processors(self):
|
||||
# disable_full_determinism()
|
||||
device = "cuda" # ensure determinism for the device-dependent torch.Generator
|
||||
components, _ = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs()
|
||||
|
||||
# run normal sd pipe
|
||||
image = sd_pipe(**inputs).images
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
# run lora xformers attention
|
||||
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
|
||||
attn_processors = {
|
||||
@@ -622,7 +607,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
orig_image_slice, orig_image_slice_two, atol=1e-3
|
||||
), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda" or not is_xformers_available(), "This test is supposed to run on GPU")
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
def test_lora_unet_attn_processors_with_xformers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.create_lora_weight_file(tmpdirname)
|
||||
@@ -659,7 +644,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
if isinstance(module, Attention):
|
||||
self.assertIsInstance(module.processor, XFormersAttnProcessor)
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda" or not is_xformers_available(), "This test is supposed to run on GPU")
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
def test_lora_save_load_with_xformers(self):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**pipeline_components)
|
||||
@@ -2285,8 +2270,8 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
|
||||
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
start_time = time.time()
|
||||
@@ -2299,13 +2284,13 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
|
||||
del pipe
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
pipe.fuse_lora()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
start_time = time.time()
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
for _ in range(3):
|
||||
pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
|
||||
@@ -46,7 +46,6 @@ from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
load_image,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_peft_backend,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
@@ -107,9 +106,8 @@ class PeftLoraLoaderMixinTests:
|
||||
unet_kwargs = None
|
||||
vae_kwargs = None
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None, lora_alpha=None):
|
||||
def get_dummy_components(self, scheduler_cls=None):
|
||||
scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
|
||||
lora_alpha = 4 if lora_alpha is None else lora_alpha
|
||||
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(**self.unet_kwargs)
|
||||
@@ -124,14 +122,11 @@ class PeftLoraLoaderMixinTests:
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
|
||||
|
||||
text_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
init_lora_weights=False,
|
||||
r=4, lora_alpha=4, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], init_lora_weights=False
|
||||
)
|
||||
|
||||
unet_lora_config = LoraConfig(
|
||||
r=4, lora_alpha=lora_alpha, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
|
||||
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
|
||||
)
|
||||
|
||||
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
|
||||
@@ -718,68 +713,6 @@ class PeftLoraLoaderMixinTests:
|
||||
"Fused lora should change the output",
|
||||
)
|
||||
|
||||
def test_if_lora_alpha_is_correctly_parsed(self):
|
||||
lora_alpha = 8
|
||||
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.unet.add_adapter(unet_lora_config)
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
if self.has_two_text_encoders:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
|
||||
# Inference works?
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
unet_state_dict = get_peft_model_state_dict(pipe.unet)
|
||||
unet_lora_config = pipe.unet.peft_config["default"]
|
||||
|
||||
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
|
||||
text_encoder_lora_config = pipe.text_encoder.peft_config["default"]
|
||||
|
||||
if self.has_two_text_encoders:
|
||||
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
|
||||
text_encoder_2_lora_config = pipe.text_encoder_2.peft_config["default"]
|
||||
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
unet_lora_layers=unet_state_dict,
|
||||
text_encoder_lora_layers=text_encoder_state_dict,
|
||||
text_encoder_2_lora_layers=text_encoder_2_state_dict,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_lora_config,
|
||||
text_encoder_2_lora_config=text_encoder_2_lora_config,
|
||||
)
|
||||
else:
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
unet_lora_layers=unet_state_dict,
|
||||
text_encoder_lora_layers=text_encoder_state_dict,
|
||||
unet_lora_config=unet_lora_config,
|
||||
text_encoder_lora_config=text_encoder_lora_config,
|
||||
)
|
||||
loaded_pipe = self.pipeline_class(**components)
|
||||
loaded_pipe.load_lora_weights(tmpdirname)
|
||||
|
||||
# Inference works?
|
||||
_ = loaded_pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
assert (
|
||||
loaded_pipe.unet.peft_config["default"].lora_alpha == lora_alpha
|
||||
), "LoRA alpha not correctly loaded for UNet."
|
||||
assert (
|
||||
loaded_pipe.text_encoder.peft_config["default"].lora_alpha == lora_alpha
|
||||
), "LoRA alpha not correctly loaded for text encoder."
|
||||
if self.has_two_text_encoders:
|
||||
assert (
|
||||
loaded_pipe.text_encoder_2.peft_config["default"].lora_alpha == lora_alpha
|
||||
), "LoRA alpha not correctly loaded for text encoder 2."
|
||||
|
||||
def test_simple_inference_with_text_unet_lora_unfused(self):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
|
||||
@@ -1780,7 +1713,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
release_memory(pipe)
|
||||
|
||||
def test_sdxl_1_0_lora(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -1803,7 +1736,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
lora_model_id = "latent-consistency/lcm-lora-sdxl"
|
||||
|
||||
@@ -1820,8 +1753,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
image_np = pipe.image_processor.pil_to_numpy(image)
|
||||
expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_np.flatten(), expected_image_np.flatten())
|
||||
assert max_diff < 1e-4
|
||||
self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
@@ -1832,7 +1764,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
pipe.to("cuda")
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
lora_model_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||
pipe.load_lora_weights(lora_model_id)
|
||||
@@ -1848,8 +1780,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
image_np = pipe.image_processor.pil_to_numpy(image)
|
||||
expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_np.flatten(), expected_image_np.flatten())
|
||||
assert max_diff < 1e-4
|
||||
self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
@@ -1864,7 +1795,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape.png"
|
||||
)
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
lora_model_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||
pipe.load_lora_weights(lora_model_id)
|
||||
@@ -1885,8 +1816,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
image_np = pipe.image_processor.pil_to_numpy(image)
|
||||
expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_np.flatten(), expected_image_np.flatten())
|
||||
assert max_diff < 1e-4
|
||||
self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
@@ -1919,7 +1849,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
release_memory(pipe)
|
||||
|
||||
def test_sdxl_1_0_lora_unfusion(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
|
||||
@@ -1930,16 +1860,16 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
images = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=3
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
images_with_fusion = images.flatten()
|
||||
images_with_fusion = images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
pipe.unfuse_lora()
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
images = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=3
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
images_without_fusion = images.flatten()
|
||||
images_without_fusion = images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
self.assertTrue(np.allclose(images_with_fusion, images_without_fusion, atol=1e-3))
|
||||
release_memory(pipe)
|
||||
@@ -1983,8 +1913,10 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
|
||||
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
start_time = time.time()
|
||||
@@ -1997,17 +1929,19 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
|
||||
del pipe
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.bfloat16)
|
||||
pipe.fuse_lora()
|
||||
|
||||
# We need to unload the lora weights since in the previous API `fuse_lora` led to lora weights being
|
||||
# silently deleted - otherwise this will CPU OOM
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
start_time = time.time()
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
for _ in range(3):
|
||||
pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
|
||||
Reference in New Issue
Block a user