Change fp16 error to warning (#764)
* Swap fp16 error to warning Also remove the associated test * Formatting * warn -> warning * Update src/diffusers/pipeline_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * make style Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -169,10 +169,12 @@ class DiffusionPipeline(ConfigMixin):
|
|||||||
module = getattr(self, name)
|
module = getattr(self, name)
|
||||||
if isinstance(module, torch.nn.Module):
|
if isinstance(module, torch.nn.Module):
|
||||||
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
|
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
|
||||||
raise ValueError(
|
logger.warning(
|
||||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot be moved to `cpu` or `mps` "
|
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It"
|
||||||
"due to the lack of support for `float16` operations on those devices in PyTorch. "
|
" is not recommended to move them to `cpu` or `mps` as running them will fail. Please make"
|
||||||
"Please remove the `torch_dtype=torch.float16` argument, or use a `cuda` device."
|
" sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for"
|
||||||
|
" `float16` operations on those devices in PyTorch. Please remove the"
|
||||||
|
" `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference."
|
||||||
)
|
)
|
||||||
module.to(torch_device)
|
module.to(torch_device)
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -247,17 +247,6 @@ class PipelineFastTests(unittest.TestCase):
|
|||||||
|
|
||||||
return extract
|
return extract
|
||||||
|
|
||||||
def test_pipeline_fp16_cpu_error(self):
|
|
||||||
model = self.dummy_uncond_unet
|
|
||||||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
|
||||||
pipe = DDIMPipeline(model.half(), scheduler)
|
|
||||||
|
|
||||||
if str(torch_device) in ["cpu", "mps"]:
|
|
||||||
self.assertRaises(ValueError, pipe.to, torch_device)
|
|
||||||
else:
|
|
||||||
# moving the pipeline to GPU should work
|
|
||||||
pipe.to(torch_device)
|
|
||||||
|
|
||||||
def test_ddim(self):
|
def test_ddim(self):
|
||||||
unet = self.dummy_uncond_unet
|
unet = self.dummy_uncond_unet
|
||||||
scheduler = DDIMScheduler()
|
scheduler = DDIMScheduler()
|
||||||
|
|||||||
Reference in New Issue
Block a user