update
This commit is contained in:
@@ -354,8 +354,9 @@ def _load_shard_file(
|
||||
state_dict_folder=None,
|
||||
ignore_mismatched_sizes=False,
|
||||
low_cpu_mem_usage=False,
|
||||
disable_mmap=False,
|
||||
):
|
||||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
|
||||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap)
|
||||
mismatched_keys = _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
@@ -401,6 +402,7 @@ def _load_shard_files_with_threadpool(
|
||||
state_dict_folder=None,
|
||||
ignore_mismatched_sizes=False,
|
||||
low_cpu_mem_usage=False,
|
||||
disable_mmap=False,
|
||||
):
|
||||
# Do not spawn anymore workers than you need
|
||||
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
|
||||
@@ -427,6 +429,7 @@ def _load_shard_files_with_threadpool(
|
||||
state_dict_folder=state_dict_folder,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
|
||||
@@ -1298,6 +1298,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
dduf_entries=dduf_entries,
|
||||
is_parallel_loading_enabled=is_parallel_loading_enabled,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
@@ -1584,6 +1585,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
is_parallel_loading_enabled: Optional[bool] = False,
|
||||
disable_mmap: bool = False,
|
||||
):
|
||||
model_state_dict = model.state_dict()
|
||||
expected_keys = list(model_state_dict.keys())
|
||||
@@ -1652,6 +1654,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
state_dict_folder=state_dict_folder,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
|
||||
if is_parallel_loading_enabled:
|
||||
|
||||
Reference in New Issue
Block a user