mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-10126][feat] Increase topk upper limit to 22 for NVLinkOneSid… (#10229)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
This commit is contained in:
parent
27976fce9c
commit
93ac0bc1dc
@ -48,6 +48,12 @@ namespace kernels::moe_comm
|
||||
#define SWITCH_TOP_K(top_k, TOP_K, ...) \
|
||||
switch (top_k) \
|
||||
{ \
|
||||
case 22: \
|
||||
{ \
|
||||
constexpr int TOP_K = 22; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case 16: \
|
||||
{ \
|
||||
constexpr int TOP_K = 16; \
|
||||
@ -654,7 +660,69 @@ __device__ void vectorized_combine_impl(
|
||||
acc[k].load(recv_buffer + base_token + offset);
|
||||
}
|
||||
// Reduce acc[TOP_K] into acc[0]
|
||||
if constexpr (TOP_K == 16)
|
||||
if constexpr (TOP_K == 22)
|
||||
{
|
||||
T* a0 = reinterpret_cast<T*>(&acc[0]);
|
||||
T* a1 = reinterpret_cast<T*>(&acc[1]);
|
||||
T* a2 = reinterpret_cast<T*>(&acc[2]);
|
||||
T* a3 = reinterpret_cast<T*>(&acc[3]);
|
||||
T* a4 = reinterpret_cast<T*>(&acc[4]);
|
||||
T* a5 = reinterpret_cast<T*>(&acc[5]);
|
||||
T* a6 = reinterpret_cast<T*>(&acc[6]);
|
||||
T* a7 = reinterpret_cast<T*>(&acc[7]);
|
||||
T* a8 = reinterpret_cast<T*>(&acc[8]);
|
||||
T* a9 = reinterpret_cast<T*>(&acc[9]);
|
||||
T* a10 = reinterpret_cast<T*>(&acc[10]);
|
||||
T* a11 = reinterpret_cast<T*>(&acc[11]);
|
||||
T* a12 = reinterpret_cast<T*>(&acc[12]);
|
||||
T* a13 = reinterpret_cast<T*>(&acc[13]);
|
||||
T* a14 = reinterpret_cast<T*>(&acc[14]);
|
||||
T* a15 = reinterpret_cast<T*>(&acc[15]);
|
||||
T* a16 = reinterpret_cast<T*>(&acc[16]);
|
||||
T* a17 = reinterpret_cast<T*>(&acc[17]);
|
||||
T* a18 = reinterpret_cast<T*>(&acc[18]);
|
||||
T* a19 = reinterpret_cast<T*>(&acc[19]);
|
||||
T* a20 = reinterpret_cast<T*>(&acc[20]);
|
||||
T* a21 = reinterpret_cast<T*>(&acc[21]);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < elems_per_vec; ++j)
|
||||
{
|
||||
a0[j] += a1[j];
|
||||
a2[j] += a3[j];
|
||||
a4[j] += a5[j];
|
||||
a6[j] += a7[j];
|
||||
a8[j] += a9[j];
|
||||
a10[j] += a11[j];
|
||||
a12[j] += a13[j];
|
||||
a14[j] += a15[j];
|
||||
a16[j] += a17[j];
|
||||
a18[j] += a19[j];
|
||||
a20[j] += a21[j];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < elems_per_vec; ++j)
|
||||
{
|
||||
a0[j] += a2[j];
|
||||
a4[j] += a6[j];
|
||||
a8[j] += a10[j];
|
||||
a12[j] += a14[j];
|
||||
a16[j] += a18[j];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < elems_per_vec; ++j)
|
||||
{
|
||||
a0[j] += a4[j];
|
||||
a8[j] += a12[j];
|
||||
a16[j] += a20[j];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < elems_per_vec; ++j)
|
||||
{
|
||||
a0[j] += a8[j];
|
||||
a0[j] += a16[j];
|
||||
}
|
||||
}
|
||||
else if constexpr (TOP_K == 16)
|
||||
{
|
||||
T* a0 = reinterpret_cast<T*>(&acc[0]);
|
||||
T* a1 = reinterpret_cast<T*>(&acc[1]);
|
||||
|
||||
@ -26,7 +26,7 @@ namespace kernels::moe_comm
|
||||
{
|
||||
|
||||
// Configuration constants
|
||||
static constexpr int kMaxTopK = 16; // Maximum top-k experts per token
|
||||
static constexpr int kMaxTopK = 22; // Maximum top-k experts per token
|
||||
static constexpr int kMaxPayloads = 4; // Maximum number of different payload types
|
||||
static constexpr int kMaxRanks = 64; // Maximum supported EP size
|
||||
|
||||
|
||||
@ -565,6 +565,7 @@ class TestMoEAlltoAll:
|
||||
(2, [100, 50], 2),
|
||||
(4, [32, 32, 32, 32], 4),
|
||||
(4, [32, 32, 32, 32], 10), # (top_k=10 is used by Qwen3-next)
|
||||
(4, [32, 32, 32, 32], 22),
|
||||
(4, [1, 1, 1, 1], 2),
|
||||
(8, [640, 640, 640, 640, 640, 640, 640, 640], 4),
|
||||
(4, [32, 0, 16, 0], 2),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user