mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
This commit is contained in:
parent
f9c4bdf6cf
commit
f9e6045f39
@ -263,7 +263,9 @@ class ModelFactory(ABC):
|
||||
"""
|
||||
return model_name_or_path
|
||||
|
||||
def load_or_random_init(self, model: nn.Module, device: DeviceLikeType):
|
||||
def load_or_random_init(
|
||||
self, model: nn.Module, device: DeviceLikeType, disable_preload: bool = False
|
||||
):
|
||||
"""Load the checkpoint into the model or randomly initialize the model.
|
||||
|
||||
Args:
|
||||
@ -271,6 +273,7 @@ class ModelFactory(ABC):
|
||||
the same model that is built above but it needs to have a state dict compatible with
|
||||
the model built above.
|
||||
device: The device to load the model on.
|
||||
disable_preload: If True, disable preloading weights to CPU before moving to device.
|
||||
load_factoy_model: If True, will load weights for the factory model in addition to main
|
||||
gm. This is useful for the transformers model.
|
||||
|
||||
@ -303,7 +306,7 @@ class ModelFactory(ABC):
|
||||
|
||||
if not self.skip_loading_weights:
|
||||
self.prefetch_checkpoint(force=True)
|
||||
self._load_checkpoint(model, device)
|
||||
self._load_checkpoint(model, device, disable_preload=disable_preload)
|
||||
|
||||
@staticmethod
|
||||
def _to_maybe_random(model: nn.Module, device: DeviceLikeType):
|
||||
@ -323,7 +326,9 @@ class ModelFactory(ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
|
||||
def _load_checkpoint(
|
||||
self, model: nn.Module, device: DeviceLikeType, disable_preload: bool = False
|
||||
):
|
||||
"""Load the checkpoint into the model.
|
||||
|
||||
Args:
|
||||
@ -331,6 +336,7 @@ class ModelFactory(ABC):
|
||||
the same model that is built above but it needs to have a state dict compatible with
|
||||
the model built above.
|
||||
device: The device to load the model on.
|
||||
disable_preload: If True, disable preloading weights to CPU before moving to device.
|
||||
"""
|
||||
|
||||
def get_example_inputs(self) -> Dict[str, torch.Tensor]:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Interface to initialize and load HF models."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import types
|
||||
@ -7,6 +8,7 @@ from abc import abstractmethod
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate import init_empty_weights, load_checkpoint_in_model
|
||||
@ -419,7 +421,9 @@ class AutoModelForCausalLMFactory(AutoModelFactory):
|
||||
|
||||
return fetched_dir
|
||||
|
||||
def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
|
||||
def _load_checkpoint(
|
||||
self, model: nn.Module, device: DeviceLikeType, disable_preload: bool = False
|
||||
):
|
||||
"""Load the checkpoint into the model."""
|
||||
# identify the most relevant checkpoint file
|
||||
ckpt_file = self._get_checkpoint_file(self.model)
|
||||
@ -434,20 +438,111 @@ class AutoModelForCausalLMFactory(AutoModelFactory):
|
||||
# Ensure it's the first one.
|
||||
model._state_dict_hooks.move_to_end(key=get_handle.id, last=False)
|
||||
|
||||
# reuse the load checkpoint utility from accelerate
|
||||
try:
|
||||
with hf_load_state_dict_with_device(device):
|
||||
# Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic.
|
||||
# Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict,
|
||||
# which collects local model params, syncs weights from checkpoint, and applies them via
|
||||
# model.load_state_dict.
|
||||
# This sync step can interfere with load_hooks by mixing raw checkpoint weights and
|
||||
# model-transformed weights,leading to unexpected key mismatches or format issues.
|
||||
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
|
||||
if disable_preload:
|
||||
# Load checkpoint directly to GPU using accelerate's load_checkpoint_in_model (no CPU preload)
|
||||
ad_logger.info(
|
||||
"disable_preload=True: Using accelerate's load_checkpoint_in_model (no CPU preload)"
|
||||
)
|
||||
with hf_load_state_dict_with_device(device):
|
||||
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
|
||||
else:
|
||||
# Preload checkpoint files to CPU
|
||||
ad_logger.info("Preloading checkpoint files to CPU")
|
||||
self._load_checkpoint_with_preload(model, ckpt_file, device)
|
||||
finally:
|
||||
load_handle.remove()
|
||||
get_handle.remove()
|
||||
|
||||
def _load_checkpoint_with_preload(
|
||||
self, model: nn.Module, ckpt_file: str, device: DeviceLikeType
|
||||
):
|
||||
all_weights = self._load_full_checkpoint_to_cpu(ckpt_file)
|
||||
|
||||
ad_logger.info(f"Loading weights into model (device: {device})...")
|
||||
model.load_state_dict(all_weights, strict=False)
|
||||
|
||||
ad_logger.info("Checkpoint loading completed")
|
||||
|
||||
def _load_full_checkpoint_to_cpu(self, checkpoint: str) -> dict:
|
||||
"""Load the full checkpoint to CPU memory.
|
||||
|
||||
Args:
|
||||
checkpoint: Can be:
|
||||
- a path to a file containing a whole model state dict
|
||||
- a path to a `.json` file containing the index to a sharded checkpoint
|
||||
- a path to a folder containing a unique `.index.json` file and the shards
|
||||
- a path to a folder containing a unique pytorch_model.bin or model.safetensors
|
||||
"""
|
||||
checkpoint_files = None
|
||||
index_filename = None
|
||||
|
||||
# Fast path: Direct .index.json file (most common case for sharded checkpoints)
|
||||
if os.path.isfile(checkpoint):
|
||||
if checkpoint.endswith(".index.json"):
|
||||
index_filename = checkpoint
|
||||
else:
|
||||
checkpoint_files = [checkpoint]
|
||||
elif os.path.isdir(checkpoint):
|
||||
# Check if the whole state dict is present (priority order matches accelerate)
|
||||
potential_state_bin = [f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME]
|
||||
potential_state_safetensor = [
|
||||
f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME
|
||||
]
|
||||
|
||||
# Case 1: pytorch_model.bin (WEIGHTS_NAME)
|
||||
if len(potential_state_bin) == 1:
|
||||
checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])]
|
||||
# Case 2: model.safetensors (SAFE_WEIGHTS_NAME)
|
||||
elif len(potential_state_safetensor) == 1:
|
||||
checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])]
|
||||
else:
|
||||
# Case 3: Otherwise check for sharded checkpoints
|
||||
potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")]
|
||||
if len(potential_index) == 0:
|
||||
raise ValueError(
|
||||
f"{checkpoint} is not a folder containing a `.index.json` file or a "
|
||||
f"{WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file"
|
||||
)
|
||||
elif len(potential_index) == 1:
|
||||
index_filename = os.path.join(checkpoint, potential_index[0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
|
||||
f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got "
|
||||
f"{checkpoint}."
|
||||
)
|
||||
|
||||
# Load checkpoint files from index if needed
|
||||
if index_filename is not None:
|
||||
checkpoint_folder = os.path.dirname(index_filename)
|
||||
with open(index_filename, "r") as f:
|
||||
index = json.load(f)
|
||||
|
||||
if "weight_map" in index:
|
||||
index = index["weight_map"]
|
||||
checkpoint_files = list(set(index.values()))
|
||||
checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files]
|
||||
|
||||
# Load all weights
|
||||
all_weights = {}
|
||||
for checkpoint_file in checkpoint_files:
|
||||
ad_logger.info(f"Loading weight file: {checkpoint_file}")
|
||||
if checkpoint_file.endswith(".safetensors"):
|
||||
file_weights = safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
elif checkpoint_file.endswith((".bin", ".pth")):
|
||||
file_weights = torch.load(checkpoint_file, map_location="cpu", weights_only=True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported checkpoint format: {checkpoint_file}")
|
||||
|
||||
all_weights.update(file_weights)
|
||||
|
||||
return all_weights
|
||||
|
||||
def _load_quantization_config(self, fetched_dir: str):
|
||||
"""Load the quantization config from the model directory if not done already."""
|
||||
if self._quant_config_reader is not None:
|
||||
|
||||
@ -25,6 +25,10 @@ class MoveDeviceConfig(TransformConfig):
|
||||
default=None,
|
||||
description="Optional device to init checkpoint before move to shared_config.local_device.",
|
||||
)
|
||||
disable_preload: bool = Field(
|
||||
default=False,
|
||||
description="If True, disable preloading weights.",
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("load_weights")
|
||||
@ -48,7 +52,11 @@ class LoadWeightsToDevice(BaseTransform):
|
||||
total_size_GB = bytes_to(params_size, unit="GB")
|
||||
self._log_info(f"Estimated parameters memory: {total_size_GB:.2f} GB")
|
||||
|
||||
factory.load_or_random_init(mod, device=self.config.checkpoint_device or cm.device)
|
||||
factory.load_or_random_init(
|
||||
mod,
|
||||
device=self.config.checkpoint_device or cm.device,
|
||||
disable_preload=self.config.disable_preload,
|
||||
)
|
||||
move_to_device(mod, cm.device)
|
||||
|
||||
info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""High-level entrypoint to transform a model into an efficient inference model."""
|
||||
|
||||
import gc
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -10,6 +11,7 @@ import torch.nn as nn
|
||||
from ..distributed import common as dist_ad
|
||||
from ..models.factory import ModelFactory
|
||||
from ..shim.interface import CachedSequenceInterface
|
||||
from ..utils.logger import ad_logger
|
||||
from .interface import (
|
||||
InferenceOptimizerConfig,
|
||||
SharedConfig,
|
||||
@ -64,11 +66,14 @@ class InferenceOptimizer:
|
||||
mod = nn.Module()
|
||||
|
||||
# iterate over all transforms sorted by stage in the config
|
||||
start_time = time.time()
|
||||
for idx, (t_name, t_config) in enumerate(self.config.items()):
|
||||
# instantiate transform
|
||||
transform = TransformRegistry.get(t_name)(t_config)
|
||||
# run transform
|
||||
mod = transform(mod, cm, self.factory, self.shared_config, idx)
|
||||
total_time = time.time() - start_time
|
||||
ad_logger.info(f"Total time for all transforms: {total_time:.2f}s")
|
||||
|
||||
############################################################################################
|
||||
# RETURN OPTIMIZED MODEL
|
||||
|
||||
Loading…
Reference in New Issue
Block a user