From c9081a8abde4d9b660ff7af0ded43723c2ac9024 Mon Sep 17 00:00:00 2001 From: Haofan Wang Date: Wed, 24 Jan 2024 11:48:12 +0800 Subject: [PATCH] [Fix bugs] pipeline_controlnet_sd_xl.py (#6653) * Update pipeline_controlnet_sd_xl.py * Update pipeline_controlnet_xs_sd_xl.py --- .../controlnetxs/pipeline_controlnet_xs_sd_xl.py | 5 ----- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 5 ----- 2 files changed, 10 deletions(-) diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py index be888d7e11..ed45b3bb5a 100644 --- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py +++ b/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py @@ -1041,11 +1041,6 @@ class StableDiffusionXLControlNetXSPipeline( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) - # manually for max memory savings - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 02e515c0ff..78793c2866 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1404,11 +1404,6 @@ class StableDiffusionXLControlNetPipeline( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) - # manually for max memory savings - if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: - self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast