@@ -19,7 +19,7 @@ import torch
|
|||||||
from huggingface_hub.utils import validate_hf_hub_args
|
from huggingface_hub.utils import validate_hf_hub_args
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
|
||||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
_get_model_file,
|
_get_model_file,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
@@ -182,7 +182,7 @@ class IPAdapterMixin:
|
|||||||
elif key.startswith("ip_adapter."):
|
elif key.startswith("ip_adapter."):
|
||||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||||
else:
|
else:
|
||||||
state_dict = torch.load(model_file, map_location="cpu")
|
state_dict = load_state_dict(model_file)
|
||||||
else:
|
else:
|
||||||
state_dict = pretrained_model_name_or_path_or_dict
|
state_dict = pretrained_model_name_or_path_or_dict
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from packaging import version
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
USE_PEFT_BACKEND,
|
USE_PEFT_BACKEND,
|
||||||
_get_model_file,
|
_get_model_file,
|
||||||
@@ -281,7 +281,7 @@ class LoraLoaderMixin:
|
|||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
state_dict = torch.load(model_file, map_location="cpu")
|
state_dict = load_state_dict(model_file)
|
||||||
else:
|
else:
|
||||||
state_dict = pretrained_model_name_or_path_or_dict
|
state_dict = pretrained_model_name_or_path_or_dict
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import torch
|
|||||||
from huggingface_hub.utils import validate_hf_hub_args
|
from huggingface_hub.utils import validate_hf_hub_args
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from ..models.modeling_utils import load_state_dict
|
||||||
from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
|
from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
|
||||||
|
|
||||||
|
|
||||||
@@ -100,7 +101,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
|
|||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
state_dict = torch.load(model_file, map_location="cpu")
|
state_dict = load_state_dict(model_file)
|
||||||
else:
|
else:
|
||||||
state_dict = pretrained_model_name_or_path
|
state_dict = pretrained_model_name_or_path
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from ..models.embeddings import (
|
|||||||
IPAdapterPlusImageProjection,
|
IPAdapterPlusImageProjection,
|
||||||
MultiIPAdapterImageProjection,
|
MultiIPAdapterImageProjection,
|
||||||
)
|
)
|
||||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
USE_PEFT_BACKEND,
|
USE_PEFT_BACKEND,
|
||||||
_get_model_file,
|
_get_model_file,
|
||||||
@@ -214,7 +214,7 @@ class UNet2DConditionLoadersMixin:
|
|||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
state_dict = torch.load(model_file, map_location="cpu")
|
state_dict = load_state_dict(model_file)
|
||||||
else:
|
else:
|
||||||
state_dict = pretrained_model_name_or_path_or_dict
|
state_dict = pretrained_model_name_or_path_or_dict
|
||||||
|
|
||||||
|
|||||||
@@ -108,7 +108,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
|||||||
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||||
else:
|
else:
|
||||||
return torch.load(checkpoint_file, map_location="cpu")
|
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
||||||
|
return torch.load(
|
||||||
|
checkpoint_file,
|
||||||
|
map_location="cpu",
|
||||||
|
**weights_only_kwarg,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
try:
|
||||||
with open(checkpoint_file) as f:
|
with open(checkpoint_file) as f:
|
||||||
|
|||||||
Reference in New Issue
Block a user