[Conversion] Support convert diffusers to safetensors (#1996)
fix: support diffusers to safetensors
This commit is contained in:
@@ -8,6 +8,8 @@ import re
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
|
||||||
# =================#
|
# =================#
|
||||||
# UNet Conversion #
|
# UNet Conversion #
|
||||||
@@ -266,6 +268,9 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
|
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
|
||||||
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
|
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||||
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
|
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -306,5 +311,9 @@ if __name__ == "__main__":
|
|||||||
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||||
if args.half:
|
if args.half:
|
||||||
state_dict = {k: v.half() for k, v in state_dict.items()}
|
state_dict = {k: v.half() for k, v in state_dict.items()}
|
||||||
state_dict = {"state_dict": state_dict}
|
|
||||||
torch.save(state_dict, args.checkpoint_path)
|
if args.use_safetensors:
|
||||||
|
save_file(state_dict, args.checkpoint_path)
|
||||||
|
else:
|
||||||
|
state_dict = {"state_dict": state_dict}
|
||||||
|
torch.save(state_dict, args.checkpoint_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user