[Bugfix] Update TrtLLM MoE routing methods (#44347)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Wei Zhao
2026-06-03 05:56:43 -04:00
committed by GitHub
parent 0e2b13103b
commit ace95c9cf8
6 changed files with 24 additions and 24 deletions
+17 -12
View File
@@ -102,23 +102,26 @@ def _quant_flags_to_group_shape(
class RoutingMethodType(IntEnum):
# Default: Softmax -> TopK
Default = (0,)
# Renormalize: TopK -> Softmax/Sigmoid
# Renormalize: TopK -> Softmax
Renormalize = (1,)
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
# -> Top8 experts from the Top4 groups
DeepSeekV3 = (2,)
# Llama4: Top1 -> Sigmoid
Llama4 = (3,)
# RenormalizeNaive: Softmax/Sigmoid -> TopK -> Renormalize
# RenormalizeNaive: Softmax -> TopK -> Renormalize
RenormalizeNaive = (4,)
# TopK: TopK (no softmax)
TopK = (5,)
# SigmoidRenorm: Sigmoid -> TopK -> Renormalize (divide by sum of top-K)
SigmoidRenorm = (6,)
# MiniMax2: Sigmoid + Bias -> TopK -> ScaledSumNormalize
# (routeScale=1.0, epsilon=1e-20)
MiniMax2 = (7,)
# Sigmoid: Sigmoid -> TopK (no renormalization)
Sigmoid = (8,)
# Unspecified
Unspecified = (8,)
Unspecified = (9,)
# other routing types (not passed to FlashInfer kernels)
# Deepseek V4 -> sqrtsoftplus + Bias + Normalize
DeepseekV4 = (100,)
@@ -132,6 +135,7 @@ def get_routing_method_type(
renormalize: bool,
num_expert_group: int | None,
has_e_score_bias: bool,
routed_scaling_factor: float | None = 1.0,
) -> RoutingMethodType:
if scoring_func == "sqrtsoftplus":
# DeepSeek V4 uses sqrtsoftplus routing with optional routing bias
@@ -142,20 +146,21 @@ def get_routing_method_type(
return RoutingMethodType.Unspecified
if has_e_score_bias:
if (num_expert_group or 0) > 0 and scoring_func == "sigmoid":
return RoutingMethodType.DeepSeekV3
elif scoring_func == "sigmoid":
return RoutingMethodType.MiniMax2
if scoring_func == "sigmoid":
if not renormalize:
return RoutingMethodType.Unspecified
if (num_expert_group or 0) > 0:
return RoutingMethodType.DeepSeekV3
if routed_scaling_factor in (None, 1.0):
return RoutingMethodType.MiniMax2
return RoutingMethodType.Unspecified
else:
return RoutingMethodType.Unspecified
if scoring_func == "sigmoid":
if top_k == 1:
return RoutingMethodType.Llama4
elif renormalize:
if renormalize:
return RoutingMethodType.SigmoidRenorm
else:
return RoutingMethodType.Unspecified
return RoutingMethodType.Sigmoid
if scoring_func == "softmax":
if renormalize:
@@ -257,7 +257,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
return router_logits_dtype != torch.float32
return router_logits_dtype in [torch.bfloat16, torch.float32]
@staticmethod
def _supports_routing_method(
@@ -279,6 +279,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.SigmoidRenorm,
RoutingMethodType.Sigmoid,
RoutingMethodType.MiniMax2,
RoutingMethodType.Simulated,
]
@@ -290,6 +291,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.SigmoidRenorm,
RoutingMethodType.Sigmoid,
RoutingMethodType.MiniMax2,
RoutingMethodType.Simulated,
]
@@ -402,11 +404,6 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
else:
assert not apply_router_weight_on_input
# Currently FI requires bfloat16 routing bias.
# https://github.com/flashinfer-ai/flashinfer/issues/2909
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16)
out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
@@ -360,7 +360,7 @@ class TrtLlmNvFp4ExpertsMonolithic(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
return router_logits_dtype != torch.float32
return router_logits_dtype in [torch.bfloat16, torch.float32]
def apply(
self,
@@ -393,11 +393,6 @@ class TrtLlmNvFp4ExpertsMonolithic(
and self.routing_method_type != RoutingMethodType.Llama4
)
# Currently FI requires bfloat16 routing bias.
# https://github.com/flashinfer-ai/flashinfer/issues/2909
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16)
output1_scale_gate_scalar = self.quant_config.g1_alphas
# Invoke kernel.
@@ -326,6 +326,7 @@ class FusedTopKBiasRouter(BaseRouter):
renormalize=self.renormalize,
num_expert_group=None,
has_e_score_bias=True,
routed_scaling_factor=self.routed_scaling_factor,
)
def _compute_routing(
@@ -283,6 +283,7 @@ class GroupedTopKRouter(BaseRouter):
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
has_e_score_bias=self.e_score_correction_bias is not None,
routed_scaling_factor=self.routed_scaling_factor,
)
def _compute_routing(
@@ -63,6 +63,7 @@ class ZeroExpertRouter(BaseRouter):
renormalize=self.renormalize,
num_expert_group=None,
has_e_score_bias=True,
routed_scaling_factor=self.routed_scaling_factor,
)
def _compute_routing(