mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Skip prefetching consolidated safetensors when appropriate (#7225)
* Why? Some models (e.g. anything produced by Mistral) can have both sharded safetensors and a consolidated safetensor in the same checkpoint directory. In such cases, prefetching both to memory is a waste of time, and memory. * What? This commit skips over consolidated safetensors when they are not the only safetensor file present in the checkpoint directory. Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
parent
85b4ae26b7
commit
34c1e9c341
@ -33,6 +33,7 @@ extend_skip_glob = [
|
||||
"tensorrt_llm/top_model_mixin.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
|
||||
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
|
||||
]
|
||||
|
||||
[tool.yapf]
|
||||
@ -63,6 +64,7 @@ ignore_patterns = [
|
||||
"tensorrt_llm/top_model_mixin.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
|
||||
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
|
||||
]
|
||||
|
||||
[tool.codespell]
|
||||
@ -97,6 +99,7 @@ exclude = [
|
||||
"tensorrt_llm/top_model_mixin.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
|
||||
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
|
||||
]
|
||||
|
||||
|
||||
@ -140,6 +143,7 @@ include = [
|
||||
"tensorrt_llm/top_model_mixin.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
|
||||
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
|
||||
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
|
||||
]
|
||||
exclude = [
|
||||
"**3rdparty/**",
|
||||
|
||||
@ -26,6 +26,14 @@ class HfWeightLoader(BaseWeightLoader):
|
||||
|
||||
def load_weights(self, checkpoint_dir: str) -> dict[str, Any]:
|
||||
weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
|
||||
# Some model checkpoint directories contain not only the sharded safetensors, but one
|
||||
# consolidated tensor. In the presence of both, we favor the former, as there really is no need
|
||||
# to prefetch the (usually) ridiculously large consolidated tensor into memory in such a case.
|
||||
filtered_weight_files = [
|
||||
x for x in weight_files if "consolidated" not in os.path.split(x)[1]
|
||||
]
|
||||
if len(filtered_weight_files) > 0:
|
||||
weight_files = filtered_weight_files
|
||||
if weight_files:
|
||||
# Prefetch the weight files to CPU memory if the size is less than 90% of the available memory.
|
||||
# This is a heuristic to avoid prefetching files that are too large and causing file cache thrashing.
|
||||
|
||||
@ -16,6 +16,9 @@ l0_a10:
|
||||
# ------------- PyTorch tests ---------------
|
||||
- unittest/_torch/modeling/test_modeling_mistral.py
|
||||
- unittest/_torch/modeling/test_modeling_pixtral.py
|
||||
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
|
||||
# test list either).
|
||||
- unittest/_torch/models/checkpoints/hf/test_weight_loader.py
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]
|
||||
|
||||
@ -0,0 +1,80 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm._torch.models.checkpoints import HfWeightLoader
|
||||
|
||||
|
||||
class MyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dir_name, safetensor_filenames, expected_safetensor_filenames",
|
||||
[
|
||||
(
|
||||
"foo",
|
||||
[
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-000002-of-00002.safetensors",
|
||||
"consolidated.safetensors",
|
||||
],
|
||||
["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"],
|
||||
),
|
||||
(
|
||||
"foo",
|
||||
[
|
||||
*(f"model-0000{i}-of-00010.safetensors" for i in range(1, 11)),
|
||||
"foo-consolidated.safetensors",
|
||||
],
|
||||
[f"model-0000{i}-of-00010.safetensors" for i in range(1, 11)],
|
||||
),
|
||||
# If there is only a consolidated safetensor, that one should still be used.
|
||||
(
|
||||
"foo",
|
||||
["consolidated.safetensors"],
|
||||
["consolidated.safetensors"],
|
||||
),
|
||||
# If the directory contains "consolidated" in its name, but its contents are sharded tensors.
|
||||
(
|
||||
"consolidated-model",
|
||||
[
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-000002-of-00002.safetensors",
|
||||
"consolidated.safetensors",
|
||||
],
|
||||
["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists(
|
||||
tmp_path,
|
||||
dir_name: str,
|
||||
safetensor_filenames: list[str],
|
||||
expected_safetensor_filenames: list[str],
|
||||
):
|
||||
checkpoint_dir = tmp_path / dir_name
|
||||
checkpoint_dir.mkdir()
|
||||
for filename in safetensor_filenames:
|
||||
(checkpoint_dir / filename).touch()
|
||||
expected_safetensor_filenames = set(
|
||||
str(checkpoint_dir / filename) for filename in expected_safetensor_filenames
|
||||
)
|
||||
|
||||
loader = HfWeightLoader()
|
||||
with (
|
||||
mock.patch.object(
|
||||
loader, "_load_weights_in_parallel", side_effect=MyError
|
||||
) as load_weights_in_parallel,
|
||||
mock.patch.object(loader, "prefetch_files") as prefetch_files,
|
||||
pytest.raises(MyError),
|
||||
):
|
||||
loader.load_weights(checkpoint_dir=str(checkpoint_dir))
|
||||
|
||||
prefetch_files.assert_called_once()
|
||||
prefetched_files = prefetch_files.call_args[0][0]
|
||||
assert set(prefetched_files) == expected_safetensor_filenames
|
||||
|
||||
load_weights_in_parallel.assert_called_once()
|
||||
loaded_weight_files = load_weights_in_parallel.call_args[0][0]
|
||||
assert set(loaded_weight_files) == expected_safetensor_filenames
|
||||
Loading…
Reference in New Issue
Block a user