Fix image upcasting (#7858)
Fix image's upcasting before `vae.encode()` when using `fp16` Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -1419,7 +1419,6 @@ class LEditsPPPipelineStableDiffusionXL(
|
|||||||
if needs_upcasting:
|
if needs_upcasting:
|
||||||
image = image.float()
|
image = image.float()
|
||||||
self.upcast_vae()
|
self.upcast_vae()
|
||||||
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
|
||||||
|
|
||||||
x0 = self.vae.encode(image).latent_dist.mode()
|
x0 = self.vae.encode(image).latent_dist.mode()
|
||||||
x0 = x0.to(dtype)
|
x0 = x0.to(dtype)
|
||||||
|
|||||||
+1
-1
@@ -525,8 +525,8 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
|||||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||||
if needs_upcasting:
|
if needs_upcasting:
|
||||||
|
image = image.float()
|
||||||
self.upcast_vae()
|
self.upcast_vae()
|
||||||
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
|
||||||
|
|
||||||
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
|
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user