[TRTLLM-10126][feat] Increase topk upper limit to 22 for NVLinkOneSid… (#10229)

Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
This commit is contained in:
Guoming Zhang 2025-12-27 22:48:10 +08:00 committed by GitHub
parent 27976fce9c
commit 93ac0bc1dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 71 additions and 2 deletions

View File

@ -48,6 +48,12 @@ namespace kernels::moe_comm
#define SWITCH_TOP_K(top_k, TOP_K, ...) \
switch (top_k) \
{ \
case 22: \
{ \
constexpr int TOP_K = 22; \
__VA_ARGS__; \
break; \
} \
case 16: \
{ \
constexpr int TOP_K = 16; \
@ -654,7 +660,69 @@ __device__ void vectorized_combine_impl(
acc[k].load(recv_buffer + base_token + offset);
}
// Reduce acc[TOP_K] into acc[0]
if constexpr (TOP_K == 16)
if constexpr (TOP_K == 22)
{
T* a0 = reinterpret_cast<T*>(&acc[0]);
T* a1 = reinterpret_cast<T*>(&acc[1]);
T* a2 = reinterpret_cast<T*>(&acc[2]);
T* a3 = reinterpret_cast<T*>(&acc[3]);
T* a4 = reinterpret_cast<T*>(&acc[4]);
T* a5 = reinterpret_cast<T*>(&acc[5]);
T* a6 = reinterpret_cast<T*>(&acc[6]);
T* a7 = reinterpret_cast<T*>(&acc[7]);
T* a8 = reinterpret_cast<T*>(&acc[8]);
T* a9 = reinterpret_cast<T*>(&acc[9]);
T* a10 = reinterpret_cast<T*>(&acc[10]);
T* a11 = reinterpret_cast<T*>(&acc[11]);
T* a12 = reinterpret_cast<T*>(&acc[12]);
T* a13 = reinterpret_cast<T*>(&acc[13]);
T* a14 = reinterpret_cast<T*>(&acc[14]);
T* a15 = reinterpret_cast<T*>(&acc[15]);
T* a16 = reinterpret_cast<T*>(&acc[16]);
T* a17 = reinterpret_cast<T*>(&acc[17]);
T* a18 = reinterpret_cast<T*>(&acc[18]);
T* a19 = reinterpret_cast<T*>(&acc[19]);
T* a20 = reinterpret_cast<T*>(&acc[20]);
T* a21 = reinterpret_cast<T*>(&acc[21]);
#pragma unroll
for (int j = 0; j < elems_per_vec; ++j)
{
a0[j] += a1[j];
a2[j] += a3[j];
a4[j] += a5[j];
a6[j] += a7[j];
a8[j] += a9[j];
a10[j] += a11[j];
a12[j] += a13[j];
a14[j] += a15[j];
a16[j] += a17[j];
a18[j] += a19[j];
a20[j] += a21[j];
}
#pragma unroll
for (int j = 0; j < elems_per_vec; ++j)
{
a0[j] += a2[j];
a4[j] += a6[j];
a8[j] += a10[j];
a12[j] += a14[j];
a16[j] += a18[j];
}
#pragma unroll
for (int j = 0; j < elems_per_vec; ++j)
{
a0[j] += a4[j];
a8[j] += a12[j];
a16[j] += a20[j];
}
#pragma unroll
for (int j = 0; j < elems_per_vec; ++j)
{
a0[j] += a8[j];
a0[j] += a16[j];
}
}
else if constexpr (TOP_K == 16)
{
T* a0 = reinterpret_cast<T*>(&acc[0]);
T* a1 = reinterpret_cast<T*>(&acc[1]);

View File

@ -26,7 +26,7 @@ namespace kernels::moe_comm
{
// Configuration constants
static constexpr int kMaxTopK = 16; // Maximum top-k experts per token
static constexpr int kMaxTopK = 22; // Maximum top-k experts per token
static constexpr int kMaxPayloads = 4; // Maximum number of different payload types
static constexpr int kMaxRanks = 64; // Maximum supported EP size

View File

@ -565,6 +565,7 @@ class TestMoEAlltoAll:
(2, [100, 50], 2),
(4, [32, 32, 32, 32], 4),
(4, [32, 32, 32, 32], 10), # (top_k=10 is used by Qwen3-next)
(4, [32, 32, 32, 32], 22),
(4, [1, 1, 1, 1], 2),
(8, [640, 640, 640, 640, 640, 640, 640, 640], 4),
(4, [32, 0, 16, 0], 2),