[#11086][feat] Optimize Auto Deploy weight loading by preloading weights to CPU (#11059)

Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
This commit is contained in:
Taylor Yeonbok Lee 2026-02-03 13:23:10 -08:00 committed by GitHub
parent f9c4bdf6cf
commit f9e6045f39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 128 additions and 14 deletions

View File

@ -263,7 +263,9 @@ class ModelFactory(ABC):
"""
return model_name_or_path
def load_or_random_init(self, model: nn.Module, device: DeviceLikeType):
def load_or_random_init(
self, model: nn.Module, device: DeviceLikeType, disable_preload: bool = False
):
"""Load the checkpoint into the model or randomly initialize the model.
Args:
@ -271,6 +273,7 @@ class ModelFactory(ABC):
the same model that is built above but it needs to have a state dict compatible with
the model built above.
device: The device to load the model on.
disable_preload: If True, disable preloading weights to CPU before moving to device.
load_factoy_model: If True, will load weights for the factory model in addition to main
gm. This is useful for the transformers model.
@ -303,7 +306,7 @@ class ModelFactory(ABC):
if not self.skip_loading_weights:
self.prefetch_checkpoint(force=True)
self._load_checkpoint(model, device)
self._load_checkpoint(model, device, disable_preload=disable_preload)
@staticmethod
def _to_maybe_random(model: nn.Module, device: DeviceLikeType):
@ -323,7 +326,9 @@ class ModelFactory(ABC):
)
@abstractmethod
def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
def _load_checkpoint(
self, model: nn.Module, device: DeviceLikeType, disable_preload: bool = False
):
"""Load the checkpoint into the model.
Args:
@ -331,6 +336,7 @@ class ModelFactory(ABC):
the same model that is built above but it needs to have a state dict compatible with
the model built above.
device: The device to load the model on.
disable_preload: If True, disable preloading weights to CPU before moving to device.
"""
def get_example_inputs(self) -> Dict[str, torch.Tensor]:

View File

