Change fp16 error to warning (#764)

* Swap fp16 error to warning

Also remove the associated test

* Formatting

* warn -> warning

* Update src/diffusers/pipeline_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make style

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
apolinario
2022-10-07 10:31:52 +02:00
committed by GitHub
parent d3f1a4c0f0
commit fdfa7c8f15
2 changed files with 6 additions and 15 deletions
+6 -4
View File
@@ -169,10 +169,12 @@ class DiffusionPipeline(ConfigMixin):
module = getattr(self, name) module = getattr(self, name)
if isinstance(module, torch.nn.Module): if isinstance(module, torch.nn.Module):
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]: if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
raise ValueError( logger.warning(
"Pipelines loaded with `torch_dtype=torch.float16` cannot be moved to `cpu` or `mps` " "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It"
"due to the lack of support for `float16` operations on those devices in PyTorch. " " is not recommended to move them to `cpu` or `mps` as running them will fail. Please make"
"Please remove the `torch_dtype=torch.float16` argument, or use a `cuda` device." " sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for"
" `float16` operations on those devices in PyTorch. Please remove the"
" `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference."
) )
module.to(torch_device) module.to(torch_device)
return self return self
-11
View File
@@ -247,17 +247,6 @@ class PipelineFastTests(unittest.TestCase):
return extract return extract
def test_pipeline_fp16_cpu_error(self):
model = self.dummy_uncond_unet
scheduler = DDPMScheduler(num_train_timesteps=10)
pipe = DDIMPipeline(model.half(), scheduler)
if str(torch_device) in ["cpu", "mps"]:
self.assertRaises(ValueError, pipe.to, torch_device)
else:
# moving the pipeline to GPU should work
pipe.to(torch_device)
def test_ddim(self): def test_ddim(self):
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = DDIMScheduler() scheduler = DDIMScheduler()