mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: Fix broken vanilla moe since FusedMoE refactor. (#4897)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
1fca654bfd
commit
6b3242654e
@ -50,7 +50,7 @@ def add_llm_args(parser):
|
||||
parser.add_argument('--moe_backend',
|
||||
type=str,
|
||||
default='CUTLASS',
|
||||
choices=['CUTLASS', 'TRTLLM'])
|
||||
choices=['CUTLASS', 'TRTLLM', 'VANILLA'])
|
||||
parser.add_argument('--enable_attention_dp',
|
||||
default=False,
|
||||
action='store_true')
|
||||
|
||||
@ -486,12 +486,14 @@ class Deepseekv3MoE(nn.Module):
|
||||
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
routed_output = self.experts(hidden_states_fp4 or hidden_states,
|
||||
router_logits,
|
||||
cutlass_min_latency_mode,
|
||||
output_dtype=hidden_states.dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding)
|
||||
routed_output = self.experts(
|
||||
hidden_states_fp4 or hidden_states,
|
||||
router_logits,
|
||||
cutlass_min_latency_mode=cutlass_min_latency_mode,
|
||||
output_dtype=hidden_states.dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
)
|
||||
|
||||
return routed_output
|
||||
|
||||
|
||||
@ -286,11 +286,13 @@ class Llama4MoE(nn.Module):
|
||||
(0, 0, 0,
|
||||
max_num_token_across_dp_ranks - hidden_states.shape[0]))
|
||||
router_logits = self.router(hidden_states)
|
||||
routed_output = self.experts(hidden_states,
|
||||
router_logits,
|
||||
cutlass_min_latency_mode,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding)
|
||||
routed_output = self.experts(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
cutlass_min_latency_mode=cutlass_min_latency_mode,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
)
|
||||
return routed_output
|
||||
|
||||
def forward(
|
||||
|
||||
@ -55,7 +55,6 @@ def create_moe(
|
||||
enable_alltoall: bool = False,
|
||||
moe_load_balancer: Optional[MoeLoadBalancer] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
pack_weights: bool = False,
|
||||
) -> MoE:
|
||||
moe_cls = get_moe_cls(model_config, override_quant_config)
|
||||
|
||||
@ -63,7 +62,6 @@ def create_moe(
|
||||
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE."
|
||||
assert not enable_alltoall, "enable_alltoall is not supported in TRTLLMGenFusedMoE."
|
||||
assert moe_load_balancer is None, "moe_load_balancer is not supported in TRTLLMGenFusedMoE."
|
||||
assert not pack_weights, "pack_weights is not supported in TRTLLMGenFusedMoE."
|
||||
|
||||
return moe_cls(
|
||||
routing_method=routing_method,
|
||||
@ -77,8 +75,6 @@ def create_moe(
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
elif moe_cls == CutlassFusedMoE:
|
||||
assert not pack_weights, "pack_weights is not supported in CutlassFusedMoE."
|
||||
|
||||
return moe_cls(
|
||||
routing_method=routing_method,
|
||||
num_experts=num_experts,
|
||||
@ -107,13 +103,8 @@ def create_moe(
|
||||
dtype=dtype,
|
||||
reduce_results=reduce_results,
|
||||
model_config=model_config,
|
||||
aux_stream=aux_stream,
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
enable_alltoall=enable_alltoall,
|
||||
moe_load_balancer=moe_load_balancer,
|
||||
layer_idx=layer_idx,
|
||||
pack_weights=pack_weights,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported moe backend: {moe_cls}")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import copy
|
||||
import math
|
||||
from dataclasses import replace
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@ -27,11 +27,9 @@ class VanillaMoE(nn.ModuleList):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
model_config: ModelConfig = ModelConfig(),
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
|
||||
VANILLA,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
enable_alltoall: bool = False,
|
||||
pack_weights: bool = False,
|
||||
):
|
||||
from ...distributed import AllReduce
|
||||
@ -86,7 +84,9 @@ class VanillaMoE(nn.ModuleList):
|
||||
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
|
||||
if self.use_dp:
|
||||
max_num_tokens *= model_config.mapping.world_size
|
||||
self.moe_max_num_tokens = model_config.moe_max_num_tokens if model_config.moe_max_num_tokens is not None else max_num_tokens
|
||||
self.moe_max_num_tokens = (model_config.moe_max_num_tokens
|
||||
if model_config.moe_max_num_tokens
|
||||
is not None else max_num_tokens)
|
||||
|
||||
self.enable_alltoall = False
|
||||
|
||||
@ -100,14 +100,16 @@ class VanillaMoE(nn.ModuleList):
|
||||
def create_experts(self, module_list: nn.ModuleList = None):
|
||||
if module_list is None:
|
||||
module_list = self
|
||||
model_config = copy.copy(self.model_config)
|
||||
model_config.mapping = Mapping(
|
||||
world_size=self.mapping.moe_tp_size,
|
||||
tp_size=self.mapping.moe_tp_size,
|
||||
rank=self.mapping.moe_tp_rank,
|
||||
model_config = replace(
|
||||
self.model_config,
|
||||
mapping=Mapping(
|
||||
world_size=self.mapping.moe_tp_size,
|
||||
tp_size=self.mapping.moe_tp_size,
|
||||
rank=self.mapping.moe_tp_rank,
|
||||
),
|
||||
quant_config=self.quant_config,
|
||||
skip_create_weights_in_init=False,
|
||||
)
|
||||
model_config.quant_config = self.quant_config
|
||||
model_config.skip_create_weights_in_init = False
|
||||
for expert_idx in range(self.num_experts):
|
||||
if self.expert_start <= expert_idx < self.expert_end:
|
||||
module_list[expert_idx] = GatedMLP(
|
||||
|
||||
@ -102,7 +102,6 @@ def test_fused_moe(moe_cls, dtype, experts, RoutingMethodCls, mapping=None):
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(output, ref_output, rtol=0.2, atol=0.2)
|
||||
m //= 2
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
@ -121,7 +120,7 @@ def test_fused_moe_multi_gpu(moe_cls, ep_size):
|
||||
moe_tp_size=world_size // ep_size))] * world_size),
|
||||
)
|
||||
for r in results:
|
||||
assert r is True
|
||||
assert r is None
|
||||
|
||||
|
||||
@skip_pre_hopper
|
||||
|
||||
Loading…
Reference in New Issue
Block a user