Fix image upcasting (#7858)
Fix image's upcasting before `vae.encode()` when using `fp16` Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -1419,7 +1419,6 @@ class LEditsPPPipelineStableDiffusionXL(
|
||||
if needs_upcasting:
|
||||
image = image.float()
|
||||
self.upcast_vae()
|
||||
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
x0 = self.vae.encode(image).latent_dist.mode()
|
||||
x0 = x0.to(dtype)
|
||||
|
||||
+1
-1
@@ -525,8 +525,8 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
if needs_upcasting:
|
||||
image = image.float()
|
||||
self.upcast_vae()
|
||||
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user