[Cohere] fix RoutingMethodType (#44021)

Signed-off-by: Terrencezzj <terrence@cohere.ai>
This commit is contained in:
Terrence Zhao
2026-06-05 19:25:53 -04:00
committed by GitHub
parent f6a708ab2b
commit a50e675b0d
3 changed files with 7 additions and 3 deletions
@@ -80,6 +80,8 @@ class TrtLlmBf16Experts(mk.FusedMoEExpertsMonolithic):
RoutingMethodType.Llama4,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.SigmoidRenorm,
RoutingMethodType.Sigmoid,
]
@staticmethod
@@ -350,9 +350,9 @@ class TrtLlmNvFp4ExpertsMonolithic(
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.Llama4,
RoutingMethodType.SigmoidRenorm,
RoutingMethodType.Sigmoid,
RoutingMethodType.MiniMax2,
RoutingMethodType.Simulated,
RoutingMethodType.SigmoidRenorm,
]
@staticmethod
@@ -38,9 +38,11 @@ class CustomRoutingRouter(BaseRouter):
# NOTE: FLASHINFER_TRTLLM support the Llama4 router.
if self.custom_routing_function == Llama4MoE.custom_routing_function:
return RoutingMethodType.Llama4
# Cohere MoE uses a sigmoid -> top-k -> renormalize routing function.
# Cohere MoE uses sigmoid -> top-k, optionally followed by renormalize.
if self.custom_routing_function == token_choice_with_bias:
if self.renormalize:
return RoutingMethodType.SigmoidRenorm
return RoutingMethodType.Sigmoid
return RoutingMethodType.Custom
def _compute_routing(