Fix issue with prompt embeds and latents in SD Cascade Decoder with multiple image embeddings for a single prompt. (#7381)
* fix * update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -100,8 +100,10 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
self.register_to_config(latent_dim_scale=latent_dim_scale)
|
self.register_to_config(latent_dim_scale=latent_dim_scale)
|
||||||
|
|
||||||
def prepare_latents(self, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler):
|
def prepare_latents(
|
||||||
batch_size, channels, height, width = image_embeddings.shape
|
self, batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler
|
||||||
|
):
|
||||||
|
_, channels, height, width = image_embeddings.shape
|
||||||
latents_shape = (
|
latents_shape = (
|
||||||
batch_size * num_images_per_prompt,
|
batch_size * num_images_per_prompt,
|
||||||
4,
|
4,
|
||||||
@@ -383,7 +385,19 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
if isinstance(image_embeddings, list):
|
if isinstance(image_embeddings, list):
|
||||||
image_embeddings = torch.cat(image_embeddings, dim=0)
|
image_embeddings = torch.cat(image_embeddings, dim=0)
|
||||||
batch_size = image_embeddings.shape[0]
|
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
# Compute the effective number of images per prompt
|
||||||
|
# We must account for the fact that the image embeddings from the prior can be generated with num_images_per_prompt > 1
|
||||||
|
# This results in a case where a single prompt is associated with multiple image embeddings
|
||||||
|
# Divide the number of image embeddings by the batch size to determine if this is the case.
|
||||||
|
num_images_per_prompt = num_images_per_prompt * (image_embeddings.shape[0] // batch_size)
|
||||||
|
|
||||||
# 2. Encode caption
|
# 2. Encode caption
|
||||||
if prompt_embeds is None and negative_prompt_embeds is None:
|
if prompt_embeds is None and negative_prompt_embeds is None:
|
||||||
@@ -417,7 +431,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
# 5. Prepare latents
|
# 5. Prepare latents
|
||||||
latents = self.prepare_latents(
|
latents = self.prepare_latents(
|
||||||
image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. Run denoising loop
|
# 6. Run denoising loop
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
|
|
||||||
from ..test_pipelines_common import PipelineTesterMixin
|
from ..test_pipelines_common import PipelineTesterMixin
|
||||||
|
|
||||||
@@ -246,6 +247,66 @@ class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
|||||||
|
|
||||||
assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
|
assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
|
||||||
|
|
||||||
|
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings(self):
|
||||||
|
device = "cpu"
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
|
||||||
|
pipe = StableCascadeDecoderPipeline(**components)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prior_num_images_per_prompt = 2
|
||||||
|
decoder_num_images_per_prompt = 2
|
||||||
|
prompt = ["a cat"]
|
||||||
|
batch_size = len(prompt)
|
||||||
|
|
||||||
|
generator = torch.Generator(device)
|
||||||
|
image_embeddings = randn_tensor(
|
||||||
|
(batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0)
|
||||||
|
)
|
||||||
|
decoder_output = pipe(
|
||||||
|
image_embeddings=image_embeddings,
|
||||||
|
prompt=prompt,
|
||||||
|
num_inference_steps=1,
|
||||||
|
output_type="np",
|
||||||
|
guidance_scale=0.0,
|
||||||
|
generator=generator.manual_seed(0),
|
||||||
|
num_images_per_prompt=decoder_num_images_per_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert decoder_output.images.shape[0] == (
|
||||||
|
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_stable_cascade_decoder_single_prompt_multiple_image_embeddings_with_guidance(self):
|
||||||
|
device = "cpu"
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
|
||||||
|
pipe = StableCascadeDecoderPipeline(**components)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prior_num_images_per_prompt = 2
|
||||||
|
decoder_num_images_per_prompt = 2
|
||||||
|
prompt = ["a cat"]
|
||||||
|
batch_size = len(prompt)
|
||||||
|
|
||||||
|
generator = torch.Generator(device)
|
||||||
|
image_embeddings = randn_tensor(
|
||||||
|
(batch_size * prior_num_images_per_prompt, 4, 4, 4), generator=generator.manual_seed(0)
|
||||||
|
)
|
||||||
|
decoder_output = pipe(
|
||||||
|
image_embeddings=image_embeddings,
|
||||||
|
prompt=prompt,
|
||||||
|
num_inference_steps=1,
|
||||||
|
output_type="np",
|
||||||
|
guidance_scale=2.0,
|
||||||
|
generator=generator.manual_seed(0),
|
||||||
|
num_images_per_prompt=decoder_num_images_per_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert decoder_output.images.shape[0] == (
|
||||||
|
batch_size * prior_num_images_per_prompt * decoder_num_images_per_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
|||||||
Reference in New Issue
Block a user