[None][fix] Reduce host memory usage during model loading (#11119)

Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
This commit is contained in:
jthomson04 2026-02-05 08:57:40 -08:00 committed by GitHub
parent e52eb82780
commit d778b26062
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 228 additions and 20 deletions

View File

@ -1,14 +1,99 @@
import threading
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Dict, Iterator, Tuple, Union
from tensorrt_llm.mapping import Mapping
class ConsumableWeightsDict:
"""
Wrapper around a weights dictionary that allows marking keys as consumed
to free memory during model loading.
This reduces peak memory usage by deleting weight tensors from the dictionary
after they have been copied to the model, rather than keeping all weights
in memory until loading completes.
Thread-safe: uses a lock to protect concurrent access. Iteration methods
(keys, values, items, __iter__) return snapshot copies to allow safe
concurrent iteration while other threads may modify the dictionary.
"""
def __init__(self, weights: Dict[str, Any]):
self._weights = weights
self._lock = threading.Lock()
def __getitem__(self, key: str) -> Any:
return self._weights[key]
def __setitem__(self, key: str, value: Any) -> None:
with self._lock:
self._weights[key] = value
def __delitem__(self, key: str) -> None:
with self._lock:
del self._weights[key]
def __contains__(self, key: str) -> bool:
return key in self._weights
def __len__(self) -> int:
return len(self._weights)
def __iter__(self) -> Iterator[str]:
# Return iterator over a snapshot copy of keys to allow concurrent modification
with self._lock:
return iter(list(self._weights.keys()))
def keys(self):
# Return a snapshot copy of keys to allow concurrent modification
with self._lock:
return list(self._weights.keys())
def values(self):
# Return a snapshot copy of values to allow concurrent modification
with self._lock:
return list(self._weights.values())
def items(self) -> Iterator[Tuple[str, Any]]:
# Return a snapshot copy of items to allow concurrent modification
with self._lock:
return list(self._weights.items())
def get(self, key: str, default: Any = None) -> Any:
return self._weights.get(key, default)
def update(self, other: Dict[str, Any]) -> None:
with self._lock:
self._weights.update(other)
def mark_consumed(self, prefix: str) -> int:
"""
Delete all keys starting with the given prefix to free memory.
Args:
prefix: The prefix to match. Keys starting with "{prefix}." will be deleted.
Returns:
The number of keys deleted.
Thread-safe: uses a lock to prevent concurrent modification issues.
"""
with self._lock:
keys_to_delete = [
k for k in self._weights.keys() if k.startswith(prefix + ".")
]
for key in keys_to_delete:
del self._weights[key]
return len(keys_to_delete)
class BaseWeightLoader(ABC):
@abstractmethod
def load_weights(self, checkpoint_dir: str,
mapping: Mapping) -> dict[str, Any]:
def load_weights(
self, checkpoint_dir: str,
mapping: Mapping) -> Union[Dict[str, Any], ConsumableWeightsDict]:
"""
Loads weights from a checkpoint directory.
@ -17,7 +102,8 @@ class BaseWeightLoader(ABC):
mapping: A mapping object containing the distributed configuration.
Returns:
A dictionary where keys are tensor names and values are the tensors.
A dictionary (or ConsumableWeightsDict) where keys are tensor names
and values are the tensors.
"""
def cleanup(self) -> None:

View File

@ -74,13 +74,19 @@ class BaseWeightMapper(ABC):
pattern_mapping = {
r'(.*?)out_proj(.*)': r'\1o_proj\2'
}
weights: A dictionary of weights
weights: A dictionary of weights (or ConsumableWeightsDict)
Returns:
A dictionary of weights with renamed keys
A dictionary of weights with renamed keys (preserves ConsumableWeightsDict if input was one)
"""
import re
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
ConsumableWeightsDict
# Check if input is a ConsumableWeightsDict to preserve the type
is_consumable = isinstance(weights, ConsumableWeightsDict)
# Create a new dictionary to store the renamed weights
renamed_weights = {}
@ -103,6 +109,9 @@ class BaseWeightMapper(ABC):
if key not in matched_keys:
renamed_weights[key] = weights[key]
# Preserve ConsumableWeightsDict type if that's what was passed in
if is_consumable:
return ConsumableWeightsDict(renamed_weights)
return renamed_weights
def preprocess_weights(self, weights: dict) -> dict:

View File

@ -9,8 +9,8 @@ import safetensors
import torch
import tqdm
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
BaseWeightLoader
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import (
BaseWeightLoader, ConsumableWeightsDict)
from tensorrt_llm._torch.models.modeling_utils import (
register_checkpoint_weight_loader, run_concurrently)
from tensorrt_llm._utils import (local_mpi_barrier, local_mpi_rank,
@ -70,7 +70,7 @@ class HfWeightLoader(BaseWeightLoader):
raise RuntimeError(f"No weight files found in {checkpoint_dir}.")
def _load_weights_in_parallel(self, weight_files: List[str], load_func,
description: str) -> dict[str, Any]:
description: str) -> ConsumableWeightsDict:
"""
Load weight files in parallel using the specified loading function.
@ -80,7 +80,7 @@ class HfWeightLoader(BaseWeightLoader):
description: Description for the progress bar
Returns:
Dictionary containing all loaded weights
ConsumableWeightsDict containing all loaded weights
"""
weights = {}
pbar = tqdm.tqdm(total=len(weight_files), desc=description)
@ -91,7 +91,7 @@ class HfWeightLoader(BaseWeightLoader):
reduce_func=weights.update,
pbar=pbar)
return weights
return ConsumableWeightsDict(weights)
@staticmethod
def _load_safetensors_file(file):

View File

@ -49,6 +49,11 @@ class MistralWeightMapper(HfWeightMapper):
# Adapted from:
# https://github.com/vllm-project/vllm/blob/883b42896a9ed9791750d721fad26005b7569eba/vllm/model_executor/models/llama.py#L657
def rename_by_params_map(self, params_map: dict[str, str], weights: dict) -> dict:
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import ConsumableWeightsDict
# Check if input is a ConsumableWeightsDict to preserve the type
is_consumable = isinstance(weights, ConsumableWeightsDict)
renamed_weights = {}
for key in list(weights.keys()):
@ -68,6 +73,9 @@ class MistralWeightMapper(HfWeightMapper):
renamed_weights[new_key] = weights[key]
# Preserve ConsumableWeightsDict type if that's what was passed in
if is_consumable:
return ConsumableWeightsDict(renamed_weights)
return renamed_weights
def _permute_qk(self, module: nn.Module, new_name: str, weights: dict):

View File

@ -39,6 +39,8 @@ from transformers import PretrainedConfig
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm._ipc_utils import can_access_peer
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
ConsumableWeightsDict
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.mapping import Mapping
@ -147,7 +149,9 @@ class DeepseekV3WeightLoader:
self.model_config = model.model_config
self.is_draft_model = is_draft_model
def load_weights(self, weights: Dict, skip_modules: List[str] = []):
def load_weights(self,
weights: ConsumableWeightsDict,
skip_modules: List[str] = []):
def requantize_weight_with_new_scale(weight, weight_scale, old_scale_2,
new_scale_2, device):
@ -324,6 +328,9 @@ class DeepseekV3WeightLoader:
params_map = {'gate_up_proj': ['gate_proj', 'up_proj']}
all_named_modules = dict(self.model.named_modules())
# Check if weights supports mark_consumed (ConsumableWeightsDict)
can_mark_consumed = hasattr(weights, 'mark_consumed')
for name, module in tqdm(all_named_modules.items(),
desc="Loading weights"):
if len(module._parameters) <= 0 or name.startswith("draft_model"):
@ -399,6 +406,9 @@ class DeepseekV3WeightLoader:
-1, v_b_proj_scale.shape[-1]).cuda(),
).view(*attn_module.v_b_proj_dequant.shape).to(
attn_module.v_b_proj_dequant.dtype))
# Mark consumed kv_b_proj weights
if can_mark_consumed:
weights.mark_consumed(name)
elif names[-1] == "kv_a_proj_with_mqa":
nvfp4_fused_a = self.model_config.get_quant_config(
).layer_quant_mode.has_nvfp4() and weights[
@ -522,6 +532,13 @@ class DeepseekV3WeightLoader:
# For DeepseekV32: kv_a_proj_with_mqa is oversized
# to include indexer k weights, which is filled in post_load_weights.
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
# Mark consumed kv_a_proj_with_mqa and q_a_proj weights
if can_mark_consumed:
parent_prefix = '.'.join(names[:-1])
weights.mark_consumed(
f"{parent_prefix}.kv_a_proj_with_mqa")
if not is_lite:
weights.mark_consumed(f"{parent_prefix}.q_a_proj")
elif names[-1] in params_map:
module_weights = []
for new_name in params_map[names[-1]]:
@ -529,6 +546,11 @@ class DeepseekV3WeightLoader:
filter_weights('.'.join(names[:-1] + [new_name]),
weights))
module.load_weights(weights=module_weights)
# Mark consumed source weights (e.g., gate_proj, up_proj)
if can_mark_consumed:
for src_name in params_map[names[-1]]:
weights.mark_consumed('.'.join(names[:-1] +
[src_name]))
elif names[-1] == "experts":
module_weights = filter_weights(name, weights)
module_weights = rename_moe_weight(module_weights, {
@ -537,6 +559,9 @@ class DeepseekV3WeightLoader:
"gate_proj": "w1",
})
module.load_weights(weights=[module_weights])
# Mark consumed experts weights
if can_mark_consumed:
weights.mark_consumed(name)
elif names[-1] == "backend" and isinstance(module, MoE):
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
# Currently saved MoE weights don't include 'backend' in their names.
@ -552,6 +577,9 @@ class DeepseekV3WeightLoader:
"gate_proj": "w1",
})
module.load_weights(weights=[module_weights])
# Mark consumed MoE weights using parent name
if can_mark_consumed:
weights.mark_consumed(parent_name)
elif names[-1] == "self_attn":
continue
elif names[-1] == "next_layer_layernorm":
@ -563,6 +591,9 @@ class DeepseekV3WeightLoader:
else:
for n, p in module.named_parameters():
p.data.copy_(module_weights[n][:])
# Mark consumed weights
if can_mark_consumed:
weights.mark_consumed(name)
class DeepseekV3MTPHead(nn.Module):
@ -1805,7 +1836,7 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
return_context_logits=return_context_logits,
**kwargs)
def load_weights(self, weights: Dict):
def load_weights(self, weights: ConsumableWeightsDict):
weight_loader = DeepseekV3WeightLoader(self)
weight_loader.load_weights(weights)

View File

@ -9,6 +9,7 @@ from tqdm import tqdm
from transformers import PretrainedConfig
from tensorrt_llm._ipc_utils import can_access_peer
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import ConsumableWeightsDict
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.models.modeling_utils import QuantConfig
@ -53,7 +54,7 @@ class Glm4WeightLoader:
self.model_config = model.model_config
self.is_draft_model = is_draft_model
def load_weights(self, weights: Dict, allow_partial_loading: bool = False):
def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bool = False):
def rename_moe_weight(weights: Dict, rename_rules: Dict):
result = {}
for key, value in weights.items():
@ -81,6 +82,9 @@ class Glm4WeightLoader:
else self.config.num_attention_heads
)
# Check if weights supports mark_consumed (ConsumableWeightsDict)
can_mark_consumed = hasattr(weights, "mark_consumed")
for name, module in tqdm(all_named_modules.items(), desc="Loading weights"):
if len(module._parameters) <= 0 or name.startswith("draft_model"):
continue
@ -116,6 +120,10 @@ class Glm4WeightLoader:
}
module_weights.append(fw)
module.load_weights(weights=module_weights)
# Mark consumed source weights (e.g., q_proj, k_proj, v_proj)
if can_mark_consumed:
for src_name in params_map[names[-1]]:
weights.mark_consumed(".".join(names[:-1] + [src_name]))
elif names[-1] == "experts":
module_weights = filter_weights(name, weights)
module_weights = rename_moe_weight(
@ -129,6 +137,9 @@ class Glm4WeightLoader:
module.load_weights(
weights=[module_weights], allow_partial_loading=allow_partial_loading
)
# Mark consumed experts weights
if can_mark_consumed:
weights.mark_consumed(name)
elif names[-1] == "backend" and isinstance(module, MoE):
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
# Currently saved MoE weights don't include 'backend' in their names.
@ -149,6 +160,9 @@ class Glm4WeightLoader:
module.load_weights(
weights=[module_weights], allow_partial_loading=allow_partial_loading
)
# Mark consumed MoE weights using parent name
if can_mark_consumed:
weights.mark_consumed(parent_name)
elif names[-1] == "self_attn":
continue
elif names[-1] == "next_layer_layernorm":
@ -173,6 +187,9 @@ class Glm4WeightLoader:
assert n in module_weights
if n in module_weights:
p.data.copy_(module_weights[n][:])
# Mark consumed weights
if can_mark_consumed:
weights.mark_consumed(name)
class Glm4Attention(QKNormRoPEAttention):
@ -1030,7 +1047,7 @@ class Glm4MoeForCausalLM(SpecDecOneEngineForCausalLM[Glm4Model, PretrainedConfig
**kwargs,
)
def load_weights(self, weights: Dict, allow_partial_loading: bool = False):
def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bool = False):
weight_loader = Glm4WeightLoader(self)
weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading)

View File

@ -6,6 +6,8 @@ from tqdm import tqdm
from transformers import PretrainedConfig
from tensorrt_llm._torch.distributed import AllReduceParams
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
ConsumableWeightsDict
from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
@ -623,7 +625,7 @@ class HunYuanDenseV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
vocab_size=model_config.pretrained_config.vocab_size)
self._execution_stats = None
def load_weights(self, weights: Dict):
def load_weights(self, weights: ConsumableWeightsDict):
tp_size = self.model_config.mapping.tp_size
head_dim = getattr(
self.config, "head_dim",
@ -641,6 +643,10 @@ class HunYuanDenseV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
'gate_up_proj': ['gate_proj', 'up_proj']
}
# Check if weights supports mark_consumed (ConsumableWeightsDict)
can_mark_consumed = hasattr(weights, 'mark_consumed')
for name, module in tqdm(list(self.named_modules()),
desc="Loading weights"):
if len(module._parameters) > 0:
@ -667,7 +673,13 @@ class HunYuanDenseV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
}
module_weights.append(fw)
module.load_weights(weights=module_weights)
# Mark consumed source weights (e.g., q_proj, k_proj, v_proj)
if can_mark_consumed:
for src_name in params_map[names[-1]]:
weights.mark_consumed('.'.join(names[:-1] +
[src_name]))
else:
original_name = name
name = name.replace('gate', 'gate.wg')
module_weights = filter_weights(name, weights)
if hasattr(module, 'load_weights'):
@ -678,6 +690,9 @@ class HunYuanDenseV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
for n, p in module._parameters.items():
if p is not None:
p.data.copy_(module_weights[n][:])
# Mark consumed weights
if can_mark_consumed:
weights.mark_consumed(original_name)
def forward(
self,

View File

@ -6,6 +6,8 @@ from tqdm import tqdm
from transformers import PretrainedConfig
from tensorrt_llm._torch.distributed import AllReduceParams
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
ConsumableWeightsDict
from tensorrt_llm.functional import PositionEmbeddingType
from ..attention_backend import AttentionMetadata
@ -341,7 +343,7 @@ class HunYuanMoEV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
self._execution_stats = None
print("---debug model_config: ", model_config)
def load_weights(self, weights: Dict):
def load_weights(self, weights: ConsumableWeightsDict):
tp_size = self.model_config.mapping.tp_size
head_dim = self.config.hidden_size // self.config.num_attention_heads
@ -357,6 +359,10 @@ class HunYuanMoEV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
'gate_up_proj': ['gate_proj', 'up_proj']
}
# Check if weights supports mark_consumed (ConsumableWeightsDict)
can_mark_consumed = hasattr(weights, 'mark_consumed')
for name, module in tqdm(list(self.named_modules()),
desc="Loading weights"):
if len(module._parameters) > 0:
@ -394,7 +400,13 @@ class HunYuanMoEV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
}
module_weights.append(fw)
module.load_weights(weights=module_weights)
# Mark consumed source weights (e.g., q_proj, k_proj, v_proj)
if can_mark_consumed:
for src_name in params_map[names[-1]]:
weights.mark_consumed('.'.join(names[:-1] +
[src_name]))
else:
original_name = name
name = name.replace('gate', 'gate.wg')
module_weights = filter_weights(name, weights)
if isinstance(module, CutlassFusedMoE) or isinstance(
@ -418,6 +430,9 @@ class HunYuanMoEV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
for n, p in module._parameters.items():
if p is not None:
p.data.copy_(module_weights[n][:])
# Mark consumed weights
if can_mark_consumed:
weights.mark_consumed(original_name)
def forward(
self,

View File

@ -748,12 +748,18 @@ def rename_weights_with_regex(pattern_mapping: Dict[str, str], weights: Dict):
pattern_mapping = {
r'(.*?)out_proj(.*)': r'\1o_proj\2'
}
weights: A dictionary of weights
weights: A dictionary of weights (or ConsumableWeightsDict)
Returns:
A dictionary of weights with renamed keys
A dictionary of weights with renamed keys (preserves ConsumableWeightsDict if input was one)
"""
import re
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
ConsumableWeightsDict
# Check if input is a ConsumableWeightsDict to preserve the type
is_consumable = isinstance(weights, ConsumableWeightsDict)
# Create a new dictionary to store the renamed weights
renamed_weights = {}
@ -776,6 +782,9 @@ def rename_weights_with_regex(pattern_mapping: Dict[str, str], weights: Dict):
if key not in matched_keys:
renamed_weights[key] = weights[key]
# Preserve ConsumableWeightsDict type if that's what was passed in
if is_consumable:
return ConsumableWeightsDict(renamed_weights)
return renamed_weights
@ -913,6 +922,10 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
module_weights.append(fw)
module.load_weights(weights=module_weights,
allow_partial_loading=allow_partial_loading)
# Mark consumed source weights (e.g., q_proj, k_proj, v_proj for qkv_proj)
if hasattr(weights, 'mark_consumed'):
for src_name in params_map[names[-1]]:
weights.mark_consumed('.'.join(names[:-1] + [src_name]))
else:
module_weights = filter_weights(name, weights)
@ -934,6 +947,10 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
if n in module_weights:
p.data.copy_(module_weights[n][:])
# Mark consumed weights
if hasattr(weights, 'mark_consumed'):
weights.mark_consumed(name)
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
"False") in ["True", "true", "1", "yes", "y"]:
for name, module in tqdm(list(
@ -969,7 +986,7 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
weights: Dict,
weights,
weight_mapper: "BaseWeightMapper",
skip_modules: List[str] = [],
params_map: Optional[Dict[str, str]] = None,
@ -1008,6 +1025,12 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
module, module_name, module_names_breakdown, weights)
module.load_weights(weights=module_weights,
allow_partial_loading=allow_partial_loading)
# Mark consumed source weights (e.g., q_proj, k_proj, v_proj for qkv_proj)
if hasattr(weights, 'mark_consumed'):
for src_name in weight_mapper._mapping.get(module_name, []):
prefix = '.'.join(module_names_breakdown + [src_name])
weights.mark_consumed(prefix)
else:
module_weights = weight_mapper.filter_weights(name, weights)
# Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
@ -1039,6 +1062,10 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
p,
allow_partial_loading=allow_partial_loading)
# Mark consumed weights
if hasattr(weights, 'mark_consumed'):
weights.mark_consumed(name)
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
"False") in ["True", "true", "1", "yes", "y"]:
for name, module in tqdm(list(