mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
fix: clamp NaN/Inf in topk_softmax to prevent duplicate expert IDs (#39391)
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
This commit is contained in:
@@ -126,7 +126,9 @@ __launch_bounds__(TPB) __global__
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val = toFloat(input[idx]);
|
||||
const float softmax_val = expf(val - float_max) * normalizing_factor;
|
||||
float softmax_val = expf(val - float_max) * normalizing_factor;
|
||||
// Clamp NaN/Inf to 0 to prevent duplicate expert IDs downstream.
|
||||
if (isnan(softmax_val) || isinf(softmax_val)) softmax_val = 0.f;
|
||||
output[idx] = softmax_val;
|
||||
}
|
||||
}
|
||||
@@ -147,7 +149,9 @@ __launch_bounds__(TPB) __global__
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val = toFloat(input[idx]);
|
||||
const float sigmoid_val = 1.0f / (1.0f + __expf(-val));
|
||||
float sigmoid_val = 1.0f / (1.0f + __expf(-val));
|
||||
// Clamp NaN/Inf to 0 to prevent duplicate expert IDs downstream.
|
||||
if (isnan(sigmoid_val) || isinf(sigmoid_val)) sigmoid_val = 0.f;
|
||||
output[idx] = sigmoid_val;
|
||||
}
|
||||
}
|
||||
@@ -442,6 +446,19 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
||||
}
|
||||
}
|
||||
|
||||
// Fix: clamp NaN/Inf values to 0 to prevent duplicate expert IDs.
|
||||
// NaN gating (from degenerate hidden states in CUDA graph padding) causes
|
||||
// softmax to produce all-NaN, which makes the argmax loop always pick
|
||||
// expert 0 for every top-k slot, producing duplicate expert IDs that
|
||||
// crash FlashInfer's three-step MoE sort.
|
||||
// With 0s, the argmax uses index tie-breaking to pick [0,1,2,...,k-1].
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
if (isnan(row_chunk[ii]) || isinf(row_chunk[ii])) {
|
||||
row_chunk[ii] = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
||||
|
||||
// If bias is not null, use biased value for selection
|
||||
|
||||
@@ -135,3 +135,70 @@ def test_fused_topk_bias(
|
||||
topk_weights_ref.to(torch.float32), topk_weights, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
torch.testing.assert_close(topk_ids_ref.to(torch.int32), topk_ids, atol=0, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
|
||||
)
|
||||
@pytest.mark.parametrize("num_experts", [6, 8, 16])
|
||||
@pytest.mark.parametrize("topk", [3, 4])
|
||||
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
|
||||
@pytest.mark.parametrize("bad_value", [float("nan"), float("inf")])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
|
||||
def test_fused_topk_nan_inf_clamp(
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
scoring_func: str,
|
||||
bad_value: float,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Regression test for the NaN/Inf clamp in topk_softmax_kernels.cu.
|
||||
|
||||
Degenerate hidden states (e.g., from CUDA graph padding) can produce
|
||||
NaN/Inf gating logits. Without the clamp, softmax/sigmoid outputs are
|
||||
NaN and the argmax loop picks expert 0 for every top-k slot (since
|
||||
"NaN > NaN" is false per IEEE 754), yielding duplicate expert IDs that
|
||||
crash downstream MoE sort kernels. The fix clamps NaN/Inf to 0 before
|
||||
argmax so index tie-breaking selects unique experts [0, 1, ..., k-1].
|
||||
"""
|
||||
torch.manual_seed(0)
|
||||
num_tokens = 4
|
||||
hidden_size = 1024
|
||||
hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda")
|
||||
|
||||
# Row 0: all normal. Rows 1-3: fully poisoned with NaN or Inf.
|
||||
gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
|
||||
gating_output[1:, :] = bad_value
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=False,
|
||||
scoring_func=scoring_func,
|
||||
)
|
||||
|
||||
# Normal row must still match the torch reference.
|
||||
ref_weights, ref_ids = torch_topk(
|
||||
gating_output=gating_output[:1],
|
||||
topk=topk,
|
||||
renormalize=False,
|
||||
scoring_func=scoring_func,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
ref_weights.to(torch.float32), topk_weights[:1], atol=1e-2, rtol=1e-2
|
||||
)
|
||||
torch.testing.assert_close(ref_ids.to(torch.int32), topk_ids[:1], atol=0, rtol=0)
|
||||
|
||||
# Poisoned rows: IDs must be unique (no duplicates) and weights must be
|
||||
# finite (no NaN/Inf propagation into downstream MoE kernels).
|
||||
for row in range(1, num_tokens):
|
||||
row_ids = topk_ids[row]
|
||||
assert row_ids.unique().numel() == topk, (
|
||||
f"Row {row} has duplicate expert IDs {row_ids.tolist()} "
|
||||
f"(bad_value={bad_value}, scoring_func={scoring_func})"
|
||||
)
|
||||
assert torch.isfinite(topk_weights[row]).all(), (
|
||||
f"Row {row} has non-finite weights {topk_weights[row].tolist()} "
|
||||
f"(bad_value={bad_value}, scoring_func={scoring_func})"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user