diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 7db587a11b..06ffd32998 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -5,11 +5,12 @@ import glob import inspect import itertools import math +import multiprocessing import os import traceback from abc import ABC, abstractmethod from collections import defaultdict -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import safetensors import torch @@ -123,10 +124,41 @@ def validate_and_set_kv_cache_quant(model_config: ModelConfig, model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant -def load_weights(checkpoint_dir: str): +def prefetch_files(file_names: List[str], mapping: Mapping): + """ + Prefetch safetensors files to memory so that the weight loading will be much faster. + When multiple ranks run in parallel, each rank will prefetch some files. + TODO: On systems with small memory, prefetching may cause file cache thrashing, so we may want to add some + heuristics about when to prefetch and when not to. + """ + + def _prefetch_one_file(file_name, rank): + if os.path.exists(file_name): + logger.info(f"Rank {rank} prefetching {file_name} to memory...") + with open(file_name, 'rb') as f: + f.read() + logger.info(f"Rank {rank} finished prefetching {file_name}.") + + # Find out the files to prefetch for the current rank. + # Each rank loads files with indices rank, rank + world_size, rank + 2*world_size, etc. + local_file_names = file_names[mapping.rank::mapping.world_size] + + processes = [] + for file_name in local_file_names: + process = multiprocessing.Process(target=_prefetch_one_file, + args=(file_name, mapping.rank)) + process.start() + processes.append(process) + + for process in processes: + process.join() + + +def load_weights(checkpoint_dir: str, mapping: Mapping): weights = {} weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors") if weight_files: + prefetch_files(weight_files, mapping) for file in weight_files: logger.info(f"Loading {file}") part_weights = safetensors.torch.load_file(file) @@ -890,9 +922,10 @@ class PyTorchModelEngine(ModelEngine): if load_format == LoadFormat.AUTO: if hasattr(model, 'llm_checkpoint_dir'): - weights = load_weights(model.llm_checkpoint_dir) + weights = load_weights(model.llm_checkpoint_dir, + self.mapping) else: - weights = load_weights(checkpoint_dir) + weights = load_weights(checkpoint_dir, self.mapping) model.load_weights(weights)