[PixArt-Alpha] Fix PixArt-Alpha pipeline when number of images to generate is more than 1 (#5752)
* does this fix things? * attention mask use * attention mask order * better masking. * add: tesrt * remove mask_featur * test * debug * fix: tests * deprecate mask_feature * add deprecation test * add slow test * add print statements to retrieve the assertion values. * fix for the 1024 fast tes * fix tesy * fix the remaining * Apply suggestions from code review * more debug --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from ...models import AutoencoderKL, Transformer2DModel
|
|||||||
from ...schedulers import DPMSolverMultistepScheduler
|
from ...schedulers import DPMSolverMultistepScheduler
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
BACKENDS_MAPPING,
|
BACKENDS_MAPPING,
|
||||||
|
deprecate,
|
||||||
is_bs4_available,
|
is_bs4_available,
|
||||||
is_ftfy_available,
|
is_ftfy_available,
|
||||||
logging,
|
logging,
|
||||||
@@ -162,8 +163,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
clean_caption: bool = False,
|
clean_caption: bool = False,
|
||||||
mask_feature: bool = True,
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Encodes the prompt into text encoder hidden states.
|
Encodes the prompt into text encoder hidden states.
|
||||||
@@ -189,10 +192,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
string.
|
string.
|
||||||
clean_caption (bool, defaults to `False`):
|
clean_caption (bool, defaults to `False`):
|
||||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||||
mask_feature: (bool, defaults to `True`):
|
|
||||||
If `True`, the function will mask the text embeddings.
|
|
||||||
"""
|
"""
|
||||||
embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
|
|
||||||
|
if "mask_feature" in kwargs:
|
||||||
|
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
|
||||||
|
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = self._execution_device
|
device = self._execution_device
|
||||||
@@ -229,13 +233,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
f" {max_length} tokens: {removed_text}"
|
f" {max_length} tokens: {removed_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
attention_mask = text_inputs.attention_mask.to(device)
|
prompt_attention_mask = text_inputs.attention_mask
|
||||||
prompt_embeds_attention_mask = attention_mask
|
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||||
|
|
||||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
||||||
prompt_embeds = prompt_embeds[0]
|
prompt_embeds = prompt_embeds[0]
|
||||||
else:
|
|
||||||
prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
|
|
||||||
|
|
||||||
if self.text_encoder is not None:
|
if self.text_encoder is not None:
|
||||||
dtype = self.text_encoder.dtype
|
dtype = self.text_encoder.dtype
|
||||||
@@ -250,8 +252,8 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||||
prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
|
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
|
||||||
prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)
|
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||||
|
|
||||||
# get unconditional embeddings for classifier free guidance
|
# get unconditional embeddings for classifier free guidance
|
||||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||||
@@ -267,11 +269,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
attention_mask = uncond_input.attention_mask.to(device)
|
negative_prompt_attention_mask = uncond_input.attention_mask
|
||||||
|
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
||||||
|
|
||||||
negative_prompt_embeds = self.text_encoder(
|
negative_prompt_embeds = self.text_encoder(
|
||||||
uncond_input.input_ids.to(device),
|
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
|
||||||
attention_mask=attention_mask,
|
|
||||||
)
|
)
|
||||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||||
|
|
||||||
@@ -284,23 +286,13 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
# For classifier free guidance, we need to do two forward passes.
|
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
|
||||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||||
# to avoid doing two forward passes
|
|
||||||
else:
|
else:
|
||||||
negative_prompt_embeds = None
|
negative_prompt_embeds = None
|
||||||
|
negative_prompt_attention_mask = None
|
||||||
|
|
||||||
# Perform additional masking.
|
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||||
if mask_feature and not embeds_initially_provided:
|
|
||||||
prompt_embeds = prompt_embeds.unsqueeze(1)
|
|
||||||
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
|
|
||||||
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
|
|
||||||
masked_negative_prompt_embeds = (
|
|
||||||
negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
|
|
||||||
)
|
|
||||||
return masked_prompt_embeds, masked_negative_prompt_embeds
|
|
||||||
|
|
||||||
return prompt_embeds, negative_prompt_embeds
|
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||||
def prepare_extra_step_kwargs(self, generator, eta):
|
def prepare_extra_step_kwargs(self, generator, eta):
|
||||||
@@ -329,6 +321,8 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
callback_steps,
|
callback_steps,
|
||||||
prompt_embeds=None,
|
prompt_embeds=None,
|
||||||
negative_prompt_embeds=None,
|
negative_prompt_embeds=None,
|
||||||
|
prompt_attention_mask=None,
|
||||||
|
negative_prompt_attention_mask=None,
|
||||||
):
|
):
|
||||||
if height % 8 != 0 or width % 8 != 0:
|
if height % 8 != 0 or width % 8 != 0:
|
||||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||||
@@ -365,6 +359,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||||
|
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||||
|
|
||||||
|
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||||
|
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||||
|
|
||||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -372,6 +372,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||||
f" {negative_prompt_embeds.shape}."
|
f" {negative_prompt_embeds.shape}."
|
||||||
)
|
)
|
||||||
|
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||||
|
raise ValueError(
|
||||||
|
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||||
|
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||||
|
f" {negative_prompt_attention_mask.shape}."
|
||||||
|
)
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||||
def _text_preprocessing(self, text, clean_caption=False):
|
def _text_preprocessing(self, text, clean_caption=False):
|
||||||
@@ -579,14 +585,16 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
latents: Optional[torch.FloatTensor] = None,
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
output_type: Optional[str] = "pil",
|
output_type: Optional[str] = "pil",
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||||
callback_steps: int = 1,
|
callback_steps: int = 1,
|
||||||
clean_caption: bool = True,
|
clean_caption: bool = True,
|
||||||
mask_feature: bool = True,
|
|
||||||
use_resolution_binning: bool = True,
|
use_resolution_binning: bool = True,
|
||||||
|
**kwargs,
|
||||||
) -> Union[ImagePipelineOutput, Tuple]:
|
) -> Union[ImagePipelineOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@@ -630,9 +638,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
provided, text embeddings will be generated from `prompt` input argument.
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
|
||||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
|
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
|
||||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||||
|
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated attention mask for negative text embeddings.
|
||||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
The output format of the generate image. Choose between
|
The output format of the generate image. Choose between
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
@@ -648,11 +659,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||||
prompt.
|
prompt.
|
||||||
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
|
use_resolution_binning (`bool` defaults to `True`):
|
||||||
use_resolution_binning:
|
If set to `True`, the requested height and width are first mapped to the closest resolutions using
|
||||||
(`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the
|
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
|
||||||
closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images,
|
the requested resolution. Useful for generating non-square images.
|
||||||
they are resized back to the requested resolution. Useful for generating non-square images.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -661,6 +671,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
||||||
returned where the first element is a list with the generated images
|
returned where the first element is a list with the generated images
|
||||||
"""
|
"""
|
||||||
|
if "mask_feature" in kwargs:
|
||||||
|
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
|
||||||
|
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
# 1. Check inputs. Raise error if not correct
|
# 1. Check inputs. Raise error if not correct
|
||||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
||||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||||
@@ -669,7 +682,15 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN)
|
height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN)
|
||||||
|
|
||||||
self.check_inputs(
|
self.check_inputs(
|
||||||
prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds
|
prompt,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
negative_prompt,
|
||||||
|
callback_steps,
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
prompt_attention_mask,
|
||||||
|
negative_prompt_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Default height and width to transformer
|
# 2. Default height and width to transformer
|
||||||
@@ -688,7 +709,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
# 3. Encode input prompt
|
# 3. Encode input prompt
|
||||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
(
|
||||||
|
prompt_embeds,
|
||||||
|
prompt_attention_mask,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
negative_prompt_attention_mask,
|
||||||
|
) = self.encode_prompt(
|
||||||
prompt,
|
prompt,
|
||||||
do_classifier_free_guidance,
|
do_classifier_free_guidance,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
@@ -696,11 +722,13 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
device=device,
|
device=device,
|
||||||
prompt_embeds=prompt_embeds,
|
prompt_embeds=prompt_embeds,
|
||||||
negative_prompt_embeds=negative_prompt_embeds,
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
prompt_attention_mask=prompt_attention_mask,
|
||||||
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||||
clean_caption=clean_caption,
|
clean_caption=clean_caption,
|
||||||
mask_feature=mask_feature,
|
|
||||||
)
|
)
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||||
|
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||||
|
|
||||||
# 4. Prepare timesteps
|
# 4. Prepare timesteps
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
@@ -758,6 +786,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|||||||
noise_pred = self.transformer(
|
noise_pred = self.transformer(
|
||||||
latent_model_input,
|
latent_model_input,
|
||||||
encoder_hidden_states=prompt_embeds,
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
encoder_attention_mask=prompt_attention_mask,
|
||||||
timestep=current_timestep,
|
timestep=current_timestep,
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
|
|||||||
@@ -111,13 +111,20 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
num_inference_steps = inputs["num_inference_steps"]
|
num_inference_steps = inputs["num_inference_steps"]
|
||||||
output_type = inputs["output_type"]
|
output_type = inputs["output_type"]
|
||||||
|
|
||||||
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt, mask_feature=False)
|
(
|
||||||
|
prompt_embeds,
|
||||||
|
prompt_attention_mask,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
negative_prompt_attention_mask,
|
||||||
|
) = pipe.encode_prompt(prompt)
|
||||||
|
|
||||||
# inputs with prompt converted to embeddings
|
# inputs with prompt converted to embeddings
|
||||||
inputs = {
|
inputs = {
|
||||||
"prompt_embeds": prompt_embeds,
|
"prompt_embeds": prompt_embeds,
|
||||||
|
"prompt_attention_mask": prompt_attention_mask,
|
||||||
"negative_prompt": None,
|
"negative_prompt": None,
|
||||||
"negative_prompt_embeds": negative_prompt_embeds,
|
"negative_prompt_embeds": negative_prompt_embeds,
|
||||||
|
"negative_prompt_attention_mask": negative_prompt_attention_mask,
|
||||||
"generator": generator,
|
"generator": generator,
|
||||||
"num_inference_steps": num_inference_steps,
|
"num_inference_steps": num_inference_steps,
|
||||||
"output_type": output_type,
|
"output_type": output_type,
|
||||||
@@ -151,8 +158,10 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
# inputs with prompt converted to embeddings
|
# inputs with prompt converted to embeddings
|
||||||
inputs = {
|
inputs = {
|
||||||
"prompt_embeds": prompt_embeds,
|
"prompt_embeds": prompt_embeds,
|
||||||
|
"prompt_attention_mask": prompt_attention_mask,
|
||||||
"negative_prompt": None,
|
"negative_prompt": None,
|
||||||
"negative_prompt_embeds": negative_prompt_embeds,
|
"negative_prompt_embeds": negative_prompt_embeds,
|
||||||
|
"negative_prompt_attention_mask": negative_prompt_attention_mask,
|
||||||
"generator": generator,
|
"generator": generator,
|
||||||
"num_inference_steps": num_inference_steps,
|
"num_inference_steps": num_inference_steps,
|
||||||
"output_type": output_type,
|
"output_type": output_type,
|
||||||
@@ -211,13 +220,15 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
num_inference_steps = inputs["num_inference_steps"]
|
num_inference_steps = inputs["num_inference_steps"]
|
||||||
output_type = inputs["output_type"]
|
output_type = inputs["output_type"]
|
||||||
|
|
||||||
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt)
|
prompt_embeds, prompt_attn_mask, negative_prompt_embeds, neg_prompt_attn_mask = pipe.encode_prompt(prompt)
|
||||||
|
|
||||||
# inputs with prompt converted to embeddings
|
# inputs with prompt converted to embeddings
|
||||||
inputs = {
|
inputs = {
|
||||||
"prompt_embeds": prompt_embeds,
|
"prompt_embeds": prompt_embeds,
|
||||||
|
"prompt_attention_mask": prompt_attn_mask,
|
||||||
"negative_prompt": None,
|
"negative_prompt": None,
|
||||||
"negative_prompt_embeds": negative_prompt_embeds,
|
"negative_prompt_embeds": negative_prompt_embeds,
|
||||||
|
"negative_prompt_attention_mask": neg_prompt_attn_mask,
|
||||||
"generator": generator,
|
"generator": generator,
|
||||||
"num_inference_steps": num_inference_steps,
|
"num_inference_steps": num_inference_steps,
|
||||||
"output_type": output_type,
|
"output_type": output_type,
|
||||||
@@ -252,8 +263,10 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
# inputs with prompt converted to embeddings
|
# inputs with prompt converted to embeddings
|
||||||
inputs = {
|
inputs = {
|
||||||
"prompt_embeds": prompt_embeds,
|
"prompt_embeds": prompt_embeds,
|
||||||
|
"prompt_attention_mask": prompt_attn_mask,
|
||||||
"negative_prompt": None,
|
"negative_prompt": None,
|
||||||
"negative_prompt_embeds": negative_prompt_embeds,
|
"negative_prompt_embeds": negative_prompt_embeds,
|
||||||
|
"negative_prompt_attention_mask": neg_prompt_attn_mask,
|
||||||
"generator": generator,
|
"generator": generator,
|
||||||
"num_inference_steps": num_inference_steps,
|
"num_inference_steps": num_inference_steps,
|
||||||
"output_type": output_type,
|
"output_type": output_type,
|
||||||
@@ -266,6 +279,40 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||||
self.assertLess(max_diff, 1e-4)
|
self.assertLess(max_diff, 1e-4)
|
||||||
|
|
||||||
|
def test_inference_with_multiple_images_per_prompt(self):
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
inputs["num_images_per_prompt"] = 2
|
||||||
|
image = pipe(**inputs).images
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
self.assertEqual(image.shape, (2, 8, 8, 3))
|
||||||
|
expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675])
|
||||||
|
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3)
|
||||||
|
|
||||||
|
def test_raises_warning_for_mask_feature(self):
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
inputs.update({"mask_feature": True})
|
||||||
|
|
||||||
|
with self.assertWarns(FutureWarning) as warning_ctx:
|
||||||
|
_ = pipe(**inputs).images
|
||||||
|
|
||||||
|
assert "mask_feature" in str(warning_ctx.warning)
|
||||||
|
|
||||||
def test_inference_batch_single_identical(self):
|
def test_inference_batch_single_identical(self):
|
||||||
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
|
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
|
||||||
|
|
||||||
@@ -290,7 +337,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1323])
|
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
||||||
|
|
||||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||||
self.assertLessEqual(max_diff, 1e-3)
|
self.assertLessEqual(max_diff, 1e-3)
|
||||||
@@ -307,7 +354,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0266])
|
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
||||||
|
|
||||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||||
self.assertLessEqual(max_diff, 1e-3)
|
self.assertLessEqual(max_diff, 1e-3)
|
||||||
@@ -323,7 +370,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
expected_slice = np.array([0.1501, 0.1755, 0.1877, 0.1445, 0.1665, 0.1763, 0.1389, 0.176, 0.2031])
|
expected_slice = np.array([0.1941, 0.2117, 0.2188, 0.1946, 0.218, 0.2124, 0.199, 0.2437, 0.2583])
|
||||||
|
|
||||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||||
self.assertLessEqual(max_diff, 1e-3)
|
self.assertLessEqual(max_diff, 1e-3)
|
||||||
@@ -340,7 +387,26 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
expected_slice = np.array([0.2515, 0.2593, 0.2593, 0.2544, 0.2759, 0.2788, 0.2812, 0.3169, 0.332])
|
expected_slice = np.array([0.2637, 0.291, 0.2939, 0.207, 0.2512, 0.2783, 0.2168, 0.2324, 0.2817])
|
||||||
|
|
||||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||||
self.assertLessEqual(max_diff, 1e-3)
|
self.assertLessEqual(max_diff, 1e-3)
|
||||||
|
|
||||||
|
def test_pixart_1024_without_resolution_binning(self):
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
|
||||||
|
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
prompt = "A small cactus with a happy face in the Sahara desert."
|
||||||
|
|
||||||
|
image = pipe(prompt, generator=generator, num_inference_steps=5, output_type="np").images
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
no_res_bin_image = pipe(
|
||||||
|
prompt, generator=generator, num_inference_steps=5, output_type="np", use_resolution_binning=False
|
||||||
|
).images
|
||||||
|
no_res_bin_image_slice = no_res_bin_image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert not np.allclose(image_slice, no_res_bin_image_slice, atol=1e-4, rtol=1e-4)
|
||||||
|
|||||||
Reference in New Issue
Block a user