mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: Prefetch safetensors files before loading them (#4140)
Prefetching safetensors files so that they are stored in the system file cache. This significantly speeds up the model weight loading for the very first run after entering the docker container. This is beneficial because model weight loading is done layer-by-layer, which means reading from the safetensors chunk-by-chunk, and that cannot utilize the internet bandwidth very well, assuming that these files are stored in some network drives. Instead, loading the whole files in bulk can achieve higher internet bandwidth utilization. When running with world_size>1, all ranks collaboratedly prefetch these files. In theory, we should add heuristics to decide whether to prefetch the files or not, but that is beyond the scope of this commit. For example, when the CPU memory is small, doing prefetching may result in file cache thrashing, resulting in slower weight loading time. Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
parent
24be357964
commit
13c8e5a8a8
@ -5,11 +5,12 @@ import glob
|
||||
import inspect
|
||||
import itertools
|
||||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@ -123,10 +124,41 @@ def validate_and_set_kv_cache_quant(model_config: ModelConfig,
|
||||
model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant
|
||||
|
||||
|
||||
def load_weights(checkpoint_dir: str):
|
||||
def prefetch_files(file_names: List[str], mapping: Mapping):
|
||||
"""
|
||||
Prefetch safetensors files to memory so that the weight loading will be much faster.
|
||||
When multiple ranks run in parallel, each rank will prefetch some files.
|
||||
TODO: On systems with small memory, prefetching may cause file cache thrashing, so we may want to add some
|
||||
heuristics about when to prefetch and when not to.
|
||||
"""
|
||||
|
||||
def _prefetch_one_file(file_name, rank):
|
||||
if os.path.exists(file_name):
|
||||
logger.info(f"Rank {rank} prefetching {file_name} to memory...")
|
||||
with open(file_name, 'rb') as f:
|
||||
f.read()
|
||||
logger.info(f"Rank {rank} finished prefetching {file_name}.")
|
||||
|
||||
# Find out the files to prefetch for the current rank.
|
||||
# Each rank loads files with indices rank, rank + world_size, rank + 2*world_size, etc.
|
||||
local_file_names = file_names[mapping.rank::mapping.world_size]
|
||||
|
||||
processes = []
|
||||
for file_name in local_file_names:
|
||||
process = multiprocessing.Process(target=_prefetch_one_file,
|
||||
args=(file_name, mapping.rank))
|
||||
process.start()
|
||||
processes.append(process)
|
||||
|
||||
for process in processes:
|
||||
process.join()
|
||||
|
||||
|
||||
def load_weights(checkpoint_dir: str, mapping: Mapping):
|
||||
weights = {}
|
||||
weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
|
||||
if weight_files:
|
||||
prefetch_files(weight_files, mapping)
|
||||
for file in weight_files:
|
||||
logger.info(f"Loading {file}")
|
||||
part_weights = safetensors.torch.load_file(file)
|
||||
@ -890,9 +922,10 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
if load_format == LoadFormat.AUTO:
|
||||
if hasattr(model, 'llm_checkpoint_dir'):
|
||||
weights = load_weights(model.llm_checkpoint_dir)
|
||||
weights = load_weights(model.llm_checkpoint_dir,
|
||||
self.mapping)
|
||||
else:
|
||||
weights = load_weights(checkpoint_dir)
|
||||
weights = load_weights(checkpoint_dir, self.mapping)
|
||||
|
||||
model.load_weights(weights)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user