This commit is contained in:
sayakpaul
2025-09-28 16:18:35 +05:30
parent a9d50c8f2a
commit 1185f82450
4 changed files with 20 additions and 10 deletions
@@ -308,7 +308,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
return prompt_embeds, prompt_embeds_mask
def check_inputs(
@@ -309,6 +309,7 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
"""
print(f"{image[0].size=}")
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -322,7 +323,7 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
print(f"{prompt_embeds.shape=}, {prompt_embeds_mask.shape=}")
return prompt_embeds, prompt_embeds_mask
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
@@ -133,15 +133,17 @@ class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
else:
generator = torch.Generator(device=device).manual_seed(seed)
# Even if we specify smaller dimensions for the images, it won't work because of how
# the internal implementation enforces a minimal resolution of 1024x1024.
inputs = {
"prompt": "dance monkey",
"image": Image.new("RGB", (32, 32)),
"image": Image.new("RGB", (1024, 1024)),
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"true_cfg_scale": 1.0,
"height": 32,
"width": 32,
"height": 1024,
"width": 1024,
"max_sequence_length": 16,
"output_type": "pt",
}
@@ -240,5 +242,8 @@ class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_encode_prompt_works_in_isolation(
self, extra_required_param_value_dict=None, keep_params=None, atol=1e-4, rtol=1e-4
):
keep_params = ["image"]
# We include `image` because it's needed in both `encode_prompt` and some other subsequent calculations.
# `max_sequence_length` to maintain parity between its value during all invokations of `encode_prompt`
# in the following test.
keep_params = ["image", "max_sequence_length"]
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, keep_params, atol, rtol)
@@ -134,7 +134,9 @@ class QwenImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = Image.new("RGB", (32, 32))
# Even if we specify smaller dimensions for the images, it won't work because of how
# the internal implementation enforces a minimal resolution of 384*384.
image = Image.new("RGB", (384, 384))
inputs = {
"prompt": "dance monkey",
"image": [image, image],
@@ -142,8 +144,8 @@ class QwenImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
"generator": generator,
"num_inference_steps": 2,
"true_cfg_scale": 1.0,
"height": 32,
"width": 32,
"height": 384,
"width": 384,
"max_sequence_length": 16,
"output_type": "pt",
}
@@ -239,7 +241,10 @@ class QwenImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
def test_encode_prompt_works_in_isolation(
self, extra_required_param_value_dict=None, keep_params=None, atol=1e-4, rtol=1e-4
):
keep_params = ["image"]
# We include `image` because it's needed in both `encode_prompt` and some other subsequent calculations.
# `max_sequence_length` to maintain parity between its value during all invokations of `encode_prompt`
# in the following test.
keep_params = ["image", "max_sequence_length"]
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, keep_params, atol, rtol)
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)