Fix dtype mismatch in topk_softplus_sqrt for DeepEP backends

DeepEP requires int64 topk indices, but DeepSeek-V4's hash MoE creates
input_ids and hash_indices_table as int32. The CUDA kernel dispatches on
topk_ids dtype and assumes all index tensors match, causing a crash:
"expected scalar type Long but found Int".

Cast input_ids and hash_indices_table to match indices_type before
calling the kernel. The hash table cast is cached since it's static.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Tyler Michael Smith
2026-05-11 15:24:14 -04:00
parent 6d7a3fab28
commit b4dbbc7102
@@ -278,6 +278,20 @@ class FusedTopKBiasRouter(BaseRouter):
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute routing using fused top-k with bias."""
# The topk kernel dispatches dtype based on topk_ids (set by
# indices_type) and assumes input_tokens/hash_indices_table match.
# Cast them here so backends like DeepEP that require int64 indices
# don't hit a dtype mismatch against the model's int32 buffers.
hash_table = self._hash_indices_table
if indices_type is not None:
if input_ids is not None:
input_ids = input_ids.to(dtype=indices_type)
if (hash_table is not None
and hash_table.dtype != indices_type):
self._hash_indices_table = hash_table.to(
dtype=indices_type)
hash_table = self._hash_indices_table
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
@@ -289,7 +303,7 @@ class FusedTopKBiasRouter(BaseRouter):
renormalize=self.renormalize,
indices_type=indices_type,
input_tokens=input_ids,
hash_indices_table=self._hash_indices_table,
hash_indices_table=hash_table,
routed_scaling_factor=self.routed_scaling_factor,
)