diff --git a/tensorrt_llm/_torch/models/checkpoints/base_weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/base_weight_loader.py index 2c829e027a..2cd78e3f9d 100644 --- a/tensorrt_llm/_torch/models/checkpoints/base_weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/base_weight_loader.py @@ -1,14 +1,99 @@ +import threading from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Dict, Iterator, Tuple, Union from tensorrt_llm.mapping import Mapping +class ConsumableWeightsDict: + """ + Wrapper around a weights dictionary that allows marking keys as consumed + to free memory during model loading. + + This reduces peak memory usage by deleting weight tensors from the dictionary + after they have been copied to the model, rather than keeping all weights + in memory until loading completes. + + Thread-safe: uses a lock to protect concurrent access. Iteration methods + (keys, values, items, __iter__) return snapshot copies to allow safe + concurrent iteration while other threads may modify the dictionary. + """ + + def __init__(self, weights: Dict[str, Any]): + self._weights = weights + self._lock = threading.Lock() + + def __getitem__(self, key: str) -> Any: + return self._weights[key] + + def __setitem__(self, key: str, value: Any) -> None: + with self._lock: + self._weights[key] = value + + def __delitem__(self, key: str) -> None: + with self._lock: + del self._weights[key] + + def __contains__(self, key: str) -> bool: + return key in self._weights + + def __len__(self) -> int: + return len(self._weights) + + def __iter__(self) -> Iterator[str]: + # Return iterator over a snapshot copy of keys to allow concurrent modification + with self._lock: + return iter(list(self._weights.keys())) + + def keys(self): + # Return a snapshot copy of keys to allow concurrent modification + with self._lock: + return list(self._weights.keys()) + + def values(self): + # Return a snapshot copy of values to allow concurrent modification + with self._lock: + return list(self._weights.values()) + + def items(self) -> Iterator[Tuple[str, Any]]: + # Return a snapshot copy of items to allow concurrent modification + with self._lock: + return list(self._weights.items()) + + def get(self, key: str, default: Any = None) -> Any: + return self._weights.get(key, default) + + def update(self, other: Dict[str, Any]) -> None: + with self._lock: + self._weights.update(other) + + def mark_consumed(self, prefix: str) -> int: + """ + Delete all keys starting with the given prefix to free memory. + + Args: + prefix: The prefix to match. Keys starting with "{prefix}." will be deleted. + + Returns: + The number of keys deleted. + + Thread-safe: uses a lock to prevent concurrent modification issues. + """ + with self._lock: + keys_to_delete = [ + k for k in self._weights.keys() if k.startswith(prefix + ".") + ] + for key in keys_to_delete: + del self._weights[key] + return len(keys_to_delete) + + class BaseWeightLoader(ABC): @abstractmethod - def load_weights(self, checkpoint_dir: str, - mapping: Mapping) -> dict[str, Any]: + def load_weights( + self, checkpoint_dir: str, + mapping: Mapping) -> Union[Dict[str, Any], ConsumableWeightsDict]: """ Loads weights from a checkpoint directory. @@ -17,7 +102,8 @@ class BaseWeightLoader(ABC): mapping: A mapping object containing the distributed configuration. Returns: - A dictionary where keys are tensor names and values are the tensors. + A dictionary (or ConsumableWeightsDict) where keys are tensor names + and values are the tensors. """ def cleanup(self) -> None: diff --git a/tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py index 790be65eed..3ac76d6809 100644 --- a/tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py @@ -74,13 +74,19 @@ class BaseWeightMapper(ABC): pattern_mapping = { r'(.*?)out_proj(.*)': r'\1o_proj\2' } - weights: A dictionary of weights + weights: A dictionary of weights (or ConsumableWeightsDict) Returns: - A dictionary of weights with renamed keys + A dictionary of weights with renamed keys (preserves ConsumableWeightsDict if input was one) """ import re + from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \ + ConsumableWeightsDict + + # Check if input is a ConsumableWeightsDict to preserve the type + is_consumable = isinstance(weights, ConsumableWeightsDict) + # Create a new dictionary to store the renamed weights renamed_weights = {} @@ -103,6 +109,9 @@ class BaseWeightMapper(ABC): if key not in matched_keys: renamed_weights[key] = weights[key] + # Preserve ConsumableWeightsDict type if that's what was passed in + if is_consumable: + return ConsumableWeightsDict(renamed_weights) return renamed_weights def preprocess_weights(self, weights: dict) -> dict: diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py index 3b1c3af172..f47e77a816 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py @@ -9,8 +9,8 @@ import safetensors import torch import tqdm -from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \ - BaseWeightLoader +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import ( + BaseWeightLoader, ConsumableWeightsDict) from tensorrt_llm._torch.models.modeling_utils import ( register_checkpoint_weight_loader, run_concurrently) from tensorrt_llm._utils import (local_mpi_barrier, local_mpi_rank, @@ -70,7 +70,7 @@ class HfWeightLoader(BaseWeightLoader): raise RuntimeError(f"No weight files found in {checkpoint_dir}.") def _load_weights_in_parallel(self, weight_files: List[str], load_func, - description: str) -> dict[str, Any]: + description: str) -> ConsumableWeightsDict: """ Load weight files in parallel using the specified loading function. @@ -80,7 +80,7 @@ class HfWeightLoader(BaseWeightLoader): description: Description for the progress bar Returns: - Dictionary containing all loaded weights + ConsumableWeightsDict containing all loaded weights """ weights = {} pbar = tqdm.tqdm(total=len(weight_files), desc=description) @@ -91,7 +91,7 @@ class HfWeightLoader(BaseWeightLoader): reduce_func=weights.update, pbar=pbar) - return weights + return ConsumableWeightsDict(weights) @staticmethod def _load_safetensors_file(file): diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py index 583ebf1bf1..035b539639 100644 --- a/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py @@ -49,6 +49,11 @@ class MistralWeightMapper(HfWeightMapper): # Adapted from: # https://github.com/vllm-project/vllm/blob/883b42896a9ed9791750d721fad26005b7569eba/vllm/model_executor/models/llama.py#L657 def rename_by_params_map(self, params_map: dict[str, str], weights: dict) -> dict: + from tensorrt_llm._torch.models.checkpoints.base_weight_loader import ConsumableWeightsDict + + # Check if input is a ConsumableWeightsDict to preserve the type + is_consumable = isinstance(weights, ConsumableWeightsDict) + renamed_weights = {} for key in list(weights.keys()): @@ -68,6 +73,9 @@ class MistralWeightMapper(HfWeightMapper): renamed_weights[new_key] = weights[key] + # Preserve ConsumableWeightsDict type if that's what was passed in + if is_consumable: + return ConsumableWeightsDict(renamed_weights) return renamed_weights def _permute_qk(self, module: nn.Module, new_name: str, weights: dict): diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 430208fb57..1171bb23f6 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -39,6 +39,8 @@ from transformers import PretrainedConfig import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils from tensorrt_llm._ipc_utils import can_access_peer +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \ + ConsumableWeightsDict from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import PositionEmbeddingType from tensorrt_llm.mapping import Mapping @@ -147,7 +149,9 @@ class DeepseekV3WeightLoader: self.model_config = model.model_config self.is_draft_model = is_draft_model - def load_weights(self, weights: Dict, skip_modules: List[str] = []): + def load_weights(self, + weights: ConsumableWeightsDict, + skip_modules: List[str] = []): def requantize_weight_with_new_scale(weight, weight_scale, old_scale_2, new_scale_2, device): @@ -324,6 +328,9 @@ class DeepseekV3WeightLoader: params_map = {'gate_up_proj': ['gate_proj', 'up_proj']} all_named_modules = dict(self.model.named_modules()) + # Check if weights supports mark_consumed (ConsumableWeightsDict) + can_mark_consumed = hasattr(weights, 'mark_consumed') + for name, module in tqdm(all_named_modules.items(), desc="Loading weights"): if len(module._parameters) <= 0 or name.startswith("draft_model"): @@ -399,6 +406,9 @@ class DeepseekV3WeightLoader: -1, v_b_proj_scale.shape[-1]).cuda(), ).view(*attn_module.v_b_proj_dequant.shape).to( attn_module.v_b_proj_dequant.dtype)) + # Mark consumed kv_b_proj weights + if can_mark_consumed: + weights.mark_consumed(name) elif names[-1] == "kv_a_proj_with_mqa": nvfp4_fused_a = self.model_config.get_quant_config( ).layer_quant_mode.has_nvfp4() and weights[ @@ -522,6 +532,13 @@ class DeepseekV3WeightLoader: # For DeepseekV32: kv_a_proj_with_mqa is oversized # to include indexer k weights, which is filled in post_load_weights. module.weight.data[0:fused_a.shape[0]].copy_(fused_a) + # Mark consumed kv_a_proj_with_mqa and q_a_proj weights + if can_mark_consumed: + parent_prefix = '.'.join(names[:-1]) + weights.mark_consumed( + f"{parent_prefix}.kv_a_proj_with_mqa") + if not is_lite: + weights.mark_consumed(f"{parent_prefix}.q_a_proj") elif names[-1] in params_map: module_weights = [] for new_name in params_map[names[-1]]: @@ -529,6 +546,11 @@ class DeepseekV3WeightLoader: filter_weights('.'.join(names[:-1] + [new_name]), weights)) module.load_weights(weights=module_weights) + # Mark consumed source weights (e.g., gate_proj, up_proj) + if can_mark_consumed: + for src_name in params_map[names[-1]]: + weights.mark_consumed('.'.join(names[:-1] + + [src_name])) elif names[-1] == "experts": module_weights = filter_weights(name, weights) module_weights = rename_moe_weight(module_weights, { @@ -537,6 +559,9 @@ class DeepseekV3WeightLoader: "gate_proj": "w1", }) module.load_weights(weights=[module_weights]) + # Mark consumed experts weights + if can_mark_consumed: + weights.mark_consumed(name) elif names[-1] == "backend" and isinstance(module, MoE): # Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE) # Currently saved MoE weights don't include 'backend' in their names. @@ -552,6 +577,9 @@ class DeepseekV3WeightLoader: "gate_proj": "w1", }) module.load_weights(weights=[module_weights]) + # Mark consumed MoE weights using parent name + if can_mark_consumed: + weights.mark_consumed(parent_name) elif names[-1] == "self_attn": continue elif names[-1] == "next_layer_layernorm": @@ -563,6 +591,9 @@ class DeepseekV3WeightLoader: else: for n, p in module.named_parameters(): p.data.copy_(module_weights[n][:]) + # Mark consumed weights + if can_mark_consumed: + weights.mark_consumed(name) class DeepseekV3MTPHead(nn.Module): @@ -1805,7 +1836,7 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model, return_context_logits=return_context_logits, **kwargs) - def load_weights(self, weights: Dict): + def load_weights(self, weights: ConsumableWeightsDict): weight_loader = DeepseekV3WeightLoader(self) weight_loader.load_weights(weights) diff --git a/tensorrt_llm/_torch/models/modeling_glm.py b/tensorrt_llm/_torch/models/modeling_glm.py index 53d45d3edc..3698c273af 100644 --- a/tensorrt_llm/_torch/models/modeling_glm.py +++ b/tensorrt_llm/_torch/models/modeling_glm.py @@ -9,6 +9,7 @@ from tqdm import tqdm from transformers import PretrainedConfig from tensorrt_llm._ipc_utils import can_access_peer +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import ConsumableWeightsDict from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import PositionEmbeddingType from tensorrt_llm.models.modeling_utils import QuantConfig @@ -53,7 +54,7 @@ class Glm4WeightLoader: self.model_config = model.model_config self.is_draft_model = is_draft_model - def load_weights(self, weights: Dict, allow_partial_loading: bool = False): + def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bool = False): def rename_moe_weight(weights: Dict, rename_rules: Dict): result = {} for key, value in weights.items(): @@ -81,6 +82,9 @@ class Glm4WeightLoader: else self.config.num_attention_heads ) + # Check if weights supports mark_consumed (ConsumableWeightsDict) + can_mark_consumed = hasattr(weights, "mark_consumed") + for name, module in tqdm(all_named_modules.items(), desc="Loading weights"): if len(module._parameters) <= 0 or name.startswith("draft_model"): continue @@ -116,6 +120,10 @@ class Glm4WeightLoader: } module_weights.append(fw) module.load_weights(weights=module_weights) + # Mark consumed source weights (e.g., q_proj, k_proj, v_proj) + if can_mark_consumed: + for src_name in params_map[names[-1]]: + weights.mark_consumed(".".join(names[:-1] + [src_name])) elif names[-1] == "experts": module_weights = filter_weights(name, weights) module_weights = rename_moe_weight( @@ -129,6 +137,9 @@ class Glm4WeightLoader: module.load_weights( weights=[module_weights], allow_partial_loading=allow_partial_loading ) + # Mark consumed experts weights + if can_mark_consumed: + weights.mark_consumed(name) elif names[-1] == "backend" and isinstance(module, MoE): # Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE) # Currently saved MoE weights don't include 'backend' in their names. @@ -149,6 +160,9 @@ class Glm4WeightLoader: module.load_weights( weights=[module_weights], allow_partial_loading=allow_partial_loading ) + # Mark consumed MoE weights using parent name + if can_mark_consumed: + weights.mark_consumed(parent_name) elif names[-1] == "self_attn": continue elif names[-1] == "next_layer_layernorm": @@ -173,6 +187,9 @@ class Glm4WeightLoader: assert n in module_weights if n in module_weights: p.data.copy_(module_weights[n][:]) + # Mark consumed weights + if can_mark_consumed: + weights.mark_consumed(name) class Glm4Attention(QKNormRoPEAttention): @@ -1030,7 +1047,7 @@ class Glm4MoeForCausalLM(SpecDecOneEngineForCausalLM[Glm4Model, PretrainedConfig **kwargs, ) - def load_weights(self, weights: Dict, allow_partial_loading: bool = False): + def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bool = False): weight_loader = Glm4WeightLoader(self) weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading) diff --git a/tensorrt_llm/_torch/models/modeling_hunyuan_dense.py b/tensorrt_llm/_torch/models/modeling_hunyuan_dense.py index 2863a3a15e..30d95141f2 100644 --- a/tensorrt_llm/_torch/models/modeling_hunyuan_dense.py +++ b/tensorrt_llm/_torch/models/modeling_hunyuan_dense.py @@ -6,6 +6,8 @@ from tqdm import tqdm from transformers import PretrainedConfig from tensorrt_llm._torch.distributed import AllReduceParams +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \ + ConsumableWeightsDict from tensorrt_llm.functional import PositionEmbeddingType from ..attention_backend import AttentionMetadata @@ -623,7 +625,7 @@ class HunYuanDenseV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel, vocab_size=model_config.pretrained_config.vocab_size) self._execution_stats = None - def load_weights(self, weights: Dict): + def load_weights(self, weights: ConsumableWeightsDict): tp_size = self.model_config.mapping.tp_size head_dim = getattr( self.config, "head_dim", @@ -641,6 +643,10 @@ class HunYuanDenseV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel, 'qkv_proj': ['q_proj', 'k_proj', 'v_proj'], 'gate_up_proj': ['gate_proj', 'up_proj'] } + + # Check if weights supports mark_consumed (ConsumableWeightsDict) + can_mark_consumed = hasattr(weights, 'mark_consumed') + for name, module in tqdm(list(self.named_modules()), desc="Loading weights"): if len(module._parameters) > 0: @@ -667,7 +673,13 @@ class HunYuanDenseV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel, } module_weights.append(fw) module.load_weights(weights=module_weights) + # Mark consumed source weights (e.g., q_proj, k_proj, v_proj) + if can_mark_consumed: + for src_name in params_map[names[-1]]: + weights.mark_consumed('.'.join(names[:-1] + + [src_name])) else: + original_name = name name = name.replace('gate', 'gate.wg') module_weights = filter_weights(name, weights) if hasattr(module, 'load_weights'): @@ -678,6 +690,9 @@ class HunYuanDenseV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel, for n, p in module._parameters.items(): if p is not None: p.data.copy_(module_weights[n][:]) + # Mark consumed weights + if can_mark_consumed: + weights.mark_consumed(original_name) def forward( self, diff --git a/tensorrt_llm/_torch/models/modeling_hunyuan_moe.py b/tensorrt_llm/_torch/models/modeling_hunyuan_moe.py index 89e43d869b..541bfa30e7 100644 --- a/tensorrt_llm/_torch/models/modeling_hunyuan_moe.py +++ b/tensorrt_llm/_torch/models/modeling_hunyuan_moe.py @@ -6,6 +6,8 @@ from tqdm import tqdm from transformers import PretrainedConfig from tensorrt_llm._torch.distributed import AllReduceParams +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \ + ConsumableWeightsDict from tensorrt_llm.functional import PositionEmbeddingType from ..attention_backend import AttentionMetadata @@ -341,7 +343,7 @@ class HunYuanMoEV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel, self._execution_stats = None print("---debug model_config: ", model_config) - def load_weights(self, weights: Dict): + def load_weights(self, weights: ConsumableWeightsDict): tp_size = self.model_config.mapping.tp_size head_dim = self.config.hidden_size // self.config.num_attention_heads @@ -357,6 +359,10 @@ class HunYuanMoEV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel, 'qkv_proj': ['q_proj', 'k_proj', 'v_proj'], 'gate_up_proj': ['gate_proj', 'up_proj'] } + + # Check if weights supports mark_consumed (ConsumableWeightsDict) + can_mark_consumed = hasattr(weights, 'mark_consumed') + for name, module in tqdm(list(self.named_modules()), desc="Loading weights"): if len(module._parameters) > 0: @@ -394,7 +400,13 @@ class HunYuanMoEV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel, } module_weights.append(fw) module.load_weights(weights=module_weights) + # Mark consumed source weights (e.g., q_proj, k_proj, v_proj) + if can_mark_consumed: + for src_name in params_map[names[-1]]: + weights.mark_consumed('.'.join(names[:-1] + + [src_name])) else: + original_name = name name = name.replace('gate', 'gate.wg') module_weights = filter_weights(name, weights) if isinstance(module, CutlassFusedMoE) or isinstance( @@ -418,6 +430,9 @@ class HunYuanMoEV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel, for n, p in module._parameters.items(): if p is not None: p.data.copy_(module_weights[n][:]) + # Mark consumed weights + if can_mark_consumed: + weights.mark_consumed(original_name) def forward( self, diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index f577922fcb..f8c9f3af40 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -748,12 +748,18 @@ def rename_weights_with_regex(pattern_mapping: Dict[str, str], weights: Dict): pattern_mapping = { r'(.*?)out_proj(.*)': r'\1o_proj\2' } - weights: A dictionary of weights + weights: A dictionary of weights (or ConsumableWeightsDict) Returns: - A dictionary of weights with renamed keys + A dictionary of weights with renamed keys (preserves ConsumableWeightsDict if input was one) """ import re + from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \ + ConsumableWeightsDict + + # Check if input is a ConsumableWeightsDict to preserve the type + is_consumable = isinstance(weights, ConsumableWeightsDict) + # Create a new dictionary to store the renamed weights renamed_weights = {} @@ -776,6 +782,9 @@ def rename_weights_with_regex(pattern_mapping: Dict[str, str], weights: Dict): if key not in matched_keys: renamed_weights[key] = weights[key] + # Preserve ConsumableWeightsDict type if that's what was passed in + if is_consumable: + return ConsumableWeightsDict(renamed_weights) return renamed_weights @@ -913,6 +922,10 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM], module_weights.append(fw) module.load_weights(weights=module_weights, allow_partial_loading=allow_partial_loading) + # Mark consumed source weights (e.g., q_proj, k_proj, v_proj for qkv_proj) + if hasattr(weights, 'mark_consumed'): + for src_name in params_map[names[-1]]: + weights.mark_consumed('.'.join(names[:-1] + [src_name])) else: module_weights = filter_weights(name, weights) @@ -934,6 +947,10 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM], if n in module_weights: p.data.copy_(module_weights[n][:]) + # Mark consumed weights + if hasattr(weights, 'mark_consumed'): + weights.mark_consumed(name) + if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL", "False") in ["True", "true", "1", "yes", "y"]: for name, module in tqdm(list( @@ -969,7 +986,7 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM], def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM], - weights: Dict, + weights, weight_mapper: "BaseWeightMapper", skip_modules: List[str] = [], params_map: Optional[Dict[str, str]] = None, @@ -1008,6 +1025,12 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM], module, module_name, module_names_breakdown, weights) module.load_weights(weights=module_weights, allow_partial_loading=allow_partial_loading) + + # Mark consumed source weights (e.g., q_proj, k_proj, v_proj for qkv_proj) + if hasattr(weights, 'mark_consumed'): + for src_name in weight_mapper._mapping.get(module_name, []): + prefix = '.'.join(module_names_breakdown + [src_name]) + weights.mark_consumed(prefix) else: module_weights = weight_mapper.filter_weights(name, weights) # Note: module_weights may be empty after filtering (e.g., in streaming weight updates) @@ -1039,6 +1062,10 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM], p, allow_partial_loading=allow_partial_loading) + # Mark consumed weights + if hasattr(weights, 'mark_consumed'): + weights.mark_consumed(name) + if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL", "False") in ["True", "true", "1", "yes", "y"]: for name, module in tqdm(list(