[TRTLLM-10329][feat] Fix weight loading for Nemotron 3 models on DGX Spark (#11405)

Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com>
This commit is contained in:
Pamela Peng 2026-02-13 15:29:41 -05:00 committed by GitHub
parent 052fe2f7f6
commit 19a3031ecb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 2 deletions

View File

@ -80,7 +80,9 @@ class NemotronHHfWeightMapper(HfWeightMapper):
elif "A" in key:
w = split(weights[name], tp_size, tp_rank)
w = w.to(torch.float32)
w = -torch.exp(w)
# Avoid extra temporaries: one fp32 cast, then in-place exp/neg.
w.exp_()
w.neg_()
new_weights[key] = w
elif "D" in key:
w = split(weights[name], tp_size, tp_rank)

View File

@ -23,7 +23,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from tensorrt_llm._utils import get_sm_version, is_sm_100f
from tensorrt_llm._utils import get_sm_version, is_device_integrated, is_sm_100f
from tensorrt_llm.logger import logger
from tensorrt_llm.quantization.functional import \
preprocess_weights_for_mixed_gemm
@ -38,6 +38,7 @@ from ...utils import (replace_parameter_and_save_metadata, swizzle_sf,
unswizzle_sf)
from ..linear import TensorParallelMode, load_weight_shard
from .interface import MoEWeightLoadingMode
from .moe_load_balancer import advise_tensor_pageout
# The declarations aligns with moe_kernels.h
# pack inputs into int64, e.g. 4 x bf16 input values
@ -306,6 +307,20 @@ class FusedMoEMethodBase(ABC):
w3_w1_kargs["allow_partial_loading"] = allow_partial_loading
if "allow_partial_loading" in w2_args:
w2_kargs["allow_partial_loading"] = allow_partial_loading
def maybe_pageout_mmapped_cpu_weights(
weight_tensors: List[object]) -> None:
# Integrated GPU systems share physical memory with CPU. After we
# finish copying from mmapped CPU weights, proactively advising the
# kernel to drop those pages reduces shared-memory pressure.
if not is_device_integrated():
return
for weight in weight_tensors:
if (isinstance(weight, torch.Tensor)
and weight.device.type == "cpu"
and weight.is_contiguous()):
advise_tensor_pageout(weight)
# Multithread weight load is superseded by prefetch_files() in model_engine.py
# Also, threading adds overhead in order to protect shuffle index cache with critical section.
for local_slot_id, expert_id in enumerate(load_expert_ids):
@ -361,6 +376,7 @@ class FusedMoEMethodBase(ABC):
if weight is not None
]
module._add_raw_shared_weights_for_unmap(unmap_weights)
maybe_pageout_mmapped_cpu_weights(unmap_weights)
if module.bias:
self.load_expert_w3_w1_weight(
@ -375,6 +391,7 @@ class FusedMoEMethodBase(ABC):
if weight is not None
]
module._add_raw_shared_weights_for_unmap(unmap_weights)
maybe_pageout_mmapped_cpu_weights(unmap_weights)
def load_weights(self,
module: torch.nn.Module,