|
|
|
@@ -19,6 +19,9 @@ from diffusers.utils.testing_utils import require_torch, torch_device
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ALLOWED_REQUIRED_ARGS = ["source_prompt", "prompt", "image", "mask_image", "example_image"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@require_torch
|
|
|
|
|
class PipelineTesterMixin:
|
|
|
|
|
"""
|
|
|
|
@@ -115,10 +118,138 @@ class PipelineTesterMixin:
|
|
|
|
|
self.assertLess(max_diff, 1e-5)
|
|
|
|
|
|
|
|
|
|
def test_pipeline_call_implements_required_args(self):
|
|
|
|
|
required_args = ["num_inference_steps", "generator", "return_dict"]
|
|
|
|
|
assert hasattr(self.pipeline_class, "__call__"), f"{self.pipeline_class} should have a `__call__` method"
|
|
|
|
|
parameters = inspect.signature(self.pipeline_class.__call__).parameters
|
|
|
|
|
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
|
|
|
|
required_parameters.pop("self")
|
|
|
|
|
required_parameters = set(required_parameters)
|
|
|
|
|
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
|
|
|
|
|
|
|
|
|
for arg in required_args:
|
|
|
|
|
self.assertTrue(arg in inspect.signature(self.pipeline_class.__call__).parameters)
|
|
|
|
|
for param in required_parameters:
|
|
|
|
|
assert param in ALLOWED_REQUIRED_ARGS
|
|
|
|
|
|
|
|
|
|
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
|
|
|
|
|
|
|
|
|
required_optional_params = ["generator", "num_inference_steps", "return_dict"]
|
|
|
|
|
for param in required_optional_params:
|
|
|
|
|
assert param in optional_parameters
|
|
|
|
|
|
|
|
|
|
def test_inference_batch_image_pil_torch(self):
|
|
|
|
|
inputs = self.get_dummy_inputs(torch_device)
|
|
|
|
|
|
|
|
|
|
allowed_image_args = [v for v in ALLOWED_REQUIRED_ARGS if v != "prompt"]
|
|
|
|
|
|
|
|
|
|
if set(allowed_image_args) - set(inputs.keys()) == set(allowed_image_args):
|
|
|
|
|
# pipeline has no allowed required image args, so no need to test
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
components = self.get_dummy_components()
|
|
|
|
|
pipe = self.pipeline_class(**components)
|
|
|
|
|
pipe.to(torch_device)
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
|
|
|
|
|
# batchify inputs
|
|
|
|
|
for batch_size in [2, 4, 13]:
|
|
|
|
|
batched_inputs = {}
|
|
|
|
|
for name, value in inputs.items():
|
|
|
|
|
if name in allowed_image_args:
|
|
|
|
|
batched_inputs[name] = batch_size * [value]
|
|
|
|
|
else:
|
|
|
|
|
batched_inputs[name] = value
|
|
|
|
|
|
|
|
|
|
batched_inputs["num_inference_steps"] = 2
|
|
|
|
|
batched_inputs["output_type"] = "np"
|
|
|
|
|
batched_inputs["generator"] = torch.Generator(torch_device).manual_seed(33)
|
|
|
|
|
output = pipe(**batched_inputs)
|
|
|
|
|
|
|
|
|
|
for name in allowed_image_args:
|
|
|
|
|
# convert pil to torch
|
|
|
|
|
if name in batched_inputs:
|
|
|
|
|
batched_inputs = torch.tensor(pipe.pil_to_numpy(value), dtype=torch.float32, device=torch_device)
|
|
|
|
|
batched_inputs["num_inference_steps"] = 2
|
|
|
|
|
batched_inputs["output_type"] = "np"
|
|
|
|
|
|
|
|
|
|
batched_inputs["generator"] = torch.Generator(torch_device).manual_seed(33)
|
|
|
|
|
output_torch_image = pipe(**batched_inputs)
|
|
|
|
|
|
|
|
|
|
max_diff = np.abs(output - output_torch_image).max()
|
|
|
|
|
self.assertLess(max_diff, 1e-4)
|
|
|
|
|
|
|
|
|
|
def test_inference_batch_consistent(self):
|
|
|
|
|
components = self.get_dummy_components()
|
|
|
|
|
pipe = self.pipeline_class(**components)
|
|
|
|
|
pipe.to(torch_device)
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
|
|
|
|
|
inputs = self.get_dummy_inputs(torch_device)
|
|
|
|
|
|
|
|
|
|
# batchify inputs
|
|
|
|
|
for batch_size in [2, 4, 13]:
|
|
|
|
|
batched_inputs = {}
|
|
|
|
|
for name, value in inputs.items():
|
|
|
|
|
if name in ALLOWED_REQUIRED_ARGS:
|
|
|
|
|
# prompt is string
|
|
|
|
|
if name == "prompt":
|
|
|
|
|
len_prompt = len(value)
|
|
|
|
|
# make unequal batch sizes
|
|
|
|
|
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
|
|
|
|
# or else we have images
|
|
|
|
|
else:
|
|
|
|
|
batched_inputs[name] = batch_size * [value]
|
|
|
|
|
else:
|
|
|
|
|
batched_inputs[name] = value
|
|
|
|
|
|
|
|
|
|
batched_inputs["num_inference_steps"] = 2
|
|
|
|
|
batched_inputs["output_type"] = None
|
|
|
|
|
output = pipe(**batched_inputs)
|
|
|
|
|
|
|
|
|
|
assert len(output[0]) == batch_size
|
|
|
|
|
|
|
|
|
|
batched_inputs["output_type"] = "np"
|
|
|
|
|
output = pipe(**batched_inputs)[0]
|
|
|
|
|
|
|
|
|
|
assert output.shape[0] == batch_size
|
|
|
|
|
|
|
|
|
|
def test_inference_generator_equality(self):
|
|
|
|
|
components = self.get_dummy_components()
|
|
|
|
|
pipe = self.pipeline_class(**components)
|
|
|
|
|
pipe.to(torch_device)
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
|
|
|
|
|
inputs = self.get_dummy_inputs(torch_device)
|
|
|
|
|
|
|
|
|
|
# batchify inputs
|
|
|
|
|
batched_inputs = {}
|
|
|
|
|
batch_size = 2
|
|
|
|
|
for name, value in inputs.items():
|
|
|
|
|
if name in ALLOWED_REQUIRED_ARGS:
|
|
|
|
|
# prompt is string
|
|
|
|
|
if name == "prompt":
|
|
|
|
|
len_prompt = len(value)
|
|
|
|
|
# make unequal batch sizes
|
|
|
|
|
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
|
|
|
|
# or else we have images
|
|
|
|
|
else:
|
|
|
|
|
batched_inputs[name] = batch_size * [value]
|
|
|
|
|
else:
|
|
|
|
|
batched_inputs[name] = value
|
|
|
|
|
|
|
|
|
|
batched_inputs["num_inference_steps"] = 2
|
|
|
|
|
batched_inputs["output_type"] = "np"
|
|
|
|
|
seeds = [0, 44] # make sure length is equal to batch_size
|
|
|
|
|
generators = [torch.Generator(device=torch_device).manual_seed(s) for s in seeds]
|
|
|
|
|
batched_inputs["generator"] = generators
|
|
|
|
|
|
|
|
|
|
output_batch = pipe(**batched_inputs)[0]
|
|
|
|
|
|
|
|
|
|
generators = [torch.Generator(device=torch_device).manual_seed(s) for s in seeds]
|
|
|
|
|
for i, seed in enumerate(seeds):
|
|
|
|
|
inputs = {k: v[i] if isinstance(v, list) else v for k, v in batched_inputs.items()}
|
|
|
|
|
inputs["generator"] = generators[i]
|
|
|
|
|
|
|
|
|
|
output_single = pipe(**inputs)[0]
|
|
|
|
|
|
|
|
|
|
assert np.abs(output_single - output_batch[i]).sum() < 1e-4
|
|
|
|
|
|
|
|
|
|
def test_num_inference_steps_consistent(self):
|
|
|
|
|
components = self.get_dummy_components()
|
|
|
|
|