@ -1,5 +1,6 @@
"""Interface to initialize and load HF models."""
import json
import os
import re
import types
@ -7,6 +8,7 @@ from abc import abstractmethod
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import safetensors.torch
import torch
import torch.nn as nn
from accelerate import init_empty_weights, load_checkpoint_in_model
@ -419,7 +421,9 @@ class AutoModelForCausalLMFactory(AutoModelFactory):
return fetched_dir
def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
def _load_checkpoint(
self, model: nn.Module, device: DeviceLikeType, disable_preload: bool = False
):
"""Load the checkpoint into the model."""
# identify the most relevant checkpoint file
ckpt_file = self._get_checkpoint_file(self.model)
@ -434,20 +438,111 @@ class AutoModelForCausalLMFactory(AutoModelFactory):
# Ensure it's the first one.
model._state_dict_hooks.move_to_end(key=get_handle.id, last=False)
# reuse the load checkpoint utility from accelerate
try:
with hf_load_state_dict_with_device(device):
# Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic.
# Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict,
# which collects local model params, syncs weights from checkpoint, and applies them via
# model.load_state_dict.
# This sync step can interfere with load_hooks by mixing raw checkpoint weights and
# model-transformed weights,leading to unexpected key mismatches or format issues.
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
if disable_preload:
# Load checkpoint directly to GPU using accelerate's load_checkpoint_in_model (no CPU preload)
ad_logger.info(
"disable_preload=True: Using accelerate's load_checkpoint_in_model (no CPU preload)"
)
with hf_load_state_dict_with_device(device):
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
else:
# Preload checkpoint files to CPU
ad_logger.info("Preloading checkpoint files to CPU")
self._load_checkpoint_with_preload(model, ckpt_file, device)
finally:
load_handle.remove()
get_handle.remove()
def _load_checkpoint_with_preload(
self, model: nn.Module, ckpt_file: str, device: DeviceLikeType
):
all_weights = self._load_full_checkpoint_to_cpu(ckpt_file)
ad_logger.info(f"Loading weights into model (device: {device})...")
model.load_state_dict(all_weights, strict=False)
ad_logger.info("Checkpoint loading completed")
def _load_full_checkpoint_to_cpu(self, checkpoint: str) -> dict:
"""Load the full checkpoint to CPU memory.
Args:
checkpoint: Can be:
- a path to a file containing a whole model state dict
- a path to a `.json` file containing the index to a sharded checkpoint
- a path to a folder containing a unique `.index.json` file and the shards
- a path to a folder containing a unique pytorch_model.bin or model.safetensors
"""
checkpoint_files = None
index_filename = None
# Fast path: Direct .index.json file (most common case for sharded checkpoints)
if os.path.isfile(checkpoint):
if checkpoint.endswith(".index.json"):
index_filename = checkpoint
else:
checkpoint_files = [checkpoint]
elif os.path.isdir(checkpoint):
# Check if the whole state dict is present (priority order matches accelerate)
potential_state_bin = [f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME]
potential_state_safetensor = [
f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME
]
# Case 1: pytorch_model.bin (WEIGHTS_NAME)
if len(potential_state_bin) == 1:
checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])]
# Case 2: model.safetensors (SAFE_WEIGHTS_NAME)
elif len(potential_state_safetensor) == 1:
checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])]
else:
# Case 3: Otherwise check for sharded checkpoints
potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")]
if len(potential_index) == 0:
raise ValueError(
f"{checkpoint} is not a folder containing a `.index.json` file or a "
f"{WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file"
)
elif len(potential_index) == 1:
index_filename = os.path.join(checkpoint, potential_index[0])
else:
raise ValueError(
f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
)
else:
raise ValueError(
f"`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got "
f"{checkpoint}."
)
# Load checkpoint files from index if needed
if index_filename is not None:
checkpoint_folder = os.path.dirname(index_filename)
with open(index_filename, "r") as f:
index = json.load(f)
if "weight_map" in index:
index = index["weight_map"]
checkpoint_files = list(set(index.values()))
checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files]
# Load all weights
all_weights = {}
for checkpoint_file in checkpoint_files:
ad_logger.info(f"Loading weight file: {checkpoint_file}")
if checkpoint_file.endswith(".safetensors"):
file_weights = safetensors.torch.load_file(checkpoint_file, device="cpu")
elif checkpoint_file.endswith((".bin", ".pth")):
file_weights = torch.load(checkpoint_file, map_location="cpu", weights_only=True)
else:
raise ValueError(f"Unsupported checkpoint format: {checkpoint_file}")
all_weights.update(file_weights)
return all_weights
def _load_quantization_config(self, fetched_dir: str):
"""Load the quantization config from the model directory if not done already."""
if self._quant_config_reader is not None:

View File

@ -25,6 +25,10 @@ class MoveDeviceConfig(TransformConfig):
default=None,
description="Optional device to init checkpoint before move to shared_config.local_device.",
)
disable_preload: bool = Field(
default=False,
description="If True, disable preloading weights.",
)
@TransformRegistry.register("load_weights")
@ -48,7 +52,11 @@ class LoadWeightsToDevice(BaseTransform):
total_size_GB = bytes_to(params_size, unit="GB")
self._log_info(f"Estimated parameters memory: {total_size_GB:.2f} GB")
factory.load_or_random_init(mod, device=self.config.checkpoint_device or cm.device)
factory.load_or_random_init(
mod,
device=self.config.checkpoint_device or cm.device,
disable_preload=self.config.disable_preload,
)
move_to_device(mod, cm.device)
info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True)

View File

@ -1,6 +1,7 @@
"""High-level entrypoint to transform a model into an efficient inference model."""
import gc
import time
from typing import Optional
import torch
@ -10,6 +11,7 @@ import torch.nn as nn
from ..distributed import common as dist_ad
from ..models.factory import ModelFactory
from ..shim.interface import CachedSequenceInterface
from ..utils.logger import ad_logger
from .interface import (
InferenceOptimizerConfig,
SharedConfig,
@ -64,11 +66,14 @@ class InferenceOptimizer:
mod = nn.Module()
# iterate over all transforms sorted by stage in the config
start_time = time.time()
for idx, (t_name, t_config) in enumerate(self.config.items()):
# instantiate transform
transform = TransformRegistry.get(t_name)(t_config)
# run transform
mod = transform(mod, cm, self.factory, self.shared_config, idx)
total_time = time.time() - start_time
ad_logger.info(f"Total time for all transforms: {total_time:.2f}s")
############################################################################################
# RETURN OPTIMIZED MODEL