From 54bc882d96db1d581f66883cde8f2a5c7eccb919 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 24 Feb 2023 15:19:53 +0100 Subject: [PATCH] `mps` test fixes (#2470) * Skip variant tests (UNet1d, UNetRL) on mps. mish op not yet supported. * Exclude a couple of panorama tests on mps They are too slow for fast CI. * Exclude mps panorama from more tests. * mps: exclude all fast panorama tests as they keep failing. --- tests/models/test_models_unet_1d.py | 8 ++++++++ .../stable_diffusion/test_stable_diffusion_panorama.py | 3 ++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index b494c231b5..17dd46496e 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -66,6 +66,10 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): def test_from_save_pretrained(self): super().test_from_save_pretrained() + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_save_pretrained_variant(self): + super().test_from_save_pretrained_variant() + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_model_from_pretrained(self): super().test_model_from_pretrained() @@ -186,6 +190,10 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): def test_from_save_pretrained(self): super().test_from_save_pretrained() + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_save_pretrained_variant(self): + super().test_from_save_pretrained_variant() + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_model_from_pretrained(self): super().test_model_from_pretrained() diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py index 72835953d1..366cd39da5 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py @@ -30,7 +30,7 @@ from diffusers import ( UNet2DConditionModel, ) from diffusers.utils import slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import require_torch_gpu, skip_mps from ...test_pipelines_common import PipelineTesterMixin @@ -38,6 +38,7 @@ from ...test_pipelines_common import PipelineTesterMixin torch.backends.cuda.matmul.allow_tf32 = False +@skip_mps class StableDiffusionPanoramaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionPanoramaPipeline