[None][feat] Skip prefetching consolidated safetensors when appropriate (#7225)

* Why?

Some models (e.g. anything produced by Mistral) can have both sharded
safetensors and a consolidated safetensor in the same checkpoint
directory. In such cases, prefetching both to memory is a waste of time,
and memory.

* What?

This commit skips over consolidated safetensors when they are not the
only safetensor file present in the checkpoint directory.

Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
William Zhang 2025-08-26 09:40:17 -07:00 committed by GitHub
parent 85b4ae26b7
commit 34c1e9c341
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 95 additions and 0 deletions

View File

@ -33,6 +33,7 @@ extend_skip_glob = [
"tensorrt_llm/top_model_mixin.py",
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
]
[tool.yapf]
@ -63,6 +64,7 @@ ignore_patterns = [
"tensorrt_llm/top_model_mixin.py",
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
]
[tool.codespell]
@ -97,6 +99,7 @@ exclude = [
"tensorrt_llm/top_model_mixin.py",
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
]
@ -140,6 +143,7 @@ include = [
"tensorrt_llm/top_model_mixin.py",
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
]
exclude = [
"**3rdparty/**",

View File

@ -26,6 +26,14 @@ class HfWeightLoader(BaseWeightLoader):
def load_weights(self, checkpoint_dir: str) -> dict[str, Any]:
weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
# Some model checkpoint directories contain not only the sharded safetensors, but one
# consolidated tensor. In the presence of both, we favor the former, as there really is no need
# to prefetch the (usually) ridiculously large consolidated tensor into memory in such a case.
filtered_weight_files = [
x for x in weight_files if "consolidated" not in os.path.split(x)[1]
]
if len(filtered_weight_files) > 0:
weight_files = filtered_weight_files
if weight_files:
# Prefetch the weight files to CPU memory if the size is less than 90% of the available memory.
# This is a heuristic to avoid prefetching files that are too large and causing file cache thrashing.

View File

@ -16,6 +16,9 @@ l0_a10:
# ------------- PyTorch tests ---------------
- unittest/_torch/modeling/test_modeling_mistral.py
- unittest/_torch/modeling/test_modeling_pixtral.py
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
# test list either).
- unittest/_torch/models/checkpoints/hf/test_weight_loader.py
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]

View File

@ -0,0 +1,80 @@
from unittest import mock
import pytest
from tensorrt_llm._torch.models.checkpoints import HfWeightLoader
class MyError(Exception):
pass
@pytest.mark.parametrize(
"dir_name, safetensor_filenames, expected_safetensor_filenames",
[
(
"foo",
[
"model-00001-of-00002.safetensors",
"model-000002-of-00002.safetensors",
"consolidated.safetensors",
],
["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"],
),
(
"foo",
[
*(f"model-0000{i}-of-00010.safetensors" for i in range(1, 11)),
"foo-consolidated.safetensors",
],
[f"model-0000{i}-of-00010.safetensors" for i in range(1, 11)],
),
# If there is only a consolidated safetensor, that one should still be used.
(
"foo",
["consolidated.safetensors"],
["consolidated.safetensors"],
),
# If the directory contains "consolidated" in its name, but its contents are sharded tensors.
(
"consolidated-model",
[
"model-00001-of-00002.safetensors",
"model-000002-of-00002.safetensors",
"consolidated.safetensors",
],
["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"],
),
],
)
def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists(
tmp_path,
dir_name: str,
safetensor_filenames: list[str],
expected_safetensor_filenames: list[str],
):
checkpoint_dir = tmp_path / dir_name
checkpoint_dir.mkdir()
for filename in safetensor_filenames:
(checkpoint_dir / filename).touch()
expected_safetensor_filenames = set(
str(checkpoint_dir / filename) for filename in expected_safetensor_filenames
)
loader = HfWeightLoader()
with (
mock.patch.object(
loader, "_load_weights_in_parallel", side_effect=MyError
) as load_weights_in_parallel,
mock.patch.object(loader, "prefetch_files") as prefetch_files,
pytest.raises(MyError),
):
loader.load_weights(checkpoint_dir=str(checkpoint_dir))
prefetch_files.assert_called_once()
prefetched_files = prefetch_files.call_args[0][0]
assert set(prefetched_files) == expected_safetensor_filenames
load_weights_in_parallel.assert_called_once()
loaded_weight_files = load_weights_in_parallel.call_args[0][0]
assert set(loaded_weight_files) == expected_safetensor_filenames