mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Bugfix] Update TrtLLM MoE routing methods (#44347)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
@@ -102,23 +102,26 @@ def _quant_flags_to_group_shape(
|
||||
class RoutingMethodType(IntEnum):
|
||||
# Default: Softmax -> TopK
|
||||
Default = (0,)
|
||||
# Renormalize: TopK -> Softmax/Sigmoid
|
||||
# Renormalize: TopK -> Softmax
|
||||
Renormalize = (1,)
|
||||
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
|
||||
# -> Top8 experts from the Top4 groups
|
||||
DeepSeekV3 = (2,)
|
||||
# Llama4: Top1 -> Sigmoid
|
||||
Llama4 = (3,)
|
||||
# RenormalizeNaive: Softmax/Sigmoid -> TopK -> Renormalize
|
||||
# RenormalizeNaive: Softmax -> TopK -> Renormalize
|
||||
RenormalizeNaive = (4,)
|
||||
# TopK: TopK (no softmax)
|
||||
TopK = (5,)
|
||||
# SigmoidRenorm: Sigmoid -> TopK -> Renormalize (divide by sum of top-K)
|
||||
SigmoidRenorm = (6,)
|
||||
# MiniMax2: Sigmoid + Bias -> TopK -> ScaledSumNormalize
|
||||
# (routeScale=1.0, epsilon=1e-20)
|
||||
MiniMax2 = (7,)
|
||||
# Sigmoid: Sigmoid -> TopK (no renormalization)
|
||||
Sigmoid = (8,)
|
||||
# Unspecified
|
||||
Unspecified = (8,)
|
||||
Unspecified = (9,)
|
||||
# other routing types (not passed to FlashInfer kernels)
|
||||
# Deepseek V4 -> sqrtsoftplus + Bias + Normalize
|
||||
DeepseekV4 = (100,)
|
||||
@@ -132,6 +135,7 @@ def get_routing_method_type(
|
||||
renormalize: bool,
|
||||
num_expert_group: int | None,
|
||||
has_e_score_bias: bool,
|
||||
routed_scaling_factor: float | None = 1.0,
|
||||
) -> RoutingMethodType:
|
||||
if scoring_func == "sqrtsoftplus":
|
||||
# DeepSeek V4 uses sqrtsoftplus routing with optional routing bias
|
||||
@@ -142,20 +146,21 @@ def get_routing_method_type(
|
||||
return RoutingMethodType.Unspecified
|
||||
|
||||
if has_e_score_bias:
|
||||
if (num_expert_group or 0) > 0 and scoring_func == "sigmoid":
|
||||
return RoutingMethodType.DeepSeekV3
|
||||
elif scoring_func == "sigmoid":
|
||||
return RoutingMethodType.MiniMax2
|
||||
if scoring_func == "sigmoid":
|
||||
if not renormalize:
|
||||
return RoutingMethodType.Unspecified
|
||||
if (num_expert_group or 0) > 0:
|
||||
return RoutingMethodType.DeepSeekV3
|
||||
if routed_scaling_factor in (None, 1.0):
|
||||
return RoutingMethodType.MiniMax2
|
||||
return RoutingMethodType.Unspecified
|
||||
else:
|
||||
return RoutingMethodType.Unspecified
|
||||
|
||||
if scoring_func == "sigmoid":
|
||||
if top_k == 1:
|
||||
return RoutingMethodType.Llama4
|
||||
elif renormalize:
|
||||
if renormalize:
|
||||
return RoutingMethodType.SigmoidRenorm
|
||||
else:
|
||||
return RoutingMethodType.Unspecified
|
||||
return RoutingMethodType.Sigmoid
|
||||
|
||||
if scoring_func == "softmax":
|
||||
if renormalize:
|
||||
|
||||
@@ -257,7 +257,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
return router_logits_dtype != torch.float32
|
||||
return router_logits_dtype in [torch.bfloat16, torch.float32]
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
@@ -279,6 +279,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
RoutingMethodType.SigmoidRenorm,
|
||||
RoutingMethodType.Sigmoid,
|
||||
RoutingMethodType.MiniMax2,
|
||||
RoutingMethodType.Simulated,
|
||||
]
|
||||
@@ -290,6 +291,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
RoutingMethodType.SigmoidRenorm,
|
||||
RoutingMethodType.Sigmoid,
|
||||
RoutingMethodType.MiniMax2,
|
||||
RoutingMethodType.Simulated,
|
||||
]
|
||||
@@ -402,11 +404,6 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
else:
|
||||
assert not apply_router_weight_on_input
|
||||
|
||||
# Currently FI requires bfloat16 routing bias.
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2909
|
||||
if e_score_correction_bias is not None:
|
||||
e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16)
|
||||
|
||||
out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
|
||||
@@ -360,7 +360,7 @@ class TrtLlmNvFp4ExpertsMonolithic(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
return router_logits_dtype != torch.float32
|
||||
return router_logits_dtype in [torch.bfloat16, torch.float32]
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -393,11 +393,6 @@ class TrtLlmNvFp4ExpertsMonolithic(
|
||||
and self.routing_method_type != RoutingMethodType.Llama4
|
||||
)
|
||||
|
||||
# Currently FI requires bfloat16 routing bias.
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2909
|
||||
if e_score_correction_bias is not None:
|
||||
e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16)
|
||||
|
||||
output1_scale_gate_scalar = self.quant_config.g1_alphas
|
||||
|
||||
# Invoke kernel.
|
||||
|
||||
@@ -326,6 +326,7 @@ class FusedTopKBiasRouter(BaseRouter):
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=None,
|
||||
has_e_score_bias=True,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
def _compute_routing(
|
||||
|
||||
@@ -283,6 +283,7 @@ class GroupedTopKRouter(BaseRouter):
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
has_e_score_bias=self.e_score_correction_bias is not None,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
def _compute_routing(
|
||||
|
||||
@@ -63,6 +63,7 @@ class ZeroExpertRouter(BaseRouter):
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=None,
|
||||
has_e_score_bias=True,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
def _compute_routing(
|
||||
|
||||
Reference in New Issue
Block a user