fix: clamp NaN/Inf in topk_softmax to prevent duplicate expert IDs (#39391)

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
This commit is contained in:
Jhao-Ting Chen
2026-04-21 04:04:41 -07:00
committed by GitHub
parent 3975eb6de6
commit 28c222157b
2 changed files with 86 additions and 2 deletions
+19 -2
View File
@@ -126,7 +126,9 @@ __launch_bounds__(TPB) __global__
{
const int idx = thread_row_offset + ii;
const float val = toFloat(input[idx]);
const float softmax_val = expf(val - float_max) * normalizing_factor;
float softmax_val = expf(val - float_max) * normalizing_factor;
// Clamp NaN/Inf to 0 to prevent duplicate expert IDs downstream.
if (isnan(softmax_val) || isinf(softmax_val)) softmax_val = 0.f;
output[idx] = softmax_val;
}
}
@@ -147,7 +149,9 @@ __launch_bounds__(TPB) __global__
{
const int idx = thread_row_offset + ii;
const float val = toFloat(input[idx]);
const float sigmoid_val = 1.0f / (1.0f + __expf(-val));
float sigmoid_val = 1.0f / (1.0f + __expf(-val));
// Clamp NaN/Inf to 0 to prevent duplicate expert IDs downstream.
if (isnan(sigmoid_val) || isinf(sigmoid_val)) sigmoid_val = 0.f;
output[idx] = sigmoid_val;
}
}
@@ -442,6 +446,19 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
}
}
// Fix: clamp NaN/Inf values to 0 to prevent duplicate expert IDs.
// NaN gating (from degenerate hidden states in CUDA graph padding) causes
// softmax to produce all-NaN, which makes the argmax loop always pick
// expert 0 for every top-k slot, producing duplicate expert IDs that
// crash FlashInfer's three-step MoE sort.
// With 0s, the argmax uses index tie-breaking to pick [0,1,2,...,k-1].
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
if (isnan(row_chunk[ii]) || isinf(row_chunk[ii])) {
row_chunk[ii] = 0.f;
}
}
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
// If bias is not null, use biased value for selection
+67
View File
@@ -135,3 +135,70 @@ def test_fused_topk_bias(
topk_weights_ref.to(torch.float32), topk_weights, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(topk_ids_ref.to(torch.int32), topk_ids, atol=0, rtol=0)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize("num_experts", [6, 8, 16])
@pytest.mark.parametrize("topk", [3, 4])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("bad_value", [float("nan"), float("inf")])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
def test_fused_topk_nan_inf_clamp(
num_experts: int,
topk: int,
scoring_func: str,
bad_value: float,
dtype: torch.dtype,
):
"""Regression test for the NaN/Inf clamp in topk_softmax_kernels.cu.
Degenerate hidden states (e.g., from CUDA graph padding) can produce
NaN/Inf gating logits. Without the clamp, softmax/sigmoid outputs are
NaN and the argmax loop picks expert 0 for every top-k slot (since
"NaN > NaN" is false per IEEE 754), yielding duplicate expert IDs that
crash downstream MoE sort kernels. The fix clamps NaN/Inf to 0 before
argmax so index tie-breaking selects unique experts [0, 1, ..., k-1].
"""
torch.manual_seed(0)
num_tokens = 4
hidden_size = 1024
hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda")
# Row 0: all normal. Rows 1-3: fully poisoned with NaN or Inf.
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
gating_output[1:, :] = bad_value
topk_weights, topk_ids, _ = fused_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=False,
scoring_func=scoring_func,
)
# Normal row must still match the torch reference.
ref_weights, ref_ids = torch_topk(
gating_output=gating_output[:1],
topk=topk,
renormalize=False,
scoring_func=scoring_func,
)
torch.testing.assert_close(
ref_weights.to(torch.float32), topk_weights[:1], atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(ref_ids.to(torch.int32), topk_ids[:1], atol=0, rtol=0)
# Poisoned rows: IDs must be unique (no duplicates) and weights must be
# finite (no NaN/Inf propagation into downstream MoE kernels).
for row in range(1, num_tokens):
row_ids = topk_ids[row]
assert row_ids.unique().numel() == topk, (
f"Row {row} has duplicate expert IDs {row_ids.tolist()} "
f"(bad_value={bad_value}, scoring_func={scoring_func})"
)
assert torch.isfinite(topk_weights[row]).all(), (
f"Row {row} has non-finite weights {topk_weights[row].tolist()} "
f"(bad_value={bad_value}, scoring_func={scoring_func})"
)