mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Add routing support for the new model for both cutlass and trtllm moe backend (#9792)
Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
This commit is contained in:
parent
4ce35eacf1
commit
dff77efa2a
@ -32,11 +32,14 @@ TRTLLM_NAMESPACE_BEGIN
|
||||
namespace kernels
|
||||
{
|
||||
static constexpr int WARP_SIZE = 32;
|
||||
static constexpr int NumNemotronExperts = 512;
|
||||
static constexpr int NumKimiK2Experts = 384;
|
||||
static constexpr int NumDeepseekExperts = 256;
|
||||
static constexpr int MaxSupportedExpertCount = std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts});
|
||||
static constexpr int MaxNumExpertsUnit = 128;
|
||||
static constexpr int NumTopGroupScores = 2;
|
||||
static constexpr int MaxNumTopExperts = 8;
|
||||
static constexpr int DefaultMaxNumTopExperts = 8;
|
||||
static constexpr int MaxSupportedTopExperts = 22;
|
||||
static constexpr int MaxNumTopGroups = 4;
|
||||
|
||||
static __device__ inline float sigmoid_accurate(float x)
|
||||
@ -44,7 +47,8 @@ static __device__ inline float sigmoid_accurate(float x)
|
||||
return 0.5f * tanhf(0.5f * x) + 0.5f;
|
||||
}
|
||||
|
||||
template <typename InputT, typename BiasT, typename OutputT, typename IdxT, int MaxNumExperts, bool UseGroups>
|
||||
template <typename InputT, typename BiasT, typename OutputT, typename IdxT, int MaxNumExperts, bool UseGroups,
|
||||
int MaxNumTopExperts = DefaultMaxNumTopExperts>
|
||||
__global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, IdxT* topkIndices, BiasT* routingBias,
|
||||
int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup, int64_t const topk,
|
||||
int64_t const numExperts, int64_t const numExpertsPerGroup, double const routedScalingFactor)
|
||||
@ -132,7 +136,7 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
|
||||
/* minValue */ invalidScoreFloat);
|
||||
|
||||
// get the final group score and write it to shared
|
||||
if (laneIdx == 0)
|
||||
if (warp.thread_rank() == 0)
|
||||
{
|
||||
auto groupScore = topExpGroupScores[0] + topExpGroupScores[1];
|
||||
smemGroupScores[warpIdx] = groupScore;
|
||||
@ -151,9 +155,7 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
|
||||
|
||||
reduce_topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx,
|
||||
/* minValue */ invalidScoreFloat);
|
||||
|
||||
// final expert selection: get relevant indexes and scores from shared
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < MaxNumTopGroups; ++ii)
|
||||
{ // bound of numGroup
|
||||
@ -161,12 +163,11 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
|
||||
expertIdxGroup[ii] = groupIdx * numExpertsPerGroup + laneIdx;
|
||||
|
||||
expertScoreGroup[ii]
|
||||
= groupIdx < numGroup && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat;
|
||||
= (ii < topkGroup) && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat;
|
||||
}
|
||||
|
||||
tensorrt_llm::kernels::reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
|
||||
expertIdxGroup,
|
||||
/* minValue */ invalidScoreFloat, topk);
|
||||
tensorrt_llm::kernels::reduce_topk::reduceTopK(
|
||||
warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, /* minValue */ invalidScoreFloat, topk);
|
||||
}
|
||||
}
|
||||
else if constexpr (MaxNumExperts > MaxNumExpertsUnit)
|
||||
@ -197,11 +198,16 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
|
||||
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx];
|
||||
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx];
|
||||
}
|
||||
else if (laneIdx >= topk && laneIdx < MaxNumTopExperts)
|
||||
{
|
||||
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = invalidScoreFloat;
|
||||
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = MaxNumExperts - 1;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (warpIdx == 0)
|
||||
{
|
||||
int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WARP_SIZE + 1;
|
||||
int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1;
|
||||
float intermidiateScore[NumInterTopKPerThread];
|
||||
int32_t intermidiateExpert[NumInterTopKPerThread];
|
||||
for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; i += WARP_SIZE)
|
||||
@ -268,11 +274,11 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk
|
||||
{
|
||||
|
||||
// Check if we can use the optimized deepseek_v3_topk_kernel
|
||||
bool const is_single_group = (n_group == 1) && (num_experts <= NumKimiK2Experts);
|
||||
bool const is_single_group = (n_group <= 1) && (num_experts <= MaxSupportedExpertCount);
|
||||
|
||||
int64_t const experts_per_group = num_experts / n_group;
|
||||
bool const is_multi_group = (n_group != 1) && (num_experts <= NumDeepseekExperts)
|
||||
&& (experts_per_group <= WARP_SIZE) && (experts_per_group * topk_group <= MaxNumExpertsUnit);
|
||||
bool const is_multi_group = (n_group > 1) && (num_experts <= NumDeepseekExperts) && (experts_per_group <= WARP_SIZE)
|
||||
&& (experts_per_group * topk_group <= MaxNumExpertsUnit);
|
||||
|
||||
if (is_single_group || is_multi_group)
|
||||
{
|
||||
@ -281,7 +287,20 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk
|
||||
int num_threads = NumDeepseekExperts;
|
||||
if (is_single_group)
|
||||
{
|
||||
if (num_experts > MaxNumExpertsUnit)
|
||||
// Special case for Nemotron, which selects top 22 from 512 experts, and 1 group only.
|
||||
if (num_experts == NumNemotronExperts && n_group == 1 && topk == MaxSupportedTopExperts)
|
||||
{
|
||||
kernel_instance = &deepseek_v3_topk_kernel<InputT, BiasT, OutputT, IdxT, NumNemotronExperts, false,
|
||||
MaxSupportedTopExperts>;
|
||||
num_threads = NumNemotronExperts;
|
||||
}
|
||||
else if (num_experts > NumKimiK2Experts && num_experts <= MaxSupportedExpertCount)
|
||||
{
|
||||
kernel_instance
|
||||
= &deepseek_v3_topk_kernel<InputT, BiasT, OutputT, IdxT, MaxSupportedExpertCount, false>;
|
||||
num_threads = MaxSupportedExpertCount;
|
||||
}
|
||||
else if (num_experts > MaxNumExpertsUnit && num_experts <= NumKimiK2Experts)
|
||||
{
|
||||
kernel_instance = &deepseek_v3_topk_kernel<InputT, BiasT, OutputT, IdxT, NumKimiK2Experts, false>;
|
||||
num_threads = NumKimiK2Experts;
|
||||
|
||||
@ -182,37 +182,37 @@ namespace moe::dev
|
||||
TLLM_LOG_ERROR("Unsupported dtypeExpW"); \
|
||||
}
|
||||
|
||||
#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT( \
|
||||
data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag, forceFloatInput, numExperts) \
|
||||
#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
|
||||
stream, extraFlag, forceFloatInput, numExperts, numTopExperts) \
|
||||
if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) \
|
||||
{ \
|
||||
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, numThreads, \
|
||||
smemSize, stream); \
|
||||
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, true), kernel, numBlocks, \
|
||||
numThreads, smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeExpW == tg::Dtype::Fp32) \
|
||||
{ \
|
||||
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, numThreads, \
|
||||
smemSize, stream); \
|
||||
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, false), kernel, numBlocks, \
|
||||
numThreads, smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) \
|
||||
{ \
|
||||
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, true), kernel, numBlocks, \
|
||||
numThreads, smemSize, stream); \
|
||||
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, true), kernel, \
|
||||
numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) \
|
||||
{ \
|
||||
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), kernel, numBlocks, \
|
||||
numThreads, smemSize, stream); \
|
||||
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), \
|
||||
kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) \
|
||||
{ \
|
||||
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, false), kernel, numBlocks, \
|
||||
numThreads, smemSize, stream); \
|
||||
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, false), kernel, \
|
||||
numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \
|
||||
{ \
|
||||
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), kernel, numBlocks, \
|
||||
numThreads, smemSize, stream); \
|
||||
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), \
|
||||
kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
|
||||
@ -23,11 +23,13 @@ namespace routingDeepSeek
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static constexpr int NumNemotronExperts = 512;
|
||||
static constexpr int NumKimiK2Experts = 384;
|
||||
static constexpr int NumDeepseekExperts = 256;
|
||||
static constexpr int MaxSupportedExpertCount = std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts});
|
||||
static constexpr int NumTopGroupScores = 2;
|
||||
static constexpr int MaxNumTopExperts = 8;
|
||||
static constexpr int DefaultMaxNumTopExperts = 8;
|
||||
static constexpr int MaxSupportedTopExperts = 22;
|
||||
static constexpr int MaxNumTopGroups = 4;
|
||||
static constexpr int MaxNumGroups = 8;
|
||||
|
||||
@ -125,8 +127,8 @@ __global__ void routingMainKernel(KernelParams params)
|
||||
int32_t topGroupIdx[MaxNumTopGroups];
|
||||
float expertScoreGroup[MaxNumTopGroups];
|
||||
int32_t expertIdxGroup[MaxNumTopGroups];
|
||||
float topScores[MaxNumTopExperts]; // bound of params.mTopK
|
||||
int32_t topExperts[MaxNumTopExperts];
|
||||
float topScores[KernelParams::MaxNumTopExperts]; // bound of params.mTopK
|
||||
int32_t topExperts[KernelParams::MaxNumTopExperts];
|
||||
|
||||
if constexpr (KernelParams::UseGroups)
|
||||
{
|
||||
@ -152,7 +154,6 @@ __global__ void routingMainKernel(KernelParams params)
|
||||
topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx,
|
||||
/* minValue */ invalidScoreFloat);
|
||||
// final expert selection: get relevant indexes and scores from shared
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < MaxNumTopGroups; ++ii)
|
||||
{ // bound of params.mNumLimitedGroups
|
||||
@ -164,7 +165,8 @@ __global__ void routingMainKernel(KernelParams params)
|
||||
// groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - params.mNumExpertsPerGroup
|
||||
// => expertIdxGroup[ii] < params.mNumExperts <= NumThreads,
|
||||
// so the access is safe here
|
||||
expertScoreGroup[ii] = groupIdx < params.mNumExpertGroups && expertSelected
|
||||
expertScoreGroup[ii]
|
||||
= (ii < params.mNumLimitedGroups) && (groupIdx < params.mNumExpertGroups) && expertSelected
|
||||
? smemScoreBias[expertIdxGroup[ii]]
|
||||
: invalidScoreFloat;
|
||||
}
|
||||
@ -177,7 +179,7 @@ __global__ void routingMainKernel(KernelParams params)
|
||||
{
|
||||
// without groups, each thread just takes `MaxNumTopGroups` experts
|
||||
int constexpr NumExpertWarps = (KernelParams::MaxNumExperts - 1) / topk::MaxNumExpertsUnit + 1;
|
||||
int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts;
|
||||
int constexpr NumInterTopK = NumExpertWarps * KernelParams::MaxNumTopExperts;
|
||||
__shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK];
|
||||
__shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK];
|
||||
if (warpIdx < NumExpertWarps)
|
||||
@ -196,14 +198,20 @@ __global__ void routingMainKernel(KernelParams params)
|
||||
|
||||
if (laneIdx < params.mTopK)
|
||||
{
|
||||
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx];
|
||||
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx];
|
||||
smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = topScores[laneIdx];
|
||||
smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = topExperts[laneIdx];
|
||||
}
|
||||
else if (laneIdx >= params.mTopK && laneIdx < KernelParams::MaxNumTopExperts)
|
||||
{
|
||||
smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = invalidScoreFloat;
|
||||
smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx]
|
||||
= MaxSupportedExpertCount - 1;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (warpIdx == 0)
|
||||
{
|
||||
int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WarpSize + 1;
|
||||
int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1;
|
||||
float intermidiateScore[NumInterTopKPerThread];
|
||||
int32_t intermidiateExpert[NumInterTopKPerThread];
|
||||
for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize)
|
||||
@ -295,7 +303,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Ke
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
routingPermutation<KernelParams, OutputT, KernelParams::MaxNumExperts, KernelParams::MaxNumExperts / WarpSize,
|
||||
MaxNumTopExperts, /*LoadExpertIdxFromGlobal=*/true>(params, nullptr, warpIdx, clusterBlockRank);
|
||||
KernelParams::MaxNumTopExperts, /*LoadExpertIdxFromGlobal=*/true>(params, nullptr, warpIdx, clusterBlockRank);
|
||||
}
|
||||
#else
|
||||
__global__ void routingIndicesClusterKernel(KernelParams params)
|
||||
@ -558,6 +566,10 @@ int constexpr getMaxNumExperts(int32_t numExperts)
|
||||
{
|
||||
return NumKimiK2Experts;
|
||||
}
|
||||
else if (numExperts <= NumNemotronExperts)
|
||||
{
|
||||
return NumNemotronExperts;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_ERROR("Unsupported numExperts");
|
||||
@ -571,17 +583,30 @@ int constexpr getMaxNumExperts(int32_t numExperts)
|
||||
if (data.mNumExperts <= topk::MaxNumExpertsUnit) \
|
||||
{ \
|
||||
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
|
||||
stream, extraFlag1, forceFloatInput, topk::MaxNumExpertsUnit); \
|
||||
stream, extraFlag1, forceFloatInput, topk::MaxNumExpertsUnit, DefaultMaxNumTopExperts); \
|
||||
} \
|
||||
else if (data.mNumExperts <= NumDeepseekExperts) \
|
||||
{ \
|
||||
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
|
||||
stream, extraFlag1, forceFloatInput, NumDeepseekExperts); \
|
||||
stream, extraFlag1, forceFloatInput, NumDeepseekExperts, DefaultMaxNumTopExperts); \
|
||||
} \
|
||||
else if (data.mNumExperts <= NumKimiK2Experts) \
|
||||
{ \
|
||||
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
|
||||
stream, extraFlag1, forceFloatInput, NumKimiK2Experts); \
|
||||
stream, extraFlag1, forceFloatInput, NumKimiK2Experts, DefaultMaxNumTopExperts); \
|
||||
} \
|
||||
else if (data.mNumExperts <= NumNemotronExperts) \
|
||||
{ \
|
||||
if (data.mTopK <= DefaultMaxNumTopExperts) \
|
||||
{ \
|
||||
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, \
|
||||
smemSize, stream, extraFlag1, forceFloatInput, NumNemotronExperts, DefaultMaxNumTopExperts); \
|
||||
} \
|
||||
else if (data.mTopK <= MaxSupportedTopExperts) \
|
||||
{ \
|
||||
LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, \
|
||||
smemSize, stream, extraFlag1, forceFloatInput, NumNemotronExperts, MaxSupportedTopExperts); \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
@ -603,25 +628,6 @@ void run(Data& data, void* stream)
|
||||
(data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) && data.mPtrPermutedIdxSize,
|
||||
"If permuted index is required, `mPtrTopKPacked` or `mPtrTopKIds` is also required");
|
||||
TLLM_CHECK_WITH_INFO(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet");
|
||||
TLLM_CHECK_WITH_INFO(data.mNumLimitedGroups <= MaxNumTopGroups, "Routing kernel expects <= %d top groups, got %d",
|
||||
MaxNumTopGroups, data.mNumLimitedGroups);
|
||||
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d",
|
||||
MaxNumTopExperts, data.mTopK);
|
||||
TLLM_CHECK_WITH_INFO(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", data.mTopK);
|
||||
TLLM_CHECK_WITH_INFO(data.mTopK * data.mNumLimitedGroups <= WarpSize,
|
||||
"Routing kernel expects top K * top groups <= warp size (for now), got %d * %d", data.mTopK,
|
||||
data.mNumLimitedGroups);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExperts >= MaxNumTopExperts, "Routing kernel expects %d to be at most #experts %d",
|
||||
MaxNumTopExperts, data.mNumExperts);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExperts <= NumKimiK2Experts, "Routing kernel expects #experts %d <= #threads %d",
|
||||
data.mNumExperts, NumKimiK2Experts);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExpertGroups >= data.mNumLimitedGroups,
|
||||
"Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups,
|
||||
data.mNumExpertGroups);
|
||||
// Note: Routing-specific constraints (experts per group, topK limits) are checked later
|
||||
// only when routing is actually needed (data.mPtrTopKIds == nullptr)
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
|
||||
int const numBlocks = data.mNumTokens;
|
||||
int const numThreadsHist = getMaxNumExperts(data.mNumExperts);
|
||||
|
||||
@ -655,9 +661,18 @@ void run(Data& data, void* stream)
|
||||
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
|
||||
if (data.mPtrTopKIds == nullptr)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExperts >= MaxSupportedTopExperts,
|
||||
"Routing kernel expects %d to be at most #experts %d", MaxSupportedTopExperts, data.mNumExperts);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxSupportedExpertCount,
|
||||
"Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, MaxSupportedExpertCount);
|
||||
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d",
|
||||
MaxSupportedTopExperts, data.mTopK);
|
||||
|
||||
// Routing needs to be executed - validate routing kernel constraints
|
||||
if (data.mNumExpertGroups > 1)
|
||||
{
|
||||
// Note: Routing-specific constraints (experts per group, topK limits) are checked when routing is actually
|
||||
// needed (data.mPtrTopKIds == nullptr)
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExpertGroups <= MaxNumGroups,
|
||||
"Routing kernel expects #expert groups %d to be <= max groups %d", data.mNumExpertGroups, MaxNumGroups);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExperts % data.mNumExpertGroups == 0,
|
||||
@ -667,14 +682,17 @@ void run(Data& data, void* stream)
|
||||
"Routing kernel expects #experts per group <= warp size (%d), got %d experts / %d groups = %d experts "
|
||||
"per group",
|
||||
WarpSize, data.mNumExperts, data.mNumExpertGroups, data.mNumExperts / data.mNumExpertGroups);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(data.mTopK <= topk::MaxNumTopK, "Routing kernel expects top K %d to be <= max topk %d",
|
||||
data.mTopK, topk::MaxNumTopK);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumLimitedGroups <= MaxNumTopGroups,
|
||||
"Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, data.mNumLimitedGroups);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExpertGroups >= data.mNumLimitedGroups,
|
||||
"Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups,
|
||||
data.mNumExpertGroups);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.",
|
||||
data.mNumExperts);
|
||||
}
|
||||
|
||||
int const numThreadsMain = data.mNumExperts < NumDeepseekExperts ? NumDeepseekExperts : NumKimiK2Experts;
|
||||
int const numThreadsMain = max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts));
|
||||
LAUNCH_ROUTING_DEEPSEEK(data,
|
||||
/*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
|
||||
@ -189,13 +189,15 @@ struct Data : public DataBase
|
||||
bool mUseRoutingSoftmax;
|
||||
};
|
||||
|
||||
template <typename InputT_, typename OutputT_, int MaxNumExperts_, bool UseGroups_, bool isPow2_, bool UsePdl_>
|
||||
template <typename InputT_, typename OutputT_, int MaxNumExperts_, int MaxNumTopExperts_, bool UseGroups_, bool isPow2_,
|
||||
bool UsePdl_>
|
||||
struct KernelParams : public KernelParamsBase<InputT_, OutputT_, MaxNumExperts_, isPow2_, UsePdl_>
|
||||
{
|
||||
using InputT = InputT_;
|
||||
using OutputT = OutputT_;
|
||||
|
||||
static constexpr bool UseGroups = UseGroups_;
|
||||
static constexpr int MaxNumTopExperts = MaxNumTopExperts_;
|
||||
|
||||
PackedScoreIdx<OutputT>* mPtrTopKPacked = nullptr;
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ namespace cg = cooperative_groups;
|
||||
|
||||
static constexpr int WarpSize = 32;
|
||||
static constexpr int MaxNumExpertsUnit = 128;
|
||||
static constexpr int MaxNumTopK = 10;
|
||||
static constexpr int MaxSupportedTopExperts = 22;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ static constexpr int NumExpertsLimit = 512;
|
||||
|
||||
static constexpr int NumThreads = 1024;
|
||||
static constexpr int NumWarps = NumThreads / WarpSize;
|
||||
static constexpr int MaxNumTopExperts = 10;
|
||||
static constexpr int MaxSupportedTopExperts = 10;
|
||||
|
||||
static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads;
|
||||
static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps;
|
||||
@ -34,8 +34,8 @@ static constexpr int BlockKernelMaxNumTokens = 4;
|
||||
|
||||
template <typename DataType, typename InputType, int VecSize, bool DoSoftmaxBeforeTopK>
|
||||
__forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile<WarpSize> const& warp,
|
||||
DataType (&score)[VecSize], int32_t (&idx)[VecSize], DataType (&warpTopKScore)[MaxNumTopExperts],
|
||||
int32_t (&warpTopKExpertIdx)[MaxNumTopExperts], int32_t const laneIdx, int32_t const numExperts, int32_t topK,
|
||||
DataType (&score)[VecSize], int32_t (&idx)[VecSize], DataType (&warpTopKScore)[MaxSupportedTopExperts],
|
||||
int32_t (&warpTopKExpertIdx)[MaxSupportedTopExperts], int32_t const laneIdx, int32_t const numExperts, int32_t topK,
|
||||
InputType const* ptrScores, bool const normTopkProb, bool const applySoftmaxAfterTopK = true)
|
||||
{
|
||||
DataType minScore = DataType{-INFINITY};
|
||||
@ -149,8 +149,8 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesBlo
|
||||
BaseType score[VecSize];
|
||||
int32_t idx[VecSize];
|
||||
|
||||
BaseType warpTopKScore[MaxNumTopExperts];
|
||||
int32_t warpTopKExpertIdx[MaxNumTopExperts];
|
||||
BaseType warpTopKScore[MaxSupportedTopExperts];
|
||||
int32_t warpTopKExpertIdx[MaxSupportedTopExperts];
|
||||
|
||||
BaseType minScore = BaseType{-INFINITY};
|
||||
if (validToken)
|
||||
@ -306,7 +306,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu
|
||||
|
||||
static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize;
|
||||
|
||||
__shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * MaxNumTopExperts];
|
||||
__shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * MaxSupportedTopExperts];
|
||||
|
||||
uint32_t const clusterBlockRank = blockIdx.x;
|
||||
|
||||
@ -332,8 +332,8 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu
|
||||
BaseType score[VecSize];
|
||||
int32_t idx[VecSize];
|
||||
|
||||
BaseType warpTopKScore[MaxNumTopExperts];
|
||||
int32_t warpTopKExpertIdx[MaxNumTopExperts];
|
||||
BaseType warpTopKScore[MaxSupportedTopExperts];
|
||||
int32_t warpTopKExpertIdx[MaxSupportedTopExperts];
|
||||
|
||||
BaseType minScore = BaseType{-INFINITY};
|
||||
if (validToken)
|
||||
@ -356,12 +356,12 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu
|
||||
|
||||
if (params.mPtrScores != nullptr)
|
||||
{
|
||||
routingPermutation<KernelParams, BaseType, NumThreads, NumWarps, MaxNumTopExperts,
|
||||
routingPermutation<KernelParams, BaseType, NumThreads, NumWarps, MaxSupportedTopExperts,
|
||||
/*LoadExpertIdxFromGlobal=*/false>(params, smemPackedScoreIdx, warpIdx, clusterBlockRank);
|
||||
}
|
||||
else
|
||||
{
|
||||
routingPermutation<KernelParams, BaseType, NumThreads, NumWarps, MaxNumTopExperts,
|
||||
routingPermutation<KernelParams, BaseType, NumThreads, NumWarps, MaxSupportedTopExperts,
|
||||
/*LoadExpertIdxFromGlobal=*/true>(params, smemPackedScoreIdx, warpIdx, clusterBlockRank);
|
||||
}
|
||||
}
|
||||
@ -417,8 +417,8 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesHis
|
||||
// over all warps/tokens
|
||||
BaseType allScores[VecSize];
|
||||
int32_t allExpertIdx[VecSize];
|
||||
BaseType warpTopKScore[MaxNumTopExperts];
|
||||
int32_t warpTopKExpertIdx[MaxNumTopExperts];
|
||||
BaseType warpTopKScore[MaxSupportedTopExperts];
|
||||
int32_t warpTopKExpertIdx[MaxSupportedTopExperts];
|
||||
for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride)
|
||||
{
|
||||
auto scoreOffset = tokenIdx * params.mNumExperts;
|
||||
@ -486,8 +486,8 @@ void run(Data const& data, void* stream)
|
||||
TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr
|
||||
&& data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr,
|
||||
"Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers");
|
||||
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d",
|
||||
MaxNumTopExperts, data.mTopK);
|
||||
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d",
|
||||
MaxSupportedTopExperts, data.mTopK);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExperts <= NumExpertsLimit,
|
||||
"Routing kernel expects #experts %d to be no more than %d", data.mNumExperts, NumExpertsLimit);
|
||||
// static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads");
|
||||
|
||||
@ -70,7 +70,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
|
||||
{
|
||||
if (routingMethodType == RoutingMethodType::DeepSeekV3)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(topK <= 8, "For DeepSeek routing method, must have topK <= 8");
|
||||
TLLM_CHECK_WITH_INFO(topK <= 22, "For DeepSeek routing method, must have topK <= 22");
|
||||
TLLM_CHECK_WITH_INFO(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4");
|
||||
moe::dev::routing::routingDeepSeek::Data routingData;
|
||||
routingData.mDtypeExpW = btg::Dtype::Bfloat16;
|
||||
|
||||
@ -106,7 +106,7 @@ std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch:
|
||||
TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape.");
|
||||
}
|
||||
|
||||
if (n_group.has_value() && n_group.value() != 0)
|
||||
if (n_group.has_value() && n_group.value() > 1)
|
||||
{
|
||||
TORCH_CHECK(static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3,
|
||||
"Routing kernel with groups implies DeepSeekV3 routing method.");
|
||||
|
||||
@ -104,7 +104,7 @@ at::Tensor run_fp8_block_scale_moe(at::optional<at::Tensor> const& routing_logit
|
||||
TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape.");
|
||||
}
|
||||
|
||||
if (n_group.has_value() && n_group.value() != 0)
|
||||
if (n_group.has_value() && n_group.value() > 1)
|
||||
{
|
||||
TORCH_CHECK(static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3,
|
||||
"Routing kernel with groups implies DeepSeekV3 routing method.");
|
||||
|
||||
@ -107,7 +107,7 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::optional<torch::Tensor> con
|
||||
TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape.");
|
||||
}
|
||||
|
||||
if (n_group.has_value() && n_group.value() != 0)
|
||||
if (n_group.has_value() && n_group.value() > 1)
|
||||
{
|
||||
TORCH_CHECK(static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3,
|
||||
"Routing kernel with groups implies DeepSeekV3 routing method.");
|
||||
|
||||
@ -114,7 +114,7 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional<torch::Tensor>
|
||||
TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape.");
|
||||
}
|
||||
|
||||
if (n_group.has_value() && n_group.value() != 0)
|
||||
if (n_group.has_value() && n_group.value() > 1)
|
||||
{
|
||||
TORCH_CHECK(static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3,
|
||||
"Routing kernel with groups implies DeepSeekV3 routing method.");
|
||||
|
||||
@ -244,6 +244,17 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization384)
|
||||
this->runTest(param);
|
||||
};
|
||||
|
||||
TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization512)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/4, // 1024
|
||||
/*numExperts=*/512, /*topK=*/22,
|
||||
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
|
||||
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
|
||||
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
|
||||
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
|
||||
this->runTest(param);
|
||||
};
|
||||
|
||||
TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1024, // 10
|
||||
@ -310,6 +321,17 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization384)
|
||||
this->runTest(param);
|
||||
};
|
||||
|
||||
TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization512)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030,
|
||||
/*numExperts=*/512, /*topK=*/22,
|
||||
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
|
||||
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
|
||||
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
|
||||
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
|
||||
this->runTest(param);
|
||||
};
|
||||
|
||||
TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/20300,
|
||||
@ -332,6 +354,17 @@ TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization384)
|
||||
this->runTest(param);
|
||||
};
|
||||
|
||||
TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization512)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/20300,
|
||||
/*numExperts=*/512, /*topK=*/22,
|
||||
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
|
||||
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
|
||||
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
|
||||
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
|
||||
this->runTest(param);
|
||||
};
|
||||
|
||||
TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationTop2)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/10,
|
||||
|
||||
@ -263,7 +263,8 @@ class Deepseekv3RoutingImpl:
|
||||
)
|
||||
self.is_fused = False
|
||||
else:
|
||||
if num_experts > 384 or self.top_k > 8:
|
||||
# We have special implementation for n_group == 1, top_k == 22 and num_experts == 512 for Nemotron Super v3.
|
||||
if num_experts > 512 or (self.top_k > 8 and self.top_k != 22):
|
||||
if (self.is_fused):
|
||||
warnings.warn(
|
||||
"The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation."
|
||||
@ -292,7 +293,11 @@ class Deepseekv3RoutingImpl:
|
||||
score_mask = group_mask.unsqueeze(-1).expand(
|
||||
scores_shape[:-1] +
|
||||
[n_group, scores_shape[-1] // n_group]).reshape(scores_shape)
|
||||
scores_with_bias = scores_with_bias * score_mask
|
||||
scores_with_bias = torch.where(
|
||||
score_mask.bool(), scores_with_bias,
|
||||
torch.tensor(float('-inf'),
|
||||
dtype=scores_with_bias.dtype,
|
||||
device=scores_with_bias.device))
|
||||
_, topk_idx = torch.topk(scores_with_bias,
|
||||
k=self.top_k,
|
||||
dim=-1,
|
||||
|
||||
@ -9,6 +9,7 @@ from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate
|
||||
(256, 8, 4, 8),
|
||||
(72, 1, 1, 6),
|
||||
(384, 1, 1, 8),
|
||||
(512, 1, 1, 22),
|
||||
])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float16, torch.bfloat16, torch.float32])
|
||||
|
||||
@ -1008,6 +1008,17 @@ class TestMoeFp4:
|
||||
"routing_method_type": RoutingMethodType.DeepSeekV3
|
||||
},
|
||||
id="RoutingDSv3"),
|
||||
pytest.param(
|
||||
{
|
||||
"num_experts": 512,
|
||||
"top_k": 22,
|
||||
"n_groups": 1,
|
||||
"top_k_groups": 1,
|
||||
"routed_scaling": 2.5,
|
||||
"has_routing_bias": True,
|
||||
"routing_method_type": RoutingMethodType.DeepSeekV3
|
||||
},
|
||||
id="RoutingDS_SuperV3"),
|
||||
pytest.param(
|
||||
{
|
||||
"num_experts": 72,
|
||||
@ -1238,7 +1249,7 @@ class TestMoeFp4:
|
||||
pytest.skip("https://nvbugs/5434352")
|
||||
|
||||
assert top_k <= num_experts
|
||||
assert top_k <= 10
|
||||
assert top_k <= 22
|
||||
assert num_experts % 4 == 0
|
||||
|
||||
if use_topk_as_input:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user