[None][feat] Add routing support for the new model for both cutlass and trtllm moe backend (#9792)

Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
This commit is contained in:
ChristinaZ 2025-12-16 11:59:08 +08:00 committed by GitHub
parent 4ce35eacf1
commit dff77efa2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 180 additions and 91 deletions

View File

@ -32,11 +32,14 @@ TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
static constexpr int WARP_SIZE = 32;
static constexpr int NumNemotronExperts = 512;
static constexpr int NumKimiK2Experts = 384;
static constexpr int NumDeepseekExperts = 256;
static constexpr int MaxSupportedExpertCount = std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts});
static constexpr int MaxNumExpertsUnit = 128;
static constexpr int NumTopGroupScores = 2;
static constexpr int MaxNumTopExperts = 8;
static constexpr int DefaultMaxNumTopExperts = 8;
static constexpr int MaxSupportedTopExperts = 22;
static constexpr int MaxNumTopGroups = 4;
static __device__ inline float sigmoid_accurate(float x)
@ -44,7 +47,8 @@ static __device__ inline float sigmoid_accurate(float x)
return 0.5f * tanhf(0.5f * x) + 0.5f;
}
template <typename InputT, typename BiasT, typename OutputT, typename IdxT, int MaxNumExperts, bool UseGroups>
template <typename InputT, typename BiasT, typename OutputT, typename IdxT, int MaxNumExperts, bool UseGroups,
int MaxNumTopExperts = DefaultMaxNumTopExperts>
__global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, IdxT* topkIndices, BiasT* routingBias,
int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup, int64_t const topk,
int64_t const numExperts, int64_t const numExpertsPerGroup, double const routedScalingFactor)
@ -132,7 +136,7 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
/* minValue */ invalidScoreFloat);
// get the final group score and write it to shared
if (laneIdx == 0)
if (warp.thread_rank() == 0)
{
auto groupScore = topExpGroupScores[0] + topExpGroupScores[1];
smemGroupScores[warpIdx] = groupScore;
@ -151,9 +155,7 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
reduce_topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx,
/* minValue */ invalidScoreFloat);
// final expert selection: get relevant indexes and scores from shared
#pragma unroll
for (int ii = 0; ii < MaxNumTopGroups; ++ii)
{ // bound of numGroup
@ -161,12 +163,11 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
expertIdxGroup[ii] = groupIdx * numExpertsPerGroup + laneIdx;
expertScoreGroup[ii]
= groupIdx < numGroup && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat;
= (ii < topkGroup) && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat;
}
tensorrt_llm::kernels::reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
expertIdxGroup,
/* minValue */ invalidScoreFloat, topk);
tensorrt_llm::kernels::reduce_topk::reduceTopK(
warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, /* minValue */ invalidScoreFloat, topk);
}
}
else if constexpr (MaxNumExperts > MaxNumExpertsUnit)
@ -197,11 +198,16 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx];
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx];
}
else if (laneIdx >= topk && laneIdx < MaxNumTopExperts)
{
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = invalidScoreFloat;
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = MaxNumExperts - 1;
}
}
__syncthreads();
if (warpIdx == 0)
{
int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WARP_SIZE + 1;
int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1;
float intermidiateScore[NumInterTopKPerThread];
int32_t intermidiateExpert[NumInterTopKPerThread];
for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; i += WARP_SIZE)
@ -268,11 +274,11 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk
{
// Check if we can use the optimized deepseek_v3_topk_kernel
bool const is_single_group = (n_group == 1) && (num_experts <= NumKimiK2Experts);
bool const is_single_group = (n_group <= 1) && (num_experts <= MaxSupportedExpertCount);
int64_t const experts_per_group = num_experts / n_group;
bool const is_multi_group = (n_group != 1) && (num_experts <= NumDeepseekExperts)
&& (experts_per_group <= WARP_SIZE) && (experts_per_group * topk_group <= MaxNumExpertsUnit);
bool const is_multi_group = (n_group > 1) && (num_experts <= NumDeepseekExperts) && (experts_per_group <= WARP_SIZE)
&& (experts_per_group * topk_group <= MaxNumExpertsUnit);
if (is_single_group || is_multi_group)
{
@ -281,7 +287,20 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk
int num_threads = NumDeepseekExperts;
if (is_single_group)
{
if (num_experts > MaxNumExpertsUnit)
// Special case for Nemotron, which selects top 22 from 512 experts, and 1 group only.
if (num_experts == NumNemotronExperts && n_group == 1 && topk == MaxSupportedTopExperts)
{
kernel_instance = &deepseek_v3_topk_kernel<InputT, BiasT, OutputT, IdxT, NumNemotronExperts, false,
MaxSupportedTopExperts>;
num_threads = NumNemotronExperts;
}
else if (num_experts > NumKimiK2Experts && num_experts <= MaxSupportedExpertCount)
{
kernel_instance
= &deepseek_v3_topk_kernel<InputT, BiasT, OutputT, IdxT, MaxSupportedExpertCount, false>;
num_threads = MaxSupportedExpertCount;
}
else if (num_experts > MaxNumExpertsUnit && num_experts <= NumKimiK2Experts)
{
kernel_instance = &deepseek_v3_topk_kernel<InputT, BiasT, OutputT, IdxT, NumKimiK2Experts, false>;
num_threads = NumKimiK2Experts;

View File

@ -182,37 +182,37 @@ namespace moe::dev
TLLM_LOG_ERROR("Unsupported dtypeExpW"); \
}
#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT( \
data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag, forceFloatInput, numExperts) \
#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
stream, extraFlag, forceFloatInput, numExperts, numTopExperts) \
if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) \
{ \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, numThreads, \
smemSize, stream); \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, true), kernel, numBlocks, \
numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Fp32) \
{ \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, numThreads, \
smemSize, stream); \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, false), kernel, numBlocks, \
numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) \
{ \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, true), kernel, numBlocks, \
numThreads, smemSize, stream); \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, true), kernel, \
numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) \
{ \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), kernel, numBlocks, \
numThreads, smemSize, stream); \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), \
kernel, numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) \
{ \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, false), kernel, numBlocks, \
numThreads, smemSize, stream); \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, false), kernel, \
numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \
{ \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), kernel, numBlocks, \
numThreads, smemSize, stream); \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), \
kernel, numBlocks, numThreads, smemSize, stream); \
} \
else \
{ \

View File

@ -23,11 +23,13 @@ namespace routingDeepSeek
{
////////////////////////////////////////////////////////////////////////////////////////////////////
static constexpr int NumNemotronExperts = 512;
static constexpr int NumKimiK2Experts = 384;
static constexpr int NumDeepseekExperts = 256;
static constexpr int MaxSupportedExpertCount = std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts});
static constexpr int NumTopGroupScores = 2;
static constexpr int MaxNumTopExperts = 8;
static constexpr int DefaultMaxNumTopExperts = 8;
static constexpr int MaxSupportedTopExperts = 22;
static constexpr int MaxNumTopGroups = 4;
static constexpr int MaxNumGroups = 8;
@ -125,8 +127,8 @@ __global__ void routingMainKernel(KernelParams params)
int32_t topGroupIdx[MaxNumTopGroups];
float expertScoreGroup[MaxNumTopGroups];
int32_t expertIdxGroup[MaxNumTopGroups];
float topScores[MaxNumTopExperts]; // bound of params.mTopK
int32_t topExperts[MaxNumTopExperts];
float topScores[KernelParams::MaxNumTopExperts]; // bound of params.mTopK
int32_t topExperts[KernelParams::MaxNumTopExperts];
if constexpr (KernelParams::UseGroups)
{
@ -152,7 +154,6 @@ __global__ void routingMainKernel(KernelParams params)
topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx,
/* minValue */ invalidScoreFloat);
// final expert selection: get relevant indexes and scores from shared
#pragma unroll
for (int ii = 0; ii < MaxNumTopGroups; ++ii)
{ // bound of params.mNumLimitedGroups
@ -164,7 +165,8 @@ __global__ void routingMainKernel(KernelParams params)
// groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - params.mNumExpertsPerGroup
// => expertIdxGroup[ii] < params.mNumExperts <= NumThreads,
// so the access is safe here
expertScoreGroup[ii] = groupIdx < params.mNumExpertGroups && expertSelected
expertScoreGroup[ii]
= (ii < params.mNumLimitedGroups) && (groupIdx < params.mNumExpertGroups) && expertSelected
? smemScoreBias[expertIdxGroup[ii]]
: invalidScoreFloat;
}
@ -177,7 +179,7 @@ __global__ void routingMainKernel(KernelParams params)
{
// without groups, each thread just takes `MaxNumTopGroups` experts
int constexpr NumExpertWarps = (KernelParams::MaxNumExperts - 1) / topk::MaxNumExpertsUnit + 1;
int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts;
int constexpr NumInterTopK = NumExpertWarps * KernelParams::MaxNumTopExperts;
__shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK];
__shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK];
if (warpIdx < NumExpertWarps)
@ -196,14 +198,20 @@ __global__ void routingMainKernel(KernelParams params)
if (laneIdx < params.mTopK)
{
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx];
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx];
smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = topScores[laneIdx];
smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = topExperts[laneIdx];
}
else if (laneIdx >= params.mTopK && laneIdx < KernelParams::MaxNumTopExperts)
{
smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = invalidScoreFloat;
smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx]
= MaxSupportedExpertCount - 1;
}
}
__syncthreads();
if (warpIdx == 0)
{
int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WarpSize + 1;
int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1;
float intermidiateScore[NumInterTopKPerThread];
int32_t intermidiateExpert[NumInterTopKPerThread];
for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize)
@ -295,7 +303,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Ke
cudaGridDependencySynchronize();
}
routingPermutation<KernelParams, OutputT, KernelParams::MaxNumExperts, KernelParams::MaxNumExperts / WarpSize,
MaxNumTopExperts, /*LoadExpertIdxFromGlobal=*/true>(params, nullptr, warpIdx, clusterBlockRank);
KernelParams::MaxNumTopExperts, /*LoadExpertIdxFromGlobal=*/true>(params, nullptr, warpIdx, clusterBlockRank);
}
#else
__global__ void routingIndicesClusterKernel(KernelParams params)
@ -558,6 +566,10 @@ int constexpr getMaxNumExperts(int32_t numExperts)
{
return NumKimiK2Experts;
}
else if (numExperts <= NumNemotronExperts)
{
return NumNemotronExperts;
}
else
{
TLLM_LOG_ERROR("Unsupported numExperts");
@ -571,17 +583,30 @@ int constexpr getMaxNumExperts(int32_t numExperts)
if (data.mNumExperts <= topk::MaxNumExpertsUnit) \
{ \
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
stream, extraFlag1, forceFloatInput, topk::MaxNumExpertsUnit); \
stream, extraFlag1, forceFloatInput, topk::MaxNumExpertsUnit, DefaultMaxNumTopExperts); \
} \
else if (data.mNumExperts <= NumDeepseekExperts) \
{ \
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
stream, extraFlag1, forceFloatInput, NumDeepseekExperts); \
stream, extraFlag1, forceFloatInput, NumDeepseekExperts, DefaultMaxNumTopExperts); \
} \
else if (data.mNumExperts <= NumKimiK2Experts) \
{ \
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
stream, extraFlag1, forceFloatInput, NumKimiK2Experts); \
stream, extraFlag1, forceFloatInput, NumKimiK2Experts, DefaultMaxNumTopExperts); \
} \
else if (data.mNumExperts <= NumNemotronExperts) \
{ \
if (data.mTopK <= DefaultMaxNumTopExperts) \
{ \
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, \
smemSize, stream, extraFlag1, forceFloatInput, NumNemotronExperts, DefaultMaxNumTopExperts); \
} \
else if (data.mTopK <= MaxSupportedTopExperts) \
{ \
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, \
smemSize, stream, extraFlag1, forceFloatInput, NumNemotronExperts, MaxSupportedTopExperts); \
} \
} \
else \
{ \
@ -603,25 +628,6 @@ void run(Data& data, void* stream)
(data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) && data.mPtrPermutedIdxSize,
"If permuted index is required, `mPtrTopKPacked` or `mPtrTopKIds` is also required");
TLLM_CHECK_WITH_INFO(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet");
TLLM_CHECK_WITH_INFO(data.mNumLimitedGroups <= MaxNumTopGroups, "Routing kernel expects <= %d top groups, got %d",
MaxNumTopGroups, data.mNumLimitedGroups);
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d",
MaxNumTopExperts, data.mTopK);
TLLM_CHECK_WITH_INFO(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", data.mTopK);
TLLM_CHECK_WITH_INFO(data.mTopK * data.mNumLimitedGroups <= WarpSize,
"Routing kernel expects top K * top groups <= warp size (for now), got %d * %d", data.mTopK,
data.mNumLimitedGroups);
TLLM_CHECK_WITH_INFO(data.mNumExperts >= MaxNumTopExperts, "Routing kernel expects %d to be at most #experts %d",
MaxNumTopExperts, data.mNumExperts);
TLLM_CHECK_WITH_INFO(data.mNumExperts <= NumKimiK2Experts, "Routing kernel expects #experts %d <= #threads %d",
data.mNumExperts, NumKimiK2Experts);
TLLM_CHECK_WITH_INFO(data.mNumExpertGroups >= data.mNumLimitedGroups,
"Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups,
data.mNumExpertGroups);
// Note: Routing-specific constraints (experts per group, topK limits) are checked later
// only when routing is actually needed (data.mPtrTopKIds == nullptr)
TLLM_CHECK_WITH_INFO(
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
int const numBlocks = data.mNumTokens;
int const numThreadsHist = getMaxNumExperts(data.mNumExperts);
@ -655,9 +661,18 @@ void run(Data& data, void* stream)
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
if (data.mPtrTopKIds == nullptr)
{
TLLM_CHECK_WITH_INFO(data.mNumExperts >= MaxSupportedTopExperts,
"Routing kernel expects %d to be at most #experts %d", MaxSupportedTopExperts, data.mNumExperts);
TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxSupportedExpertCount,
"Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, MaxSupportedExpertCount);
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d",
MaxSupportedTopExperts, data.mTopK);
// Routing needs to be executed - validate routing kernel constraints
if (data.mNumExpertGroups > 1)
{
// Note: Routing-specific constraints (experts per group, topK limits) are checked when routing is actually
// needed (data.mPtrTopKIds == nullptr)
TLLM_CHECK_WITH_INFO(data.mNumExpertGroups <= MaxNumGroups,
"Routing kernel expects #expert groups %d to be <= max groups %d", data.mNumExpertGroups, MaxNumGroups);
TLLM_CHECK_WITH_INFO(data.mNumExperts % data.mNumExpertGroups == 0,
@ -667,14 +682,17 @@ void run(Data& data, void* stream)
"Routing kernel expects #experts per group <= warp size (%d), got %d experts / %d groups = %d experts "
"per group",
WarpSize, data.mNumExperts, data.mNumExpertGroups, data.mNumExperts / data.mNumExpertGroups);
}
else
{
TLLM_CHECK_WITH_INFO(data.mTopK <= topk::MaxNumTopK, "Routing kernel expects top K %d to be <= max topk %d",
data.mTopK, topk::MaxNumTopK);
TLLM_CHECK_WITH_INFO(data.mNumLimitedGroups <= MaxNumTopGroups,
"Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, data.mNumLimitedGroups);
TLLM_CHECK_WITH_INFO(data.mNumExpertGroups >= data.mNumLimitedGroups,
"Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups,
data.mNumExpertGroups);
TLLM_CHECK_WITH_INFO(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.",
data.mNumExperts);
}
int const numThreadsMain = data.mNumExperts < NumDeepseekExperts ? NumDeepseekExperts : NumKimiK2Experts;
int const numThreadsMain = max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts));
LAUNCH_ROUTING_DEEPSEEK(data,
/*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain,
/*smemSize=*/0, // No dynamic smem

View File

@ -189,13 +189,15 @@ struct Data : public DataBase
bool mUseRoutingSoftmax;
};
template <typename InputT_, typename OutputT_, int MaxNumExperts_, bool UseGroups_, bool isPow2_, bool UsePdl_>
template <typename InputT_, typename OutputT_, int MaxNumExperts_, int MaxNumTopExperts_, bool UseGroups_, bool isPow2_,
bool UsePdl_>
struct KernelParams : public KernelParamsBase<InputT_, OutputT_, MaxNumExperts_, isPow2_, UsePdl_>
{
using InputT = InputT_;
using OutputT = OutputT_;
static constexpr bool UseGroups = UseGroups_;
static constexpr int MaxNumTopExperts = MaxNumTopExperts_;
PackedScoreIdx<OutputT>* mPtrTopKPacked = nullptr;

View File

@ -35,7 +35,7 @@ namespace cg = cooperative_groups;
static constexpr int WarpSize = 32;
static constexpr int MaxNumExpertsUnit = 128;
static constexpr int MaxNumTopK = 10;
static constexpr int MaxSupportedTopExperts = 22;
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -25,7 +25,7 @@ static constexpr int NumExpertsLimit = 512;
static constexpr int NumThreads = 1024;
static constexpr int NumWarps = NumThreads / WarpSize;
static constexpr int MaxNumTopExperts = 10;
static constexpr int MaxSupportedTopExperts = 10;
static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads;
static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps;
@ -34,8 +34,8 @@ static constexpr int BlockKernelMaxNumTokens = 4;
template <typename DataType, typename InputType, int VecSize, bool DoSoftmaxBeforeTopK>
__forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile<WarpSize> const& warp,
DataType (&score)[VecSize], int32_t (&idx)[VecSize], DataType (&warpTopKScore)[MaxNumTopExperts],
int32_t (&warpTopKExpertIdx)[MaxNumTopExperts], int32_t const laneIdx, int32_t const numExperts, int32_t topK,
DataType (&score)[VecSize], int32_t (&idx)[VecSize], DataType (&warpTopKScore)[MaxSupportedTopExperts],
int32_t (&warpTopKExpertIdx)[MaxSupportedTopExperts], int32_t const laneIdx, int32_t const numExperts, int32_t topK,
InputType const* ptrScores, bool const normTopkProb, bool const applySoftmaxAfterTopK = true)
{
DataType minScore = DataType{-INFINITY};
@ -149,8 +149,8 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesBlo
BaseType score[VecSize];
int32_t idx[VecSize];
BaseType warpTopKScore[MaxNumTopExperts];
int32_t warpTopKExpertIdx[MaxNumTopExperts];
BaseType warpTopKScore[MaxSupportedTopExperts];
int32_t warpTopKExpertIdx[MaxSupportedTopExperts];
BaseType minScore = BaseType{-INFINITY};
if (validToken)
@ -306,7 +306,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu
static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize;
__shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * MaxNumTopExperts];
__shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * MaxSupportedTopExperts];
uint32_t const clusterBlockRank = blockIdx.x;
@ -332,8 +332,8 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu
BaseType score[VecSize];
int32_t idx[VecSize];
BaseType warpTopKScore[MaxNumTopExperts];
int32_t warpTopKExpertIdx[MaxNumTopExperts];
BaseType warpTopKScore[MaxSupportedTopExperts];
int32_t warpTopKExpertIdx[MaxSupportedTopExperts];
BaseType minScore = BaseType{-INFINITY};
if (validToken)
@ -356,12 +356,12 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu
if (params.mPtrScores != nullptr)
{
routingPermutation<KernelParams, BaseType, NumThreads, NumWarps, MaxNumTopExperts,
routingPermutation<KernelParams, BaseType, NumThreads, NumWarps, MaxSupportedTopExperts,
/*LoadExpertIdxFromGlobal=*/false>(params, smemPackedScoreIdx, warpIdx, clusterBlockRank);
}
else
{
routingPermutation<KernelParams, BaseType, NumThreads, NumWarps, MaxNumTopExperts,
routingPermutation<KernelParams, BaseType, NumThreads, NumWarps, MaxSupportedTopExperts,
/*LoadExpertIdxFromGlobal=*/true>(params, smemPackedScoreIdx, warpIdx, clusterBlockRank);
}
}
@ -417,8 +417,8 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesHis
// over all warps/tokens
BaseType allScores[VecSize];
int32_t allExpertIdx[VecSize];
BaseType warpTopKScore[MaxNumTopExperts];
int32_t warpTopKExpertIdx[MaxNumTopExperts];
BaseType warpTopKScore[MaxSupportedTopExperts];
int32_t warpTopKExpertIdx[MaxSupportedTopExperts];
for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride)
{
auto scoreOffset = tokenIdx * params.mNumExperts;
@ -486,8 +486,8 @@ void run(Data const& data, void* stream)
TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr
&& data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr,
"Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers");
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d",
MaxNumTopExperts, data.mTopK);
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d",
MaxSupportedTopExperts, data.mTopK);
TLLM_CHECK_WITH_INFO(data.mNumExperts <= NumExpertsLimit,
"Routing kernel expects #experts %d to be no more than %d", data.mNumExperts, NumExpertsLimit);
// static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads");

