mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Support multi-gpu running for nemotron-v3-nano and super (#10118)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
parent
819d03fa88
commit
14554ab3f3
@ -15,6 +15,30 @@ class NemotronHHfWeightMapper(HfWeightMapper):
|
||||
tp_rank = self.config.mapping.tp_rank
|
||||
d_inner = config.mamba_head_dim * config.mamba_num_heads
|
||||
|
||||
def _split_mamba2_mixer_in_proj(w: torch.Tensor) -> torch.Tensor:
|
||||
# Special handling for Mamba2 mixer in_proj.weights and scales.
|
||||
in_proj_z, in_proj_x, in_proj_b, in_proj_c, in_proj_dt = torch.split(
|
||||
w, [
|
||||
d_inner, d_inner, n_groups * d_state, n_groups * d_state,
|
||||
nheads
|
||||
],
|
||||
dim=0)
|
||||
w = []
|
||||
for rank in range(tp_size):
|
||||
in_proj_z_rank = split(in_proj_z, tp_size, rank)
|
||||
in_proj_x_rank = split(in_proj_x, tp_size, rank)
|
||||
in_proj_b_rank = split(in_proj_b, tp_size, rank)
|
||||
in_proj_c_rank = split(in_proj_c, tp_size, rank)
|
||||
in_proj_dt_rank = split(in_proj_dt, tp_size, rank)
|
||||
y = torch.concat([
|
||||
in_proj_z_rank, in_proj_x_rank, in_proj_b_rank,
|
||||
in_proj_c_rank, in_proj_dt_rank
|
||||
])
|
||||
w.append(y)
|
||||
w = torch.concat(w).contiguous()
|
||||
return w
|
||||
|
||||
is_nvfp4 = self.config.quant_config.quant_algo == "NVFP4"
|
||||
n_groups = config.n_groups
|
||||
d_state = config.ssm_state_size
|
||||
nheads = config.mamba_num_heads
|
||||
@ -36,7 +60,12 @@ class NemotronHHfWeightMapper(HfWeightMapper):
|
||||
|
||||
if ("mixer.in_proj" in key
|
||||
or "mixer.out_proj" in key) and "_scale" in key:
|
||||
new_weights[key] = weights[name]
|
||||
# Special handing for nvfp4 Mamba2 mixer in_proj.weight_scale.
|
||||
if is_nvfp4 and "in_proj.weight_scale_2" not in key and "in_proj.weight_scale" in key:
|
||||
new_weights[key] = _split_mamba2_mixer_in_proj(
|
||||
weights[name])
|
||||
else:
|
||||
new_weights[key] = weights[name]
|
||||
elif "A" in key:
|
||||
w = split(weights[name], tp_size, tp_rank)
|
||||
w = w.to(torch.float32)
|
||||
@ -51,29 +80,7 @@ class NemotronHHfWeightMapper(HfWeightMapper):
|
||||
w = w.to(torch.float32)
|
||||
new_weights[key] = w
|
||||
elif "mixer.in_proj" in key:
|
||||
w = weights[name]
|
||||
in_proj_z, in_proj_x, in_proj_b, in_proj_c, in_proj_dt = torch.split(
|
||||
w, [
|
||||
d_inner, d_inner, n_groups * d_state,
|
||||
n_groups * d_state, nheads
|
||||
],
|
||||
dim=0)
|
||||
|
||||
w = []
|
||||
for rank in range(tp_size):
|
||||
in_proj_z_rank = split(in_proj_z, tp_size, rank)
|
||||
in_proj_x_rank = split(in_proj_x, tp_size, rank)
|
||||
in_proj_b_rank = split(in_proj_b, tp_size, rank)
|
||||
in_proj_c_rank = split(in_proj_c, tp_size, rank)
|
||||
in_proj_dt_rank = split(in_proj_dt, tp_size, rank)
|
||||
y = torch.concat([
|
||||
in_proj_z_rank, in_proj_x_rank, in_proj_b_rank,
|
||||
in_proj_c_rank, in_proj_dt_rank
|
||||
])
|
||||
w.append(y)
|
||||
|
||||
w = torch.concat(w).contiguous()
|
||||
new_weights[key] = w
|
||||
new_weights[key] = _split_mamba2_mixer_in_proj(weights[name])
|
||||
elif "conv1d" in key:
|
||||
w = weights[name]
|
||||
# removing dim(1) because we are using Linear to store conv1d weights
|
||||
@ -110,19 +117,21 @@ class NemotronHHfWeightMapper(HfWeightMapper):
|
||||
elif "weight_scale" in key:
|
||||
# NVFP4 case.
|
||||
if weights[name].shape:
|
||||
new_weights[w3_key] = weights[
|
||||
name][:weights[name].shape[0] // 2]
|
||||
new_weights[w1_key] = weights[name][
|
||||
weights[name].shape[0] // 2:]
|
||||
# w3 weight (gate_proj) scale should be empty for Nemotron-H MoE model.
|
||||
# Use [:0] to keep the same input dimension as the other weights.
|
||||
# The w3 weight_scale shape should be [0, input_dim].
|
||||
new_weights[w3_key] = weights[name][:0]
|
||||
new_weights[w1_key] = weights[name]
|
||||
# FP8 case.
|
||||
else:
|
||||
new_weights[w3_key] = weights[name]
|
||||
new_weights[w1_key] = weights[name]
|
||||
else:
|
||||
new_weights[w3_key] = weights[name][:weights[name].
|
||||
shape[0] // 2]
|
||||
new_weights[w1_key] = weights[name][weights[name].
|
||||
shape[0] // 2:]
|
||||
# w3 weight (gate_proj) should be empty for Nemotron-H MoE model.
|
||||
# Use [:0] to keep the same input dimension as the other weights.
|
||||
# The w3 weight shape should be [0, input_dim].
|
||||
new_weights[w3_key] = weights[name][:0]
|
||||
new_weights[w1_key] = weights[name]
|
||||
elif "down_proj" in key:
|
||||
key = key.replace("down_proj", "w2")
|
||||
new_weights[key] = weights[name]
|
||||
|
||||
@ -69,12 +69,17 @@ class HfWeightMapper(BaseWeightMapper):
|
||||
num_kv_heads = kv_shape * 2 // self._head_dim
|
||||
else:
|
||||
num_kv_heads = kv_shape // self._head_dim
|
||||
|
||||
duplicated_keys = ["weight", "bias"]
|
||||
if module.quant_config.quant_mode.has_nvfp4():
|
||||
duplicated_keys.append("weight_scale")
|
||||
|
||||
processed_weights = {
|
||||
k:
|
||||
self._duplicate_kv(weight=v[:],
|
||||
num_kv_heads=num_kv_heads,
|
||||
tensor_parallel_size=self._tp_size)
|
||||
if k in ["weight", "bias"] else v
|
||||
if k in duplicated_keys else v
|
||||
for k, v in weights.items()
|
||||
}
|
||||
return processed_weights
|
||||
|
||||
@ -26,6 +26,7 @@ from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
|
||||
from tensorrt_llm._torch.utils import ActivationType, relu2
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..distributed import AllReduce
|
||||
from ..model_config import ModelConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
@ -124,7 +125,7 @@ class NemotronHMOE(nn.Module):
|
||||
from .modeling_deepseekv3 import DeepseekV3Gate
|
||||
|
||||
self.activation_type = ActivationType.Relu2
|
||||
self.reduce_results = True
|
||||
self.reduce_results = False
|
||||
|
||||
config = model_config.pretrained_config
|
||||
self.hidden_dim = config.hidden_size
|
||||
@ -144,6 +145,7 @@ class NemotronHMOE(nn.Module):
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.enable_attention_dp = model_config.mapping.enable_attention_dp
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.mapping = model_config.mapping
|
||||
|
||||
# Setup shared expert MLP.
|
||||
if config.n_shared_experts is None or config.n_shared_experts == 0:
|
||||
@ -160,6 +162,7 @@ class NemotronHMOE(nn.Module):
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
layer_idx=self.layer_idx,
|
||||
reduce_output=False,
|
||||
)
|
||||
# Setup MoE gate.
|
||||
self.gate = DeepseekV3Gate(
|
||||
@ -190,6 +193,12 @@ class NemotronHMOE(nn.Module):
|
||||
activation_type=self.activation_type,
|
||||
)
|
||||
|
||||
# AllReduce for combining shared and routed expert outputs in multi-GPU settings.
|
||||
self.allreduce = AllReduce(
|
||||
mapping=model_config.mapping,
|
||||
strategy=model_config.allreduce_strategy,
|
||||
)
|
||||
|
||||
# Setup latent projection layers.
|
||||
if self.use_latent_moe:
|
||||
self.fc1_latent_proj = Linear(
|
||||
@ -223,6 +232,7 @@ class NemotronHMOE(nn.Module):
|
||||
assert hidden_states.shape[-1] == self.hidden_dim
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
|
||||
|
||||
def _compute_shared_output():
|
||||
if self.shared_experts is not None:
|
||||
@ -239,7 +249,6 @@ class NemotronHMOE(nn.Module):
|
||||
routed_hidden_states = self.fc1_latent_proj(
|
||||
routed_hidden_states)
|
||||
|
||||
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
|
||||
final_hidden_states = self.experts(
|
||||
routed_hidden_states,
|
||||
router_logits,
|
||||
@ -258,6 +267,10 @@ class NemotronHMOE(nn.Module):
|
||||
|
||||
final_hidden_states = shared_output + routed_output
|
||||
|
||||
# Perform all-reduce after combining outputs for multi-GPU support.
|
||||
if not self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||
final_hidden_states = self.allreduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
|
||||
@ -475,12 +475,21 @@ class FusedMoEMethodBase(ABC):
|
||||
TensorParallelMode.COLUMN,
|
||||
device=device) if w3_weight is not None else None
|
||||
|
||||
dst_w3_weight, dst_w1_weight = dst_w3_w1_weight.chunk(2, dim=0)
|
||||
src_w3_size_shard = w3_weight_shard.shape[
|
||||
0] if w3_weight_shard is not None else 0
|
||||
src_w1_size_shard = w1_weight_shard.shape[
|
||||
0] if w1_weight_shard is not None else 0
|
||||
if w1_weight is not None:
|
||||
dst_w1_weight = dst_w3_w1_weight.narrow(dim=0,
|
||||
start=src_w3_size_shard,
|
||||
length=src_w1_size_shard)
|
||||
dst_w1_weight.copy_(w1_weight_shard.contiguous().view(
|
||||
dst_w3_w1_weight.dtype),
|
||||
non_blocking=True)
|
||||
if w3_weight is not None:
|
||||
dst_w3_weight = dst_w3_w1_weight.narrow(dim=0,
|
||||
start=0,
|
||||
length=src_w3_size_shard)
|
||||
dst_w3_weight.copy_(w3_weight_shard.contiguous().view(
|
||||
dst_w3_w1_weight.dtype),
|
||||
non_blocking=True)
|
||||
@ -701,6 +710,37 @@ class FP8QDQFusedMoEMethod(FusedMoEMethodBase):
|
||||
module.fc2_dequant.data.copy_(tmp_w2_weight_scale * max_fc2_input_scale)
|
||||
module.fc31_input_dequant.data.copy_(max_fc31_input_scale)
|
||||
|
||||
def post_load_weights(self, module):
|
||||
super().post_load_weights(module)
|
||||
|
||||
# Padding weights to meet FP8 GEMM alignment requirements.
|
||||
def _maybe_padding_weights(tensor: torch.Tensor, row_alignment: int,
|
||||
col_alignment: int):
|
||||
row_pad_size = (row_alignment - tensor.size(1)) % row_alignment
|
||||
col_pad_size = (col_alignment - tensor.size(2)) % col_alignment
|
||||
is_padded = row_pad_size != 0 or col_pad_size != 0
|
||||
if is_padded:
|
||||
return F.pad(tensor, (0, col_pad_size, 0, row_pad_size),
|
||||
mode='constant',
|
||||
value=0), is_padded
|
||||
return tensor, is_padded
|
||||
|
||||
if getattr(module, "moe_backend", None) == "CUTLASS":
|
||||
cutlass_fp8_row_alignment, cutlass_fp8_col_alignment = 32, 16
|
||||
padded_w3_w1_weight, is_padded_w3_w1_weight = _maybe_padding_weights(
|
||||
module.w3_w1_weight, cutlass_fp8_row_alignment,
|
||||
cutlass_fp8_col_alignment)
|
||||
# Use `row_alignment` for `w2_weight.shape[2]` to match the shape of `w3_w1_weight.shape[1]`.
|
||||
padded_w2_weight, is_padded_w2_weight = _maybe_padding_weights(
|
||||
module.w2_weight, cutlass_fp8_row_alignment,
|
||||
cutlass_fp8_row_alignment)
|
||||
if is_padded_w3_w1_weight:
|
||||
module.w3_w1_weight = nn.Parameter(padded_w3_w1_weight,
|
||||
requires_grad=False)
|
||||
if is_padded_w2_weight:
|
||||
module.w2_weight = nn.Parameter(padded_w2_weight,
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@ -2079,10 +2119,12 @@ class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
|
||||
def create_weights(self, module: torch.nn.Module):
|
||||
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
|
||||
block_scales_vec_size = torch.iinfo(self.block_scales_dtype).bits // 8
|
||||
self.block_scales_vec_size = torch.iinfo(
|
||||
self.block_scales_dtype).bits // 8
|
||||
|
||||
super().create_weights(module, self.weight_dtype, weight_vec_size,
|
||||
self.block_scales_dtype, block_scales_vec_size)
|
||||
self.block_scales_dtype,
|
||||
self.block_scales_vec_size)
|
||||
|
||||
def load_expert_w3_w1_weight_scale_nvfp4(
|
||||
self, module: torch.nn.Module, w1_weight_scale: torch.Tensor,
|
||||
@ -2131,6 +2173,16 @@ class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
module.tp_rank,
|
||||
TensorParallelMode.ROW,
|
||||
device=device)
|
||||
# Padding w2_weight_scale (dtype=float8_e4m3fn) to match the shape of dst_w2_weight_scale (dtype=float32)
|
||||
src_w2_scale_size = w2_weight_scale.shape[1]
|
||||
adjusted_dst_w2_scale_size = dst_w2_weight_scale.shape[
|
||||
1] * self.block_scales_vec_size
|
||||
assert adjusted_dst_w2_scale_size >= src_w2_scale_size, "adjusted_dst_w2_scale_size must be greater than or equal to src_w2_scale_size"
|
||||
if adjusted_dst_w2_scale_size > src_w2_scale_size:
|
||||
w2_weight_scale = torch.nn.functional.pad(
|
||||
w2_weight_scale,
|
||||
(0, adjusted_dst_w2_scale_size - src_w2_scale_size), "constant",
|
||||
0).contiguous()
|
||||
|
||||
cast_w2_weight_scale = w2_weight_scale.view(dst_w2_weight_scale.dtype)
|
||||
cast_w2_weight_scale = self._maybe_padding_shape(
|
||||
|
||||
@ -158,6 +158,7 @@ class Mamba2Mixer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# calculate split size
|
||||
|
||||
@ -19,7 +19,8 @@ class MLP(nn.Module):
|
||||
activation: Callable[[torch.Tensor], torch.Tensor] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
config: Optional[ModelConfig] = None,
|
||||
layer_idx: Optional[int] = None):
|
||||
layer_idx: Optional[int] = None,
|
||||
reduce_output: bool = True):
|
||||
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
@ -60,7 +61,8 @@ class MLP(nn.Module):
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
lora=self.down_lora,
|
||||
allreduce_strategy=config.allreduce_strategy,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
reduce_output=reduce_output)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user