[None][feat] Support multi-gpu running for nemotron-v3-nano and super (#10118)

Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
Wanli Jiang 2025-12-26 11:23:14 +08:00 committed by GitHub
parent 819d03fa88
commit 14554ab3f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 122 additions and 40 deletions

View File

@ -15,6 +15,30 @@ class NemotronHHfWeightMapper(HfWeightMapper):
tp_rank = self.config.mapping.tp_rank
d_inner = config.mamba_head_dim * config.mamba_num_heads
def _split_mamba2_mixer_in_proj(w: torch.Tensor) -> torch.Tensor:
# Special handling for Mamba2 mixer in_proj.weights and scales.
in_proj_z, in_proj_x, in_proj_b, in_proj_c, in_proj_dt = torch.split(
w, [
d_inner, d_inner, n_groups * d_state, n_groups * d_state,
nheads
],
dim=0)
w = []
for rank in range(tp_size):
in_proj_z_rank = split(in_proj_z, tp_size, rank)
in_proj_x_rank = split(in_proj_x, tp_size, rank)
in_proj_b_rank = split(in_proj_b, tp_size, rank)
in_proj_c_rank = split(in_proj_c, tp_size, rank)
in_proj_dt_rank = split(in_proj_dt, tp_size, rank)
y = torch.concat([
in_proj_z_rank, in_proj_x_rank, in_proj_b_rank,
in_proj_c_rank, in_proj_dt_rank
])
w.append(y)
w = torch.concat(w).contiguous()
return w
is_nvfp4 = self.config.quant_config.quant_algo == "NVFP4"
n_groups = config.n_groups
d_state = config.ssm_state_size
nheads = config.mamba_num_heads
@ -36,7 +60,12 @@ class NemotronHHfWeightMapper(HfWeightMapper):
if ("mixer.in_proj" in key
or "mixer.out_proj" in key) and "_scale" in key:
new_weights[key] = weights[name]
# Special handing for nvfp4 Mamba2 mixer in_proj.weight_scale.
if is_nvfp4 and "in_proj.weight_scale_2" not in key and "in_proj.weight_scale" in key:
new_weights[key] = _split_mamba2_mixer_in_proj(
weights[name])
else:
new_weights[key] = weights[name]
elif "A" in key:
w = split(weights[name], tp_size, tp_rank)
w = w.to(torch.float32)
@ -51,29 +80,7 @@ class NemotronHHfWeightMapper(HfWeightMapper):
w = w.to(torch.float32)
new_weights[key] = w
elif "mixer.in_proj" in key:
w = weights[name]
in_proj_z, in_proj_x, in_proj_b, in_proj_c, in_proj_dt = torch.split(
w, [
d_inner, d_inner, n_groups * d_state,
n_groups * d_state, nheads
],
dim=0)
w = []
for rank in range(tp_size):
in_proj_z_rank = split(in_proj_z, tp_size, rank)
in_proj_x_rank = split(in_proj_x, tp_size, rank)
in_proj_b_rank = split(in_proj_b, tp_size, rank)
in_proj_c_rank = split(in_proj_c, tp_size, rank)
in_proj_dt_rank = split(in_proj_dt, tp_size, rank)
y = torch.concat([
in_proj_z_rank, in_proj_x_rank, in_proj_b_rank,
in_proj_c_rank, in_proj_dt_rank
])
w.append(y)
w = torch.concat(w).contiguous()
new_weights[key] = w
new_weights[key] = _split_mamba2_mixer_in_proj(weights[name])
elif "conv1d" in key:
w = weights[name]
# removing dim(1) because we are using Linear to store conv1d weights
@ -110,19 +117,21 @@ class NemotronHHfWeightMapper(HfWeightMapper):
elif "weight_scale" in key:
# NVFP4 case.
if weights[name].shape:
new_weights[w3_key] = weights[
name][:weights[name].shape[0] // 2]
new_weights[w1_key] = weights[name][
weights[name].shape[0] // 2:]
# w3 weight (gate_proj) scale should be empty for Nemotron-H MoE model.
# Use [:0] to keep the same input dimension as the other weights.
# The w3 weight_scale shape should be [0, input_dim].
new_weights[w3_key] = weights[name][:0]
new_weights[w1_key] = weights[name]
# FP8 case.
else:
new_weights[w3_key] = weights[name]
new_weights[w1_key] = weights[name]
else:
new_weights[w3_key] = weights[name][:weights[name].
shape[0] // 2]
new_weights[w1_key] = weights[name][weights[name].
shape[0] // 2:]
# w3 weight (gate_proj) should be empty for Nemotron-H MoE model.
# Use [:0] to keep the same input dimension as the other weights.
# The w3 weight shape should be [0, input_dim].
new_weights[w3_key] = weights[name][:0]
new_weights[w1_key] = weights[name]
elif "down_proj" in key:
key = key.replace("down_proj", "w2")
new_weights[key] = weights[name]

View File

@ -69,12 +69,17 @@ class HfWeightMapper(BaseWeightMapper):
num_kv_heads = kv_shape * 2 // self._head_dim
else:
num_kv_heads = kv_shape // self._head_dim
duplicated_keys = ["weight", "bias"]
if module.quant_config.quant_mode.has_nvfp4():
duplicated_keys.append("weight_scale")
processed_weights = {
k:
self._duplicate_kv(weight=v[:],
num_kv_heads=num_kv_heads,
tensor_parallel_size=self._tp_size)
if k in ["weight", "bias"] else v
if k in duplicated_keys else v
for k, v in weights.items()
}
return processed_weights

View File

@ -26,6 +26,7 @@ from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
from tensorrt_llm._torch.utils import ActivationType, relu2
from ..attention_backend import AttentionMetadata
from ..distributed import AllReduce
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
@ -124,7 +125,7 @@ class NemotronHMOE(nn.Module):
from .modeling_deepseekv3 import DeepseekV3Gate
self.activation_type = ActivationType.Relu2
self.reduce_results = True
self.reduce_results = False
config = model_config.pretrained_config
self.hidden_dim = config.hidden_size
@ -144,6 +145,7 @@ class NemotronHMOE(nn.Module):
self.top_k = config.num_experts_per_tok
self.enable_attention_dp = model_config.mapping.enable_attention_dp
self.routed_scaling_factor = config.routed_scaling_factor
self.mapping = model_config.mapping
# Setup shared expert MLP.
if config.n_shared_experts is None or config.n_shared_experts == 0:
@ -160,6 +162,7 @@ class NemotronHMOE(nn.Module):
dtype=config.torch_dtype,
config=model_config,
layer_idx=self.layer_idx,
reduce_output=False,
)
# Setup MoE gate.
self.gate = DeepseekV3Gate(
@ -190,6 +193,12 @@ class NemotronHMOE(nn.Module):
activation_type=self.activation_type,
)
# AllReduce for combining shared and routed expert outputs in multi-GPU settings.
self.allreduce = AllReduce(
mapping=model_config.mapping,
strategy=model_config.allreduce_strategy,
)
# Setup latent projection layers.
if self.use_latent_moe:
self.fc1_latent_proj = Linear(
@ -223,6 +232,7 @@ class NemotronHMOE(nn.Module):
assert hidden_states.shape[-1] == self.hidden_dim
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_dim)
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
def _compute_shared_output():
if self.shared_experts is not None:
@ -239,7 +249,6 @@ class NemotronHMOE(nn.Module):
routed_hidden_states = self.fc1_latent_proj(
routed_hidden_states)
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
final_hidden_states = self.experts(
routed_hidden_states,
router_logits,
@ -258,6 +267,10 @@ class NemotronHMOE(nn.Module):
final_hidden_states = shared_output + routed_output
# Perform all-reduce after combining outputs for multi-GPU support.
if not self.enable_attention_dp and self.mapping.tp_size > 1:
final_hidden_states = self.allreduce(final_hidden_states)
return final_hidden_states.view(orig_shape)

View File

@ -475,12 +475,21 @@ class FusedMoEMethodBase(ABC):
TensorParallelMode.COLUMN,
device=device) if w3_weight is not None else None
dst_w3_weight, dst_w1_weight = dst_w3_w1_weight.chunk(2, dim=0)
src_w3_size_shard = w3_weight_shard.shape[
0] if w3_weight_shard is not None else 0
src_w1_size_shard = w1_weight_shard.shape[
0] if w1_weight_shard is not None else 0
if w1_weight is not None:
dst_w1_weight = dst_w3_w1_weight.narrow(dim=0,
start=src_w3_size_shard,
length=src_w1_size_shard)
dst_w1_weight.copy_(w1_weight_shard.contiguous().view(
dst_w3_w1_weight.dtype),
non_blocking=True)
if w3_weight is not None:
dst_w3_weight = dst_w3_w1_weight.narrow(dim=0,
start=0,
length=src_w3_size_shard)
dst_w3_weight.copy_(w3_weight_shard.contiguous().view(
dst_w3_w1_weight.dtype),
non_blocking=True)
@ -701,6 +710,37 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase):
module.fc2_dequant.data.copy_(tmp_w2_weight_scale * max_fc2_input_scale)
module.fc31_input_dequant.data.copy_(max_fc31_input_scale)
def post_load_weights(self, module):
super().post_load_weights(module)
# Padding weights to meet FP8 GEMM alignment requirements.
def _maybe_padding_weights(tensor: torch.Tensor, row_alignment: int,
col_alignment: int):
row_pad_size = (row_alignment - tensor.size(1)) % row_alignment
col_pad_size = (col_alignment - tensor.size(2)) % col_alignment
is_padded = row_pad_size != 0 or col_pad_size != 0
if is_padded:
return F.pad(tensor, (0, col_pad_size, 0, row_pad_size),
mode='constant',
value=0), is_padded
return tensor, is_padded
if getattr(module, "moe_backend", None) == "CUTLASS":
cutlass_fp8_row_alignment, cutlass_fp8_col_alignment = 32, 16
padded_w3_w1_weight, is_padded_w3_w1_weight = _maybe_padding_weights(
module.w3_w1_weight, cutlass_fp8_row_alignment,
cutlass_fp8_col_alignment)
# Use `row_alignment` for `w2_weight.shape[2]` to match the shape of `w3_w1_weight.shape[1]`.
padded_w2_weight, is_padded_w2_weight = _maybe_padding_weights(
module.w2_weight, cutlass_fp8_row_alignment,
cutlass_fp8_row_alignment)
if is_padded_w3_w1_weight:
module.w3_w1_weight = nn.Parameter(padded_w3_w1_weight,
requires_grad=False)
if is_padded_w2_weight:
module.w2_weight = nn.Parameter(padded_w2_weight,
requires_grad=False)
class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
@ -2079,10 +2119,12 @@ class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
def create_weights(self, module: torch.nn.Module):
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
block_scales_vec_size = torch.iinfo(self.block_scales_dtype).bits // 8
self.block_scales_vec_size = torch.iinfo(
self.block_scales_dtype).bits // 8
super().create_weights(module, self.weight_dtype, weight_vec_size,
self.block_scales_dtype, block_scales_vec_size)
self.block_scales_dtype,
self.block_scales_vec_size)
def load_expert_w3_w1_weight_scale_nvfp4(
self, module: torch.nn.Module, w1_weight_scale: torch.Tensor,
@ -2131,6 +2173,16 @@ class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
module.tp_rank,
TensorParallelMode.ROW,
device=device)
# Padding w2_weight_scale (dtype=float8_e4m3fn) to match the shape of dst_w2_weight_scale (dtype=float32)
src_w2_scale_size = w2_weight_scale.shape[1]
adjusted_dst_w2_scale_size = dst_w2_weight_scale.shape[
1] * self.block_scales_vec_size
assert adjusted_dst_w2_scale_size >= src_w2_scale_size, "adjusted_dst_w2_scale_size must be greater than or equal to src_w2_scale_size"
if adjusted_dst_w2_scale_size > src_w2_scale_size:
w2_weight_scale = torch.nn.functional.pad(
w2_weight_scale,
(0, adjusted_dst_w2_scale_size - src_w2_scale_size), "constant",
0).contiguous()
cast_w2_weight_scale = w2_weight_scale.view(dst_w2_weight_scale.dtype)
cast_w2_weight_scale = self._maybe_padding_shape(

View File

@ -158,6 +158,7 @@ class Mamba2Mixer(nn.Module):
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_metadata: Mamba2Metadata,
**kwargs,
) -> torch.Tensor:
# calculate split size

View File

@ -19,7 +19,8 @@ class MLP(nn.Module):
activation: Callable[[torch.Tensor], torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
config: Optional[ModelConfig] = None,
layer_idx: Optional[int] = None):
layer_idx: Optional[int] = None,
reduce_output: bool = True):
super().__init__()
self.layer_idx = layer_idx
@ -60,7 +61,8 @@ class MLP(nn.Module):
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.down_lora,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
reduce_output=reduce_output)
def forward(
self,