View File

@ -70,7 +70,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
{
if (routingMethodType == RoutingMethodType::DeepSeekV3)
{
TLLM_CHECK_WITH_INFO(topK <= 8, "For DeepSeek routing method, must have topK <= 8");
TLLM_CHECK_WITH_INFO(topK <= 22, "For DeepSeek routing method, must have topK <= 22");
TLLM_CHECK_WITH_INFO(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4");
moe::dev::routing::routingDeepSeek::Data routingData;
routingData.mDtypeExpW = btg::Dtype::Bfloat16;

View File

@ -106,7 +106,7 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape.");
}
if (n_group.has_value() && n_group.value() != 0)
if (n_group.has_value() && n_group.value() > 1)
{
TORCH_CHECK(static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3,
"Routing kernel with groups implies DeepSeekV3 routing method.");

View File

@ -104,7 +104,7 @@ at::Tensor run_fp8_block_scale_moe(at::optional<at::Tensor> const& routing_logit
TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape.");
}
if (n_group.has_value() && n_group.value() != 0)
if (n_group.has_value() && n_group.value() > 1)
{
TORCH_CHECK(static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3,
"Routing kernel with groups implies DeepSeekV3 routing method.");

View File

@ -107,7 +107,7 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::optional<torch::Tensor> con
TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape.");
}
if (n_group.has_value() && n_group.value() != 0)
if (n_group.has_value() && n_group.value() > 1)
{
TORCH_CHECK(static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3,
"Routing kernel with groups implies DeepSeekV3 routing method.");

View File

@ -114,7 +114,7 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional<torch::Tensor>
TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape.");
}
if (n_group.has_value() && n_group.value() != 0)
if (n_group.has_value() && n_group.value() > 1)
{
TORCH_CHECK(static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3,
"Routing kernel with groups implies DeepSeekV3 routing method.");

View File

@ -244,6 +244,17 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization384)
this->runTest(param);
};
TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization512)
{
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/4, // 1024
/*numExperts=*/512, /*topK=*/22,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization)
{
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1024, // 10
@ -310,6 +321,17 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization384)
this->runTest(param);
};
TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization512)
{
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030,
/*numExperts=*/512, /*topK=*/22,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
this->runTest(param);
};
TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization)
{
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/20300,
@ -332,6 +354,17 @@ TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization384)
this->runTest(param);
};
TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization512)
{
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/20300,
/*numExperts=*/512, /*topK=*/22,
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
this->runTest(param);
};
TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationTop2)
{
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/10,

View File

@ -263,7 +263,8 @@ class Deepseekv3RoutingImpl:
)
self.is_fused = False
else:
if num_experts > 384 or self.top_k > 8:
# We have special implementation for n_group == 1, top_k == 22 and num_experts == 512 for Nemotron Super v3.
if num_experts > 512 or (self.top_k > 8 and self.top_k != 22):
if (self.is_fused):
warnings.warn(
"The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation."
@ -292,7 +293,11 @@ class Deepseekv3RoutingImpl:
score_mask = group_mask.unsqueeze(-1).expand(
scores_shape[:-1] +
[n_group, scores_shape[-1] // n_group]).reshape(scores_shape)
scores_with_bias = scores_with_bias * score_mask
scores_with_bias = torch.where(
score_mask.bool(), scores_with_bias,
torch.tensor(float('-inf'),
dtype=scores_with_bias.dtype,
device=scores_with_bias.device))
_, topk_idx = torch.topk(scores_with_bias,
k=self.top_k,
dim=-1,

View File

@ -9,6 +9,7 @@ from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate
(256, 8, 4, 8),
(72, 1, 1, 6),
(384, 1, 1, 8),
(512, 1, 1, 22),
])
@pytest.mark.parametrize("dtype",
[torch.float16, torch.bfloat16, torch.float32])

View File

@ -1008,6 +1008,17 @@ class TestMoeFp4:
"routing_method_type": RoutingMethodType.DeepSeekV3
},
id="RoutingDSv3"),
pytest.param(
{
"num_experts": 512,
"top_k": 22,
"n_groups": 1,
"top_k_groups": 1,
"routed_scaling": 2.5,
"has_routing_bias": True,
"routing_method_type": RoutingMethodType.DeepSeekV3
},
id="RoutingDS_SuperV3"),
pytest.param(
{
"num_experts": 72,
@ -1238,7 +1249,7 @@ class TestMoeFp4:
pytest.skip("https://nvbugs/5434352")
assert top_k <= num_experts
assert top_k <= 10
assert top_k <= 22
assert num_experts % 4 == 0
if use_topk_as_input: