From f9e6045f392710407b7397d67ec0e5020a5b7f45 Mon Sep 17 00:00:00 2001 From: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> Date: Tue, 3 Feb 2026 13:23:10 -0800 Subject: [PATCH] [#11086][feat] Optimize Auto Deploy weight loading by preloading weights to CPU (#11059) Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> --- .../_torch/auto_deploy/models/factory.py | 12 +- tensorrt_llm/_torch/auto_deploy/models/hf.py | 115 ++++++++++++++++-- .../transform/library/load_weights.py | 10 +- .../_torch/auto_deploy/transform/optimizer.py | 5 + 4 files changed, 128 insertions(+), 14 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index ecfc11889b..d04c35e5fe 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -263,7 +263,9 @@ class ModelFactory(ABC): """ return model_name_or_path - def load_or_random_init(self, model: nn.Module, device: DeviceLikeType): + def load_or_random_init( + self, model: nn.Module, device: DeviceLikeType, disable_preload: bool = False + ): """Load the checkpoint into the model or randomly initialize the model. Args: @@ -271,6 +273,7 @@ class ModelFactory(ABC): the same model that is built above but it needs to have a state dict compatible with the model built above. device: The device to load the model on. + disable_preload: If True, disable preloading weights to CPU before moving to device. load_factoy_model: If True, will load weights for the factory model in addition to main gm. This is useful for the transformers model. @@ -303,7 +306,7 @@ class ModelFactory(ABC): if not self.skip_loading_weights: self.prefetch_checkpoint(force=True) - self._load_checkpoint(model, device) + self._load_checkpoint(model, device, disable_preload=disable_preload) @staticmethod def _to_maybe_random(model: nn.Module, device: DeviceLikeType): @@ -323,7 +326,9 @@ class ModelFactory(ABC): ) @abstractmethod - def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType): + def _load_checkpoint( + self, model: nn.Module, device: DeviceLikeType, disable_preload: bool = False + ): """Load the checkpoint into the model. Args: @@ -331,6 +336,7 @@ class ModelFactory(ABC): the same model that is built above but it needs to have a state dict compatible with the model built above. device: The device to load the model on. + disable_preload: If True, disable preloading weights to CPU before moving to device. """ def get_example_inputs(self) -> Dict[str, torch.Tensor]: diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index fad26fa6e2..4bd525df85 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -1,5 +1,6 @@ """Interface to initialize and load HF models.""" +import json import os import re import types @@ -7,6 +8,7 @@ from abc import abstractmethod from contextlib import contextmanager, nullcontext from typing import Any, Dict, List, Optional, Tuple, Type, Union +import safetensors.torch import torch import torch.nn as nn from accelerate import init_empty_weights, load_checkpoint_in_model @@ -419,7 +421,9 @@ class AutoModelForCausalLMFactory(AutoModelFactory): return fetched_dir - def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType): + def _load_checkpoint( + self, model: nn.Module, device: DeviceLikeType, disable_preload: bool = False + ): """Load the checkpoint into the model.""" # identify the most relevant checkpoint file ckpt_file = self._get_checkpoint_file(self.model) @@ -434,20 +438,111 @@ class AutoModelForCausalLMFactory(AutoModelFactory): # Ensure it's the first one. model._state_dict_hooks.move_to_end(key=get_handle.id, last=False) - # reuse the load checkpoint utility from accelerate try: - with hf_load_state_dict_with_device(device): - # Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic. - # Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict, - # which collects local model params, syncs weights from checkpoint, and applies them via - # model.load_state_dict. - # This sync step can interfere with load_hooks by mixing raw checkpoint weights and - # model-transformed weights,leading to unexpected key mismatches or format issues. - load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False) + if disable_preload: + # Load checkpoint directly to GPU using accelerate's load_checkpoint_in_model (no CPU preload) + ad_logger.info( + "disable_preload=True: Using accelerate's load_checkpoint_in_model (no CPU preload)" + ) + with hf_load_state_dict_with_device(device): + load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False) + else: + # Preload checkpoint files to CPU + ad_logger.info("Preloading checkpoint files to CPU") + self._load_checkpoint_with_preload(model, ckpt_file, device) finally: load_handle.remove() get_handle.remove() + def _load_checkpoint_with_preload( + self, model: nn.Module, ckpt_file: str, device: DeviceLikeType + ): + all_weights = self._load_full_checkpoint_to_cpu(ckpt_file) + + ad_logger.info(f"Loading weights into model (device: {device})...") + model.load_state_dict(all_weights, strict=False) + + ad_logger.info("Checkpoint loading completed") + + def _load_full_checkpoint_to_cpu(self, checkpoint: str) -> dict: + """Load the full checkpoint to CPU memory. + + Args: + checkpoint: Can be: + - a path to a file containing a whole model state dict + - a path to a `.json` file containing the index to a sharded checkpoint + - a path to a folder containing a unique `.index.json` file and the shards + - a path to a folder containing a unique pytorch_model.bin or model.safetensors + """ + checkpoint_files = None + index_filename = None + + # Fast path: Direct .index.json file (most common case for sharded checkpoints) + if os.path.isfile(checkpoint): + if checkpoint.endswith(".index.json"): + index_filename = checkpoint + else: + checkpoint_files = [checkpoint] + elif os.path.isdir(checkpoint): + # Check if the whole state dict is present (priority order matches accelerate) + potential_state_bin = [f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME] + potential_state_safetensor = [ + f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME + ] + + # Case 1: pytorch_model.bin (WEIGHTS_NAME) + if len(potential_state_bin) == 1: + checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])] + # Case 2: model.safetensors (SAFE_WEIGHTS_NAME) + elif len(potential_state_safetensor) == 1: + checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])] + else: + # Case 3: Otherwise check for sharded checkpoints + potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")] + if len(potential_index) == 0: + raise ValueError( + f"{checkpoint} is not a folder containing a `.index.json` file or a " + f"{WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file" + ) + elif len(potential_index) == 1: + index_filename = os.path.join(checkpoint, potential_index[0]) + else: + raise ValueError( + f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones." + ) + else: + raise ValueError( + f"`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded " + f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got " + f"{checkpoint}." + ) + + # Load checkpoint files from index if needed + if index_filename is not None: + checkpoint_folder = os.path.dirname(index_filename) + with open(index_filename, "r") as f: + index = json.load(f) + + if "weight_map" in index: + index = index["weight_map"] + checkpoint_files = list(set(index.values())) + checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files] + + # Load all weights + all_weights = {} + for checkpoint_file in checkpoint_files: + ad_logger.info(f"Loading weight file: {checkpoint_file}") + if checkpoint_file.endswith(".safetensors"): + file_weights = safetensors.torch.load_file(checkpoint_file, device="cpu") + elif checkpoint_file.endswith((".bin", ".pth")): + file_weights = torch.load(checkpoint_file, map_location="cpu", weights_only=True) + else: + raise ValueError(f"Unsupported checkpoint format: {checkpoint_file}") + + all_weights.update(file_weights) + + return all_weights + def _load_quantization_config(self, fetched_dir: str): """Load the quantization config from the model directory if not done already.""" if self._quant_config_reader is not None: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py b/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py index 0579543054..77ef3fc6f1 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py @@ -25,6 +25,10 @@ class MoveDeviceConfig(TransformConfig): default=None, description="Optional device to init checkpoint before move to shared_config.local_device.", ) + disable_preload: bool = Field( + default=False, + description="If True, disable preloading weights.", + ) @TransformRegistry.register("load_weights") @@ -48,7 +52,11 @@ class LoadWeightsToDevice(BaseTransform): total_size_GB = bytes_to(params_size, unit="GB") self._log_info(f"Estimated parameters memory: {total_size_GB:.2f} GB") - factory.load_or_random_init(mod, device=self.config.checkpoint_device or cm.device) + factory.load_or_random_init( + mod, + device=self.config.checkpoint_device or cm.device, + disable_preload=self.config.disable_preload, + ) move_to_device(mod, cm.device) info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py b/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py index ce43730928..397d83494c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py @@ -1,6 +1,7 @@ """High-level entrypoint to transform a model into an efficient inference model.""" import gc +import time from typing import Optional import torch @@ -10,6 +11,7 @@ import torch.nn as nn from ..distributed import common as dist_ad from ..models.factory import ModelFactory from ..shim.interface import CachedSequenceInterface +from ..utils.logger import ad_logger from .interface import ( InferenceOptimizerConfig, SharedConfig, @@ -64,11 +66,14 @@ class InferenceOptimizer: mod = nn.Module() # iterate over all transforms sorted by stage in the config + start_time = time.time() for idx, (t_name, t_config) in enumerate(self.config.items()): # instantiate transform transform = TransformRegistry.get(t_name)(t_config) # run transform mod = transform(mod, cm, self.factory, self.shared_config, idx) + total_time = time.time() - start_time + ad_logger.info(f"Total time for all transforms: {total_time:.2f}s") ############################################################################################ # RETURN OPTIMIZED MODEL