Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b08fc2d9d5 | |||
| a031abdc89 |
@@ -162,25 +162,6 @@ class LCMLoRATextToImageBenchmark(TextToImageBenchmark):
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
def benchmark(self, args):
|
||||
flush()
|
||||
|
||||
print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
|
||||
|
||||
time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
|
||||
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
|
||||
benchmark_info = BenchmarkInfo(time=time, memory=memory)
|
||||
|
||||
pipeline_class_name = str(self.pipe.__class__.__name__)
|
||||
flush()
|
||||
csv_dict = generate_csv_dict(
|
||||
pipeline_cls=pipeline_class_name, ckpt=self.lora_id, args=args, benchmark_info=benchmark_info
|
||||
)
|
||||
filepath = self.get_result_filepath(args)
|
||||
write_to_csv(filepath, csv_dict)
|
||||
print(f"Logs written to: {filepath}")
|
||||
flush()
|
||||
|
||||
|
||||
class ImageToImageBenchmark(TextToImageBenchmark):
|
||||
pipeline_class = AutoPipelineForImage2Image
|
||||
|
||||
@@ -198,8 +198,6 @@
|
||||
title: Outputs
|
||||
title: Main Classes
|
||||
- sections:
|
||||
- local: api/loaders/ip_adapter
|
||||
title: IP-Adapter
|
||||
- local: api/loaders/lora
|
||||
title: LoRA
|
||||
- local: api/loaders/single_file
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# IP-Adapter
|
||||
|
||||
[IP-Adapter](https://hf.co/papers/2308.06721) is a lightweight adapter that enables prompting a diffusion model with an image. This method decouples the cross-attention layers of the image and text features. The image features are generated from an image encoder. Files generated from IP-Adapter are only ~100MBs.
|
||||
|
||||
<Tip>
|
||||
|
||||
Learn how to load an IP-Adapter checkpoint and image in the [IP-Adapter](../../using-diffusers/loading_adapters#ip-adapter) loading guide.
|
||||
|
||||
</Tip>
|
||||
|
||||
## IPAdapterMixin
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
|
||||
@@ -49,12 +49,12 @@ make_image_grid([original_image, mask_image, image], rows=1, cols=3)
|
||||
|
||||
## AsymmetricAutoencoderKL
|
||||
|
||||
[[autodoc]] models.autoencoders.autoencoder_asym_kl.AsymmetricAutoencoderKL
|
||||
[[autodoc]] models.autoencoder_asym_kl.AsymmetricAutoencoderKL
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
|
||||
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
[[autodoc]] models.vae.DecoderOutput
|
||||
|
||||
@@ -54,4 +54,4 @@ image
|
||||
|
||||
## AutoencoderTinyOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.autoencoder_tiny.AutoencoderTinyOutput
|
||||
[[autodoc]] models.autoencoder_tiny.AutoencoderTinyOutput
|
||||
|
||||
@@ -36,11 +36,11 @@ model = AutoencoderKL.from_single_file(url)
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
|
||||
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
[[autodoc]] models.vae.DecoderOutput
|
||||
|
||||
## FlaxAutoencoderKL
|
||||
|
||||
|
||||
@@ -179,7 +179,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--dataloader_num_workers=8 \
|
||||
--resolution=512 \
|
||||
--resolution=512
|
||||
--center_crop \
|
||||
--random_flip \
|
||||
--train_batch_size=1 \
|
||||
@@ -214,4 +214,4 @@ image = pipeline("A pokemon with blue eyes").images[0]
|
||||
Congratulations on training a new model with LoRA! To learn more about how to use your new model, the following guides may be helpful:
|
||||
|
||||
- Learn how to [load different LoRA formats](../using-diffusers/loading_adapters#LoRA) trained using community trainers like Kohya and TheLastBen.
|
||||
- Learn how to use and [combine multiple LoRA's](../tutorials/using_peft_for_inference) with PEFT for inference.
|
||||
- Learn how to use and [combine multiple LoRA's](../tutorials/using_peft_for_inference) with PEFT for inference.
|
||||
@@ -186,7 +186,7 @@ accelerate launch train_unconditional.py \
|
||||
If you're training with more than one GPU, add the `--multi_gpu` parameter to the training command:
|
||||
|
||||
```bash
|
||||
accelerate launch --multi_gpu train_unconditional.py \
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_unconditional.py \
|
||||
--dataset_name="huggan/flowers-102-categories" \
|
||||
--output_dir="ddpm-ema-flowers-64" \
|
||||
--mixed_precision="fp16" \
|
||||
|
||||
@@ -112,7 +112,7 @@ def save_model_card(
|
||||
repo_folder=None,
|
||||
vae_path=None,
|
||||
):
|
||||
img_str = "widget:\n"
|
||||
img_str = "widget:\n" if images else ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"""
|
||||
@@ -121,10 +121,6 @@ def save_model_card(
|
||||
url:
|
||||
"image_{i}.png"
|
||||
"""
|
||||
if not images:
|
||||
img_str += f"""
|
||||
- text: '{instance_prompt}'
|
||||
"""
|
||||
|
||||
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
|
||||
diffusers_imports_pivotal = ""
|
||||
@@ -161,6 +157,8 @@ tags:
|
||||
base_model: {base_model}
|
||||
instance_prompt: {instance_prompt}
|
||||
license: openrail++
|
||||
widget:
|
||||
- text: '{validation_prompt if validation_prompt else instance_prompt}'
|
||||
---
|
||||
"""
|
||||
|
||||
@@ -2012,42 +2010,43 @@ def main(args):
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
images = [
|
||||
|
||||
@@ -156,7 +156,7 @@ class WebdatasetFilter:
|
||||
return False
|
||||
|
||||
|
||||
class SDText2ImageDataset:
|
||||
class Text2ImageDataset:
|
||||
def __init__(
|
||||
self,
|
||||
train_shards_path_or_url: Union[str, List[str]],
|
||||
@@ -359,43 +359,19 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
|
||||
|
||||
|
||||
# Compare LCMScheduler.step, Step 4
|
||||
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
if prediction_type == "epsilon":
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
pred_x_0 = (sample - sigmas * model_output) / alphas
|
||||
elif prediction_type == "sample":
|
||||
pred_x_0 = model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_x_0 = alphas * sample - sigmas * model_output
|
||||
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
|
||||
|
||||
return pred_x_0
|
||||
|
||||
|
||||
# Based on step 4 in DDIMScheduler.step
|
||||
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
pred_epsilon = model_output
|
||||
elif prediction_type == "sample":
|
||||
pred_epsilon = (sample - alphas * model_output) / sigmas
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_epsilon = alphas * model_output + sigmas * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_epsilon
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
@@ -859,35 +835,34 @@ def main(args):
|
||||
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
|
||||
# The scheduler calculates the alpha and sigma schedule for us
|
||||
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
|
||||
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
|
||||
# Initialize the DDIM ODE solver for distillation.
|
||||
solver = DDIMSolver(
|
||||
noise_scheduler.alphas_cumprod.numpy(),
|
||||
timesteps=noise_scheduler.config.num_train_timesteps,
|
||||
ddim_timesteps=args.num_ddim_timesteps,
|
||||
)
|
||||
|
||||
# 2. Load tokenizers from SD 1.X/2.X checkpoint.
|
||||
# 2. Load tokenizers from SD-XL checkpoint.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
|
||||
)
|
||||
|
||||
# 3. Load text encoders from SD 1.X/2.X checkpoint.
|
||||
# 3. Load text encoders from SD-1.5 checkpoint.
|
||||
# import correct text encoder classes
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# 4. Load VAE from SD 1.X/2.X checkpoint
|
||||
# 4. Load VAE from SD-XL checkpoint (or more stable VAE)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_teacher_model,
|
||||
subfolder="vae",
|
||||
revision=args.teacher_revision,
|
||||
)
|
||||
|
||||
# 5. Load teacher U-Net from SD 1.X/2.X checkpoint
|
||||
# 5. Load teacher U-Net from SD-XL checkpoint
|
||||
teacher_unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
|
||||
)
|
||||
@@ -897,7 +872,7 @@ def main(args):
|
||||
text_encoder.requires_grad_(False)
|
||||
teacher_unet.requires_grad_(False)
|
||||
|
||||
# 7. Create online student U-Net.
|
||||
# 7. Create online (`unet`) student U-Nets.
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
|
||||
)
|
||||
@@ -960,7 +935,6 @@ def main(args):
|
||||
# Also move the alpha and sigma noise schedules to accelerator.device.
|
||||
alpha_schedule = alpha_schedule.to(accelerator.device)
|
||||
sigma_schedule = sigma_schedule.to(accelerator.device)
|
||||
# Move the ODE solver to accelerator.device.
|
||||
solver = solver.to(accelerator.device)
|
||||
|
||||
# 10. Handle saving and loading of checkpoints
|
||||
@@ -1037,14 +1011,13 @@ def main(args):
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# 13. Dataset creation and data processing
|
||||
# Here, we compute not just the text embeddings but also the additional embeddings
|
||||
# needed for the SD XL UNet to operate.
|
||||
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
|
||||
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
dataset = SDText2ImageDataset(
|
||||
dataset = Text2ImageDataset(
|
||||
train_shards_path_or_url=args.train_shards_path_or_url,
|
||||
num_train_examples=args.max_train_samples,
|
||||
per_gpu_batch_size=args.train_batch_size,
|
||||
@@ -1064,7 +1037,6 @@ def main(args):
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# 14. LR Scheduler creation
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
|
||||
@@ -1079,7 +1051,6 @@ def main(args):
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
# 15. Prepare for training
|
||||
# Prepare everything with our `accelerator`.
|
||||
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
|
||||
|
||||
@@ -1101,7 +1072,7 @@ def main(args):
|
||||
).input_ids.to(accelerator.device)
|
||||
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
|
||||
|
||||
# 16. Train!
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
@@ -1152,7 +1123,6 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# 1. Load and process the image and text conditioning
|
||||
image, text = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
@@ -1170,37 +1140,37 @@ def main(args):
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
latents = latents.to(weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
|
||||
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
|
||||
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
|
||||
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
|
||||
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
|
||||
start_timesteps = solver.ddim_timesteps[index]
|
||||
timesteps = start_timesteps - topk
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noise = torch.randn_like(latents)
|
||||
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
|
||||
|
||||
# 5. Sample a random guidance scale w from U[w_min, w_max]
|
||||
# Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# 6. Prepare prompt embeds and unet_added_conditions
|
||||
# 20.4.8. Prepare prompt embeds and unet_added_conditions
|
||||
prompt_embeds = encoded_text.pop("prompt_embeds")
|
||||
|
||||
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
|
||||
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
|
||||
noise_pred = unet(
|
||||
noisy_model_input,
|
||||
start_timesteps,
|
||||
@@ -1209,7 +1179,7 @@ def main(args):
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
noise_pred,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1220,27 +1190,17 @@ def main(args):
|
||||
|
||||
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
|
||||
|
||||
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
|
||||
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
|
||||
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and conditional embedding
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=prompt_embeds.to(weight_dtype),
|
||||
).sample
|
||||
cond_pred_x0 = get_predicted_original_sample(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
cond_pred_noise = get_predicted_noise(
|
||||
cond_pred_x0 = predicted_origin(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1249,21 +1209,13 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and unconditional embedding
|
||||
uncond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
|
||||
).sample
|
||||
uncond_pred_x0 = get_predicted_original_sample(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
uncond_pred_noise = get_predicted_noise(
|
||||
uncond_pred_x0 = predicted_origin(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1272,17 +1224,12 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
|
||||
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
|
||||
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
|
||||
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
|
||||
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
|
||||
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
|
||||
# augmented PF-ODE trajectory (solving backward in time)
|
||||
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
|
||||
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# Note that we do not use a separate target network for LCM-LoRA distillation.
|
||||
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
target_noise_pred = unet(
|
||||
@@ -1291,7 +1238,7 @@ def main(args):
|
||||
timestep_cond=None,
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
).sample
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
target_noise_pred,
|
||||
timesteps,
|
||||
x_prev,
|
||||
@@ -1301,7 +1248,7 @@ def main(args):
|
||||
)
|
||||
target = c_skip * x_prev + c_out * pred_x_0
|
||||
|
||||
# 10. Calculate loss
|
||||
# 20.4.13. Calculate loss
|
||||
if args.loss_type == "l2":
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
elif args.loss_type == "huber":
|
||||
@@ -1309,7 +1256,7 @@ def main(args):
|
||||
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
|
||||
)
|
||||
|
||||
# 11. Backpropagate on the online student model (`unet`)
|
||||
# 20.4.14. Backpropagate on the online student model (`unet`)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
||||
|
||||
@@ -162,7 +162,7 @@ class WebdatasetFilter:
|
||||
return False
|
||||
|
||||
|
||||
class SDXLText2ImageDataset:
|
||||
class Text2ImageDataset:
|
||||
def __init__(
|
||||
self,
|
||||
train_shards_path_or_url: Union[str, List[str]],
|
||||
@@ -346,43 +346,19 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
|
||||
|
||||
|
||||
# Compare LCMScheduler.step, Step 4
|
||||
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
if prediction_type == "epsilon":
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
pred_x_0 = (sample - sigmas * model_output) / alphas
|
||||
elif prediction_type == "sample":
|
||||
pred_x_0 = model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_x_0 = alphas * sample - sigmas * model_output
|
||||
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
|
||||
|
||||
return pred_x_0
|
||||
|
||||
|
||||
# Based on step 4 in DDIMScheduler.step
|
||||
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
pred_epsilon = model_output
|
||||
elif prediction_type == "sample":
|
||||
pred_epsilon = (sample - alphas * model_output) / sigmas
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_epsilon = alphas * model_output + sigmas * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_epsilon
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
@@ -854,10 +830,9 @@ def main(args):
|
||||
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
|
||||
# The scheduler calculates the alpha and sigma schedule for us
|
||||
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
|
||||
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
|
||||
# Initialize the DDIM ODE solver for distillation.
|
||||
solver = DDIMSolver(
|
||||
noise_scheduler.alphas_cumprod.numpy(),
|
||||
timesteps=noise_scheduler.config.num_train_timesteps,
|
||||
@@ -911,7 +886,7 @@ def main(args):
|
||||
text_encoder_two.requires_grad_(False)
|
||||
teacher_unet.requires_grad_(False)
|
||||
|
||||
# 7. Create online student U-Net.
|
||||
# 7. Create online (`unet`) student U-Nets.
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
|
||||
)
|
||||
@@ -975,7 +950,6 @@ def main(args):
|
||||
# Also move the alpha and sigma noise schedules to accelerator.device.
|
||||
alpha_schedule = alpha_schedule.to(accelerator.device)
|
||||
sigma_schedule = sigma_schedule.to(accelerator.device)
|
||||
# Move the ODE solver to accelerator.device.
|
||||
solver = solver.to(accelerator.device)
|
||||
|
||||
# 10. Handle saving and loading of checkpoints
|
||||
@@ -1083,7 +1057,7 @@ def main(args):
|
||||
|
||||
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
|
||||
|
||||
dataset = SDXLText2ImageDataset(
|
||||
dataset = Text2ImageDataset(
|
||||
train_shards_path_or_url=args.train_shards_path_or_url,
|
||||
num_train_examples=args.max_train_samples,
|
||||
per_gpu_batch_size=args.train_batch_size,
|
||||
@@ -1201,7 +1175,6 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates)
|
||||
image, text, orig_size, crop_coords = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
@@ -1223,37 +1196,37 @@ def main(args):
|
||||
latents = latents * vae.config.scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
latents = latents.to(weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
|
||||
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
|
||||
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
|
||||
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
|
||||
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
|
||||
start_timesteps = solver.ddim_timesteps[index]
|
||||
timesteps = start_timesteps - topk
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noise = torch.randn_like(latents)
|
||||
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
|
||||
|
||||
# 5. Sample a random guidance scale w from U[w_min, w_max]
|
||||
# Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# 6. Prepare prompt embeds and unet_added_conditions
|
||||
# 20.4.8. Prepare prompt embeds and unet_added_conditions
|
||||
prompt_embeds = encoded_text.pop("prompt_embeds")
|
||||
|
||||
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
|
||||
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
|
||||
noise_pred = unet(
|
||||
noisy_model_input,
|
||||
start_timesteps,
|
||||
@@ -1262,7 +1235,7 @@ def main(args):
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
noise_pred,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1273,28 +1246,18 @@ def main(args):
|
||||
|
||||
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
|
||||
|
||||
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
|
||||
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
|
||||
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and conditional embedding
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=prompt_embeds.to(weight_dtype),
|
||||
added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
|
||||
).sample
|
||||
cond_pred_x0 = get_predicted_original_sample(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
cond_pred_noise = get_predicted_noise(
|
||||
cond_pred_x0 = predicted_origin(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1303,7 +1266,7 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and unconditional embedding
|
||||
uncond_added_conditions = copy.deepcopy(encoded_text)
|
||||
uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
|
||||
uncond_teacher_output = teacher_unet(
|
||||
@@ -1312,15 +1275,7 @@ def main(args):
|
||||
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
|
||||
added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
|
||||
).sample
|
||||
uncond_pred_x0 = get_predicted_original_sample(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
uncond_pred_noise = get_predicted_noise(
|
||||
uncond_pred_x0 = predicted_origin(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1329,17 +1284,12 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
|
||||
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
|
||||
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
|
||||
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
|
||||
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
|
||||
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
|
||||
# augmented PF-ODE trajectory (solving backward in time)
|
||||
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
|
||||
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# Note that we do not use a separate target network for LCM-LoRA distillation.
|
||||
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
|
||||
target_noise_pred = unet(
|
||||
@@ -1349,7 +1299,7 @@ def main(args):
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
target_noise_pred,
|
||||
timesteps,
|
||||
x_prev,
|
||||
@@ -1359,7 +1309,7 @@ def main(args):
|
||||
)
|
||||
target = c_skip * x_prev + c_out * pred_x_0
|
||||
|
||||
# 10. Calculate loss
|
||||
# 20.4.13. Calculate loss
|
||||
if args.loss_type == "l2":
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
elif args.loss_type == "huber":
|
||||
@@ -1367,7 +1317,7 @@ def main(args):
|
||||
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
|
||||
)
|
||||
|
||||
# 11. Backpropagate on the online student model (`unet`)
|
||||
# 20.4.14. Backpropagate on the online student model (`unet`)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
||||
|
||||
@@ -138,7 +138,7 @@ class WebdatasetFilter:
|
||||
return False
|
||||
|
||||
|
||||
class SDText2ImageDataset:
|
||||
class Text2ImageDataset:
|
||||
def __init__(
|
||||
self,
|
||||
train_shards_path_or_url: Union[str, List[str]],
|
||||
@@ -336,43 +336,19 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
|
||||
|
||||
|
||||
# Compare LCMScheduler.step, Step 4
|
||||
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
if prediction_type == "epsilon":
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
pred_x_0 = (sample - sigmas * model_output) / alphas
|
||||
elif prediction_type == "sample":
|
||||
pred_x_0 = model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_x_0 = alphas * sample - sigmas * model_output
|
||||
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
|
||||
|
||||
return pred_x_0
|
||||
|
||||
|
||||
# Based on step 4 in DDIMScheduler.step
|
||||
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
pred_epsilon = model_output
|
||||
elif prediction_type == "sample":
|
||||
pred_epsilon = (sample - alphas * model_output) / sigmas
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_epsilon = alphas * model_output + sigmas * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_epsilon
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
@@ -847,35 +823,34 @@ def main(args):
|
||||
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
|
||||
# The scheduler calculates the alpha and sigma schedule for us
|
||||
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
|
||||
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
|
||||
# Initialize the DDIM ODE solver for distillation.
|
||||
solver = DDIMSolver(
|
||||
noise_scheduler.alphas_cumprod.numpy(),
|
||||
timesteps=noise_scheduler.config.num_train_timesteps,
|
||||
ddim_timesteps=args.num_ddim_timesteps,
|
||||
)
|
||||
|
||||
# 2. Load tokenizers from SD 1.X/2.X checkpoint.
|
||||
# 2. Load tokenizers from SD-XL checkpoint.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
|
||||
)
|
||||
|
||||
# 3. Load text encoders from SD 1.X/2.X checkpoint.
|
||||
# 3. Load text encoders from SD-1.5 checkpoint.
|
||||
# import correct text encoder classes
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# 4. Load VAE from SD 1.X/2.X checkpoint
|
||||
# 4. Load VAE from SD-XL checkpoint (or more stable VAE)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_teacher_model,
|
||||
subfolder="vae",
|
||||
revision=args.teacher_revision,
|
||||
)
|
||||
|
||||
# 5. Load teacher U-Net from SD 1.X/2.X checkpoint
|
||||
# 5. Load teacher U-Net from SD-XL checkpoint
|
||||
teacher_unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
|
||||
)
|
||||
@@ -885,7 +860,7 @@ def main(args):
|
||||
text_encoder.requires_grad_(False)
|
||||
teacher_unet.requires_grad_(False)
|
||||
|
||||
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
|
||||
if teacher_unet.config.time_cond_proj_dim is None:
|
||||
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
|
||||
@@ -894,8 +869,8 @@ def main(args):
|
||||
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
|
||||
unet.train()
|
||||
|
||||
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from (online) unet
|
||||
# 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from unet
|
||||
target_unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
target_unet.load_state_dict(unet.state_dict())
|
||||
target_unet.train()
|
||||
@@ -912,7 +887,7 @@ def main(args):
|
||||
f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
|
||||
)
|
||||
|
||||
# 9. Handle mixed precision and device placement
|
||||
# 10. Handle mixed precision and device placement
|
||||
# For mixed precision training we cast all non-trainable weigths to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
@@ -939,7 +914,7 @@ def main(args):
|
||||
sigma_schedule = sigma_schedule.to(accelerator.device)
|
||||
solver = solver.to(accelerator.device)
|
||||
|
||||
# 10. Handle saving and loading of checkpoints
|
||||
# 11. Handle saving and loading of checkpoints
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -973,7 +948,7 @@ def main(args):
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
# 11. Enable optimizations
|
||||
# 12. Enable optimizations
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
@@ -1019,14 +994,13 @@ def main(args):
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# 13. Dataset creation and data processing
|
||||
# Here, we compute not just the text embeddings but also the additional embeddings
|
||||
# needed for the SD XL UNet to operate.
|
||||
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
|
||||
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
dataset = SDText2ImageDataset(
|
||||
dataset = Text2ImageDataset(
|
||||
train_shards_path_or_url=args.train_shards_path_or_url,
|
||||
num_train_examples=args.max_train_samples,
|
||||
per_gpu_batch_size=args.train_batch_size,
|
||||
@@ -1046,7 +1020,6 @@ def main(args):
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# 14. LR Scheduler creation
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
|
||||
@@ -1061,7 +1034,6 @@ def main(args):
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
# 15. Prepare for training
|
||||
# Prepare everything with our `accelerator`.
|
||||
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
|
||||
|
||||
@@ -1083,7 +1055,7 @@ def main(args):
|
||||
).input_ids.to(accelerator.device)
|
||||
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
|
||||
|
||||
# 16. Train!
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
@@ -1134,7 +1106,6 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# 1. Load and process the image and text conditioning
|
||||
image, text = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
@@ -1152,28 +1123,29 @@ def main(args):
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
latents = latents.to(weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
|
||||
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
|
||||
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
|
||||
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
|
||||
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
|
||||
start_timesteps = solver.ddim_timesteps[index]
|
||||
timesteps = start_timesteps - topk
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noise = torch.randn_like(latents)
|
||||
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
|
||||
|
||||
# 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
@@ -1181,10 +1153,10 @@ def main(args):
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# 6. Prepare prompt embeds and unet_added_conditions
|
||||
# 20.4.8. Prepare prompt embeds and unet_added_conditions
|
||||
prompt_embeds = encoded_text.pop("prompt_embeds")
|
||||
|
||||
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
|
||||
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
|
||||
noise_pred = unet(
|
||||
noisy_model_input,
|
||||
start_timesteps,
|
||||
@@ -1193,7 +1165,7 @@ def main(args):
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
noise_pred,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1204,27 +1176,17 @@ def main(args):
|
||||
|
||||
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
|
||||
|
||||
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
|
||||
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
|
||||
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and conditional embedding
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=prompt_embeds.to(weight_dtype),
|
||||
).sample
|
||||
cond_pred_x0 = get_predicted_original_sample(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
cond_pred_noise = get_predicted_noise(
|
||||
cond_pred_x0 = predicted_origin(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1233,21 +1195,13 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and unconditional embedding
|
||||
uncond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
|
||||
).sample
|
||||
uncond_pred_x0 = get_predicted_original_sample(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
uncond_pred_noise = get_predicted_noise(
|
||||
uncond_pred_x0 = predicted_origin(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1256,16 +1210,12 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
|
||||
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
|
||||
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
|
||||
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
|
||||
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
|
||||
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
|
||||
# augmented PF-ODE trajectory (solving backward in time)
|
||||
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
|
||||
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
target_noise_pred = target_unet(
|
||||
@@ -1274,7 +1224,7 @@ def main(args):
|
||||
timestep_cond=w_embedding,
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
).sample
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
target_noise_pred,
|
||||
timesteps,
|
||||
x_prev,
|
||||
@@ -1284,7 +1234,7 @@ def main(args):
|
||||
)
|
||||
target = c_skip * x_prev + c_out * pred_x_0
|
||||
|
||||
# 10. Calculate loss
|
||||
# 20.4.13. Calculate loss
|
||||
if args.loss_type == "l2":
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
elif args.loss_type == "huber":
|
||||
@@ -1292,7 +1242,7 @@ def main(args):
|
||||
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
|
||||
)
|
||||
|
||||
# 11. Backpropagate on the online student model (`unet`)
|
||||
# 20.4.14. Backpropagate on the online student model (`unet`)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
||||
@@ -1302,7 +1252,7 @@ def main(args):
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
# 12. Make EMA update to target student model parameters (`target_unet`)
|
||||
# 20.4.15. Make EMA update to target student model parameters
|
||||
update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
@@ -144,7 +144,7 @@ class WebdatasetFilter:
|
||||
return False
|
||||
|
||||
|
||||
class SDXLText2ImageDataset:
|
||||
class Text2ImageDataset:
|
||||
def __init__(
|
||||
self,
|
||||
train_shards_path_or_url: Union[str, List[str]],
|
||||
@@ -324,43 +324,19 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
|
||||
|
||||
|
||||
# Compare LCMScheduler.step, Step 4
|
||||
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
if prediction_type == "epsilon":
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
pred_x_0 = (sample - sigmas * model_output) / alphas
|
||||
elif prediction_type == "sample":
|
||||
pred_x_0 = model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_x_0 = alphas * sample - sigmas * model_output
|
||||
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
|
||||
|
||||
return pred_x_0
|
||||
|
||||
|
||||
# Based on step 4 in DDIMScheduler.step
|
||||
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
pred_epsilon = model_output
|
||||
elif prediction_type == "sample":
|
||||
pred_epsilon = (sample - alphas * model_output) / sigmas
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_epsilon = alphas * model_output + sigmas * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_epsilon
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
@@ -887,10 +863,9 @@ def main(args):
|
||||
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
|
||||
# The scheduler calculates the alpha and sigma schedule for us
|
||||
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
|
||||
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
|
||||
# Initialize the DDIM ODE solver for distillation.
|
||||
solver = DDIMSolver(
|
||||
noise_scheduler.alphas_cumprod.numpy(),
|
||||
timesteps=noise_scheduler.config.num_train_timesteps,
|
||||
@@ -944,7 +919,7 @@ def main(args):
|
||||
text_encoder_two.requires_grad_(False)
|
||||
teacher_unet.requires_grad_(False)
|
||||
|
||||
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
|
||||
if teacher_unet.config.time_cond_proj_dim is None:
|
||||
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
|
||||
@@ -953,8 +928,8 @@ def main(args):
|
||||
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
|
||||
unet.train()
|
||||
|
||||
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from (online) unet
|
||||
# 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from unet
|
||||
target_unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
target_unet.load_state_dict(unet.state_dict())
|
||||
target_unet.train()
|
||||
@@ -996,7 +971,6 @@ def main(args):
|
||||
# Also move the alpha and sigma noise schedules to accelerator.device.
|
||||
alpha_schedule = alpha_schedule.to(accelerator.device)
|
||||
sigma_schedule = sigma_schedule.to(accelerator.device)
|
||||
# Move the ODE solver to accelerator.device.
|
||||
solver = solver.to(accelerator.device)
|
||||
|
||||
# 10. Handle saving and loading of checkpoints
|
||||
@@ -1110,7 +1084,7 @@ def main(args):
|
||||
|
||||
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
|
||||
|
||||
dataset = SDXLText2ImageDataset(
|
||||
dataset = Text2ImageDataset(
|
||||
train_shards_path_or_url=args.train_shards_path_or_url,
|
||||
num_train_examples=args.max_train_samples,
|
||||
per_gpu_batch_size=args.train_batch_size,
|
||||
@@ -1228,7 +1202,6 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates)
|
||||
image, text, orig_size, crop_coords = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
@@ -1250,39 +1223,38 @@ def main(args):
|
||||
latents = latents * vae.config.scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
latents = latents.to(weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
|
||||
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
|
||||
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
|
||||
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
|
||||
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
|
||||
start_timesteps = solver.ddim_timesteps[index]
|
||||
timesteps = start_timesteps - topk
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noise = torch.randn_like(latents)
|
||||
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
|
||||
|
||||
# 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
# Move to U-Net device and dtype
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# 6. Prepare prompt embeds and unet_added_conditions
|
||||
# 20.4.8. Prepare prompt embeds and unet_added_conditions
|
||||
prompt_embeds = encoded_text.pop("prompt_embeds")
|
||||
|
||||
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
|
||||
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
|
||||
noise_pred = unet(
|
||||
noisy_model_input,
|
||||
start_timesteps,
|
||||
@@ -1291,7 +1263,7 @@ def main(args):
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
noise_pred,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1302,28 +1274,18 @@ def main(args):
|
||||
|
||||
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
|
||||
|
||||
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
|
||||
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
|
||||
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and conditional embedding
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=prompt_embeds.to(weight_dtype),
|
||||
added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
|
||||
).sample
|
||||
cond_pred_x0 = get_predicted_original_sample(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
cond_pred_noise = get_predicted_noise(
|
||||
cond_pred_x0 = predicted_origin(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1332,7 +1294,7 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and unconditional embedding
|
||||
uncond_added_conditions = copy.deepcopy(encoded_text)
|
||||
uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
|
||||
uncond_teacher_output = teacher_unet(
|
||||
@@ -1341,15 +1303,7 @@ def main(args):
|
||||
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
|
||||
added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
|
||||
).sample
|
||||
uncond_pred_x0 = get_predicted_original_sample(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
uncond_pred_noise = get_predicted_noise(
|
||||
uncond_pred_x0 = predicted_origin(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1358,16 +1312,12 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
|
||||
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
|
||||
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
|
||||
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
|
||||
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
|
||||
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
|
||||
# augmented PF-ODE trajectory (solving backward in time)
|
||||
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
|
||||
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
target_noise_pred = target_unet(
|
||||
@@ -1377,7 +1327,7 @@ def main(args):
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
pred_x_0 = predicted_origin(
|
||||
target_noise_pred,
|
||||
timesteps,
|
||||
x_prev,
|
||||
@@ -1387,7 +1337,7 @@ def main(args):
|
||||
)
|
||||
target = c_skip * x_prev + c_out * pred_x_0
|
||||
|
||||
# 10. Calculate loss
|
||||
# 20.4.13. Calculate loss
|
||||
if args.loss_type == "l2":
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
elif args.loss_type == "huber":
|
||||
@@ -1395,7 +1345,7 @@ def main(args):
|
||||
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
|
||||
)
|
||||
|
||||
# 11. Backpropagate on the online student model (`unet`)
|
||||
# 20.4.14. Backpropagate on the online student model (`unet`)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
||||
@@ -1405,7 +1355,7 @@ def main(args):
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
# 12. Make EMA update to target student model parameters (`target_unet`)
|
||||
# 20.4.15. Make EMA update to target student model parameters
|
||||
update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
@@ -64,6 +64,39 @@ check_min_version("0.25.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
|
||||
@@ -64,6 +64,39 @@ check_min_version("0.25.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
|
||||
@@ -101,8 +101,8 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
|
||||
|
||||
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
model_path = "path_to_saved_model"
|
||||
@@ -114,13 +114,12 @@ image.save("yoda-pokemon.png")
|
||||
```
|
||||
|
||||
Checkpoints only save the unet, so to run inference from a checkpoint, just load the unet
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
|
||||
|
||||
model_path = "path_to_saved_model"
|
||||
unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-<N>/unet", torch_dtype=torch.float16)
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-<N>/unet")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("<initial model>", unet=unet, torch_dtype=torch.float16)
|
||||
pipe.to("cuda")
|
||||
|
||||
@@ -54,6 +54,39 @@ check_min_version("0.25.0.dev0")
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
|
||||
@@ -63,6 +63,39 @@ check_min_version("0.25.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
|
||||
@@ -12,9 +12,9 @@ from safetensors.torch import load_file as stl
|
||||
from tqdm import tqdm
|
||||
|
||||
from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
|
||||
from diffusers.models.autoencoders.vae import Encoder
|
||||
from diffusers.models.embeddings import TimestepEmbedding
|
||||
from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
|
||||
from diffusers.models.vae import Encoder
|
||||
|
||||
|
||||
args = ArgumentParser()
|
||||
|
||||
@@ -159,14 +159,6 @@ vae_conversion_map_attn = [
|
||||
("proj_out.", "proj_attn."),
|
||||
]
|
||||
|
||||
# This is probably not the most ideal solution, but it does work.
|
||||
vae_extra_conversion_map = [
|
||||
("to_q", "q"),
|
||||
("to_k", "k"),
|
||||
("to_v", "v"),
|
||||
("to_out.0", "proj_out"),
|
||||
]
|
||||
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
@@ -186,20 +178,11 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
||||
weights_to_convert = ["q", "k", "v", "proj_out"]
|
||||
keys_to_rename = {}
|
||||
for k, v in new_state_dict.items():
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
print(f"Reshaping {k} for SD format")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
for weight_name, real_weight_name in vae_extra_conversion_map:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
|
||||
keys_to_rename[k] = k.replace(weight_name, real_weight_name)
|
||||
for k, v in keys_to_rename.items():
|
||||
if k in new_state_dict:
|
||||
print(f"Renaming {k} to {v}")
|
||||
new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
|
||||
del new_state_dict[k]
|
||||
return new_state_dict
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from typing import Callable, Dict, List, Optional, Union
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
@@ -230,9 +229,7 @@ class LoraLoaderMixin:
|
||||
# determine `weight_name`.
|
||||
if weight_name is None:
|
||||
weight_name = cls._best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
file_extension=".safetensors",
|
||||
local_files_only=local_files_only,
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
@@ -258,7 +255,7 @@ class LoraLoaderMixin:
|
||||
if model_file is None:
|
||||
if weight_name is None:
|
||||
weight_name = cls._best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin"
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
@@ -297,12 +294,7 @@ class LoraLoaderMixin:
|
||||
return state_dict, network_alphas
|
||||
|
||||
@classmethod
|
||||
def _best_guess_weight_name(
|
||||
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
||||
):
|
||||
if local_files_only or HF_HUB_OFFLINE:
|
||||
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
||||
|
||||
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
|
||||
targeted_files = []
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
||||
|
||||
@@ -169,12 +169,10 @@ class FromSingleFileMixin:
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
||||
prediction_type = kwargs.pop("prediction_type", None)
|
||||
text_encoder = kwargs.pop("text_encoder", None)
|
||||
text_encoder_2 = kwargs.pop("text_encoder_2", None)
|
||||
vae = kwargs.pop("vae", None)
|
||||
controlnet = kwargs.pop("controlnet", None)
|
||||
adapter = kwargs.pop("adapter", None)
|
||||
tokenizer = kwargs.pop("tokenizer", None)
|
||||
tokenizer_2 = kwargs.pop("tokenizer_2", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
@@ -276,10 +274,8 @@ class FromSingleFileMixin:
|
||||
load_safety_checker=load_safety_checker,
|
||||
prediction_type=prediction_type,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
vae=vae,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
original_config_file=original_config_file,
|
||||
config_files=config_files,
|
||||
local_files_only=local_files_only,
|
||||
|
||||
@@ -26,11 +26,11 @@ _import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
|
||||
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnetxs"] = ["ControlNetXSModel"]
|
||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
@@ -58,13 +58,11 @@ if is_flax_available():
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .autoencoders import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
)
|
||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
||||
from .controlnet import ControlNetModel
|
||||
from .controlnetxs import ControlNetXSModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
|
||||
+4
-4
@@ -16,10 +16,10 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .modeling_outputs import AutoencoderKLOutput
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
|
||||
|
||||
|
||||
+6
-6
@@ -16,10 +16,10 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalVAEMixin
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import (
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalVAEMixin
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
Attention,
|
||||
@@ -27,8 +27,8 @@ from ..attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .modeling_outputs import AutoencoderKLOutput
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
+8
-8
@@ -16,14 +16,14 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalVAEMixin
|
||||
from ...utils import is_torch_version
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalVAEMixin
|
||||
from ..utils import is_torch_version
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
from .modeling_outputs import AutoencoderKLOutput
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
+4
-4
@@ -18,10 +18,10 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DecoderTiny, EncoderTiny
|
||||
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
||||
+14
-14
@@ -18,20 +18,20 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...schedulers import ConsistencyDecoderScheduler
|
||||
from ...utils import BaseOutput
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..attention_processor import (
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..schedulers import ConsistencyDecoderScheduler
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unet_2d import UNet2DModel
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d import UNet2DModel
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_tiling
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
@@ -162,7 +162,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
self.use_tiling = use_tiling
|
||||
|
||||
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_tiling
|
||||
def disable_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
@@ -170,7 +170,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
self.enable_tiling(False)
|
||||
|
||||
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_slicing
|
||||
def enable_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
@@ -178,7 +178,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_slicing
|
||||
def disable_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
@@ -333,14 +333,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
|
||||
return DecoderOutput(sample=x_0)
|
||||
|
||||
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_v
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_h
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
@@ -26,7 +26,7 @@ from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
AttentionProcessor,
|
||||
)
|
||||
from .autoencoders import AutoencoderKL
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .lora import LoRACompatibleConv
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import (
|
||||
|
||||
@@ -1334,7 +1334,7 @@ class AlphaBlender(nn.Module):
|
||||
|
||||
alpha = torch.where(
|
||||
image_only_indicator.bool(),
|
||||
torch.ones(1, 1, device=image_only_indicator.device, dtype=self.mix_factor.dtype),
|
||||
torch.ones(1, 1, device=image_only_indicator.device),
|
||||
torch.sigmoid(self.mix_factor)[..., None],
|
||||
)
|
||||
|
||||
|
||||
@@ -18,11 +18,11 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...utils import BaseOutput, is_torch_version
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import SpatialNorm
|
||||
from ..unet_2d_blocks import (
|
||||
from ..utils import BaseOutput, is_torch_version
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .activations import get_activation
|
||||
from .attention_processor import SpatialNorm
|
||||
from .unet_2d_blocks import (
|
||||
AutoencoderTinyBlock,
|
||||
UNetMidBlock2D,
|
||||
get_down_block,
|
||||
@@ -20,8 +20,8 @@ import torch.nn as nn
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -23,7 +23,6 @@ from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import FusedAttnProcessor2_0
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -656,65 +655,6 @@ class AltDiffusionPipeline(
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
"""
|
||||
self.fusing_unet = False
|
||||
self.fusing_vae = False
|
||||
|
||||
if unet:
|
||||
self.fusing_unet = True
|
||||
self.unet.fuse_qkv_projections()
|
||||
self.unet.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
if vae:
|
||||
if not isinstance(self.vae, AutoencoderKL):
|
||||
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
|
||||
|
||||
self.fusing_vae = True
|
||||
self.vae.fuse_qkv_projections()
|
||||
self.vae.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""Disable QKV projection fusion if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
|
||||
"""
|
||||
if unet:
|
||||
if not self.fusing_unet:
|
||||
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.unet.unfuse_qkv_projections()
|
||||
self.fusing_unet = False
|
||||
|
||||
if vae:
|
||||
if not self.fusing_vae:
|
||||
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.vae.unfuse_qkv_projections()
|
||||
self.fusing_vae = False
|
||||
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
||||
|
||||
@@ -25,7 +25,6 @@ from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import FusedAttnProcessor2_0
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -716,65 +715,6 @@ class AltDiffusionImg2ImgPipeline(
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
"""
|
||||
self.fusing_unet = False
|
||||
self.fusing_vae = False
|
||||
|
||||
if unet:
|
||||
self.fusing_unet = True
|
||||
self.unet.fuse_qkv_projections()
|
||||
self.unet.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
if vae:
|
||||
if not isinstance(self.vae, AutoencoderKL):
|
||||
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
|
||||
|
||||
self.fusing_vae = True
|
||||
self.vae.fuse_qkv_projections()
|
||||
self.vae.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""Disable QKV projection fusion if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
|
||||
"""
|
||||
if unet:
|
||||
if not self.fusing_unet:
|
||||
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.unet.unfuse_qkv_projections()
|
||||
self.fusing_unet = False
|
||||
|
||||
if vae:
|
||||
if not self.fusing_vae:
|
||||
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.vae.unfuse_qkv_projections()
|
||||
self.fusing_vae = False
|
||||
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
||||
|
||||
@@ -84,12 +84,6 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -147,9 +147,6 @@ class StableDiffusionControlNetPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
Args:
|
||||
|
||||
@@ -140,11 +140,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -251,9 +251,6 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -148,10 +148,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -129,10 +129,8 @@ class StableDiffusionXLControlNetPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
- [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -155,10 +155,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -98,9 +98,7 @@ class StableDiffusionControlNetXSPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -102,9 +102,8 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -1153,9 +1153,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
vae_path=None,
|
||||
vae=None,
|
||||
text_encoder=None,
|
||||
text_encoder_2=None,
|
||||
tokenizer=None,
|
||||
tokenizer_2=None,
|
||||
config_files=None,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
@@ -1234,9 +1232,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
@@ -1343,11 +1339,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
else:
|
||||
pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline
|
||||
|
||||
if num_in_channels is None and pipeline_class in [
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
]:
|
||||
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
|
||||
num_in_channels = 9
|
||||
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
|
||||
num_in_channels = 7
|
||||
@@ -1694,9 +1686,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
elif model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
is_refiner = model_type == "SDXL-Refiner"
|
||||
|
||||
if (is_refiner is False) and (tokenizer is None):
|
||||
if model_type == "SDXL":
|
||||
try:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
||||
@@ -1705,11 +1695,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
||||
)
|
||||
|
||||
if (is_refiner is False) and (text_encoder is None):
|
||||
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
||||
|
||||
if tokenizer_2 is None:
|
||||
try:
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
|
||||
@@ -1719,69 +1705,95 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
|
||||
)
|
||||
|
||||
if text_encoder_2 is None:
|
||||
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
config_kwargs = {"projection_dim": 1280}
|
||||
prefix = "conditioner.embedders.0.model." if is_refiner else "conditioner.embedders.1.model."
|
||||
|
||||
text_encoder_2 = convert_open_clip_checkpoint(
|
||||
checkpoint,
|
||||
config_name,
|
||||
prefix=prefix,
|
||||
prefix="conditioner.embedders.1.model.",
|
||||
has_projection=True,
|
||||
local_files_only=local_files_only,
|
||||
**config_kwargs,
|
||||
)
|
||||
|
||||
if is_accelerate_available(): # SBM Now move model to cpu.
|
||||
for param_name, param in converted_unet_checkpoint.items():
|
||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||
|
||||
if controlnet:
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
)
|
||||
elif adapter:
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
adapter=adapter,
|
||||
scheduler=scheduler,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
)
|
||||
if is_accelerate_available(): # SBM Now move model to cpu.
|
||||
if model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
for param_name, param in converted_unet_checkpoint.items():
|
||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||
|
||||
if controlnet:
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
)
|
||||
elif adapter:
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
adapter=adapter,
|
||||
scheduler=scheduler,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
)
|
||||
else:
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
)
|
||||
else:
|
||||
pipeline_kwargs = {
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
tokenizer = None
|
||||
text_encoder = None
|
||||
try:
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
|
||||
)
|
||||
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
config_kwargs = {"projection_dim": 1280}
|
||||
text_encoder_2 = convert_open_clip_checkpoint(
|
||||
checkpoint,
|
||||
config_name,
|
||||
prefix="conditioner.embedders.0.model.",
|
||||
has_projection=True,
|
||||
local_files_only=local_files_only,
|
||||
**config_kwargs,
|
||||
)
|
||||
|
||||
if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or (
|
||||
pipeline_class == StableDiffusionXLInpaintPipeline
|
||||
):
|
||||
pipeline_kwargs.update({"requires_aesthetics_score": is_refiner})
|
||||
if is_accelerate_available(): # SBM Now move model to cpu.
|
||||
if model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
for param_name, param in converted_unet_checkpoint.items():
|
||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||
|
||||
if is_refiner:
|
||||
pipeline_kwargs.update({"force_zeros_for_empty_prompt": False})
|
||||
|
||||
pipe = pipeline_class(**pipeline_kwargs)
|
||||
pipe = StableDiffusionXLImg2ImgPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
requires_aesthetics_score=True,
|
||||
force_zeros_for_empty_prompt=False,
|
||||
)
|
||||
else:
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
|
||||
@@ -143,11 +143,6 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -23,7 +23,6 @@ from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import FusedAttnProcessor2_0
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -651,67 +650,6 @@ class StableDiffusionPipeline(
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
"""
|
||||
self.fusing_unet = False
|
||||
self.fusing_vae = False
|
||||
|
||||
if unet:
|
||||
self.fusing_unet = True
|
||||
self.unet.fuse_qkv_projections()
|
||||
self.unet.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
if vae:
|
||||
if not isinstance(self.vae, AutoencoderKL):
|
||||
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
|
||||
|
||||
self.fusing_vae = True
|
||||
self.vae.fuse_qkv_projections()
|
||||
self.vae.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""Disable QKV projection fusion if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
|
||||
"""
|
||||
if unet:
|
||||
if not self.fusing_unet:
|
||||
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.unet.unfuse_qkv_projections()
|
||||
self.fusing_unet = False
|
||||
|
||||
if vae:
|
||||
if not self.fusing_vae:
|
||||
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.vae.unfuse_qkv_projections()
|
||||
self.fusing_vae = False
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
|
||||
-3
@@ -177,9 +177,6 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -25,7 +25,6 @@ from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import FusedAttnProcessor2_0
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -719,67 +718,6 @@ class StableDiffusionImg2ImgPipeline(
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
"""
|
||||
self.fusing_unet = False
|
||||
self.fusing_vae = False
|
||||
|
||||
if unet:
|
||||
self.fusing_unet = True
|
||||
self.unet.fuse_qkv_projections()
|
||||
self.unet.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
if vae:
|
||||
if not isinstance(self.vae, AutoencoderKL):
|
||||
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
|
||||
|
||||
self.fusing_vae = True
|
||||
self.vae.fuse_qkv_projections()
|
||||
self.vae.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""Disable QKV projection fusion if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
|
||||
"""
|
||||
if unet:
|
||||
if not self.fusing_unet:
|
||||
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.unet.unfuse_qkv_projections()
|
||||
self.fusing_unet = False
|
||||
|
||||
if vae:
|
||||
if not self.fusing_vae:
|
||||
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.vae.unfuse_qkv_projections()
|
||||
self.fusing_vae = False
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
|
||||
@@ -25,7 +25,6 @@ from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import FusedAttnProcessor2_0
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
@@ -233,7 +232,6 @@ class StableDiffusionInpaintPipeline(
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]):
|
||||
@@ -845,67 +843,6 @@ class StableDiffusionInpaintPipeline(
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
"""
|
||||
self.fusing_unet = False
|
||||
self.fusing_vae = False
|
||||
|
||||
if unet:
|
||||
self.fusing_unet = True
|
||||
self.unet.fuse_qkv_projections()
|
||||
self.unet.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
if vae:
|
||||
if not isinstance(self.vae, AutoencoderKL):
|
||||
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
|
||||
|
||||
self.fusing_vae = True
|
||||
self.vae.fuse_qkv_projections()
|
||||
self.vae.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""Disable QKV projection fusion if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
|
||||
"""
|
||||
if unet:
|
||||
if not self.fusing_unet:
|
||||
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.unet.unfuse_qkv_projections()
|
||||
self.fusing_unet = False
|
||||
|
||||
if vae:
|
||||
if not self.fusing_vae:
|
||||
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.vae.unfuse_qkv_projections()
|
||||
self.fusing_vae = False
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
|
||||
@@ -54,11 +54,6 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental pipeline and is likely to change in the future.
|
||||
|
||||
@@ -67,9 +67,6 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, FromSingleFileMixi
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -43,11 +43,6 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -282,7 +282,7 @@ class Pix2PixZeroAttnProcessor:
|
||||
|
||||
class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for pixel-level image editing using Pix2Pix Zero. Based on Stable Diffusion.
|
||||
Pipeline for pixel-levl image editing using Pix2Pix Zero. Based on Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
@@ -76,12 +76,6 @@ class StableDiffusionUpscalePipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -65,11 +65,6 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
prior_tokenizer ([`CLIPTokenizer`]):
|
||||
A [`CLIPTokenizer`].
|
||||
|
||||
@@ -76,11 +76,6 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Feature extractor for image pre-processing before being encoded.
|
||||
|
||||
@@ -595,11 +595,10 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, IPAdapterMixin):
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipelineSafe
|
||||
from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
|
||||
|
||||
pipeline = StableDiffusionPipelineSafe.from_pretrained(
|
||||
"AIML-TUDA/stable-diffusion-safe", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
)
|
||||
prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker"
|
||||
image = pipeline(prompt=prompt, **SafetyConfig.MEDIUM).images[0]
|
||||
```
|
||||
|
||||
@@ -159,12 +159,12 @@ class StableDiffusionXLPipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
+6
-68
@@ -35,7 +35,6 @@ from ...loaders import (
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
@@ -177,12 +176,12 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
@@ -865,67 +864,6 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
"""
|
||||
self.fusing_unet = False
|
||||
self.fusing_vae = False
|
||||
|
||||
if unet:
|
||||
self.fusing_unet = True
|
||||
self.unet.fuse_qkv_projections()
|
||||
self.unet.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
if vae:
|
||||
if not isinstance(self.vae, AutoencoderKL):
|
||||
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
|
||||
|
||||
self.fusing_vae = True
|
||||
self.vae.fuse_qkv_projections()
|
||||
self.vae.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""Disable QKV projection fusion if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
|
||||
"""
|
||||
if unet:
|
||||
if not self.fusing_unet:
|
||||
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.unet.unfuse_qkv_projections()
|
||||
self.fusing_unet = False
|
||||
|
||||
if vae:
|
||||
if not self.fusing_vae:
|
||||
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.vae.unfuse_qkv_projections()
|
||||
self.fusing_vae = False
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
|
||||
+6
-68
@@ -36,7 +36,6 @@ from ...loaders import (
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
@@ -322,12 +321,12 @@ class StableDiffusionXLInpaintPipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
@@ -1085,67 +1084,6 @@ class StableDiffusionXLInpaintPipeline(
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
"""
|
||||
self.fusing_unet = False
|
||||
self.fusing_vae = False
|
||||
|
||||
if unet:
|
||||
self.fusing_unet = True
|
||||
self.unet.fuse_qkv_projections()
|
||||
self.unet.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
if vae:
|
||||
if not isinstance(self.vae, AutoencoderKL):
|
||||
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
|
||||
|
||||
self.fusing_vae = True
|
||||
self.vae.fuse_qkv_projections()
|
||||
self.vae.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||
"""Disable QKV projection fusion if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||
|
||||
"""
|
||||
if unet:
|
||||
if not self.fusing_unet:
|
||||
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.unet.unfuse_qkv_projections()
|
||||
self.fusing_unet = False
|
||||
|
||||
if vae:
|
||||
if not self.fusing_vae:
|
||||
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.vae.unfuse_qkv_projections()
|
||||
self.fusing_vae = False
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
"""
|
||||
|
||||
+5
-5
@@ -126,11 +126,11 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -178,12 +178,6 @@ class StableDiffusionXLAdapterPipeline(
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
|
||||
Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a
|
||||
|
||||
@@ -83,11 +83,6 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
-5
@@ -159,11 +159,6 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
|
||||
@@ -19,8 +19,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.autoencoders.vae import DecoderOutput, VectorQuantizer
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.vae import DecoderOutput, VectorQuantizer
|
||||
from ...models.vq_model import VQEncoderOutput
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
|
||||
|
||||
@@ -69,10 +69,6 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
|
||||
Args:
|
||||
prior ([`Prior`]):
|
||||
The canonical unCLIP prior to approximate the image embedding from the text embedding.
|
||||
|
||||
@@ -98,7 +98,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.custom_timesteps = False
|
||||
self.is_scale_input_called = False
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
@@ -231,7 +230,6 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Modified _convert_to_karras implementation that takes in ramp as argument
|
||||
def _convert_to_karras(self, ramp):
|
||||
|
||||
@@ -187,7 +187,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
@@ -255,7 +254,6 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
@@ -214,7 +214,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
@@ -291,7 +290,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
@@ -209,7 +209,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
@property
|
||||
@@ -290,7 +289,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
@@ -198,7 +198,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.noise_sampler = None
|
||||
self.noise_sampler_seed = noise_sampler_seed
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
@@ -348,7 +347,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.mid_point_sigma = None
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.noise_sampler = None
|
||||
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
|
||||
@@ -197,7 +197,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sample = None
|
||||
self.order_list = self.get_order_list(num_train_timesteps)
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def get_order_list(self, num_inference_steps: int) -> List[int]:
|
||||
"""
|
||||
@@ -289,7 +288,6 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
@@ -166,7 +166,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = False
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
@@ -250,7 +249,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
|
||||
@@ -237,7 +237,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
@@ -342,7 +341,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
|
||||
@@ -148,7 +148,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
@@ -270,7 +269,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.dt = None
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
|
||||
@@ -140,7 +140,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# set all values
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
@@ -296,7 +295,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
|
||||
@@ -140,7 +140,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
@@ -285,7 +284,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
|
||||
@@ -168,7 +168,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = False
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
@@ -280,7 +279,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
self.derivatives = []
|
||||
|
||||
|
||||
@@ -198,7 +198,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.solver_p = solver_p
|
||||
self.last_sample = None
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
@@ -269,7 +268,6 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
@@ -820,9 +820,7 @@ def _is_torch_fp16_available(device):
|
||||
|
||||
try:
|
||||
x = torch.zeros((2, 2), dtype=torch.float16).to(device)
|
||||
_ = torch.mul(x, x)
|
||||
return True
|
||||
|
||||
_ = x @ x
|
||||
except Exception as e:
|
||||
if device.type == "cuda":
|
||||
raise ValueError(
|
||||
@@ -840,9 +838,7 @@ def _is_torch_fp64_available(device):
|
||||
|
||||
try:
|
||||
x = torch.zeros((2, 2), dtype=torch.float64).to(device)
|
||||
_ = torch.mul(x, x)
|
||||
return True
|
||||
|
||||
_ = x @ x
|
||||
except Exception as e:
|
||||
if device.type == "cuda":
|
||||
raise ValueError(
|
||||
|
||||
@@ -343,21 +343,6 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
image = sd_pipe(**inputs).images
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available() or not is_xformers_available(), reason="xformers requires cuda")
|
||||
def test_stable_diffusion_set_xformers_attn_processors(self):
|
||||
# disable_full_determinism()
|
||||
device = "cuda" # ensure determinism for the device-dependent torch.Generator
|
||||
components, _ = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs()
|
||||
|
||||
# run normal sd pipe
|
||||
image = sd_pipe(**inputs).images
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
# run lora xformers attention
|
||||
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
|
||||
attn_processors = {
|
||||
@@ -622,7 +607,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
orig_image_slice, orig_image_slice_two, atol=1e-3
|
||||
), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda" or not is_xformers_available(), "This test is supposed to run on GPU")
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
def test_lora_unet_attn_processors_with_xformers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.create_lora_weight_file(tmpdirname)
|
||||
@@ -659,7 +644,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
if isinstance(module, Attention):
|
||||
self.assertIsInstance(module.processor, XFormersAttnProcessor)
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda" or not is_xformers_available(), "This test is supposed to run on GPU")
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
def test_lora_save_load_with_xformers(self):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**pipeline_components)
|
||||
@@ -2285,8 +2270,8 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
|
||||
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
start_time = time.time()
|
||||
@@ -2299,13 +2284,13 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
|
||||
del pipe
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
pipe.fuse_lora()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
start_time = time.time()
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
for _ in range(3):
|
||||
pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
|
||||
@@ -46,7 +46,6 @@ from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
load_image,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_peft_backend,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
@@ -1714,7 +1713,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
release_memory(pipe)
|
||||
|
||||
def test_sdxl_1_0_lora(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -1737,7 +1736,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
lora_model_id = "latent-consistency/lcm-lora-sdxl"
|
||||
|
||||
@@ -1754,8 +1753,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
image_np = pipe.image_processor.pil_to_numpy(image)
|
||||
expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_np.flatten(), expected_image_np.flatten())
|
||||
assert max_diff < 1e-4
|
||||
self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
@@ -1766,7 +1764,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
pipe.to("cuda")
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
lora_model_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||
pipe.load_lora_weights(lora_model_id)
|
||||
@@ -1782,8 +1780,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
image_np = pipe.image_processor.pil_to_numpy(image)
|
||||
expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_np.flatten(), expected_image_np.flatten())
|
||||
assert max_diff < 1e-4
|
||||
self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
@@ -1798,7 +1795,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape.png"
|
||||
)
|
||||
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
lora_model_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||
pipe.load_lora_weights(lora_model_id)
|
||||
@@ -1819,8 +1816,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
image_np = pipe.image_processor.pil_to_numpy(image)
|
||||
expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_np.flatten(), expected_image_np.flatten())
|
||||
assert max_diff < 1e-4
|
||||
self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
@@ -1853,7 +1849,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
release_memory(pipe)
|
||||
|
||||
def test_sdxl_1_0_lora_unfusion(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
|
||||
@@ -1864,16 +1860,16 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
images = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=3
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
images_with_fusion = images.flatten()
|
||||
images_with_fusion = images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
pipe.unfuse_lora()
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
images = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=3
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
images_without_fusion = images.flatten()
|
||||
images_without_fusion = images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
self.assertTrue(np.allclose(images_with_fusion, images_without_fusion, atol=1e-3))
|
||||
release_memory(pipe)
|
||||
@@ -1917,8 +1913,10 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
|
||||
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
start_time = time.time()
|
||||
@@ -1931,17 +1929,19 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
|
||||
del pipe
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.float16)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename, torch_dtype=torch.bfloat16)
|
||||
pipe.fuse_lora()
|
||||
|
||||
# We need to unload the lora weights since in the previous API `fuse_lora` led to lora weights being
|
||||
# silently deleted - otherwise this will CPU OOM
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
start_time = time.time()
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
for _ in range(3):
|
||||
pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
|
||||
@@ -661,37 +661,6 @@ class StableDiffusionPipelineFastTests(
|
||||
output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]
|
||||
), "Disabling of FreeU should lead to results similar to the default pipeline results."
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
sd_pipe.fuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice_fused = image[0, -3:, -3:, -1]
|
||||
|
||||
sd_pipe.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user