diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index ceddf70b2c..30048b6590 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -541,6 +541,15 @@ class ModelMixin(torch.nn.Module): param_device = "cpu" state_dict = load_state_dict(model_file) # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize" + " those weights or else make sure your checkpoint file is correct." + ) + for param_name, param in state_dict.items(): accepts_dtype = "dtype" in set( inspect.signature(set_module_tensor_to_device).parameters.keys() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 26acdd4192..db006790a2 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -21,11 +21,20 @@ from typing import Dict, List, Tuple import numpy as np import torch -from diffusers.models import ModelMixin +from diffusers.models import ModelMixin, UNet2DConditionModel from diffusers.training_utils import EMAModel from diffusers.utils import torch_device +class ModelUtilsTest(unittest.TestCase): + def test_accelerate_loading_error_message(self): + with self.assertRaises(ValueError) as error_context: + UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet") + + # make sure that error message states what keys are missing + assert "conv_out.bias" in str(error_context.exception) + + class ModelTesterMixin: def test_from_save_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()