feat: Prefetch safetensors files before loading them (#4140)

Prefetching safetensors files so that they are stored in the system file
cache. This significantly speeds up the model weight loading for the
very first run after entering the docker container.

This is beneficial because model weight loading is done layer-by-layer,
which means reading from the safetensors chunk-by-chunk, and that cannot
utilize the internet bandwidth very well, assuming that these files are
stored in some network drives. Instead, loading the whole files in bulk
can achieve higher internet bandwidth utilization.

When running with world_size>1, all ranks collaboratedly prefetch these
files.

In theory, we should add heuristics to decide whether to prefetch the
files or not, but that is beyond the scope of this commit.

For example, when the CPU memory is small, doing prefetching may result
in file cache thrashing, resulting in slower weight loading time.

Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
nvpohanh 2025-05-13 13:35:30 +08:00 committed by GitHub
parent 24be357964
commit 13c8e5a8a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,11 +5,12 @@ import glob
import inspect
import itertools
import math
import multiprocessing
import os
import traceback
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import safetensors
import torch
@ -123,10 +124,41 @@ def validate_and_set_kv_cache_quant(model_config: ModelConfig,
model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant
def load_weights(checkpoint_dir: str):
def prefetch_files(file_names: List[str], mapping: Mapping):
"""
Prefetch safetensors files to memory so that the weight loading will be much faster.
When multiple ranks run in parallel, each rank will prefetch some files.
TODO: On systems with small memory, prefetching may cause file cache thrashing, so we may want to add some
heuristics about when to prefetch and when not to.
"""
def _prefetch_one_file(file_name, rank):
if os.path.exists(file_name):
logger.info(f"Rank {rank} prefetching {file_name} to memory...")
with open(file_name, 'rb') as f:
f.read()
logger.info(f"Rank {rank} finished prefetching {file_name}.")
# Find out the files to prefetch for the current rank.
# Each rank loads files with indices rank, rank + world_size, rank + 2*world_size, etc.
local_file_names = file_names[mapping.rank::mapping.world_size]
processes = []
for file_name in local_file_names:
process = multiprocessing.Process(target=_prefetch_one_file,
args=(file_name, mapping.rank))
process.start()
processes.append(process)
for process in processes:
process.join()
def load_weights(checkpoint_dir: str, mapping: Mapping):
weights = {}
weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
if weight_files:
prefetch_files(weight_files, mapping)
for file in weight_files:
logger.info(f"Loading {file}")
part_weights = safetensors.torch.load_file(file)
@ -890,9 +922,10 @@ class PyTorchModelEngine(ModelEngine):
if load_format == LoadFormat.AUTO:
if hasattr(model, 'llm_checkpoint_dir'):
weights = load_weights(model.llm_checkpoint_dir)
weights = load_weights(model.llm_checkpoint_dir,
self.mapping)
else:
weights = load_weights(checkpoint_dir)
weights = load_weights(checkpoint_dir, self.mapping)
model.load_weights(weights)