Disable test_ddpm_ddim_equality_batched until resolved (#142)
disable test_ddpm_ddim_equality_batched
This commit is contained in:
@@ -894,10 +894,10 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||||||
generator = torch.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"]
|
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"]
|
||||||
|
|
||||||
# the values aren't exactly equal, but the images look the same upon visual inspection
|
# the values aren't exactly equal, but the images look the same visually
|
||||||
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
|
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
|
||||||
|
|
||||||
@slow
|
@unittest.skip("(Anton) The test is failing for large batch sizes, needs investigation")
|
||||||
def test_ddpm_ddim_equality_batched(self):
|
def test_ddpm_ddim_equality_batched(self):
|
||||||
model_id = "google/ddpm-cifar10-32"
|
model_id = "google/ddpm-cifar10-32"
|
||||||
|
|
||||||
@@ -909,12 +909,12 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||||||
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
|
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
|
||||||
|
|
||||||
generator = torch.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
ddpm_images = ddpm(batch_size=2, generator=generator, output_type="numpy")["sample"]
|
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"]
|
||||||
|
|
||||||
generator = torch.manual_seed(0)
|
generator = torch.manual_seed(0)
|
||||||
ddim_images = ddim(batch_size=2, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
|
ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
|
||||||
"sample"
|
"sample"
|
||||||
]
|
]
|
||||||
|
|
||||||
# the values aren't exactly equal, but the images look the same upon visual inspection
|
# the values aren't exactly equal, but the images look the same visually
|
||||||
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
|
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
|
||||||
|
|||||||
Reference in New Issue
Block a user