diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 833036da528..57461a044f9 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -126,7 +126,9 @@ __launch_bounds__(TPB) __global__ { const int idx = thread_row_offset + ii; const float val = toFloat(input[idx]); - const float softmax_val = expf(val - float_max) * normalizing_factor; + float softmax_val = expf(val - float_max) * normalizing_factor; + // Clamp NaN/Inf to 0 to prevent duplicate expert IDs downstream. + if (isnan(softmax_val) || isinf(softmax_val)) softmax_val = 0.f; output[idx] = softmax_val; } } @@ -147,7 +149,9 @@ __launch_bounds__(TPB) __global__ { const int idx = thread_row_offset + ii; const float val = toFloat(input[idx]); - const float sigmoid_val = 1.0f / (1.0f + __expf(-val)); + float sigmoid_val = 1.0f / (1.0f + __expf(-val)); + // Clamp NaN/Inf to 0 to prevent duplicate expert IDs downstream. + if (isnan(sigmoid_val) || isinf(sigmoid_val)) sigmoid_val = 0.f; output[idx] = sigmoid_val; } } @@ -442,6 +446,19 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ } } + // Fix: clamp NaN/Inf values to 0 to prevent duplicate expert IDs. + // NaN gating (from degenerate hidden states in CUDA graph padding) causes + // softmax to produce all-NaN, which makes the argmax loop always pick + // expert 0 for every top-k slot, producing duplicate expert IDs that + // crash FlashInfer's three-step MoE sort. + // With 0s, the argmax uses index tie-breaking to pick [0,1,2,...,k-1]. +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + if (isnan(row_chunk[ii]) || isinf(row_chunk[ii])) { + row_chunk[ii] = 0.f; + } + } + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; // If bias is not null, use biased value for selection diff --git a/tests/kernels/moe/test_fused_topk.py b/tests/kernels/moe/test_fused_topk.py index 5384d8964b5..a0e3580ee5a 100644 --- a/tests/kernels/moe/test_fused_topk.py +++ b/tests/kernels/moe/test_fused_topk.py @@ -135,3 +135,70 @@ def test_fused_topk_bias( topk_weights_ref.to(torch.float32), topk_weights, atol=1e-2, rtol=1e-2 ) torch.testing.assert_close(topk_ids_ref.to(torch.int32), topk_ids, atol=0, rtol=0) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) +@pytest.mark.parametrize("num_experts", [6, 8, 16]) +@pytest.mark.parametrize("topk", [3, 4]) +@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) +@pytest.mark.parametrize("bad_value", [float("nan"), float("inf")]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32]) +def test_fused_topk_nan_inf_clamp( + num_experts: int, + topk: int, + scoring_func: str, + bad_value: float, + dtype: torch.dtype, +): + """Regression test for the NaN/Inf clamp in topk_softmax_kernels.cu. + + Degenerate hidden states (e.g., from CUDA graph padding) can produce + NaN/Inf gating logits. Without the clamp, softmax/sigmoid outputs are + NaN and the argmax loop picks expert 0 for every top-k slot (since + "NaN > NaN" is false per IEEE 754), yielding duplicate expert IDs that + crash downstream MoE sort kernels. The fix clamps NaN/Inf to 0 before + argmax so index tie-breaking selects unique experts [0, 1, ..., k-1]. + """ + torch.manual_seed(0) + num_tokens = 4 + hidden_size = 1024 + hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda") + + # Row 0: all normal. Rows 1-3: fully poisoned with NaN or Inf. + gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda") + gating_output[1:, :] = bad_value + + topk_weights, topk_ids, _ = fused_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=False, + scoring_func=scoring_func, + ) + + # Normal row must still match the torch reference. + ref_weights, ref_ids = torch_topk( + gating_output=gating_output[:1], + topk=topk, + renormalize=False, + scoring_func=scoring_func, + ) + torch.testing.assert_close( + ref_weights.to(torch.float32), topk_weights[:1], atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close(ref_ids.to(torch.int32), topk_ids[:1], atol=0, rtol=0) + + # Poisoned rows: IDs must be unique (no duplicates) and weights must be + # finite (no NaN/Inf propagation into downstream MoE kernels). + for row in range(1, num_tokens): + row_ids = topk_ids[row] + assert row_ids.unique().numel() == topk, ( + f"Row {row} has duplicate expert IDs {row_ids.tolist()} " + f"(bad_value={bad_value}, scoring_func={scoring_func})" + ) + assert torch.isfinite(topk_weights[row]).all(), ( + f"Row {row} has non-finite weights {topk_weights[row].tolist()} " + f"(bad_value={bad_value}, scoring_func={scoring_func})" + )