committed by
GitHub
parent
4fc7084875
commit
8267c78445
@@ -541,6 +541,15 @@ class ModelMixin(torch.nn.Module):
|
||||
param_device = "cpu"
|
||||
state_dict = load_state_dict(model_file)
|
||||
# move the params from meta device to cpu
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
if len(missing_keys) > 0:
|
||||
raise ValueError(
|
||||
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
||||
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
||||
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
|
||||
" those weights or else make sure your checkpoint file is correct."
|
||||
)
|
||||
|
||||
for param_name, param in state_dict.items():
|
||||
accepts_dtype = "dtype" in set(
|
||||
inspect.signature(set_module_tensor_to_device).parameters.keys()
|
||||
|
||||
@@ -21,11 +21,20 @@ from typing import Dict, List, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.models import ModelMixin, UNet2DConditionModel
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import torch_device
|
||||
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
def test_accelerate_loading_error_message(self):
|
||||
with self.assertRaises(ValueError) as error_context:
|
||||
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
|
||||
|
||||
# make sure that error message states what keys are missing
|
||||
assert "conv_out.bias" in str(error_context.exception)
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
def test_from_save_pretrained(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user