[Loading] Better error message on missing keys (#2198)

* up

* finish
This commit is contained in:
Patrick von Platen
2023-02-01 15:22:39 +02:00
committed by GitHub
parent 4fc7084875
commit 8267c78445
2 changed files with 19 additions and 1 deletions
+9
View File
@@ -541,6 +541,15 @@ class ModelMixin(torch.nn.Module):
param_device = "cpu" param_device = "cpu"
state_dict = load_state_dict(model_file) state_dict = load_state_dict(model_file)
# move the params from meta device to cpu # move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
" those weights or else make sure your checkpoint file is correct."
)
for param_name, param in state_dict.items(): for param_name, param in state_dict.items():
accepts_dtype = "dtype" in set( accepts_dtype = "dtype" in set(
inspect.signature(set_module_tensor_to_device).parameters.keys() inspect.signature(set_module_tensor_to_device).parameters.keys()
+10 -1
View File
@@ -21,11 +21,20 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
from diffusers.models import ModelMixin from diffusers.models import ModelMixin, UNet2DConditionModel
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import torch_device from diffusers.utils import torch_device
class ModelUtilsTest(unittest.TestCase):
def test_accelerate_loading_error_message(self):
with self.assertRaises(ValueError) as error_context:
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
# make sure that error message states what keys are missing
assert "conv_out.bias" in str(error_context.exception)
class ModelTesterMixin: class ModelTesterMixin:
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()