fix prompt isolation test.

This commit is contained in:
sayakpaul
2025-09-26 18:50:26 +05:30
parent ec5449f3a1
commit f82c1523e5
3 changed files with 39 additions and 31 deletions
@@ -15,7 +15,6 @@
import unittest
import numpy as np
import pytest
import torch
from PIL import Image
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
@@ -238,6 +237,8 @@ class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"VAE tiling should not affect the inference results",
)
@pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True)
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)
def test_encode_prompt_works_in_isolation(
self, extra_required_param_value_dict=None, keep_params=None, atol=1e-4, rtol=1e-4
):
keep_params = ["image"]
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, keep_params, atol, rtol)
@@ -236,9 +236,11 @@ class QwenImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
"VAE tiling should not affect the inference results",
)
@pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True)
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)
def test_encode_prompt_works_in_isolation(
self, extra_required_param_value_dict=None, keep_params=None, atol=1e-4, rtol=1e-4
):
keep_params = ["image"]
super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, keep_params, atol, rtol)
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_num_images_per_prompt():
+29 -24
View File
@@ -5,7 +5,7 @@ import os
import tempfile
import unittest
import uuid
from typing import Any, Callable, Dict, Union
from typing import Any, Callable, Dict, Optional, Union
import numpy as np
import PIL.Image
@@ -2069,20 +2069,26 @@ class PipelineTesterMixin:
assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception)
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
def test_encode_prompt_works_in_isolation(
self,
extra_required_param_value_dict: Optional[dict] = None,
keep_params: Optional[list] = None,
atol=1e-4,
rtol=1e-4,
):
if not hasattr(self.pipeline_class, "encode_prompt"):
return
components = self.get_dummy_components()
def _contains_text_key(name):
return any(token in name for token in ("text", "tokenizer", "processor"))
# We initialize the pipeline with only text encoders and tokenizers,
# mimicking a real-world scenario.
components_with_text_encoders = {}
for k in components:
if "text" in k or "tokenizer" in k:
components_with_text_encoders[k] = components[k]
else:
components_with_text_encoders[k] = None
components_with_text_encoders = {
name: component if _contains_text_key(name) else None for name, component in components.items()
}
pipe_with_just_text_encoder = self.pipeline_class(**components_with_text_encoders)
pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device)
@@ -2092,17 +2098,19 @@ class PipelineTesterMixin:
encode_prompt_parameters = list(encode_prompt_signature.parameters.values())
# Required args in encode_prompt with those with no default.
required_params = []
for param in encode_prompt_parameters:
if param.name == "self" or param.name == "kwargs":
continue
if param.default is inspect.Parameter.empty:
required_params.append(param.name)
required_params = [
param.name
for param in encode_prompt_parameters
if param.name not in {"self", "kwargs"} and param.default is inspect.Parameter.empty
]
# Craft inputs for the `encode_prompt()` method to run in isolation.
encode_prompt_param_names = [p.name for p in encode_prompt_parameters if p.name != "self"]
input_keys = list(inputs.keys())
encode_prompt_inputs = {k: inputs.pop(k) for k in input_keys if k in encode_prompt_param_names}
encode_prompt_inputs = {name: inputs[name] for name in encode_prompt_param_names if name in inputs}
if keep_params:
for name in encode_prompt_param_names:
if name in inputs and name not in keep_params:
inputs.pop(name)
pipe_call_signature = inspect.signature(pipe_with_just_text_encoder.__call__)
pipe_call_parameters = pipe_call_signature.parameters
@@ -2137,18 +2145,15 @@ class PipelineTesterMixin:
# Pack the outputs of `encode_prompt`.
adapted_prompt_embeds_kwargs = {
k: prompt_embeds_kwargs.pop(k) for k in list(prompt_embeds_kwargs.keys()) if k in pipe_call_parameters
name: prompt_embeds_kwargs[name] for name in prompt_embeds_kwargs if name in pipe_call_parameters
}
# now initialize a pipeline without text encoders and compute outputs with the
# `encode_prompt()` outputs and other relevant inputs.
components_with_text_encoders = {}
for k in components:
if "text" in k or "tokenizer" in k:
components_with_text_encoders[k] = None
else:
components_with_text_encoders[k] = components[k]
pipe_without_text_encoders = self.pipeline_class(**components_with_text_encoders).to(torch_device)
components_without_text_encoders = {
name: None if _contains_text_key(name) else component for name, component in components.items()
}
pipe_without_text_encoders = self.pipeline_class(**components_without_text_encoders).to(torch_device)
# Set `negative_prompt` to None as we have already calculated its embeds
# if it was present in `inputs`. This is because otherwise we will interfere wrongly