From 0e975e5ff6b48d4a2f236059cabc28212d1f3733 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Feb 2023 17:24:40 +0200 Subject: [PATCH] [Safetensors] Make sure metadata is saved (#2506) * [Safetensors] Make sure metadata is saved * make style --- src/diffusers/models/modeling_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 913cff66c4..4108335da4 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -291,9 +291,6 @@ class ModelMixin(torch.nn.Module): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return - if save_function is None: - save_function = safetensors.torch.save_file if safe_serialization else torch.save - os.makedirs(save_directory, exist_ok=True) model_to_save = self @@ -310,7 +307,12 @@ class ModelMixin(torch.nn.Module): weights_name = _add_variant(weights_name, variant) # Save the model - save_function(state_dict, os.path.join(save_directory, weights_name)) + if safe_serialization: + safetensors.torch.save_file( + state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"} + ) + else: + torch.save(state_dict, os.path.join(save_directory, weights_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")