Safetensor loading in AnimateDiff conversion scripts (#7764)
* update * update
This commit is contained in:
parent
a38dd79512
commit
eb96ff0d59
@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
def convert_motion_module(original_state_dict):
|
||||
@ -34,7 +34,10 @@ def get_args():
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if args.ckpt_path.endswith(".safetensors"):
|
||||
state_dict = load_file(args.ckpt_path)
|
||||
else:
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from diffusers import MotionAdapter
|
||||
|
||||
@ -38,7 +39,11 @@ def get_args():
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if args.ckpt_path.endswith(".safetensors"):
|
||||
state_dict = load_file(args.ckpt_path)
|
||||
else:
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user