From e4325606db94defcb48f1831d0c86008808349e2 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 7 Aug 2024 01:38:44 +0200 Subject: [PATCH] Fix loading sharded checkpoints when we have variants (#9061) * Fix loading sharded checkpoint when we have variant * add test * remote print --------- Co-authored-by: Sayak Paul --- src/diffusers/models/modeling_utils.py | 4 ++-- tests/models/unets/test_models_unet_2d_condition.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f7324009f3..cfe692dcc5 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -773,7 +773,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): try: accelerate.load_checkpoint_and_dispatch( model, - model_file if not is_sharded else sharded_ckpt_cached_folder, + model_file if not is_sharded else index_file, device_map, max_memory=max_memory, offload_folder=offload_folder, @@ -803,7 +803,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): model._temp_convert_self_to_deprecated_attention_blocks() accelerate.load_checkpoint_and_dispatch( model, - model_file if not is_sharded else sharded_ckpt_cached_folder, + model_file if not is_sharded else index_file, device_map, max_memory=max_memory, offload_folder=offload_folder, diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 1c688c9e9c..df88e7960b 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1121,6 +1121,18 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu + def test_load_sharded_checkpoint_with_variant_from_hub(self): + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + loaded_model = self.model_class.from_pretrained( + "hf-internal-testing/unet2d-sharded-with-variant-dummy", variant="fp16" + ) + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + @require_peft_backend def test_lora(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()