Remove torch_dtype in to() to end deprecation (#6886)
* remove torch_dtype from to()
* remove torch_dtype from usage scripts.
* remove old lora backend
* Revert "remove old lora backend"
This reverts commit adcddf6ba4.
This commit is contained in:
parent
4a3d52850b
commit
1835510524
@ -576,6 +576,6 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.half:
|
if args.half:
|
||||||
pipe.to(torch_dtype=torch.float16)
|
pipe.to(dtype=torch.float16)
|
||||||
|
|
||||||
pipe.save_pretrained(args.dump_path)
|
pipe.save_pretrained(args.dump_path)
|
||||||
|
|||||||
@ -179,7 +179,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.half:
|
if args.half:
|
||||||
pipe.to(torch_dtype=torch.float16)
|
pipe.to(dtype=torch.float16)
|
||||||
|
|
||||||
if args.controlnet:
|
if args.controlnet:
|
||||||
# only save the controlnet model
|
# only save the controlnet model
|
||||||
|
|||||||
@ -801,6 +801,6 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.half:
|
if args.half:
|
||||||
pipe.to(torch_dtype=torch.float16)
|
pipe.to(dtype=torch.float16)
|
||||||
|
|
||||||
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
||||||
|
|||||||
@ -775,32 +775,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
Returns:
|
Returns:
|
||||||
[`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
|
[`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
|
||||||
"""
|
"""
|
||||||
|
dtype = kwargs.pop("dtype", None)
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
device = kwargs.pop("device", None)
|
||||||
if torch_dtype is not None:
|
|
||||||
deprecate("torch_dtype", "0.27.0", "")
|
|
||||||
torch_device = kwargs.pop("torch_device", None)
|
|
||||||
if torch_device is not None:
|
|
||||||
deprecate("torch_device", "0.27.0", "")
|
|
||||||
|
|
||||||
dtype_kwarg = kwargs.pop("dtype", None)
|
|
||||||
device_kwarg = kwargs.pop("device", None)
|
|
||||||
silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)
|
silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)
|
||||||
|
|
||||||
if torch_dtype is not None and dtype_kwarg is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`."
|
|
||||||
)
|
|
||||||
|
|
||||||
dtype = torch_dtype or dtype_kwarg
|
|
||||||
|
|
||||||
if torch_device is not None and device_kwarg is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You have passed both `torch_device` and `device` as a keyword argument. Please make sure to only pass `device`."
|
|
||||||
)
|
|
||||||
|
|
||||||
device = torch_device or device_kwarg
|
|
||||||
|
|
||||||
dtype_arg = None
|
dtype_arg = None
|
||||||
device_arg = None
|
device_arg = None
|
||||||
if len(args) == 1:
|
if len(args) == 1:
|
||||||
@ -873,12 +851,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
|
|
||||||
if is_loaded_in_8bit and dtype is not None:
|
if is_loaded_in_8bit and dtype is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision."
|
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_loaded_in_8bit and device is not None:
|
if is_loaded_in_8bit and device is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
|
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
module.to(device, dtype)
|
module.to(device, dtype)
|
||||||
|
|||||||
@ -218,7 +218,7 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||||
|
|
||||||
pipe.to(torch_dtype=torch.float16)
|
pipe.to(dtype=torch.float16)
|
||||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||||
|
|
||||||
|
|||||||
@ -224,7 +224,7 @@ class AnimateDiffVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.Tes
|
|||||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||||
|
|
||||||
pipe.to(torch_dtype=torch.float16)
|
pipe.to(dtype=torch.float16)
|
||||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||||
|
|
||||||
|
|||||||
@ -483,7 +483,7 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
|
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
|
||||||
|
|
||||||
# Once we send to fp16, all params are in half-precision, including the logit scale
|
# Once we send to fp16, all params are in half-precision, including the logit scale
|
||||||
pipe.to(torch_dtype=torch.float16)
|
pipe.to(dtype=torch.float16)
|
||||||
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
|
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
|
||||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
|
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
|
||||||
|
|
||||||
|
|||||||
@ -400,7 +400,7 @@ class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
|
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
|
||||||
|
|
||||||
# Once we send to fp16, all params are in half-precision, including the logit scale
|
# Once we send to fp16, all params are in half-precision, including the logit scale
|
||||||
pipe.to(torch_dtype=torch.float16)
|
pipe.to(dtype=torch.float16)
|
||||||
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
|
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
|
||||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
|
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
|
||||||
|
|
||||||
|
|||||||
@ -231,7 +231,7 @@ class PIAPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||||
|
|
||||||
pipe.to(torch_dtype=torch.float16)
|
pipe.to(dtype=torch.float16)
|
||||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||||
|
|
||||||
|
|||||||
@ -396,7 +396,7 @@ class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
|||||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||||
|
|
||||||
pipe.to(torch_dtype=torch.float16)
|
pipe.to(dtype=torch.float16)
|
||||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||||
|
|
||||||
|
|||||||
@ -1623,7 +1623,7 @@ class PipelineFastTests(unittest.TestCase):
|
|||||||
sd1 = sd.to(torch.float16)
|
sd1 = sd.to(torch.float16)
|
||||||
sd2 = sd.to(None, torch.float16)
|
sd2 = sd.to(None, torch.float16)
|
||||||
sd3 = sd.to(dtype=torch.float16)
|
sd3 = sd.to(dtype=torch.float16)
|
||||||
sd4 = sd.to(torch_dtype=torch.float16)
|
sd4 = sd.to(dtype=torch.float16)
|
||||||
sd5 = sd.to(None, dtype=torch.float16)
|
sd5 = sd.to(None, dtype=torch.float16)
|
||||||
sd6 = sd.to(None, torch_dtype=torch.float16)
|
sd6 = sd.to(None, torch_dtype=torch.float16)
|
||||||
|
|
||||||
|
|||||||
@ -716,7 +716,7 @@ class PipelineTesterMixin:
|
|||||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||||
|
|
||||||
pipe.to(torch_dtype=torch.float16)
|
pipe.to(dtype=torch.float16)
|
||||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user