fix: Fix broken vanilla moe since FusedMoE refactor. (#4897)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Yuxian Qiu 2025-06-05 03:56:41 +08:00 committed by GitHub
parent 1fca654bfd
commit 6b3242654e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 30 additions and 34 deletions

View File

@ -50,7 +50,7 @@ def add_llm_args(parser):
parser.add_argument('--moe_backend',
type=str,
default='CUTLASS',
choices=['CUTLASS', 'TRTLLM'])
choices=['CUTLASS', 'TRTLLM', 'VANILLA'])
parser.add_argument('--enable_attention_dp',
default=False,
action='store_true')

View File

@ -486,12 +486,14 @@ class Deepseekv3MoE(nn.Module):
router_logits = self.gate(hidden_states)
routed_output = self.experts(hidden_states_fp4 or hidden_states,
router_logits,
cutlass_min_latency_mode,
output_dtype=hidden_states.dtype,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding)
routed_output = self.experts(
hidden_states_fp4 or hidden_states,
router_logits,
cutlass_min_latency_mode=cutlass_min_latency_mode,
output_dtype=hidden_states.dtype,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
)
return routed_output

View File

@ -286,11 +286,13 @@ class Llama4MoE(nn.Module):
(0, 0, 0,
max_num_token_across_dp_ranks - hidden_states.shape[0]))
router_logits = self.router(hidden_states)
routed_output = self.experts(hidden_states,
router_logits,
cutlass_min_latency_mode,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding)
routed_output = self.experts(
hidden_states,
router_logits,
cutlass_min_latency_mode=cutlass_min_latency_mode,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
)
return routed_output
def forward(

View File

@ -55,7 +55,6 @@ def create_moe(
enable_alltoall: bool = False,
moe_load_balancer: Optional[MoeLoadBalancer] = None,
layer_idx: Optional[int] = None,
pack_weights: bool = False,
) -> MoE:
moe_cls = get_moe_cls(model_config, override_quant_config)
@ -63,7 +62,6 @@ def create_moe(
assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE."
assert not enable_alltoall, "enable_alltoall is not supported in TRTLLMGenFusedMoE."
assert moe_load_balancer is None, "moe_load_balancer is not supported in TRTLLMGenFusedMoE."
assert not pack_weights, "pack_weights is not supported in TRTLLMGenFusedMoE."
return moe_cls(
routing_method=routing_method,
@ -77,8 +75,6 @@ def create_moe(
layer_idx=layer_idx,
)
elif moe_cls == CutlassFusedMoE:
assert not pack_weights, "pack_weights is not supported in CutlassFusedMoE."
return moe_cls(
routing_method=routing_method,
num_experts=num_experts,
@ -107,13 +103,8 @@ def create_moe(
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream=aux_stream,
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
enable_alltoall=enable_alltoall,
moe_load_balancer=moe_load_balancer,
layer_idx=layer_idx,
pack_weights=pack_weights,
)
else:
raise ValueError(f"Unsupported moe backend: {moe_cls}")

View File

@ -1,5 +1,5 @@
import copy
import math
from dataclasses import replace
from typing import Dict, List, Optional
import torch
@ -27,11 +27,9 @@ class VanillaMoE(nn.ModuleList):
dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
model_config: ModelConfig = ModelConfig(),
aux_stream: Optional[torch.cuda.Stream] = None,
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
VANILLA,
apply_router_weight_on_input: bool = False,
enable_alltoall: bool = False,
pack_weights: bool = False,
):
from ...distributed import AllReduce
@ -86,7 +84,9 @@ class VanillaMoE(nn.ModuleList):
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
if self.use_dp:
max_num_tokens *= model_config.mapping.world_size
self.moe_max_num_tokens = model_config.moe_max_num_tokens if model_config.moe_max_num_tokens is not None else max_num_tokens
self.moe_max_num_tokens = (model_config.moe_max_num_tokens
if model_config.moe_max_num_tokens
is not None else max_num_tokens)
self.enable_alltoall = False
@ -100,14 +100,16 @@ class VanillaMoE(nn.ModuleList):
def create_experts(self, module_list: nn.ModuleList = None):
if module_list is None:
module_list = self
model_config = copy.copy(self.model_config)
model_config.mapping = Mapping(
world_size=self.mapping.moe_tp_size,
tp_size=self.mapping.moe_tp_size,
rank=self.mapping.moe_tp_rank,
model_config = replace(
self.model_config,
mapping=Mapping(
world_size=self.mapping.moe_tp_size,
tp_size=self.mapping.moe_tp_size,
rank=self.mapping.moe_tp_rank,
),
quant_config=self.quant_config,
skip_create_weights_in_init=False,
)
model_config.quant_config = self.quant_config
model_config.skip_create_weights_in_init = False
for expert_idx in range(self.num_experts):
if self.expert_start <= expert_idx < self.expert_end:
module_list[expert_idx] = GatedMLP(

View File

@ -102,7 +102,6 @@ def test_fused_moe(moe_cls, dtype, experts, RoutingMethodCls, mapping=None):
torch.cuda.synchronize()
torch.testing.assert_close(output, ref_output, rtol=0.2, atol=0.2)
m //= 2
return True
@pytest.mark.skipif(torch.cuda.device_count() < 4,
@ -121,7 +120,7 @@ def test_fused_moe_multi_gpu(moe_cls, ep_size):
moe_tp_size=world_size // ep_size))] * world_size),
)
for r in results:
assert r is True
assert r is None
@skip_pre_hopper