Fix issue with prompt embeds and latents in SD Cascade Decoder with multiple image embeddings for a single prompt. (#7381)

* fix

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Dhruv Nair
2024-03-19 23:10:14 +05:30
committed by GitHub
parent b09a2aa308
commit 80ff4ba63e
2 changed files with 79 additions and 4 deletions
@@ -100,8 +100,10 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
) )
self.register_to_config(latent_dim_scale=latent_dim_scale) self.register_to_config(latent_dim_scale=latent_dim_scale)
def prepare_latents(self, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler): def prepare_latents(
batch_size, channels, height, width = image_embeddings.shape self, batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler
):
_, channels, height, width = image_embeddings.shape
latents_shape = ( latents_shape = (
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
4, 4,
@@ -383,7 +385,19 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
) )
if isinstance(image_embeddings, list): if isinstance(image_embeddings, list):
image_embeddings = torch.cat(image_embeddings, dim=0) image_embeddings = torch.cat(image_embeddings, dim=0)
batch_size = image_embeddings.shape[0]
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Compute the effective number of images per prompt
# We must account for the fact that the image embeddings from the prior can be generated with num_images_per_prompt > 1
# This results in a case where a single prompt is associated with multiple image embeddings
# Divide the number of image embeddings by the batch size to determine if this is the case.
num_images_per_prompt = num_images_per_prompt * (image_embeddings.shape[0] // batch_size)
# 2. Encode caption # 2. Encode caption
if prompt_embeds is None and negative_prompt_embeds is None: if prompt_embeds is None and negative_prompt_embeds is None:
@@ -417,7 +431,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
# 5. Prepare latents # 5. Prepare latents
latents = self.prepare_latents( latents = self.prepare_latents(
image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
) )
# 6. Run denoising loop # 6. Run denoising loop
@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
slow, slow,
torch_device, torch_device,
) )
from diffusers.utils.torch_utils import randn_tensor
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin
@@ -246,6 +247,66 @@ class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCa
assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5 assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings(self):
device = "cpu"
components = self.get_dummy_components()
pipe = StableCascadeDecoderPipeline(**components)
pipe.set_progress_bar_config(disable=None)
prior_num_images_per_prompt = 2
decoder_num_images_per_prompt = 2
prompt = ["a cat"]
batch_size = len(prompt)
generator = torch.Generator(device)
image_embeddings = randn_tensor(
(batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0)
)
decoder_output = pipe(
image_embeddings=image_embeddings,
prompt=prompt,
num_inference_steps=1,
output_type="np",
guidance_scale=0.0,
generator=generator.manual_seed(0),
num_images_per_prompt=decoder_num_images_per_prompt,
)
assert decoder_output.images.shape[0] == (
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
)
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings_with_guidance(self):
device = "cpu"
components = self.get_dummy_components()
pipe = StableCascadeDecoderPipeline(**components)
pipe.set_progress_bar_config(disable=None)
prior_num_images_per_prompt = 2
decoder_num_images_per_prompt = 2
prompt = ["a cat"]
batch_size = len(prompt)
generator = torch.Generator(device)
image_embeddings = randn_tensor(
(batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0)
)
decoder_output = pipe(
image_embeddings=image_embeddings,
prompt=prompt,
num_inference_steps=1,
output_type="np",
guidance_scale=2.0,
generator=generator.manual_seed(0),
num_images_per_prompt=decoder_num_images_per_prompt,
)
assert decoder_output.images.shape[0] == (
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
)
@slow @slow
@require_torch_gpu @require_torch_gpu