7529 do not disable autocast for cuda devices (#7530)
* 7529 do not disable autocast for cuda devices * Remove typecasting error check for non-mps platforms, as a correct autocast implementation makes it a non-issue * add autocast fix to other training examples * disable native_amp for dreambooth (sdxl) * disable native_amp for pix2pix (sdxl) * remove tests from remaining files * disable native_amp on huggingface accelerator for every training example that uses it * convert more usages of autocast to nullcontext, make style fixes * make style fixes * style. * Empty-Commit --------- Co-authored-by: bghira <bghira@users.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -23,6 +23,7 @@ import os
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -1844,7 +1845,12 @@ def main(args):
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import gc
|
||||
import hashlib
|
||||
import itertools
|
||||
@@ -26,6 +25,7 @@ import random
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -2192,13 +2192,12 @@ def main(args):
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
inference_ctx = (
|
||||
contextlib.nullcontext()
|
||||
if "playground" in args.pretrained_model_name_or_path
|
||||
else torch.cuda.amp.autocast()
|
||||
)
|
||||
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with inference_ctx:
|
||||
with autocast_ctx:
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
|
||||
@@ -430,6 +430,9 @@ def main(args):
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if accelerator.is_main_process:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
@@ -23,6 +23,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
@@ -238,6 +239,10 @@ class SDText2ImageDataset:
|
||||
|
||||
def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
||||
logger.info("Running validation... ")
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
|
||||
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
@@ -274,7 +279,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
||||
|
||||
for _, prompt in enumerate(validation_prompts):
|
||||
images = []
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
with autocast_ctx:
|
||||
images = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=4,
|
||||
@@ -1172,6 +1177,11 @@ def main(args):
|
||||
).input_ids.to(accelerator.device)
|
||||
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
# 16. Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
@@ -1300,7 +1310,7 @@ def main(args):
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
with autocast_ctx:
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
@@ -1359,7 +1369,7 @@ def main(args):
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# Note that we do not use a separate target network for LCM-LoRA distillation.
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
with autocast_ctx:
|
||||
target_noise_pred = unet(
|
||||
x_prev.float(),
|
||||
timesteps,
|
||||
|
||||
@@ -22,6 +22,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
@@ -146,7 +147,12 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
|
||||
|
||||
for _, prompt in enumerate(validation_prompts):
|
||||
images = []
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
|
||||
|
||||
with autocast_ctx:
|
||||
images = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=4,
|
||||
|
||||
@@ -24,6 +24,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
@@ -256,6 +257,10 @@ class SDXLText2ImageDataset:
|
||||
|
||||
def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
||||
logger.info("Running validation... ")
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
|
||||
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
@@ -291,7 +296,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
||||
|
||||
for _, prompt in enumerate(validation_prompts):
|
||||
images = []
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
with autocast_ctx:
|
||||
images = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=4,
|
||||
@@ -1353,7 +1358,12 @@ def main(args):
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
@@ -1416,7 +1426,12 @@ def main(args):
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# Note that we do not use a separate target network for LCM-LoRA distillation.
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
|
||||
|
||||
with autocast_ctx:
|
||||
target_noise_pred = unet(
|
||||
x_prev.float(),
|
||||
timesteps,
|
||||
|
||||
@@ -23,6 +23,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
@@ -252,7 +253,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
|
||||
|
||||
for _, prompt in enumerate(validation_prompts):
|
||||
images = []
|
||||
with torch.autocast("cuda"):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
images = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=4,
|
||||
@@ -1257,7 +1263,12 @@ def main(args):
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
@@ -1315,7 +1326,12 @@ def main(args):
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
|
||||
|
||||
with autocast_ctx:
|
||||
target_noise_pred = target_unet(
|
||||
x_prev.float(),
|
||||
timesteps,
|
||||
|
||||
@@ -24,6 +24,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
@@ -270,7 +271,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
|
||||
|
||||
for _, prompt in enumerate(validation_prompts):
|
||||
images = []
|
||||
with torch.autocast("cuda"):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
images = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=4,
|
||||
@@ -1355,7 +1361,12 @@ def main(args):
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
@@ -1417,7 +1428,12 @@ def main(args):
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
|
||||
|
||||
with autocast_ctx:
|
||||
target_noise_pred = target_unet(
|
||||
x_prev.float(),
|
||||
timesteps,
|
||||
|
||||
@@ -752,6 +752,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import functools
|
||||
import gc
|
||||
import logging
|
||||
@@ -22,6 +21,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
@@ -125,11 +125,10 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
|
||||
)
|
||||
|
||||
image_logs = []
|
||||
inference_ctx = (
|
||||
contextlib.nullcontext()
|
||||
if (is_final_validation or torch.backends.mps.is_available())
|
||||
else torch.autocast("cuda")
|
||||
)
|
||||
if is_final_validation or torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
|
||||
validation_image = Image.open(validation_image).convert("RGB")
|
||||
@@ -138,7 +137,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
|
||||
images = []
|
||||
|
||||
for _ in range(args.num_validation_images):
|
||||
with inference_ctx:
|
||||
with autocast_ctx:
|
||||
image = pipeline(
|
||||
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
|
||||
).images[0]
|
||||
@@ -811,6 +810,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -676,6 +676,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -821,6 +821,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -749,6 +749,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -23,6 +23,7 @@ import os
|
||||
import random
|
||||
import shutil
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -207,18 +208,12 @@ def log_validation(
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
|
||||
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
enable_autocast = True
|
||||
if torch.backends.mps.is_available() or (
|
||||
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
|
||||
):
|
||||
enable_autocast = False
|
||||
if "playground" in args.pretrained_model_name_or_path:
|
||||
enable_autocast = False
|
||||
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with torch.autocast(
|
||||
accelerator.device.type,
|
||||
enabled=enable_autocast,
|
||||
):
|
||||
with autocast_ctx:
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
@@ -992,6 +987,10 @@ def main(args):
|
||||
kwargs_handlers=[kwargs],
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -21,6 +21,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
@@ -404,6 +405,10 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
@@ -943,9 +948,12 @@ def main():
|
||||
# run inference
|
||||
original_image = download_image(args.val_image_url)
|
||||
edited_images = []
|
||||
with torch.autocast(
|
||||
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
|
||||
):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
|
||||
@@ -20,6 +20,7 @@ import math
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -70,9 +71,7 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
|
||||
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
|
||||
|
||||
|
||||
def log_validation(
|
||||
pipeline, args, accelerator, generator, global_step, is_final_validation=False, enable_autocast=True
|
||||
):
|
||||
def log_validation(pipeline, args, accelerator, generator, global_step, is_final_validation=False):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
@@ -91,7 +90,12 @@ def log_validation(
|
||||
else Image.open(image_url_or_path).convert("RGB")
|
||||
)(args.val_image_url_or_path)
|
||||
|
||||
with torch.autocast(accelerator.device.type, enabled=enable_autocast):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
edited_images = []
|
||||
# Run inference
|
||||
for val_img_idx in range(args.num_validation_images):
|
||||
@@ -507,6 +511,10 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
@@ -983,13 +991,6 @@ def main():
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args))
|
||||
|
||||
# Some configurations require autocast to be disabled.
|
||||
enable_autocast = True
|
||||
if torch.backends.mps.is_available() or (
|
||||
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
|
||||
):
|
||||
enable_autocast = False
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
@@ -1202,7 +1203,6 @@ def main():
|
||||
generator,
|
||||
global_step,
|
||||
is_final_validation=False,
|
||||
enable_autocast=enable_autocast,
|
||||
)
|
||||
|
||||
if args.use_ema:
|
||||
@@ -1252,7 +1252,6 @@ def main():
|
||||
generator,
|
||||
global_step,
|
||||
is_final_validation=True,
|
||||
enable_autocast=enable_autocast,
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
@@ -458,6 +458,10 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -343,6 +343,11 @@ def main():
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -356,6 +356,11 @@ def main():
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -459,6 +459,10 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -916,6 +916,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -484,6 +484,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -526,6 +526,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -516,6 +516,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -623,6 +623,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -21,6 +21,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
@@ -410,6 +411,10 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
@@ -967,9 +972,12 @@ def main():
|
||||
# run inference
|
||||
original_image = download_image(args.val_image_url)
|
||||
edited_images = []
|
||||
with torch.autocast(
|
||||
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
|
||||
):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
|
||||
@@ -378,6 +378,10 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
@@ -411,6 +411,11 @@ def main():
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -698,6 +698,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -566,6 +566,10 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -439,6 +439,10 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -581,6 +581,10 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
+4
@@ -295,6 +295,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.logger == "tensorboard":
|
||||
if not is_tensorboard_available():
|
||||
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
|
||||
|
||||
@@ -799,6 +799,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -20,6 +20,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
@@ -164,7 +165,12 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
|
||||
|
||||
images = []
|
||||
for i in range(len(args.validation_prompts)):
|
||||
with torch.autocast("cuda"):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
|
||||
|
||||
images.append(image)
|
||||
@@ -523,6 +529,10 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -21,6 +21,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -408,6 +409,11 @@ def main():
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
@@ -878,7 +884,12 @@ def main():
|
||||
if args.seed is not None:
|
||||
generator = generator.manual_seed(args.seed)
|
||||
images = []
|
||||
with torch.cuda.amp.autocast():
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(
|
||||
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
|
||||
@@ -948,7 +959,12 @@ def main():
|
||||
if args.seed is not None:
|
||||
generator = generator.manual_seed(args.seed)
|
||||
images = []
|
||||
with torch.cuda.amp.autocast():
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(
|
||||
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
|
||||
|
||||
@@ -21,6 +21,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -979,13 +980,6 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
|
||||
|
||||
# Some configurations require autocast to be disabled.
|
||||
enable_autocast = True
|
||||
if torch.backends.mps.is_available() or (
|
||||
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
|
||||
):
|
||||
enable_autocast = False
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
@@ -1211,11 +1205,12 @@ def main(args):
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with torch.autocast(
|
||||
accelerator.device.type,
|
||||
enabled=enable_autocast,
|
||||
):
|
||||
with autocast_ctx:
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
|
||||
@@ -23,6 +23,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
@@ -603,6 +604,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
@@ -986,12 +991,10 @@ def main(args):
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# Some configurations require autocast to be disabled.
|
||||
enable_autocast = True
|
||||
if torch.backends.mps.is_available() or (
|
||||
accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
|
||||
):
|
||||
enable_autocast = False
|
||||
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
@@ -1226,10 +1229,7 @@ def main(args):
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
with torch.autocast(
|
||||
accelerator.device.type,
|
||||
enabled=enable_autocast,
|
||||
):
|
||||
with autocast_ctx:
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
@@ -1284,7 +1284,8 @@ def main(args):
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
with torch.autocast(accelerator.device.type, enabled=enable_autocast):
|
||||
|
||||
with autocast_ctx:
|
||||
images = [
|
||||
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
|
||||
@@ -20,6 +20,7 @@ import os
|
||||
import random
|
||||
import shutil
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -143,7 +144,12 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
|
||||
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with torch.autocast("cuda"):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
images.append(image)
|
||||
|
||||
@@ -600,6 +606,10 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -605,6 +605,10 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -460,6 +460,10 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -458,6 +458,10 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -548,8 +548,15 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
pixel_values = pixel_values.to(device=device)
|
||||
# The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16.
|
||||
# So we use `torch.autocast` here for half precision inference.
|
||||
context_manger = torch.autocast("cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext()
|
||||
with context_manger:
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = contextlib.nullcontext()
|
||||
logger.warning(
|
||||
"The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16, but autocast is not yet supported on MPS."
|
||||
)
|
||||
else:
|
||||
autocast_ctx = torch.autocast(device.type, dtype=dtype)
|
||||
|
||||
with autocast_ctx:
|
||||
depth_map = self.depth_estimator(pixel_values).predicted_depth
|
||||
else:
|
||||
depth_map = depth_map.to(device=device, dtype=dtype)
|
||||
|
||||
@@ -1199,10 +1199,6 @@ class StableDiffusionXLPipeline(
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
|
||||
)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
@@ -1241,10 +1237,6 @@ class StableDiffusionXLPipeline(
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
self.vae = self.vae.to(latents.dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
|
||||
)
|
||||
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
|
||||
@@ -1376,10 +1376,6 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
|
||||
)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
@@ -1418,10 +1414,6 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
self.vae = self.vae.to(latents.dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
|
||||
)
|
||||
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
|
||||
@@ -1726,10 +1726,6 @@ class StableDiffusionXLInpaintPipeline(
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
|
||||
)
|
||||
|
||||
if num_channels_unet == 4:
|
||||
init_latents_proper = image_latents
|
||||
@@ -1785,10 +1781,6 @@ class StableDiffusionXLInpaintPipeline(
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
self.vae = self.vae.to(latents.dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
|
||||
)
|
||||
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
|
||||
-8
@@ -924,10 +924,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
|
||||
)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
@@ -950,10 +946,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
self.vae = self.vae.to(latents.dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
|
||||
)
|
||||
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
|
||||
Reference in New Issue
Block a user