Fix image upcasting (#7858)

Fix image's upcasting before `vae.encode()` when using `fp16`

Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Tolga Cangöz
2024-05-08 05:45:02 +03:00
committed by GitHub
parent c2217142bd
commit d50baf0c63
2 changed files with 1 additions and 2 deletions
@@ -1419,7 +1419,6 @@ class LEditsPPPipelineStableDiffusionXL(
if needs_upcasting:
image = image.float()
self.upcast_vae()
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
x0 = self.vae.encode(image).latent_dist.mode()
x0 = x0.to(dtype)
@@ -525,8 +525,8 @@ class StableDiffusionXLInstructPix2PixPipeline(
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
image = image.float()
self.upcast_vae()
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")