mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[TRTLLM-10329][feat] Fix weight loading for Nemotron 3 models on DGX Spark (#11405)
Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com>
This commit is contained in:
parent
052fe2f7f6
commit
19a3031ecb
@ -80,7 +80,9 @@ class NemotronHHfWeightMapper(HfWeightMapper):
|
||||
elif "A" in key:
|
||||
w = split(weights[name], tp_size, tp_rank)
|
||||
w = w.to(torch.float32)
|
||||
w = -torch.exp(w)
|
||||
# Avoid extra temporaries: one fp32 cast, then in-place exp/neg.
|
||||
w.exp_()
|
||||
w.neg_()
|
||||
new_weights[key] = w
|
||||
elif "D" in key:
|
||||
w = split(weights[name], tp_size, tp_rank)
|
||||
|
||||
@ -23,7 +23,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._utils import get_sm_version, is_sm_100f
|
||||
from tensorrt_llm._utils import get_sm_version, is_device_integrated, is_sm_100f
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.quantization.functional import \
|
||||
preprocess_weights_for_mixed_gemm
|
||||
@ -38,6 +38,7 @@ from ...utils import (replace_parameter_and_save_metadata, swizzle_sf,
|
||||
unswizzle_sf)
|
||||
from ..linear import TensorParallelMode, load_weight_shard
|
||||
from .interface import MoEWeightLoadingMode
|
||||
from .moe_load_balancer import advise_tensor_pageout
|
||||
|
||||
# The declarations aligns with moe_kernels.h
|
||||
# pack inputs into int64, e.g. 4 x bf16 input values
|
||||
@ -306,6 +307,20 @@ class FusedMoEMethodBase(ABC):
|
||||
w3_w1_kargs["allow_partial_loading"] = allow_partial_loading
|
||||
if "allow_partial_loading" in w2_args:
|
||||
w2_kargs["allow_partial_loading"] = allow_partial_loading
|
||||
|
||||
def maybe_pageout_mmapped_cpu_weights(
|
||||
weight_tensors: List[object]) -> None:
|
||||
# Integrated GPU systems share physical memory with CPU. After we
|
||||
# finish copying from mmapped CPU weights, proactively advising the
|
||||
# kernel to drop those pages reduces shared-memory pressure.
|
||||
if not is_device_integrated():
|
||||
return
|
||||
for weight in weight_tensors:
|
||||
if (isinstance(weight, torch.Tensor)
|
||||
and weight.device.type == "cpu"
|
||||
and weight.is_contiguous()):
|
||||
advise_tensor_pageout(weight)
|
||||
|
||||
# Multithread weight load is superseded by prefetch_files() in model_engine.py
|
||||
# Also, threading adds overhead in order to protect shuffle index cache with critical section.
|
||||
for local_slot_id, expert_id in enumerate(load_expert_ids):
|
||||
@ -361,6 +376,7 @@ class FusedMoEMethodBase(ABC):
|
||||
if weight is not None
|
||||
]
|
||||
module._add_raw_shared_weights_for_unmap(unmap_weights)
|
||||
maybe_pageout_mmapped_cpu_weights(unmap_weights)
|
||||
|
||||
if module.bias:
|
||||
self.load_expert_w3_w1_weight(
|
||||
@ -375,6 +391,7 @@ class FusedMoEMethodBase(ABC):
|
||||
if weight is not None
|
||||
]
|
||||
module._add_raw_shared_weights_for_unmap(unmap_weights)
|
||||
maybe_pageout_mmapped_cpu_weights(unmap_weights)
|
||||
|
||||
def load_weights(self,
|
||||
module: torch.nn.Module,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user