Fix loading sharded checkpoints when we have variants (#9061)
* Fix loading sharded checkpoint when we have variant * add test * remote print --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -773,7 +773,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
try:
|
try:
|
||||||
accelerate.load_checkpoint_and_dispatch(
|
accelerate.load_checkpoint_and_dispatch(
|
||||||
model,
|
model,
|
||||||
model_file if not is_sharded else sharded_ckpt_cached_folder,
|
model_file if not is_sharded else index_file,
|
||||||
device_map,
|
device_map,
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
offload_folder=offload_folder,
|
offload_folder=offload_folder,
|
||||||
@@ -803,7 +803,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
model._temp_convert_self_to_deprecated_attention_blocks()
|
model._temp_convert_self_to_deprecated_attention_blocks()
|
||||||
accelerate.load_checkpoint_and_dispatch(
|
accelerate.load_checkpoint_and_dispatch(
|
||||||
model,
|
model,
|
||||||
model_file if not is_sharded else sharded_ckpt_cached_folder,
|
model_file if not is_sharded else index_file,
|
||||||
device_map,
|
device_map,
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
offload_folder=offload_folder,
|
offload_folder=offload_folder,
|
||||||
|
|||||||
@@ -1121,6 +1121,18 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
|||||||
assert loaded_model
|
assert loaded_model
|
||||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_load_sharded_checkpoint_with_variant_from_hub(self):
|
||||||
|
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
loaded_model = self.model_class.from_pretrained(
|
||||||
|
"hf-internal-testing/unet2d-sharded-with-variant-dummy", variant="fp16"
|
||||||
|
)
|
||||||
|
loaded_model = loaded_model.to(torch_device)
|
||||||
|
new_output = loaded_model(**inputs_dict)
|
||||||
|
|
||||||
|
assert loaded_model
|
||||||
|
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||||
|
|
||||||
@require_peft_backend
|
@require_peft_backend
|
||||||
def test_lora(self):
|
def test_lora(self):
|
||||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user