mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[None][fix] Reduce host memory usage during model loading (#11119)
Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
This commit is contained in:
parent
e52eb82780
commit
d778b26062
@ -1,14 +1,99 @@
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Iterator, Tuple, Union
|
||||
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
|
||||
class ConsumableWeightsDict:
|
||||
"""
|
||||
Wrapper around a weights dictionary that allows marking keys as consumed
|
||||
to free memory during model loading.
|
||||
|
||||
This reduces peak memory usage by deleting weight tensors from the dictionary
|
||||
after they have been copied to the model, rather than keeping all weights
|
||||
in memory until loading completes.
|
||||
|
||||
Thread-safe: uses a lock to protect concurrent access. Iteration methods
|
||||
(keys, values, items, __iter__) return snapshot copies to allow safe
|
||||
concurrent iteration while other threads may modify the dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, weights: Dict[str, Any]):
|
||||
self._weights = weights
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._weights[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
with self._lock:
|
||||
self._weights[key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
with self._lock:
|
||||
del self._weights[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self._weights
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._weights)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
# Return iterator over a snapshot copy of keys to allow concurrent modification
|
||||
with self._lock:
|
||||
return iter(list(self._weights.keys()))
|
||||
|
||||
def keys(self):
|
||||
# Return a snapshot copy of keys to allow concurrent modification
|
||||
with self._lock:
|
||||
return list(self._weights.keys())
|
||||
|
||||
def values(self):
|
||||
# Return a snapshot copy of values to allow concurrent modification
|
||||
with self._lock:
|
||||
return list(self._weights.values())
|
||||
|
||||
def items(self) -> Iterator[Tuple[str, Any]]:
|
||||
# Return a snapshot copy of items to allow concurrent modification
|
||||
with self._lock:
|
||||
return list(self._weights.items())
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self._weights.get(key, default)
|
||||
|
||||
def update(self, other: Dict[str, Any]) -> None:
|
||||
with self._lock:
|
||||
self._weights.update(other)
|
||||
|
||||
def mark_consumed(self, prefix: str) -> int:
|
||||
"""
|
||||
Delete all keys starting with the given prefix to free memory.
|
||||
|
||||
Args:
|
||||
prefix: The prefix to match. Keys starting with "{prefix}." will be deleted.
|
||||
|
||||
Returns:
|
||||
The number of keys deleted.
|
||||
|
||||
Thread-safe: uses a lock to prevent concurrent modification issues.
|
||||
"""
|
||||
with self._lock:
|
||||
keys_to_delete = [
|
||||
k for k in self._weights.keys() if k.startswith(prefix + ".")
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
del self._weights[key]
|
||||
return len(keys_to_delete)
|
||||
|
||||
|
||||
class BaseWeightLoader(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def load_weights(self, checkpoint_dir: str,
|
||||
mapping: Mapping) -> dict[str, Any]:
|
||||
def load_weights(
|
||||
self, checkpoint_dir: str,
|
||||
mapping: Mapping) -> Union[Dict[str, Any], ConsumableWeightsDict]:
|
||||
"""
|
||||
Loads weights from a checkpoint directory.
|
||||
|
||||
@ -17,7 +102,8 @@ class BaseWeightLoader(ABC):
|
||||
mapping: A mapping object containing the distributed configuration.
|
||||
|
||||
Returns:
|
||||
A dictionary where keys are tensor names and values are the tensors.
|
||||
A dictionary (or ConsumableWeightsDict) where keys are tensor names
|
||||
and values are the tensors.
|
||||
"""
|
||||
|
||||
def cleanup(self) -> None:
|
||||
|
||||
@ -74,13 +74,19 @@ class BaseWeightMapper(ABC):
|
||||
pattern_mapping = {
|
||||
r'(.*?)out_proj(.*)': r'\1o_proj\2'
|
||||
}
|
||||
weights: A dictionary of weights
|
||||
weights: A dictionary of weights (or ConsumableWeightsDict)
|
||||
|
||||
Returns:
|
||||
A dictionary of weights with renamed keys
|
||||
A dictionary of weights with renamed keys (preserves ConsumableWeightsDict if input was one)
|
||||
"""
|
||||
import re
|
||||
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
|
||||
ConsumableWeightsDict
|
||||
|
||||
# Check if input is a ConsumableWeightsDict to preserve the type
|
||||
is_consumable = isinstance(weights, ConsumableWeightsDict)
|
||||
|
||||
# Create a new dictionary to store the renamed weights
|
||||
renamed_weights = {}
|
||||
|
||||
@ -103,6 +109,9 @@ class BaseWeightMapper(ABC):
|
||||
if key not in matched_keys:
|
||||
renamed_weights[key] = weights[key]
|
||||
|
||||
# Preserve ConsumableWeightsDict type if that's what was passed in
|
||||
if is_consumable:
|
||||
return ConsumableWeightsDict(renamed_weights)
|
||||
return renamed_weights
|
||||
|
||||
def preprocess_weights(self, weights: dict) -> dict:
|
||||
|
||||
@ -9,8 +9,8 @@ import safetensors
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
|
||||
BaseWeightLoader
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import (
|
||||
BaseWeightLoader, ConsumableWeightsDict)
|
||||
from tensorrt_llm._torch.models.modeling_utils import (
|
||||
register_checkpoint_weight_loader, run_concurrently)
|
||||
from tensorrt_llm._utils import (local_mpi_barrier, local_mpi_rank,
|
||||
@ -70,7 +70,7 @@ class HfWeightLoader(BaseWeightLoader):
|
||||
raise RuntimeError(f"No weight files found in {checkpoint_dir}.")
|
||||
|
||||
def _load_weights_in_parallel(self, weight_files: List[str], load_func,
|
||||
description: str) -> dict[str, Any]:
|
||||
description: str) -> ConsumableWeightsDict:
|
||||
"""
|
||||
Load weight files in parallel using the specified loading function.
|
||||
|
||||
@ -80,7 +80,7 @@ class HfWeightLoader(BaseWeightLoader):
|
||||
description: Description for the progress bar
|
||||
|
||||
Returns:
|
||||
Dictionary containing all loaded weights
|
||||
ConsumableWeightsDict containing all loaded weights
|
||||
"""
|
||||
weights = {}
|
||||
pbar = tqdm.tqdm(total=len(weight_files), desc=description)
|
||||
@ -91,7 +91,7 @@ class HfWeightLoader(BaseWeightLoader):
|
||||
reduce_func=weights.update,
|
||||
pbar=pbar)
|
||||
|
||||
return weights
|
||||
return ConsumableWeightsDict(weights)
|
||||
|
||||
@staticmethod
|
||||
def _load_safetensors_file(file):
|
||||
|
||||
@ -49,6 +49,11 @@ class MistralWeightMapper(HfWeightMapper):
|
||||
# Adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/883b42896a9ed9791750d721fad26005b7569eba/vllm/model_executor/models/llama.py#L657
|
||||
def rename_by_params_map(self, params_map: dict[str, str], weights: dict) -> dict:
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import ConsumableWeightsDict
|
||||
|
||||
# Check if input is a ConsumableWeightsDict to preserve the type
|
||||
is_consumable = isinstance(weights, ConsumableWeightsDict)
|
||||
|
||||
renamed_weights = {}
|
||||
|
||||
for key in list(weights.keys()):
|
||||
@ -68,6 +73,9 @@ class MistralWeightMapper(HfWeightMapper):
|
||||
|
||||
renamed_weights[new_key] = weights[key]
|
||||
|
||||
# Preserve ConsumableWeightsDict type if that's what was passed in
|
||||
if is_consumable:
|
||||
return ConsumableWeightsDict(renamed_weights)
|
||||
return renamed_weights
|
||||
|
||||
def _permute_qk(self, module: nn.Module, new_name: str, weights: dict):
|
||||
|
||||
@ -39,6 +39,8 @@ from transformers import PretrainedConfig
|
||||
|
||||
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
from tensorrt_llm._ipc_utils import can_access_peer
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
|
||||
ConsumableWeightsDict
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -147,7 +149,9 @@ class DeepseekV3WeightLoader:
|
||||
self.model_config = model.model_config
|
||||
self.is_draft_model = is_draft_model
|
||||
|
||||
def load_weights(self, weights: Dict, skip_modules: List[str] = []):
|
||||
def load_weights(self,
|
||||
weights: ConsumableWeightsDict,
|
||||
skip_modules: List[str] = []):
|
||||
|
||||
def requantize_weight_with_new_scale(weight, weight_scale, old_scale_2,
|
||||
new_scale_2, device):
|
||||
@ -324,6 +328,9 @@ class DeepseekV3WeightLoader:
|
||||
params_map = {'gate_up_proj': ['gate_proj', 'up_proj']}
|
||||
all_named_modules = dict(self.model.named_modules())
|
||||
|
||||
# Check if weights supports mark_consumed (ConsumableWeightsDict)
|
||||
can_mark_consumed = hasattr(weights, 'mark_consumed')
|
||||
|
||||
for name, module in tqdm(all_named_modules.items(),
|
||||
desc="Loading weights"):
|
||||
if len(module._parameters) <= 0 or name.startswith("draft_model"):
|
||||
@ -399,6 +406,9 @@ class DeepseekV3WeightLoader:
|
||||
-1, v_b_proj_scale.shape[-1]).cuda(),
|
||||
).view(*attn_module.v_b_proj_dequant.shape).to(
|
||||
attn_module.v_b_proj_dequant.dtype))
|
||||
# Mark consumed kv_b_proj weights
|
||||
if can_mark_consumed:
|
||||
weights.mark_consumed(name)
|
||||
elif names[-1] == "kv_a_proj_with_mqa":
|
||||
nvfp4_fused_a = self.model_config.get_quant_config(
|
||||
).layer_quant_mode.has_nvfp4() and weights[
|
||||
@ -522,6 +532,13 @@ class DeepseekV3WeightLoader:
|
||||
# For DeepseekV32: kv_a_proj_with_mqa is oversized
|
||||
# to include indexer k weights, which is filled in post_load_weights.
|
||||
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
|
||||
# Mark consumed kv_a_proj_with_mqa and q_a_proj weights
|
||||
if can_mark_consumed:
|
||||
parent_prefix = '.'.join(names[:-1])
|
||||
weights.mark_consumed(
|
||||
f"{parent_prefix}.kv_a_proj_with_mqa")
|
||||
if not is_lite:
|
||||
weights.mark_consumed(f"{parent_prefix}.q_a_proj")
|
||||
elif names[-1] in params_map:
|
||||
module_weights = []
|
||||
for new_name in params_map[names[-1]]:
|
||||
@ -529,6 +546,11 @@ class DeepseekV3WeightLoader:
|
||||
filter_weights('.'.join(names[:-1] + [new_name]),
|
||||
weights))
|
||||
module.load_weights(weights=module_weights)
|
||||
# Mark consumed source weights (e.g., gate_proj, up_proj)
|
||||
if can_mark_consumed:
|
||||
for src_name in params_map[names[-1]]:
|
||||
weights.mark_consumed('.'.join(names[:-1] +
|
||||
[src_name]))
|
||||
elif names[-1] == "experts":
|
||||
module_weights = filter_weights(name, weights)
|
||||
module_weights = rename_moe_weight(module_weights, {
|
||||
@ -537,6 +559,9 @@ class DeepseekV3WeightLoader:
|
||||
"gate_proj": "w1",
|
||||
})
|
||||
module.load_weights(weights=[module_weights])
|
||||
# Mark consumed experts weights
|
||||
if can_mark_consumed:
|
||||
weights.mark_consumed(name)
|
||||
elif names[-1] == "backend" and isinstance(module, MoE):
|
||||
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
|
||||
# Currently saved MoE weights don't include 'backend' in their names.
|
||||
@ -552,6 +577,9 @@ class DeepseekV3WeightLoader:
|
||||
"gate_proj": "w1",
|
||||
})
|
||||
module.load_weights(weights=[module_weights])
|
||||
# Mark consumed MoE weights using parent name
|
||||
if can_mark_consumed:
|
||||
weights.mark_consumed(parent_name)
|
||||
elif names[-1] == "self_attn":
|
||||
continue
|
||||
elif names[-1] == "next_layer_layernorm":
|
||||
@ -563,6 +591,9 @@ class DeepseekV3WeightLoader:
|
||||
else:
|
||||
for n, p in module.named_parameters():
|
||||
p.data.copy_(module_weights[n][:])
|
||||
# Mark consumed weights
|
||||
if can_mark_consumed:
|
||||
weights.mark_consumed(name)
|
||||
|
||||
|
||||
class DeepseekV3MTPHead(nn.Module):
|
||||
@ -1805,7 +1836,7 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
|
||||
return_context_logits=return_context_logits,
|
||||
**kwargs)
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
def load_weights(self, weights: ConsumableWeightsDict):
|
||||
weight_loader = DeepseekV3WeightLoader(self)
|
||||
weight_loader.load_weights(weights)
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from tensorrt_llm._ipc_utils import can_access_peer
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import ConsumableWeightsDict
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
@ -53,7 +54,7 @@ class Glm4WeightLoader:
|
||||
self.model_config = model.model_config
|
||||
self.is_draft_model = is_draft_model
|
||||
|
||||
def load_weights(self, weights: Dict, allow_partial_loading: bool = False):
|
||||
def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bool = False):
|
||||
def rename_moe_weight(weights: Dict, rename_rules: Dict):
|
||||
result = {}
|
||||
for key, value in weights.items():
|
||||
@ -81,6 +82,9 @@ class Glm4WeightLoader:
|
||||
else self.config.num_attention_heads
|
||||
)
|
||||
|
||||
# Check if weights supports mark_consumed (ConsumableWeightsDict)
|
||||
can_mark_consumed = hasattr(weights, "mark_consumed")
|
||||
|
||||
for name, module in tqdm(all_named_modules.items(), desc="Loading weights"):
|
||||
if len(module._parameters) <= 0 or name.startswith("draft_model"):
|
||||
continue
|
||||
@ -116,6 +120,10 @@ class Glm4WeightLoader:
|
||||
}
|
||||
module_weights.append(fw)
|
||||
module.load_weights(weights=module_weights)
|
||||
# Mark consumed source weights (e.g., q_proj, k_proj, v_proj)
|
||||
if can_mark_consumed:
|
||||
for src_name in params_map[names[-1]]:
|
||||
weights.mark_consumed(".".join(names[:-1] + [src_name]))
|
||||
elif names[-1] == "experts":
|
||||
module_weights = filter_weights(name, weights)
|
||||
module_weights = rename_moe_weight(
|
||||
@ -129,6 +137,9 @@ class Glm4WeightLoader:
|
||||
module.load_weights(
|
||||
weights=[module_weights], allow_partial_loading=allow_partial_loading
|
||||
)
|
||||
# Mark consumed experts weights
|
||||
if can_mark_consumed:
|
||||
weights.mark_consumed(name)
|
||||
elif names[-1] == "backend" and isinstance(module, MoE):
|
||||
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
|
||||
# Currently saved MoE weights don't include 'backend' in their names.
|
||||
@ -149,6 +160,9 @@ class Glm4WeightLoader:
|
||||
module.load_weights(
|
||||
weights=[module_weights], allow_partial_loading=allow_partial_loading
|
||||
)
|
||||
# Mark consumed MoE weights using parent name
|
||||
if can_mark_consumed:
|
||||
weights.mark_consumed(parent_name)
|
||||
elif names[-1] == "self_attn":
|
||||
continue
|
||||
elif names[-1] == "next_layer_layernorm":
|
||||
@ -173,6 +187,9 @@ class Glm4WeightLoader:
|
||||
assert n in module_weights
|
||||
if n in module_weights:
|
||||
p.data.copy_(module_weights[n][:])
|
||||
# Mark consumed weights
|
||||
if can_mark_consumed:
|
||||
weights.mark_consumed(name)
|
||||
|
||||
|
||||
class Glm4Attention(QKNormRoPEAttention):
|
||||
@ -1030,7 +1047,7 @@ class Glm4MoeForCausalLM(SpecDecOneEngineForCausalLM[Glm4Model, PretrainedConfig
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Dict, allow_partial_loading: bool = False):
|
||||
def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bool = False):
|
||||
weight_loader = Glm4WeightLoader(self)
|
||||
weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading)
|
||||
|
||||
|
||||
@ -6,6 +6,8 @@ from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from tensorrt_llm._torch.distributed import AllReduceParams
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
|
||||
ConsumableWeightsDict
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
@ -623,7 +625,7 @@ class HunYuanDenseV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
|
||||
vocab_size=model_config.pretrained_config.vocab_size)
|
||||
self._execution_stats = None
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
def load_weights(self, weights: ConsumableWeightsDict):
|
||||
tp_size = self.model_config.mapping.tp_size
|
||||
head_dim = getattr(
|
||||
self.config, "head_dim",
|
||||
@ -641,6 +643,10 @@ class HunYuanDenseV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
|
||||
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
|
||||
'gate_up_proj': ['gate_proj', 'up_proj']
|
||||
}
|
||||
|
||||
# Check if weights supports mark_consumed (ConsumableWeightsDict)
|
||||
can_mark_consumed = hasattr(weights, 'mark_consumed')
|
||||
|
||||
for name, module in tqdm(list(self.named_modules()),
|
||||
desc="Loading weights"):
|
||||
if len(module._parameters) > 0:
|
||||
@ -667,7 +673,13 @@ class HunYuanDenseV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
|
||||
}
|
||||
module_weights.append(fw)
|
||||
module.load_weights(weights=module_weights)
|
||||
# Mark consumed source weights (e.g., q_proj, k_proj, v_proj)
|
||||
if can_mark_consumed:
|
||||
for src_name in params_map[names[-1]]:
|
||||
weights.mark_consumed('.'.join(names[:-1] +
|
||||
[src_name]))
|
||||
else:
|
||||
original_name = name
|
||||
name = name.replace('gate', 'gate.wg')
|
||||
module_weights = filter_weights(name, weights)
|
||||
if hasattr(module, 'load_weights'):
|
||||
@ -678,6 +690,9 @@ class HunYuanDenseV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
|
||||
for n, p in module._parameters.items():
|
||||
if p is not None:
|
||||
p.data.copy_(module_weights[n][:])
|
||||
# Mark consumed weights
|
||||
if can_mark_consumed:
|
||||
weights.mark_consumed(original_name)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -6,6 +6,8 @@ from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from tensorrt_llm._torch.distributed import AllReduceParams
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
|
||||
ConsumableWeightsDict
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
@ -341,7 +343,7 @@ class HunYuanMoEV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
|
||||
self._execution_stats = None
|
||||
print("---debug model_config: ", model_config)
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
def load_weights(self, weights: ConsumableWeightsDict):
|
||||
tp_size = self.model_config.mapping.tp_size
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
|
||||
@ -357,6 +359,10 @@ class HunYuanMoEV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
|
||||
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
|
||||
'gate_up_proj': ['gate_proj', 'up_proj']
|
||||
}
|
||||
|
||||
# Check if weights supports mark_consumed (ConsumableWeightsDict)
|
||||
can_mark_consumed = hasattr(weights, 'mark_consumed')
|
||||
|
||||
for name, module in tqdm(list(self.named_modules()),
|
||||
desc="Loading weights"):
|
||||
if len(module._parameters) > 0:
|
||||
@ -394,7 +400,13 @@ class HunYuanMoEV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
|
||||
}
|
||||
module_weights.append(fw)
|
||||
module.load_weights(weights=module_weights)
|
||||
# Mark consumed source weights (e.g., q_proj, k_proj, v_proj)
|
||||
if can_mark_consumed:
|
||||
for src_name in params_map[names[-1]]:
|
||||
weights.mark_consumed('.'.join(names[:-1] +
|
||||
[src_name]))
|
||||
else:
|
||||
original_name = name
|
||||
name = name.replace('gate', 'gate.wg')
|
||||
module_weights = filter_weights(name, weights)
|
||||
if isinstance(module, CutlassFusedMoE) or isinstance(
|
||||
@ -418,6 +430,9 @@ class HunYuanMoEV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel,
|
||||
for n, p in module._parameters.items():
|
||||
if p is not None:
|
||||
p.data.copy_(module_weights[n][:])
|
||||
# Mark consumed weights
|
||||
if can_mark_consumed:
|
||||
weights.mark_consumed(original_name)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -748,12 +748,18 @@ def rename_weights_with_regex(pattern_mapping: Dict[str, str], weights: Dict):
|
||||
pattern_mapping = {
|
||||
r'(.*?)out_proj(.*)': r'\1o_proj\2'
|
||||
}
|
||||
weights: A dictionary of weights
|
||||
weights: A dictionary of weights (or ConsumableWeightsDict)
|
||||
Returns:
|
||||
A dictionary of weights with renamed keys
|
||||
A dictionary of weights with renamed keys (preserves ConsumableWeightsDict if input was one)
|
||||
"""
|
||||
import re
|
||||
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
|
||||
ConsumableWeightsDict
|
||||
|
||||
# Check if input is a ConsumableWeightsDict to preserve the type
|
||||
is_consumable = isinstance(weights, ConsumableWeightsDict)
|
||||
|
||||
# Create a new dictionary to store the renamed weights
|
||||
renamed_weights = {}
|
||||
|
||||
@ -776,6 +782,9 @@ def rename_weights_with_regex(pattern_mapping: Dict[str, str], weights: Dict):
|
||||
if key not in matched_keys:
|
||||
renamed_weights[key] = weights[key]
|
||||
|
||||
# Preserve ConsumableWeightsDict type if that's what was passed in
|
||||
if is_consumable:
|
||||
return ConsumableWeightsDict(renamed_weights)
|
||||
return renamed_weights
|
||||
|
||||
|
||||
@ -913,6 +922,10 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
module_weights.append(fw)
|
||||
module.load_weights(weights=module_weights,
|
||||
allow_partial_loading=allow_partial_loading)
|
||||
# Mark consumed source weights (e.g., q_proj, k_proj, v_proj for qkv_proj)
|
||||
if hasattr(weights, 'mark_consumed'):
|
||||
for src_name in params_map[names[-1]]:
|
||||
weights.mark_consumed('.'.join(names[:-1] + [src_name]))
|
||||
|
||||
else:
|
||||
module_weights = filter_weights(name, weights)
|
||||
@ -934,6 +947,10 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
if n in module_weights:
|
||||
p.data.copy_(module_weights[n][:])
|
||||
|
||||
# Mark consumed weights
|
||||
if hasattr(weights, 'mark_consumed'):
|
||||
weights.mark_consumed(name)
|
||||
|
||||
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
|
||||
"False") in ["True", "true", "1", "yes", "y"]:
|
||||
for name, module in tqdm(list(
|
||||
@ -969,7 +986,7 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
|
||||
|
||||
def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
weights: Dict,
|
||||
weights,
|
||||
weight_mapper: "BaseWeightMapper",
|
||||
skip_modules: List[str] = [],
|
||||
params_map: Optional[Dict[str, str]] = None,
|
||||
@ -1008,6 +1025,12 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
module, module_name, module_names_breakdown, weights)
|
||||
module.load_weights(weights=module_weights,
|
||||
allow_partial_loading=allow_partial_loading)
|
||||
|
||||
# Mark consumed source weights (e.g., q_proj, k_proj, v_proj for qkv_proj)
|
||||
if hasattr(weights, 'mark_consumed'):
|
||||
for src_name in weight_mapper._mapping.get(module_name, []):
|
||||
prefix = '.'.join(module_names_breakdown + [src_name])
|
||||
weights.mark_consumed(prefix)
|
||||
else:
|
||||
module_weights = weight_mapper.filter_weights(name, weights)
|
||||
# Note: module_weights may be empty after filtering (e.g., in streaming weight updates)
|
||||
@ -1039,6 +1062,10 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
p,
|
||||
allow_partial_loading=allow_partial_loading)
|
||||
|
||||
# Mark consumed weights
|
||||
if hasattr(weights, 'mark_consumed'):
|
||||
weights.mark_consumed(name)
|
||||
|
||||
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
|
||||
"False") in ["True", "true", "1", "yes", "y"]:
|
||||
for name, module in tqdm(list(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user