Refactor the rest routing part for the routing kernels in the MoE TRT-LLM backend (#5771)

Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
This commit is contained in:
ChristinaZ 2025-07-11 16:37:56 +08:00 committed by GitHub
parent 37293e4dfd
commit c5fb692a7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 2443 additions and 3601 deletions

View File

@ -31,7 +31,7 @@ namespace moe::dev
////////////////////////////////////////////////////////////////////////////////////////////////////
#define LAUCNCH_ESC(...) __VA_ARGS__
#define LAUNCH_ESC(...) __VA_ARGS__
#define LAUNCH_PDL(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \
cudaLaunchConfig_t config{}; \
@ -64,53 +64,6 @@ namespace moe::dev
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernelTyped, params)); \
}
#define LAUNCH_PDL_QWEN3(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \
cudaLaunchConfig_t config{}; \
config.gridDim = numBlocks; \
config.blockDim = numThreads; \
config.dynamicSmemBytes = smemSize; \
config.stream = (cudaStream_t) stream; \
\
cudaLaunchAttribute attributes[2] = {}; \
attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \
attributes[0].val.programmaticStreamSerializationAllowed = int(data.mUsePdl); \
attributes[1].id = cudaLaunchAttributeCooperative; \
attributes[1].val.cooperative = int(coopLaunch); \
config.attrs = attributes; \
config.numAttrs = 2; \
if (data.mUsePdl && data.mDoSoftmaxBeforeTopK) \
{ \
auto params = KernelParams<types, /*mUsePdl=*/true>::setKernelParams(data); \
auto kernelTyped = kernel<KernelParams<types, /*mUsePdl=*/true>, /*mDoSoftmaxBeforeTopK=*/true>; \
if (smemSize > 48 * 1024) \
TLLM_CUDA_CHECK(cudaFuncSetAttribute(kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernelTyped, params)); \
} \
else if (data.mUsePdl && !data.mDoSoftmaxBeforeTopK) \
{ \
auto params = KernelParams<types, /*mUsePdl=*/true>::setKernelParams(data); \
auto kernelTyped = kernel<KernelParams<types, /*mUsePdl=*/true>, /*mDoSoftmaxBeforeTopK=*/false>; \
if (smemSize > 48 * 1024) \
TLLM_CUDA_CHECK(cudaFuncSetAttribute(kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernelTyped, params)); \
} \
else if (!data.mUsePdl && data.mDoSoftmaxBeforeTopK) \
{ \
auto params = KernelParams<types, /*mUsePdl=*/false>::setKernelParams(data); \
auto kernelTyped = kernel<KernelParams<types, /*mUsePdl=*/false>, /*mDoSoftmaxBeforeTopK=*/true>; \
if (smemSize > 48 * 1024) \
TLLM_CUDA_CHECK(cudaFuncSetAttribute(kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernelTyped, params)); \
} \
else \
{ \
auto params = KernelParams<types, /*mUsePdl=*/false>::setKernelParams(data); \
auto kernelTyped = kernel<KernelParams<types, /*mUsePdl=*/false>, /*mDoSoftmaxBeforeTopK=*/false>; \
if (smemSize > 48 * 1024) \
TLLM_CUDA_CHECK(cudaFuncSetAttribute(kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernelTyped, params)); \
}
#define LAUNCH(data, kernel, numBlocks, numThreads, smemSize, stream) \
if (data.mDtypeElt == tg::Dtype::Fp16) \
{ \
@ -132,31 +85,31 @@ namespace moe::dev
#define LAUNCH_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \
if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Fp32) \
{ \
LAUNCH_PDL(data, false, LAUCNCH_ESC(cutlass::half_t, float), kernel, numBlocks, numThreads, smemSize, stream); \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float), kernel, numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Fp32) \
{ \
LAUNCH_PDL( \
data, false, LAUCNCH_ESC(cutlass::float_e4m3_t, float), kernel, numBlocks, numThreads, smemSize, stream); \
data, false, LAUNCH_ESC(cutlass::float_e4m3_t, float), kernel, numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32) \
{ \
LAUNCH_PDL( \
data, false, LAUCNCH_ESC(cutlass::bfloat16_t, float), kernel, numBlocks, numThreads, smemSize, stream); \
data, false, LAUNCH_ESC(cutlass::bfloat16_t, float), kernel, numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Bfloat16) \
{ \
LAUNCH_PDL(data, false, LAUCNCH_ESC(cutlass::half_t, cutlass::bfloat16_t), kernel, numBlocks, numThreads, \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, cutlass::bfloat16_t), kernel, numBlocks, numThreads, \
smemSize, stream); \
} \
else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Bfloat16) \
{ \
LAUNCH_PDL(data, false, LAUCNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t), kernel, numBlocks, \
numThreads, smemSize, stream); \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t), kernel, numBlocks, numThreads, \
smemSize, stream); \
} \
else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16) \
{ \
LAUNCH_PDL(data, false, LAUCNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t), kernel, numBlocks, numThreads, \
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t), kernel, numBlocks, numThreads, \
smemSize, stream); \
} \
else \
@ -164,68 +117,51 @@ namespace moe::dev
TLLM_LOG_ERROR("Unsupported pair"); \
}
#define LAUNCH_EXPW_QWEN3(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \
#define LAUNCH_ROUTING(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \
if (data.mDtypeExpW == tg::Dtype::Fp32) \
{ \
LAUNCH_PDL_QWEN3(data, coopLaunch, LAUCNCH_ESC(void, float), kernel, numBlocks, numThreads, smemSize, stream); \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float), kernel, numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \
{ \
LAUNCH_PDL_QWEN3( \
data, coopLaunch, LAUCNCH_ESC(void, __nv_bfloat16), kernel, numBlocks, numThreads, smemSize, stream); \
} \
else \
{ \
TLLM_LOG_ERROR("Unsupported dtypeExpW: "); \
}
#define LAUNCH_EXPW_ONLY_QWEN3(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \
if (data.mDtypeExpW == tg::Dtype::Fp32) \
{ \
LAUNCH_PDL(data, coopLaunch, LAUCNCH_ESC(void, float), kernel, numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \
{ \
LAUNCH_PDL( \
data, coopLaunch, LAUCNCH_ESC(void, __nv_bfloat16), kernel, numBlocks, numThreads, smemSize, stream); \
} \
else \
{ \
TLLM_LOG_ERROR("Unsupported dtypeExpW: "); \
}
#define LAUNCH_EXPW_ONLY(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \
if (data.mDtypeExpW == tg::Dtype::Fp32) \
{ \
LAUNCH_PDL(data, coopLaunch, float, kernel, numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \
{ \
LAUNCH_PDL(data, coopLaunch, __nv_bfloat16, kernel, numBlocks, numThreads, smemSize, stream); \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16), kernel, numBlocks, numThreads, \
smemSize, stream); \
} \
else \
{ \
TLLM_LOG_ERROR("Unsupported dtypeExpW"); \
}
#define LAUNCH_EXPW_ONLY_GROUPS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \
if (data.mDtypeExpW == tg::Dtype::Fp32 && data.mNumExpertGroups > 1) \
#define LAUNCH_ROUTING_WITH_EXTRA_FLAG( \
data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag, forceFloatInput) \
if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) \
{ \
LAUNCH_PDL(data, coopLaunch, LAUCNCH_ESC(float, true), kernel, numBlocks, numThreads, smemSize, stream); \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, true), kernel, numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Fp32) \
{ \
LAUNCH_PDL(data, coopLaunch, LAUCNCH_ESC(float, false), kernel, numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && data.mNumExpertGroups > 1) \
{ \
LAUNCH_PDL( \
data, coopLaunch, LAUCNCH_ESC(__nv_bfloat16, true), kernel, numBlocks, numThreads, smemSize, stream); \
data, coopLaunch, LAUNCH_ESC(float, float, false), kernel, numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) \
{ \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, true), kernel, numBlocks, numThreads, smemSize, \
stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) \
{ \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, true), kernel, numBlocks, numThreads, \
smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) \
{ \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, false), kernel, numBlocks, numThreads, smemSize, \
stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \
{ \
LAUNCH_PDL( \
data, coopLaunch, LAUCNCH_ESC(__nv_bfloat16, false), kernel, numBlocks, numThreads, smemSize, stream); \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, false), kernel, numBlocks, numThreads, \
smemSize, stream); \
} \
else \
{ \

View File

@ -0,0 +1,573 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "RoutingKernel.cuh"
namespace moe::dev::routing
{
namespace routingDeepSeek
{
////////////////////////////////////////////////////////////////////////////////////////////////////
static constexpr int NumThreads = 256;
static constexpr int NumWarps = NumThreads / WarpSize;
static constexpr int NumTopGroupScores = 2;
static constexpr int MaxNumTopExperts = 8;
static constexpr int MaxNumTopGroups = 4;
template <typename KernelParams>
__global__ void routingMainKernel(KernelParams params)
{
// declare types
using OutputT = typename KernelParams::OutputT;
using InputT = typename KernelParams::InputT;
// declare shared memory structure
// number of experts is bounded by number of threads
__shared__ float __attribute((aligned(128))) smemScoreSigmoid[NumThreads];
__shared__ float __attribute((aligned(128))) smemScoreBias[NumThreads];
// number of expert groups is bounded by number of warps
__shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps];
// needed for warp reduce
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WarpSize>(block);
// for the final reduction of weight norm, only some lanes need to participate
int32_t laneIdx = threadIdx.x % WarpSize;
int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
// warps outside the range of expert groups do not participate
if constexpr (KernelParams::UseGroups)
{
if (warpIdx >= params.mNumExpertGroups)
{
return;
}
}
// note that for invalid scores, we simply use a negative value:
// they work well even with the compacted format used in topK, and
// sigmoid / bias activated scores cannot be negative
static constexpr float invalidScoreFloat = -1.F;
const OutputT invalidScore = OutputT{invalidScoreFloat};
// load bias already; each warp represents one expert group
auto threadExpert = threadIdx.x;
bool expertSelected = threadExpert < params.mNumExperts;
if constexpr (KernelParams::UseGroups)
{
threadExpert = warpIdx * params.mNumExpertsPerGroup + laneIdx;
expertSelected = laneIdx < params.mNumExpertsPerGroup;
}
auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert;
auto biasVal = expertSelected ? params.mPtrRoutingBias[threadExpert] : invalidScore;
// initialize the mPtrExpertCounts
if (params.mPtrExpertCounts)
{
int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x;
int32_t globalThreadStride = gridDim.x * NumThreads;
int32_t expertCountsNum = 2 * params.mNumExperts;
initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// trigger the secondary kernel when using PDL, then wait on primary
if constexpr (KernelParams::UsePdl)
{
cudaTriggerProgrammaticLaunchCompletion();
cudaGridDependencySynchronize();
}
#endif
// get our assigned thread score; each warp represents one expert group
float score = expertSelected ? static_cast<float>(params.mPtrScores[scoreIdx]) : invalidScoreFloat;
// get the sigmoid score
// note that for invalid values, we simply use a negative value:
// sigmoig scores are always strictly positive
auto scoreSigmoid = sigmoid_accurate(score);
// write the sigmoid score to shared for later use
if (expertSelected)
{
smemScoreSigmoid[threadExpert] = scoreSigmoid;
}
// get the score with bias
// note that with invalid values, because sigmoid is < 1 and bias is -1,
// we must get a negative value, which is smaller than any valid value
auto scoreBias = float{scoreSigmoid + float{biasVal}};
if (expertSelected)
{
smemScoreBias[threadExpert] = scoreBias;
}
// registers for top group score reduction
float topExpGroupScores[NumTopGroupScores];
[[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores];
float topGroups[MaxNumTopGroups]; // bound of params.mNumLimitedGroups
int32_t topGroupIdx[MaxNumTopGroups];
float expertScoreGroup[MaxNumTopGroups];
int32_t expertIdxGroup[MaxNumTopGroups];
float topScores[MaxNumTopExperts]; // bound of params.mTopK
int32_t topExperts[MaxNumTopExperts];
if constexpr (KernelParams::UseGroups)
{
topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert,
/* minValue */ invalidScoreFloat);
// get the final group score and write it to shared
if (cute::elect_one_sync())
{
auto groupScore = topExpGroupScores[0] + topExpGroupScores[1];
smemGroupScores[warpIdx] = groupScore;
}
}
// make group scores available to all warps
__syncthreads();
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
if (warpIdx == 0)
{
// a single warp performs the selection of top groups, and goes on to select the final experts
if constexpr (KernelParams::UseGroups)
{
float groupScore = laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : invalidScoreFloat;
topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx,
/* minValue */ invalidScoreFloat);
// final expert selection: get relevant indexes and scores from shared
#pragma unroll
for (int ii = 0; ii < MaxNumTopGroups; ++ii)
{ // bound of params.mNumLimitedGroups
auto groupIdx = topGroupIdx[ii];
expertIdxGroup[ii] = groupIdx * params.mNumExpertsPerGroup + laneIdx;
// note: expertSelected implies laneIdx < params.mNumExpertsPerGroup.
// we have params.mNumExpertsPerGroup == params.mNumExperts / params.mNumExpertGroups,
// thus groupIdx <= params.mNumExpertGroups - 1 =>
// groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - params.mNumExpertsPerGroup
// => expertIdxGroup[ii] < params.mNumExperts <= NumThreads,
// so the access is safe here
expertScoreGroup[ii] = groupIdx < params.mNumExpertGroups && expertSelected
? smemScoreBias[expertIdxGroup[ii]]
: invalidScoreFloat;
}
}
else
{
// without groups, each thread just takes `MaxNumTopGroups` experts
#pragma unroll
for (int ii = 0; ii < MaxNumTopGroups; ++ii)
{
auto expertIdx = ii * WarpSize + laneIdx;
expertIdxGroup[ii] = expertIdx;
expertScoreGroup[ii] = expertIdx < params.mNumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat;
}
}
topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup,
/* minValue */ invalidScoreFloat, params.mTopK);
// determine our lane's expert index and write to output
int32_t expertIdx = 0;
#pragma unroll
for (int ii = 0; ii < params.mTopK; ++ii)
{ // bound of params.mTopK
expertIdx = laneIdx == ii ? topExperts[ii] : expertIdx;
}
// determine whether our expert is local to this GPU
auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx;
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
float scoreNorm = laneIdx < params.mTopK ? smemScoreSigmoid[expertIdx] : 0.F;
auto redNorm = cg::reduce(warp, scoreNorm, cg::plus<float>{});
auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm};
// write expert idx out already
auto idxTopK = blockIdx.x * params.mTopK + laneIdx;
if (laneIdx < params.mTopK && params.mPtrExpertIdx != nullptr)
{
PackedScoreIdx<OutputT> packedScore{static_cast<OutputT>(finalScore), static_cast<int16_t>(expertIdx)};
params.mPtrExpertIdx[idxTopK] = packedScore;
}
if (laneIdx < params.mTopK && params.mPtrExpertWeights != nullptr)
{
params.mPtrExpertWeights[idxTopK] = finalScore;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename KernelParams>
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads)
routingIndicesClusterKernel(KernelParams params)
{
using OutputT = typename KernelParams::OutputT;
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
int32_t const clusterBlockRank = blockIdx.x;
//@todo: try to move it into routingPermutation
// then wait on primary grid
if constexpr (KernelParams::UsePdl)
{
cudaGridDependencySynchronize();
}
routingPermutation<KernelParams, OutputT, NumThreads, NumWarps, MaxNumTopExperts, /*LoadExpertIdxFromGlobal=*/true>(
params, nullptr, warpIdx, clusterBlockRank);
}
#else
__global__ void routingIndicesClusterKernel(KernelParams params)
{
assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures");
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename KernelParams>
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
__global__ void __launch_bounds__(NumThreads) routingIndicesCoopKernel(KernelParams params)
{
// number of experts is bounded by number of threads
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads];
__shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads];
// needed for the exclusive sum of token offsets
using Scan = cub::BlockScan<int32_t, NumThreads, cub::BLOCK_SCAN_WARP_SCANS>;
__shared__ typename Scan::TempStorage tempStorage;
// 64 elements -> 128+ registers. Above that we may start to see spilling to local memory.
static constexpr int MaxExpandedIdxPerThread = 64;
// Initialize grid.
cg::grid_group grid = cg::this_grid();
// Note: the following is more efficient than grid.block_index() because we don't use y and z.
int32_t const gridBlockIdx = blockIdx.x;
int32_t const gridThreadIdx = NumThreads * gridBlockIdx + threadIdx.x;
int32_t const numBlocks = gridDim.x;
int32_t const numThreadsPerGrid = numBlocks * NumThreads;
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
auto expandedIdxSize = params.mNumTokens * params.mTopK;
// pre-fill the counts with 0
smemExpertCount[threadIdx.x] = 0;
__syncthreads();
// then wait on primary grid
if constexpr (KernelParams::UsePdl)
{
cudaGridDependencySynchronize();
}
// each thread keeps has some number of "expanded indexes" assigned to it
// for each of these, we keep the associated expert and offset within expert in registers
int32_t expertIndexes[MaxExpandedIdxPerThread];
int32_t expertOffsets[MaxExpandedIdxPerThread];
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
// In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
// time, and branch between a fast path without bound checks and a slow path with bound checks.
int constexpr IterStride = 4;
static_assert(MaxExpandedIdxPerThread % IterStride == 0);
// Define a lambda to avoid code duplication in both branches.
auto loopBody = [&](int ii, int expandedIdx)
{
int32_t expertIdx = params.mPtrExpertIdx[expandedIdx].idx;
expertIndexes[ii] = expertIdx;
// check whether this expert is local to our GPU at all and ignore if not
auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx;
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0;
};
#pragma unroll
for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride)
{
// Whether it's safe to do multiple iterations without bound checks.
bool const takeFastPath = (ii0 + IterStride) * numThreadsPerGrid <= expandedIdxSize;
if (takeFastPath)
{
#pragma unroll
for (int32_t jj = 0; jj < IterStride; jj++)
{
int const ii = ii0 + jj;
auto expandedIdx = static_cast<int32_t>(gridThreadIdx) + ii * numThreadsPerGrid;
loopBody(ii, expandedIdx);
}
}
else
{
bool doBreak = false;
#pragma unroll
for (int32_t jj = 0; jj < IterStride; jj++)
{
int const ii = ii0 + jj;
auto expandedIdx = static_cast<int32_t>(gridThreadIdx) + ii * numThreadsPerGrid;
if (expandedIdx >= expandedIdxSize)
{
doBreak = true;
break;
}
loopBody(ii, expandedIdx);
}
if (doBreak)
{
break;
}
}
}
// Make histogram (token counts per expert) available to all threads in the block.
__syncthreads();
//
// Each thread now represents one expert
//
// Add the local bin count to the common bin count and get a per-CTA offset.
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
int32_t blockExpertOffset = 0;
if (threadIdx.x < params.mNumExperts)
{
blockExpertOffset = atomicAdd(&params.mPtrExpertCounts[threadIdx.x], localExpertCount);
}
// Sync to wait for completion of the histogram reduction.
grid.sync();
// Get total count for this expert.
int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0;
// Note: the scan is redundant in all CTAs, but doing it in only 1 CTA would be worse for latency.
// Compute the runtime config for projections
// Whether or not an expert is local is taken into account when smemExpertCount is computed
// so we do not need to take it into account here.
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
int32_t ctaOffset;
int32_t numNonExitingCtas;
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
for (int32_t cta = gridBlockIdx; cta < numCta; cta += numBlocks)
{
const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count);
}
// get the padded offset associated with this expert
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
// write out padded count
if (gridBlockIdx == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync())
{
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
}
// write expert offsets to shared
smemExpertOffset[threadIdx.x] = offset + blockExpertOffset;
// make expert offsets available to all threads
__syncthreads();
// trigger the secondary kernel when using PDL
// We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx,
// mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens
// TODO: this is not sufficient to ensure visibility in the next kernel!
if constexpr (KernelParams::UsePdl)
{
cudaTriggerProgrammaticLaunchCompletion();
}
// each thread has the same "expanded indexes" assigned to it as above
// at this point, we know the final offsets of experts and the offsets within
// experts, which allows writing the final index values
#pragma unroll
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii)
{
auto expandedIdx = static_cast<int32_t>(gridThreadIdx) + ii * numThreadsPerGrid;
if (expandedIdx >= expandedIdxSize)
{
break;
}
auto expertIdx = expertIndexes[ii];
// check whether this expert is local to our GPU at all
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
auto tokenIdx = expandedIdx / params.mTopK;
auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1};
if (params.mPtrExpandedIdxToPermutedIdx != nullptr)
{
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
}
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
{
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
}
}
}
#else
__global__ void routingIndicesCoopKernel(KernelParams params)
{
assert(false && "routingIndicesCoopKernel is only supported on SM90+ architectures");
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
void run(Data& data, void* stream)
{
TLLM_CHECK_WITH_INFO(
data.mPtrExpertIdx != nullptr || data.mPtrPermutedIdxSize != nullptr || data.mPtrExpertWeights != nullptr,
"Routing kernel requires at least one output parameter");
if (data.mPtrExpandedIdxToPermutedIdx != nullptr || data.mPtrPermutedIdxToTokenIdx != nullptr)
TLLM_CHECK_WITH_INFO(data.mPtrExpertIdx != nullptr && data.mPtrPermutedIdxSize,
"If permuted index is required, `mPtrExpertIdx` is also required");
TLLM_CHECK_WITH_INFO(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet");
TLLM_CHECK_WITH_INFO(data.mNumLimitedGroups <= MaxNumTopGroups, "Routing kernel expects <= %d top groups, got %d",
MaxNumTopGroups, data.mNumLimitedGroups);
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d",
MaxNumTopExperts, data.mTopK);
TLLM_CHECK_WITH_INFO(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", data.mTopK);
TLLM_CHECK_WITH_INFO(data.mTopK * data.mNumLimitedGroups <= WarpSize,
"Routing kernel expects top K * top groups <= warp size (for now), got %d * %d", data.mTopK,
data.mNumLimitedGroups);
TLLM_CHECK_WITH_INFO(data.mNumExperts >= MaxNumTopExperts, "Routing kernel expects %d to be at most #experts %d",
MaxNumTopExperts, data.mNumExperts);
TLLM_CHECK_WITH_INFO(data.mNumExperts <= NumThreads, "Routing kernel expects #experts %d <= #threads %d",
data.mNumExperts, NumThreads);
TLLM_CHECK_WITH_INFO(data.mNumExpertGroups >= data.mNumLimitedGroups,
"Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups,
data.mNumExpertGroups);
if (data.mNumExpertGroups > 1)
{
TLLM_CHECK_WITH_INFO(data.mNumExpertGroups <= NumWarps,
"Routing kernel expects #experts groups %d to be <= #warps %d", data.mNumExpertGroups, NumWarps);
TLLM_CHECK_WITH_INFO(data.mNumExperts % data.mNumExpertGroups == 0,
"Routing kernel expects #experts %d to be a multiple of #expert groups %d", data.mNumExperts,
data.mNumExpertGroups);
TLLM_CHECK_WITH_INFO(data.mNumExperts / data.mNumExpertGroups <= WarpSize,
"Routing kernel expects #experts per group <= warp size, got %d", data.mNumExperts / data.mNumExpertGroups);
}
else
{
TLLM_CHECK_WITH_INFO(data.mNumExperts <= WarpSize * MaxNumTopGroups,
"Routing kernel expects #experts %d <= WarpSize * MaxNumTopGroups %d", data.mNumExperts,
WarpSize * MaxNumTopGroups);
TLLM_CHECK_WITH_INFO(
data.mTopK <= NumWarps, "Routing kernel expects top K %d to be <= #warps %d", data.mTopK, NumWarps);
}
TLLM_CHECK_WITH_INFO(
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
TLLM_CHECK_WITH_INFO(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2);
int const numBlocks = data.mNumTokens;
bool const useSingleCluster = data.mNumTokens <= 1024;
if (!useSingleCluster)
{
// Reset the global histograms (not used in single-cluster code path).
// Cover both for the cooperative and two-kernel code paths.
TLLM_CHECK_WITH_INFO(
data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input.");
}
else
{
data.mPtrExpertCounts = nullptr; // Set it to nullptr for single-cluster code path, as it won't be used
}
// Number of blocks we can use in the cooperative kernel
// The number of blocks must be:
// >= ⌈(numTokens * topK) / (MaxExpandedIdxPerThread * NumThreads)⌉
// <= numSms, assuming an occupancy of 1 block/SM
//
// If too small for the given numTokens, fall back to the less performant two-step method.
//
// The upper bound is a strict requirement. The number of blocks should be determined by querying
// the device properties, or conservatively low.
// /!\ The following number is not portable!! (but works on H100 and B200)
int const numBlocksCoop = 128;
// Maximum number of tokens supported by the kernel using a cooperative launch.
int const maxTokensCoop = (numBlocksCoop * NumThreads * 64) / data.mTopK;
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
/*coopLaunch=*/false, routingMainKernel, numBlocks, NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
if (data.mPtrPermutedIdxSize != nullptr)
{
if (useSingleCluster)
{
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
/*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
}
else if (data.mNumTokens <= maxTokensCoop)
{
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
/*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
}
else
{
const int32_t expandedIdxSize = data.mNumTokens * data.mTopK;
const int32_t histogramEltsPerBlock = 8 * NumThreads;
const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreads;
// Limit grid size (both kernels use a grid-stride loop).
const int32_t maxNumBlocks = 1024;
int const numBlocksHistogram
= std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks);
int const numBlocksOffsets
= std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks);
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
/*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
/*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace routingDeepSeek
} // namespace moe::dev::routing

View File

@ -0,0 +1,779 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "DevKernel.h"
#include "RoutingKernel.h"
#include "RoutingKernelTopK.cuh"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/cub.cuh>
#include <cute/arch/cluster_sm90.hpp>
#include <cutlass/arch/arch.h>
#include <type_traits>
#include "tensorrt_llm/kernels/archCondition.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace moe::dev
{
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace routing
{
namespace cg = cooperative_groups;
////////////////////////////////////////////////////////////////////////////////////////////////////
static constexpr int WarpSize = 32;
static constexpr int NumBlocksPerCluster = 8;
// Performance tuning knob.
static constexpr int NumEltsPerOffsetTilePerThread = 8;
static constexpr int NumThreadsHist = 256;
static constexpr int NumWarpsHist = NumThreadsHist / WarpSize;
////////////////////////////////////////////////////////////////////////////////////////////////////
static constexpr bool TLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10>;
////////////////////////////////////////////////////////////////////////////////////////////////////
static __device__ inline float sigmoid_accurate(float x)
{
return 0.5f * tanhf(0.5f * x) + 0.5f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
__host__ __device__ constexpr T mulLog2(T a, T bLog2)
{
return a << bLog2;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
__host__ __device__ constexpr T divUpLog2(T a, T bLog2)
{
return ((a + (1 << bLog2) - 1) >> bLog2);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
__host__ __device__ constexpr T divUpMulLog2(T a, T bLog2)
{
return mulLog2<T>(divUpLog2<T>(a, bLog2), bLog2);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
__host__ __device__ constexpr int32_t getBits(int32_t value, int idx)
{
int mask = idx == 0 ? 0x000000FF : idx == 1 ? 0x0000FF00 : idx == 2 ? 0x00FF0000 : 0xFF000000;
return (value & mask) >> (idx * 8);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool IsZero = false>
__host__ __device__ constexpr void setBits(int32_t& value, int32_t newBits, int idx)
{
if constexpr (!IsZero)
{
int mask = idx == 0 ? 0xFFFFFF00 : idx == 1 ? 0xFFFF00FF : idx == 2 ? 0xFF00FFFF : 0x00FFFFFF;
value &= mask;
}
value |= (newBits << (idx * 8));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename DataType>
__device__ void initArr(int startIdx, int numElts, int stride, DataType* arr, DataType value)
{
if (arr != nullptr)
{
for (int i = startIdx; i < numElts; i += stride)
{
arr[i] = value;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename DataType, int VecSize>
__device__ void calcSoftmax(cg::thread_block_tile<WarpSize> const& warp, DataType (&scores)[VecSize])
{
DataType maxScore = DataType{-INFINITY};
DataType sumScore = DataType{0.f};
// Get the max score for each token
for (int i = 0; i < VecSize; ++i)
{
maxScore = scores[i] >= maxScore ? scores[i] : maxScore;
}
maxScore = cg::reduce(warp, maxScore, cg::greater<DataType>());
// Get the summation of scores for each token
#pragma unroll
for (int i = 0; i < VecSize; ++i)
{
scores[i] = static_cast<DataType>(exp(scores[i] - maxScore));
sumScore += scores[i];
}
sumScore = cg::reduce(warp, sumScore, cg::plus<DataType>());
// Normalize the scores
#pragma unroll
for (int i = 0; i < VecSize; ++i)
{
scores[i] = static_cast<DataType>(scores[i] / sumScore);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename DataType>
__device__ DataType calcSoftmax(
cg::thread_block_tile<WarpSize> const& warp, DataType score, int32_t laneIdx, int32_t NumTopExperts)
{
DataType maxScore = DataType{-INFINITY};
if (laneIdx < NumTopExperts)
{
maxScore = score >= maxScore ? score : maxScore;
}
maxScore = cg::reduce(warp, maxScore, cg::greater<DataType>());
float sumScore = float{0.f};
float newScore;
// Get the summation of scores for each token
if (laneIdx < NumTopExperts)
{
newScore = static_cast<float>(score) - static_cast<float>(maxScore);
newScore = static_cast<float>(exp(newScore));
sumScore += newScore;
}
sumScore = cg::reduce(warp, sumScore, cg::plus<float>());
if (laneIdx < NumTopExperts)
{
score = static_cast<DataType>(newScore / sumScore);
}
return score;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename KernelParams, typename BaseType, int NumThreads, int NumWarps, int MaxNumTopExperts,
bool LoadExpertIdxFromGlobal = false>
__device__ void routingPermutation(KernelParams params, PackedScoreIdx<BaseType>* smemPackedScoreIdx,
int32_t const warpIdx, uint32_t const clusterBlockRank)
{
using OutputT = typename KernelParams::OutputT;
using TypePacked = PackedScoreIdx<BaseType>;
static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads;
// Number of threads in the cluster.
static constexpr int NumThreadsPerCluster = NumThreads * NumBlocksPerCluster;
// same as max num tokens
static constexpr int MaxExpandedIdxPerThread
= (MaxNumTokensSingleCluster * MaxNumTopExperts + NumThreadsPerCluster - 1) / NumThreadsPerCluster;
// Needed for the exclusive sum of token offsets.
// Note: the scan might include more bins than needed, with bin counts of 0 to pad
using Scan = cub::BlockScan<int32_t, NumThreads, cub::BLOCK_SCAN_WARP_SCANS>;
__shared__ typename Scan::TempStorage tempStorage;
uint32_t const clusterThreadIdx = NumThreads * clusterBlockRank + threadIdx.x;
auto expandedIdxSize = params.mNumTokens * params.mTopK;
// number of experts is bounded by number of threads
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads];
__shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads];
// pre-fill the counts with 0
if (threadIdx.x < params.mNumExperts)
{
smemExpertCount[threadIdx.x] = 0;
}
__syncthreads();
// each thread keeps some number of "expanded indexes" assigned to it
// note that expanded indexes simply represent tokens here.
// for each of these, we keep the associated expert and offset within expert in registers
int32_t expertIndexes[MaxExpandedIdxPerThread];
int32_t expertOffsets[MaxExpandedIdxPerThread];
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
// In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
// time, and branch between a fast path without bound checks and a slow path with bound checks.
// TODO(mjoux): potentially add this back for perf tuning
// int constexpr IterStride = 4;
// static_assert(MaxExpandedIdxPerThread % IterStride == 0);
// Define a lambda to avoid code duplication in both branches.
auto loopBody = [&](int ii, int expandedIdx)
{
TypePacked scoreIdx;
if constexpr (LoadExpertIdxFromGlobal)
{
scoreIdx = TypePacked{static_cast<BaseType>(params.mPtrExpertIdx[expandedIdx].score),
static_cast<int16_t>(params.mPtrExpertIdx[expandedIdx].idx)};
}
else
{
TypePacked const* remoteSmem
= cg::cluster_group::map_shared_rank(smemPackedScoreIdx, expandedIdx / (NumWarps * params.mTopK));
scoreIdx = remoteSmem[expandedIdx % (NumWarps * params.mTopK)];
}
expertIndexes[ii] = scoreIdx.idx;
// check whether this expert is local to our GPU at all and ignore if not
auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx;
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + scoreIdx.idx, 1) : 0;
if (params.mPtrExpertWeights != nullptr)
{
params.mPtrExpertWeights[expandedIdx] = OutputT{scoreIdx.score};
}
};
int constexpr IterStride = 4;
#pragma unroll
for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride)
{
// Whether it's safe to do multiple iterations without bound checks.
bool const takeFastPath = (ii0 + IterStride) * NumThreadsPerCluster <= expandedIdxSize;
if (takeFastPath)
{
#pragma unroll
for (int32_t jj = 0; jj < IterStride; jj++)
{
int const ii = ii0 + jj;
auto expandedIdx = static_cast<int32_t>(clusterThreadIdx) + ii * NumThreadsPerCluster;
loopBody(ii, expandedIdx);
}
}
else
{
bool doBreak = false;
#pragma unroll
for (int32_t jj = 0; jj < IterStride; jj++)
{
int const ii = ii0 + jj;
auto expandedIdx = static_cast<int32_t>(clusterThreadIdx) + ii * NumThreadsPerCluster;
if (expandedIdx >= expandedIdxSize)
{
doBreak = true;
break;
}
loopBody(ii, expandedIdx);
}
if (doBreak)
{
break;
}
}
}
// Make local histogram (token counts per expert) available to all threads in the cluster.
__cluster_barrier_arrive();
__cluster_barrier_wait();
//
// Each thread now represents one expert
//
// Total number of tokens for this expert.
int32_t count = 0;
// Per-expert offset for this block.
int32_t blockExpertOffset = 0;
if (threadIdx.x < params.mNumExperts)
{
// Get the histogram bin from each rank for this expert.
int32_t expertCounts[NumBlocksPerCluster];
#pragma unroll
for (int rank = 0; rank < NumBlocksPerCluster; rank++)
{
int32_t const* remoteSmem = cg::cluster_group::map_shared_rank(smemExpertCount, rank);
expertCounts[rank] = rank * NumWarps < params.mNumTokens ? remoteSmem[threadIdx.x] : 0;
}
// Compute an exclusive prefix sum of the block-local count.
#pragma unroll
for (int rank = 0; rank < NumBlocksPerCluster; rank++)
{
if (rank == clusterBlockRank)
{
blockExpertOffset = count;
}
count += expertCounts[rank];
}
}
// Arrive: we do not access distributed shared memory after this point.
__cluster_barrier_arrive();
// Compute the runtime config for projections
// Whether or not an expert is local is taken into account when smemExpertCount is computed
// so we do not need to take it into account here.
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
int32_t ctaOffset;
int32_t numNonExitingCtas;
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
if (threadIdx.x < params.mNumExperts)
{
// Strided loop to share this work between blocks.
for (int32_t cta = clusterBlockRank; cta < numCta; cta += NumBlocksPerCluster)
{
const int32_t localExpertIdx
= (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta]
= min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count);
}
// get the padded offset associated with this expert
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
// write expert offsets to shared
smemExpertOffset[threadIdx.x] = offset + blockExpertOffset;
}
// write out padded count
if (clusterBlockRank == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync())
{
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
}
// make expert offsets available to all threads
__syncthreads();
// Wait: we cannot exit while other blocks may be accessing the current block's shared memory.
// Note: I observed a perf benefit to doing this before the final loop so the compiler can
// implement break with EXIT.
__cluster_barrier_wait();
// trigger the secondary kernel when using PDL
// We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx,
// mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens
// TODO: this is not sufficient to ensure visibility in the next kernel!
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
if constexpr (KernelParams::UsePdl)
{
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// each thread has the same "expanded indexes" assigned to it as above
// at this point, we know the final offsets of experts and the offsets within
// experts, which allows writing the final index values
#pragma unroll
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii)
{
auto expandedIdx = static_cast<int32_t>(clusterThreadIdx) + ii * NumThreadsPerCluster;
if (expandedIdx >= expandedIdxSize)
{
break;
}
auto expertIdx = expertIndexes[ii];
// check whether this expert is local to our GPU at all
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
auto tokenIdx = expandedIdx / params.mTopK;
auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1};
if (params.mPtrExpandedIdxToPermutedIdx != nullptr)
{
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
}
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
{
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Two-step approach (if number of tokens exceed limits of what cluster / cooperative launch
// variants can handle): in order to minimize the amount of data to exchange through global memory,
// we will compute the local histograms in smem twice: the first kernel will get us the total number
// of tokens per expert. The second kernel will use the smem and L2 atomics to get corresponding
// element and tile offsets.
//
// Note: the histogram calculation could also be fused with routingMainKernel, but this might be
// inefficient if we have one CTA per token doing a single global atomic.
template <typename KernelParams>
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramKernel(KernelParams params)
{
using OutputT = typename KernelParams::OutputT;
// number of experts is bounded by number of threads
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreadsHist];
// For unrolling.
uint32_t constexpr NumEltsPerThread = 8;
// Pre-fill the counts with 0
if (threadIdx.x < params.mNumExperts)
{
smemExpertCount[threadIdx.x] = 0;
}
__syncthreads();
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// Wait on primary grid and trigger secondary kernel.
if constexpr (KernelParams::UsePdl)
{
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
}
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
uint32_t const expandedIdxSize = params.mNumTokens * params.mTopK;
uint32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
uint32_t const gridBlockOffset = blockIdx.x * NumThreadsHist;
uint32_t const gridStride = gridDim.x * NumThreadsHist;
// Define a lambda to avoid code duplication in branches.
auto loopBody = [&](int expandedIdx)
{
PackedScoreIdx<OutputT> scoreIdx = params.mPtrExpertIdx[expandedIdx];
// check whether this expert is local to our GPU at all and ignore if not
auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx;
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
if (isLocalExpert)
{
atomicAdd(&smemExpertCount[scoreIdx.idx], 1);
}
if (params.mPtrExpertWeights != nullptr)
{
params.mPtrExpertWeights[expandedIdx] = static_cast<OutputT>(scoreIdx.score);
}
};
// Grid-stride loop.
for (uint32_t expandedIdx0 = gridBlockOffset * NumEltsPerThread; expandedIdx0 < expandedIdxSize;
expandedIdx0 += gridStride * NumEltsPerThread)
{
// Fast path if bound checks aren't necessary
if (expandedIdx0 + NumEltsPerThread * NumThreadsHist <= expandedIdxSize)
{
#pragma unroll
for (uint32_t ii = 0; ii < NumEltsPerThread; ii++)
{
uint32_t expandedIdx = expandedIdx0 + ii * NumThreadsHist + threadIdx.x;
loopBody(expandedIdx);
}
}
else
{
for (uint32_t expandedIdx = expandedIdx0 + threadIdx.x; expandedIdx < expandedIdxSize;
expandedIdx += NumThreadsHist)
{
loopBody(expandedIdx);
}
}
}
__syncthreads();
//
// Each thread now represents one expert
//
// Reduce histograms with atomics.
if (threadIdx.x < params.mNumExperts)
{
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
atomicAdd(&params.mPtrExpertCounts[threadIdx.x], localExpertCount);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename KernelParams>
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(KernelParams params)
{
using OutputT = typename KernelParams::OutputT;
// number of experts is bounded by number of threads
__shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreadsHist];
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreadsHist];
__shared__ int32_t __attribute((aligned(128))) smemExpertTileOffset[NumThreadsHist];
// needed for the exclusive sum of token offsets
using Scan = cub::BlockScan<int32_t, NumThreadsHist, cub::BLOCK_SCAN_WARP_SCANS>;
__shared__ typename Scan::TempStorage tempStorage;
static constexpr int MaxExpandedIdxPerThread = NumEltsPerOffsetTilePerThread;
static constexpr int MaxExpandedIdxPerBlock = NumThreadsHist * MaxExpandedIdxPerThread;
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
uint32_t const expandedIdxSize = params.mNumTokens * params.mTopK;
uint32_t const numTiles = (expandedIdxSize + MaxExpandedIdxPerBlock - 1) / (MaxExpandedIdxPerBlock);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// Wait on primary grid.
if constexpr (KernelParams::UsePdl)
{
cudaGridDependencySynchronize();
}
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// The expert offsets are common to all tiles of all blocks.
// Load the histogram, scan it and write offsets to shared memory.
// Note: the scan is redundant in all CTAs. Would it make sense to use an intermediate kernel for
// the scan, with PDL?
//
// Each thread represents one expert.
//
// Get total count for this expert.
int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0;
// Compute the runtime config for projections
// Whether or not an expert is local is taken into account when the histogram is computed
// so we do not need to take it into account here.
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
int32_t ctaOffset;
int32_t numNonExitingCtas;
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
if (threadIdx.x < params.mNumExperts)
{
// Get the padded offset associated with this expert
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
// Write expert offsets to shared
smemExpertOffset[threadIdx.x] = offset;
}
// Sync to make expert offsets available to all threads.
__syncthreads();
// The first block writes out padded count
if (blockIdx.x == 0 && warpIdx == NumWarpsHist - 1 && cute::elect_one_sync())
{
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
}
if (threadIdx.x < params.mNumExperts)
{
// Strided loop to share this work between blocks.
for (int32_t cta = blockIdx.x; cta < numCta; cta += gridDim.x)
{
const int32_t localExpertIdx
= (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta]
= min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count);
}
}
//
// Now loop on indices and compute offsets.
//
// Grid-stride loop on 1D "tiles" of input indices.
for (uint32_t tileIdx = blockIdx.x; tileIdx < numTiles; tileIdx += gridDim.x)
{
if (tileIdx > 0)
{
// Sync for safe reuse of smem buffers.
__syncthreads();
}
// Pre-fill the counts with 0
if (threadIdx.x < params.mNumExperts)
{
smemExpertCount[threadIdx.x] = 0;
}
__syncthreads();
// each thread keeps has some number of "expanded indexes" assigned to it
// for each of these, we keep the associated expert and offset within expert in registers
int32_t expertIndexes[MaxExpandedIdxPerThread];
int32_t expertOffsets[MaxExpandedIdxPerThread];
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
// Define a lambda to avoid code duplication in branches.
auto loopBody = [&](int ii, int expandedIdx)
{
PackedScoreIdx<OutputT> scoreIdx = params.mPtrExpertIdx[expandedIdx];
expertIndexes[ii] = scoreIdx.idx;
// check whether this expert is local to our GPU at all and ignore if not
auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx;
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + scoreIdx.idx, 1) : 0;
};
// For all tiles but the last, all indices are in bounds.
if (tileIdx < numTiles - 1)
{
#pragma unroll
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
{
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
loopBody(ii, expandedIdx);
}
}
else
{
// For the last tile, we need to exit the loop when out of bounds.
// In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
// time, and branch between a fast path without bound checks and a slow path with bound checks
int constexpr IterStride = 4;
static_assert(MaxExpandedIdxPerThread % IterStride == 0);
#pragma unroll
for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride)
{
// Whether it's safe to do multiple iterations without bound checks.
bool const takeFastPath
= tileIdx * MaxExpandedIdxPerBlock + (ii0 + IterStride) * NumThreadsHist <= expandedIdxSize;
if (takeFastPath)
{
#pragma unroll
for (int32_t jj = 0; jj < IterStride; jj++)
{
int const ii = ii0 + jj;
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
loopBody(ii, expandedIdx);
}
}
else
{
bool doBreak = false;
#pragma unroll
for (int32_t jj = 0; jj < IterStride; jj++)
{
int const ii = ii0 + jj;
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
if (expandedIdx >= expandedIdxSize)
{
doBreak = true;
break;
}
loopBody(ii, expandedIdx);
}
if (doBreak)
{
break;
}
}
}
}
// Make local histogram (token counts per expert) available to all threads in the block.
__syncthreads();
//
// Each thread now represents one expert
//
if (threadIdx.x < params.mNumExperts)
{
// Add the local bin count to the common bin count and get a per-CTA offset. We use the second
// half of the histogram buffer for this histogram, because the first half already holds the
// reduced histogram from the previous kernel.
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
int32_t const tileExpertOffset
= atomicAdd(&params.mPtrExpertCounts[params.mNumExperts + threadIdx.x], localExpertCount);
// Make per-expert tile offsets available to all threads in the block.
smemExpertTileOffset[threadIdx.x] = tileExpertOffset + smemExpertOffset[threadIdx.x];
}
__syncthreads();
// Add tile offset and element offset and write to global memory.
auto storeLoopBody = [&](int ii, int expandedIdx)
{
int32_t expertIdx = expertIndexes[ii];
// check whether this expert is local to our GPU at all
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
auto tokenIdx = expandedIdx / params.mTopK;
auto permutedIdx = isLocalExpert ? (expertOffsets[ii] + smemExpertTileOffset[expertIdx]) : int32_t{-1};
if (params.mPtrExpandedIdxToPermutedIdx != nullptr)
{
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
}
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
{
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
}
};
// Bound checks only in last tile.
if (tileIdx < numTiles - 1)
{
#pragma unroll
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
{
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
storeLoopBody(ii, expandedIdx);
}
}
else
{
#pragma unroll
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
{
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
if (expandedIdx >= expandedIdxSize)
{
break;
}
storeLoopBody(ii, expandedIdx);
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// Trigger secondary kernel.
// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
// dependency sync.
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
if constexpr (KernelParams::UsePdl)
{
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
}
} // namespace routing
} // namespace moe::dev

View File

@ -31,33 +31,21 @@ namespace routing
namespace tg = batchedGemm::trtllm::gen;
template <typename TypeExpW>
template <typename DataType>
struct PackedScoreIdx
{
TypeExpW score;
int16_t idx; // @TODO: Might use int8_t as the number of experts is 128
DataType score;
int16_t idx;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace routingDeepSeek
struct DataBase
{
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Data
{
tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16};
bool mUsePdl{false};
// note: at least one of the optional outputs below must be provided
// note: if one of the indexes using "PermutedIdx" is provided,
// then `mPtrExpertIdx` and `mPtrPermutedIdxSize` must be provided
// optional: if `nullptr`, it is not filled
// dim: [mNumTokens, mTopK]
int32_t* mPtrExpertIdx{nullptr};
// optional: only used as an intermediate buffer when the number of tokens is large.
// dim: [2*NumThreads] = [512]
// dim: max([2*NumThreads] = [512], mNumExperts*2)
int32_t* mPtrExpertCounts{nullptr};
// optional: if `nullptr`, it is not filled
// dim: [1]
@ -67,13 +55,25 @@ struct Data
int32_t* mPtrExpandedIdxToPermutedIdx{nullptr};
// optional: if `nullptr`, it is not filled
// dim: [mNumTokens * mTopK + (mNumExperts << mPaddingLog2) - mNumExperts]
// Note: this array (mPtrPermutedIdxToTokenIdx) is uninitialized
// Any out-of-bounds values are undefined.
int32_t* mPtrPermutedIdxToTokenIdx{nullptr};
// optional: if `nullptr`, it is not filled
// dim: [mNumLocalExperts * (2 ^ mLocalExpertsStrideLog2), mNumTokens]
void* mPtrExpertWeightsFull{nullptr};
// optional: if `nullptr`, it is not filled
// dim: [mNumTokens, mTopK]
void* mPtrExpertWeights{nullptr};
// optional: if `nullptr`, scores are used directly as input.
// If it is given, it must represent a packed value s.t. the most significant
// 16/32 bits represent the score without sigmoid activation and
// the least significant 16 bits represent the index of the chosen expert (unsigned).
// note: this is required if the number of tokens is large.
// dim: [mNumTokens, mTopK]
void* mPtrExpertIdx{nullptr};
// optional: if `nullptr`, `mPtrExpertIdx` must be provided.
// If it is given, it represents the scores without sigmoid activation for
// each token and expert.
// note: if it is provided, we always re-compute the top1 scores
// dim: [mNumTokens, mNumExperts]
void const* mPtrScores{nullptr};
//
// Grouped Gemm Launch Config Buffers
@ -81,110 +81,137 @@ struct Data
int32_t* mPtrCtaIdxXyToBatchIdx{nullptr};
int32_t* mPtrCtaIdxXyToMnLimit{nullptr};
int32_t* mPtrNumNonExitingCtas{nullptr};
// mPtrPermutedIdxSize is ptrTotalNumPaddedTokens
bool mAllToAllRouteAct{false};
void const* mPtrRoutingWeights;
void const* mPtrRoutingBias;
void const* mPtrIn;
float* mPtrScores;
//
// Metadata
//
int32_t mNumTokens;
int32_t mHiddenDim;
int32_t mNumExperts;
int32_t mNumExpertGroups;
int32_t mNumLimitedGroups;
int32_t mTopK;
int32_t mPaddingLog2;
/// For expert parallelization
int32_t mLocalExpertsStartIdx;
int32_t mLocalExpertsStrideLog2;
int32_t mNumLocalExperts;
float mRouteScale;
bool mUseRoutingSoftmax;
int32_t* mPtrNumTokensPerExpert{nullptr};
int32_t* mPtrPermutedIdxToExpandedIdx{nullptr};
};
template <typename TypeExpW_, bool UseGroups_, bool UsePdl_>
struct KernelParams
template <typename InputT_, typename OutputT_, bool UsePdl_>
struct KernelParamsBase
{
using TypeExpW = TypeExpW_;
static constexpr bool UseGroups = UseGroups_;
using InputT = InputT_;
using OutputT = OutputT_;
static constexpr bool UsePdl = UsePdl_;
int32_t* mPtrExpertIdx;
int32_t* mPtrExpertCounts;
int32_t* mPtrPermutedIdxSize;
int32_t* mPtrExpandedIdxToPermutedIdx;
int32_t* mPtrPermutedIdxToTokenIdx;
int32_t* mPtrPermutedIdxToExpandedIdx;
int32_t* mPtrNumTokensPerExpert;
// Public pointer members
int32_t* mPtrExpertCounts = nullptr;
int32_t* mPtrPermutedIdxSize = nullptr;
int32_t* mPtrExpandedIdxToPermutedIdx = nullptr;
int32_t* mPtrPermutedIdxToTokenIdx = nullptr;
int32_t* mPtrCtaIdxXyToBatchIdx = nullptr;
int32_t* mPtrCtaIdxXyToMnLimit = nullptr;
int32_t* mPtrNumNonExitingCtas = nullptr;
OutputT* mPtrExpertWeights = nullptr;
InputT const* mPtrScores = nullptr;
int32_t* mPtrCtaIdxXyToBatchIdx;
int32_t* mPtrCtaIdxXyToMnLimit;
int32_t* mPtrNumNonExitingCtas;
// Public scalar members
int32_t mNumTokens = 0;
int32_t mNumExperts = 0;
TypeExpW* mPtrExpertWeightsFull;
TypeExpW* mPtrExpertWeights;
TypeExpW const* mPtrRoutingWeights;
TypeExpW const* mPtrRoutingBias;
float const* mPtrScores;
int32_t mPaddingLog2 = 0;
int32_t mLocalExpertsStartIdx = 0;
int32_t mLocalExpertsStrideLog2 = 0;
int32_t mNumLocalExperts = 0;
int32_t mHiddenDim;
int32_t mNumExperts;
// Public initialization function - make it a template to accept different Data types
template <typename DataType>
void setBaseParams(DataType const& data)
{
mPtrExpertCounts = data.mPtrExpertCounts;
mPtrPermutedIdxSize = data.mPtrPermutedIdxSize;
mPtrExpandedIdxToPermutedIdx = data.mPtrExpandedIdxToPermutedIdx;
mPtrPermutedIdxToTokenIdx = data.mPtrPermutedIdxToTokenIdx;
mPtrCtaIdxXyToBatchIdx = data.mPtrCtaIdxXyToBatchIdx;
mPtrCtaIdxXyToMnLimit = data.mPtrCtaIdxXyToMnLimit;
mPtrNumNonExitingCtas = data.mPtrNumNonExitingCtas;
mPtrExpertWeights = static_cast<OutputT*>(data.mPtrExpertWeights);
mPtrScores = (InputT const*) data.mPtrScores;
mNumTokens = data.mNumTokens;
mNumExperts = data.mNumExperts;
mPaddingLog2 = data.mPaddingLog2;
mLocalExpertsStartIdx = data.mLocalExpertsStartIdx;
mLocalExpertsStrideLog2 = data.mLocalExpertsStrideLog2;
mNumLocalExperts = data.mNumLocalExperts;
}
};
namespace routingDeepSeek
{
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Data : public DataBase
{
tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16};
//
// Grouped Gemm Launch Config Buffers
//
void const* mPtrRoutingBias;
int32_t mHiddenDim; // not used
int32_t mNumExpertGroups;
int32_t mNumExpertsPerGroup;
int32_t mNumLimitedGroups;
trtllm::dev::IntFastDiv mTopK;
int32_t mPaddingLog2;
int32_t mLocalExpertsStartIdx;
int32_t mLocalExpertsStrideLog2;
int32_t mNumLocalExperts;
int32_t mNumTokens;
float mRouteScale;
bool mAllToAllRouteAct;
bool mUseRoutingSoftmax;
};
template <typename InputT_, typename OutputT_, bool UseGroups_, bool UsePdl_>
struct KernelParams : public KernelParamsBase<InputT_, OutputT_, UsePdl_>
{
using InputT = InputT_;
using OutputT = OutputT_;
static constexpr bool UseGroups = UseGroups_;
PackedScoreIdx<OutputT>* mPtrExpertIdx = nullptr;
// OutputT* mPtrExpertWeightsFull = nullptr;
// Note: this variable(mPtrExpertWeightsFull) might need to be added back for the low-latency kernels for MoE in
// tllm-gen in the future
OutputT const* mPtrRoutingBias = nullptr;
int32_t mNumExpertGroups = 0;
int32_t mNumExpertsPerGroup = 0;
int32_t mNumLimitedGroups = 0;
trtllm::dev::IntFastDiv mTopK;
float mRouteScale = 0.f;
static KernelParams setKernelParams(Data const& data)
{
KernelParams params;
params.setBaseParams(data);
params.mPtrExpertIdx = data.mPtrExpertIdx;
params.mPtrExpertCounts = data.mPtrExpertCounts;
params.mPtrPermutedIdxSize = data.mPtrPermutedIdxSize;
params.mPtrExpandedIdxToPermutedIdx = data.mPtrExpandedIdxToPermutedIdx;
params.mPtrPermutedIdxToTokenIdx = data.mPtrPermutedIdxToTokenIdx;
params.mPtrPermutedIdxToExpandedIdx = data.mPtrPermutedIdxToExpandedIdx;
params.mPtrNumTokensPerExpert = data.mPtrNumTokensPerExpert;
params.mPtrExpertIdx = (PackedScoreIdx<OutputT>*) data.mPtrExpertIdx;
params.mPtrCtaIdxXyToBatchIdx = data.mPtrCtaIdxXyToBatchIdx;
params.mPtrCtaIdxXyToMnLimit = data.mPtrCtaIdxXyToMnLimit;
params.mPtrNumNonExitingCtas = data.mPtrNumNonExitingCtas;
// params.mPtrExpertWeightsFull = static_cast<OutputT*>(data.mPtrExpertWeightsFull);
params.mPtrRoutingBias = static_cast<OutputT const*>(data.mPtrRoutingBias);
params.mPtrExpertWeightsFull = (TypeExpW*) data.mPtrExpertWeightsFull;
params.mPtrExpertWeights = (TypeExpW*) data.mPtrExpertWeights;
params.mPtrRoutingWeights = (TypeExpW*) data.mPtrRoutingWeights;
params.mPtrRoutingBias = (TypeExpW*) data.mPtrRoutingBias;
params.mPtrScores = data.mPtrScores;
params.mHiddenDim = data.mHiddenDim;
params.mNumExperts = data.mNumExperts;
params.mNumExpertGroups = data.mNumExpertGroups;
params.mNumExpertsPerGroup = data.mNumExperts / data.mNumExpertGroups;
params.mNumLimitedGroups = data.mNumLimitedGroups;
params.mTopK = trtllm::dev::IntFastDiv(data.mTopK);
params.mPaddingLog2 = data.mPaddingLog2;
params.mLocalExpertsStartIdx = data.mLocalExpertsStartIdx;
params.mLocalExpertsStrideLog2 = data.mLocalExpertsStrideLog2;
params.mNumLocalExperts = data.mNumLocalExperts;
params.mNumTokens = data.mNumTokens;
params.mRouteScale = data.mRouteScale;
params.mAllToAllRouteAct = data.mAllToAllRouteAct;
return params;
}
};
void run(Data const& data, void* stream);
void run(Data& data, void* stream);
} // namespace routingDeepSeek
@ -195,102 +222,28 @@ namespace routingLlama4
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Data
struct Data : public DataBase
{
tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16};
bool mUsePdl{false};
// optional: if `nullptr`, `mPtrExpertIdx` must be provided.
// If it is given, it represents the scores without sigmoid activation for
// each token and expert.
// note: if it is provided, we always re-compute the top1 scores
// dim: [mNumTokens, mNumExperts]
void const* mPtrScores{nullptr};
// optional: if `nullptr`, scores are used directly as input.
// If it is given, it must represent a packed value s.t. the most significant
// 16/32 bits represent the score without sigmoid activation and
// the least significant 16 bits represent the index of the chosen expert (unsigned).
// note: this is required if the number of tokens is large.
// dim: [mNumTokens, mTopK]
void* mPtrExpertIdx{nullptr};
// note: at least one of the optional outputs below must be provided
// optional: only used as an intermediate buffer when the number of tokens is large.
// dim: [2, mNumExperts]
int32_t* mPtrExpertCounts{nullptr};
// dim: [1]
int32_t* mPtrPermutedIdxSize{nullptr};
// optional: if `nullptr`, it is not filled
// dim: [mNumTokens * mTopK]
int32_t* mPtrExpandedIdxToPermutedIdx{nullptr};
// optional: if `nullptr`, it is not filled
// dim: [mNumTokens * mTopK + (mNumExperts << mPaddingLog2) - mNumExperts]
int32_t* mPtrPermutedIdxToTokenIdx{nullptr};
// optional: if `nullptr`, it is not filled
// dim: [mNumTokens, mTopK]
void* mPtrExpertWeights{nullptr};
//
// Grouped Gemm Launch Config Buffers
//
int32_t* mPtrCtaIdxXyToBatchIdx{nullptr};
int32_t* mPtrCtaIdxXyToMnLimit{nullptr};
int32_t* mPtrNumNonExitingCtas{nullptr};
int32_t mNumTokens;
int32_t mNumExperts;
int32_t mTopK;
int32_t mPaddingLog2;
int32_t mLocalExpertsStartIdx;
int32_t mLocalExpertsStrideLog2;
int32_t mNumLocalExperts;
};
template <typename TypeExpW_, bool UsePdl_>
struct KernelParams
template <typename InputT_, typename OutputT_, bool UsePdl_>
struct KernelParams : public KernelParamsBase<InputT_, OutputT_, UsePdl_>
{
using TypeExpW = TypeExpW_;
static constexpr bool UsePdl = UsePdl_;
using InputT = InputT_;
using OutputT = OutputT_;
PackedScoreIdx<TypeExpW>* mPtrExpertIdx;
TypeExpW const* mPtrScores;
int32_t* mPtrExpertCounts;
int32_t* mPtrPermutedIdxSize;
int32_t* mPtrExpandedIdxToPermutedIdx;
int32_t* mPtrPermutedIdxToTokenIdx;
int32_t* mPtrCtaIdxXyToBatchIdx;
int32_t* mPtrCtaIdxXyToMnLimit;
int32_t* mPtrNumNonExitingCtas;
TypeExpW* mPtrExpertWeights;
PackedScoreIdx<OutputT>* mPtrExpertIdx = nullptr;
int32_t mNumTokens;
int32_t mNumExperts;
int32_t mPaddingLog2;
int32_t mLocalExpertsStartIdx;
int32_t mLocalExpertsStrideLog2;
int32_t mNumLocalExperts;
int32_t mTopK;
static KernelParams setKernelParams(Data const& data)
{
KernelParams params;
params.setBaseParams(data);
params.mPtrExpertIdx = (PackedScoreIdx<TypeExpW>*) data.mPtrExpertIdx;
params.mPtrScores = (TypeExpW const*) data.mPtrScores;
params.mPtrExpertCounts = data.mPtrExpertCounts;
params.mPtrPermutedIdxSize = data.mPtrPermutedIdxSize;
params.mPtrExpandedIdxToPermutedIdx = data.mPtrExpandedIdxToPermutedIdx;
params.mPtrPermutedIdxToTokenIdx = data.mPtrPermutedIdxToTokenIdx;
params.mPtrCtaIdxXyToBatchIdx = data.mPtrCtaIdxXyToBatchIdx;
params.mPtrCtaIdxXyToMnLimit = data.mPtrCtaIdxXyToMnLimit;
params.mPtrNumNonExitingCtas = data.mPtrNumNonExitingCtas;
params.mPtrExpertWeights = (TypeExpW*) data.mPtrExpertWeights;
params.mNumTokens = data.mNumTokens;
params.mNumExperts = data.mNumExperts;
params.mPaddingLog2 = data.mPaddingLog2;
params.mLocalExpertsStartIdx = data.mLocalExpertsStartIdx;
params.mLocalExpertsStrideLog2 = data.mLocalExpertsStrideLog2;
params.mNumLocalExperts = data.mNumLocalExperts;
params.mPtrExpertIdx = (PackedScoreIdx<OutputT>*) data.mPtrExpertIdx;
params.mTopK = data.mTopK;
return params;
}
};
@ -306,105 +259,37 @@ namespace routingRenormalize
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Data
struct Data : public DataBase
{
tg::Dtype mDtypeExpW{tg::Dtype::Fp32};
tg::Dtype mDtypeElt{tg::Dtype::Bfloat16};
bool mUsePdl{false};
bool mDoSoftmaxBeforeTopK{false};
bool mNormTopkProb{true}; // Default value is true for Qwen3 model
// optional: if `nullptr`, `mPtrExpertIdx` must be provided.
// If it is given, it represents the scores without sigmoid activation for
// each token and expert.
// note: if it is provided, we always re-compute the top1 scores
// dim: [mNumTokens, mNumExperts]
void const* mPtrScores{nullptr};
// optional: if `nullptr`, scores are used directly as input.
// If it is given, it must represent a packed value s.t. the most significant
// 16/32 bits represent the score without sigmoid activation and
// the least significant 16 bits represent the index of the chosen expert (unsigned).
// note: this is required if the number of tokens is large.
// dim: [mNumTokens, mTopK]
void* mPtrExpertIdx{nullptr};
// note: at least one of the optional outputs below must be provided
// optional: only used as an intermediate buffer when the number of tokens is large.
// dim: [2, mNumExperts]
int32_t* mPtrExpertCounts{nullptr};
// dim: [1]
int32_t* mPtrPermutedIdxSize{nullptr};
// optional: if `nullptr`, it is not filled
// dim: [mNumTokens * mTopK]
int32_t* mPtrExpandedIdxToPermutedIdx{nullptr};
// optional: if `nullptr`, it is not filled
// dim: [mNumTokens * mTopK + (mNumExperts << mPaddingLog2) - mNumExperts]
int32_t* mPtrPermutedIdxToTokenIdx{nullptr};
// optional: if `nullptr`, it is not filled
// dim: [mNumTokens, mTopK]
void* mPtrExpertWeights{nullptr};
//
// Grouped Gemm Launch Config Buffers
//
int32_t* mPtrCtaIdxXyToBatchIdx{nullptr};
int32_t* mPtrCtaIdxXyToMnLimit{nullptr};
int32_t* mPtrNumNonExitingCtas{nullptr};
int32_t mNumTokens;
int32_t mNumExperts;
int32_t mTopK;
int32_t mPaddingLog2;
int32_t mLocalExpertsStartIdx;
int32_t mLocalExpertsStrideLog2;
int32_t mNumLocalExperts;
};
template <typename Type_, typename TypeExpW_, bool UsePdl_>
struct KernelParams
template <typename InputT_, typename OutputT_, bool DoSoftmaxBeforeTopK_, bool UsePdl_>
struct KernelParams : public KernelParamsBase<InputT_, OutputT_, UsePdl_>
{
using Type = Type_;
using TypeExpW = TypeExpW_;
static constexpr bool UsePdl = UsePdl_;
bool mNormTopkProb = true;
PackedScoreIdx<TypeExpW>* mPtrExpertIdx;
TypeExpW const* mPtrScores;
int32_t* mPtrExpertCounts;
int32_t* mPtrPermutedIdxSize;
int32_t* mPtrExpandedIdxToPermutedIdx;
int32_t* mPtrPermutedIdxToTokenIdx;
int32_t* mPtrCtaIdxXyToBatchIdx;
int32_t* mPtrCtaIdxXyToMnLimit;
int32_t* mPtrNumNonExitingCtas;
TypeExpW* mPtrExpertWeights;
using InputT = InputT_;
using OutputT = OutputT_;
int32_t mNumTokens;
int32_t mNumExperts;
int32_t mPaddingLog2;
int32_t mLocalExpertsStartIdx;
int32_t mLocalExpertsStrideLog2;
int32_t mNumLocalExperts;
static constexpr bool DoSoftmaxBeforeTopK = DoSoftmaxBeforeTopK_;
PackedScoreIdx<OutputT>* mPtrExpertIdx = nullptr;
int32_t mTopK = 0;
bool mNormTopkProb = true;
static KernelParams setKernelParams(Data const& data)
{
KernelParams params;
params.setBaseParams(data);
params.mPtrExpertIdx = (PackedScoreIdx<OutputT>*) data.mPtrExpertIdx;
params.mNormTopkProb = data.mNormTopkProb;
params.mPtrExpertIdx = (PackedScoreIdx<TypeExpW>*) data.mPtrExpertIdx;
params.mPtrScores = (TypeExpW const*) data.mPtrScores;
params.mPtrExpertCounts = data.mPtrExpertCounts;
params.mPtrPermutedIdxSize = data.mPtrPermutedIdxSize;
params.mPtrExpandedIdxToPermutedIdx = data.mPtrExpandedIdxToPermutedIdx;
params.mPtrPermutedIdxToTokenIdx = data.mPtrPermutedIdxToTokenIdx;
params.mPtrCtaIdxXyToBatchIdx = data.mPtrCtaIdxXyToBatchIdx;
params.mPtrCtaIdxXyToMnLimit = data.mPtrCtaIdxXyToMnLimit;
params.mPtrNumNonExitingCtas = data.mPtrNumNonExitingCtas;
params.mPtrExpertWeights = (TypeExpW*) data.mPtrExpertWeights;
params.mNumTokens = data.mNumTokens;
params.mNumExperts = data.mNumExperts;
params.mPaddingLog2 = data.mPaddingLog2;
params.mLocalExpertsStartIdx = data.mLocalExpertsStartIdx;
params.mLocalExpertsStrideLog2 = data.mLocalExpertsStrideLog2;
params.mNumLocalExperts = data.mNumLocalExperts;
params.mTopK = data.mTopK;
return params;
}
};

View File

@ -162,7 +162,7 @@ struct Sort<4, RedType>
template <int K, typename Type>
__forceinline__ __device__ void reduceTopK(cg::thread_block_tile<WarpSize> const& warp, Type (&out)[K],
int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue)
int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue, int actualK = K)
{
static_assert(K > 0, "Top K must have K > 0");
static_assert(K < WarpSize, "Top K must have K < WarpSize");
@ -170,7 +170,7 @@ __forceinline__ __device__ void reduceTopK(cg::thread_block_tile<WarpSize> const
RedType topK{value, idx};
typename RedType::TypeCmp packedMax{};
#pragma unroll
for (int kk = 0; kk < K; ++kk)
for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct
{
topK = kk > 0 && packedMax == topK.compVal ? RedType{minValue, idx} : topK;
// get the next largest value
@ -181,7 +181,7 @@ __forceinline__ __device__ void reduceTopK(cg::thread_block_tile<WarpSize> const
template <int K, typename Type, int N>
__forceinline__ __device__ void reduceTopK(cg::thread_block_tile<WarpSize> const& warp, Type (&out)[K],
int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue)
int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K)
{
static_assert(K > 0, "Top K must have K > 0");
static_assert(K < WarpSize, "Top K must have K < WarpSize");
@ -199,7 +199,7 @@ __forceinline__ __device__ void reduceTopK(cg::thread_block_tile<WarpSize> const
typename RedType::TypeCmp packedMax{};
#pragma unroll
for (int kk = 0; kk < K; ++kk)
for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct
{
bool update = kk > 0 && packedMax == topK[0].compVal;
#pragma unroll

View File

@ -0,0 +1,501 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "RoutingKernel.cuh"
namespace moe::dev::routing
{
namespace routingLlama4
{
////////////////////////////////////////////////////////////////////////////////////////////////////
static constexpr int NumThreads = 1024;
static constexpr int NumWarps = NumThreads / WarpSize;
static constexpr int MaxNumTopExperts = 1;
static constexpr int MaxNumExperts = 128;
static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads;
static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps;
static constexpr int WarpKernelSmemStride = 33;
// with further optimization to `routingIndicesWarpKernel`, this limit may
// increase. For now, it is a good cut-off point for when the block-wise
// operations are more efficient end-to-end.
static constexpr int WarpKernelMaxNumTokens = 4;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename DataType, int VecSize>
__forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile<WarpSize> const& warp,
DataType (&warpMaxScore)[MaxNumTopExperts], int32_t (&warpMaxExpertIdx)[MaxNumTopExperts], int32_t const laneIdx,
int32_t const numExperts, DataType const* ptrScores)
{
DataType minScore = DataType{-INFINITY};
DataType maxScore = minScore;
int32_t maxExpertIdx{-1};
using DataTypeVec = std::conditional_t<sizeof(DataType) == 2, float2, float4>;
// Non-vectorized loading: directly access ptrScores with expertIdx
for (int i = 0; i < VecSize; ++i)
{
auto expertIdx = i * WarpSize + laneIdx;
auto newScore = expertIdx < numExperts ? ptrScores[expertIdx] : minScore;
// note: use `>=` s.t. highest index always wins, just like in `reduceTopK`
if (newScore > maxScore)
{
maxScore = newScore;
maxExpertIdx = expertIdx;
}
}
topk::reduceTopK(warp, warpMaxScore, warpMaxExpertIdx, maxScore, maxExpertIdx, minScore);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename KernelParams>
__global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParams params)
{
// types used in this kernel
using OutputT = typename KernelParams::OutputT;
using InputT = typename KernelParams::InputT;
using TypePacked = PackedScoreIdx<OutputT>;
// use the default cub warp-scan, with shfl
using Scan = cub::WarpScan<int32_t>;
__shared__ typename Scan::TempStorage tempStorage;
// each thread encodes 4 experts in one `int32_t`. The assumption is that
// we don't have more than 127 tokens, but `WarpKernelMaxNumTokens` must be
// smaller than that because other approaches will be more efficient for
// 127 tokens.
static constexpr int ExpertsPerThread = sizeof(int32_t);
static_assert(WarpKernelMaxNumTokens <= 127);
// this is a full table of which token is routed to which expert.
// the assumption here is that there are no more than 128 experts.
// we use a stride of 33 instead of 32 to avoid shared memory bank conflicts.
__shared__ int32_t __attribute((aligned(128)))
smemExpertTokenCountFull[WarpKernelMaxNumTokens][WarpKernelSmemStride];
static_assert(WarpKernelSmemStride == WarpSize + 1);
static_assert(MaxNumExperts / sizeof(int32_t) <= WarpSize);
// values needed for the top-1 reduction, if required
InputT minScore = InputT{-INFINITY};
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WarpSize>(block);
#pragma unroll
for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens; ++tokenIdx)
{
// reset full shared memory field to 0
smemExpertTokenCountFull[tokenIdx][threadIdx.x] = 0;
}
__syncwarp();
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// then wait on primary grid
if constexpr (KernelParams::UsePdl)
{
cudaGridDependencySynchronize();
}
#endif
if (params.mPtrScores != nullptr)
{
// if we use `mPtrScores` as input, we need to perform the top-1 reduction
// for each token, we load the scores then use `reduceTopK` for this.
// each thread works on 4 experts, so a local reduction is done before
for (int tokenIdx = 0; tokenIdx < params.mNumTokens; ++tokenIdx)
{
auto scoreOffset = tokenIdx * params.mNumExperts;
int32_t warpMaxExpertIdx[MaxNumTopExperts];
InputT warpMaxScore[MaxNumTopExperts];
// Use routingTopKExperts function instead of inline logic
routingTopKExperts<InputT, ExpertsPerThread>(
warp, warpMaxScore, warpMaxExpertIdx, threadIdx.x, params.mNumExperts, params.mPtrScores + scoreOffset);
if (cute::elect_one_sync())
{
// one thread updates the count linking token to chosen expert
auto expertTokenCount = 0;
setBits</* IsZero= */ true>(expertTokenCount, 1, warpMaxExpertIdx[0] % ExpertsPerThread);
smemExpertTokenCountFull[tokenIdx][warpMaxExpertIdx[0] / ExpertsPerThread] = expertTokenCount;
// we also compute the final score here and write it out if required
auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})};
if (params.mPtrExpertWeights != nullptr)
{
params.mPtrExpertWeights[tokenIdx] = finalScore;
}
}
}
}
else
{
// if we do not have `mPtrScores` as input, we expect that `mPtrExpertWeights`
// contains the top-1 packed score and index already.
// Each thread represents a token here, and we extract the relevant score
// The assumption is that the #tokens is limited by warp-size
static_assert(WarpKernelMaxNumTokens <= WarpSize);
TypePacked scoreIdx = threadIdx.x < params.mNumTokens ? params.mPtrExpertIdx[threadIdx.x] : TypePacked{};
int32_t expertTokenCount = 0;
setBits</* IsZero= */ true>(expertTokenCount, 1, scoreIdx.idx % ExpertsPerThread);
if (threadIdx.x < params.mNumTokens)
{
smemExpertTokenCountFull[threadIdx.x][scoreIdx.idx / ExpertsPerThread] = expertTokenCount;
}
// we also compute the final score here and write it out if required
auto finalScore = OutputT{sigmoid_accurate(float{scoreIdx.score})};
if (params.mPtrExpertWeights != nullptr && threadIdx.x < params.mNumTokens)
{
params.mPtrExpertWeights[threadIdx.x] = finalScore;
}
}
// make the full table available to all threads
__syncwarp();
// at this point, each thread keeps a count of its 4 assigned experts in
// `expertCount`, as well as the offsets for all tokens w.r.t. these 4 experts
// in `expertOffset`.
int32_t expertCount = 0;
int32_t expertOffset[WarpKernelMaxNumTokens + 1];
#pragma unroll
for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens + 1; ++tokenIdx)
{
if (tokenIdx > params.mNumTokens)
break;
// simple reduction for `expertCount`, and scan for `expertOffset`
auto expertTokenCount = tokenIdx < params.mNumTokens ? smemExpertTokenCountFull[tokenIdx][threadIdx.x] : 0;
expertOffset[tokenIdx] = expertCount;
expertCount += expertTokenCount;
}
// at this point, we are ready for the scan across all experts to get the
// thread-wise offsets across experts
// first, we need to reduce across our 4 experts into `numCta`
int32_t numCta = 0;
#pragma unroll
for (int ii = 0; ii < ExpertsPerThread; ++ii)
{
auto count = getBits(expertCount, ii);
numCta += divUpLog2<int32_t>(count, params.mPaddingLog2);
}
// second, we perform the exclusive sum across the warp
int32_t ctaOffset;
int32_t numNonExitingCtas;
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
// finally, we perform a scan across our local experts, starting with the
// warp-wide scan result (`ctaOffset`)
auto ctaOffsetExp = ctaOffset;
#pragma unroll
for (int ii = 0; ii < ExpertsPerThread; ++ii)
{
auto count = getBits(expertCount, ii);
auto finalNumCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
auto expertIdx = threadIdx.x * ExpertsPerThread + ii;
// during the scan for expert offsets, we can already write out
// both `mPtrCtaIdxXyToBatchIdx` and `mPtrCtaIdxXyToMnLimit`
for (int cta = 0; cta < finalNumCta; ++cta)
{
params.mPtrCtaIdxXyToBatchIdx[ctaOffsetExp + cta] = expertIdx;
params.mPtrCtaIdxXyToMnLimit[ctaOffsetExp + cta]
= min(mulLog2<int32_t>(ctaOffsetExp + cta + 1, params.mPaddingLog2),
mulLog2<int32_t>(ctaOffsetExp, params.mPaddingLog2) + count);
}
ctaOffsetExp += finalNumCta;
}
// at this point, we can write out padded count from the warp-aggregate
if (cute::elect_one_sync())
{
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
// we can trigger the next kernel at this point
if constexpr (KernelParams::UsePdl)
{
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
#endif
// at this point, all values for offsets are ready, except the final offsets
// within the padded index (`permutedIdx`)
// for this, we perform a scan similar to the one directly after the warp-scan:
// here, we keep the local offset for each of the thread's experts in a field
// of registers
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
int32_t finalExpertOffset[ExpertsPerThread];
finalExpertOffset[0] = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
#pragma unroll
for (int ii = 1; ii < ExpertsPerThread; ++ii)
{
finalExpertOffset[ii]
= finalExpertOffset[ii - 1] + divUpMulLog2<int32_t>(getBits(expertCount, ii - 1), params.mPaddingLog2);
}
#pragma unroll
for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens; ++tokenIdx)
{
// at this point, we can calculate the final index:
// we simply loop over all tokens, and all experts assigned to this thread.
// For each pair, we determine whether that token was routed to that expert
// based on whether the offset for that token changed.
// we can then easily compute the final `expertIdx` and `permutedIdx` relative
// to this token and expert, and write them out.
if (tokenIdx >= params.mNumTokens)
break;
#pragma unroll
for (int ii = 0; ii < ExpertsPerThread; ++ii)
{
// determine whether the offset for this expert and token changes
auto localOffsetToken = getBits(expertOffset[tokenIdx], ii);
auto isTokenRouted = getBits(expertOffset[tokenIdx + 1], ii) > localOffsetToken;
// the expert index of this expert
auto expertIdx = threadIdx.x * ExpertsPerThread + ii;
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
// the permuted index: we add the local offset relative to this expert and token
// to the global offset from the scan for this expert
auto permutedIdx = isLocalExpert ? finalExpertOffset[ii] + localOffsetToken : int32_t{-1};
// write out `mPtrExpandedIdxToPermutedIdx` if required
if (params.mPtrExpandedIdxToPermutedIdx != nullptr && isTokenRouted)
{
params.mPtrExpandedIdxToPermutedIdx[tokenIdx] = permutedIdx;
}
// write out `mPtrPermutedIdxToTokenIdx` if required
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert && isTokenRouted)
{
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename KernelParams>
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads)
routingIndicesClusterKernel(KernelParams params)
{
// number of tokens/expanded idx is bounded by total number of warps
using OutputT = typename KernelParams::OutputT;
using InputT = typename KernelParams::InputT;
using TypePacked = PackedScoreIdx<OutputT>;
__shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps];
uint32_t const clusterBlockRank = blockIdx.x;
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
int32_t const laneIdx = cutlass::arch::LaneId();
// TODO(mjoux): expand to more tokens (possibly)
auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx;
auto scoreOffset = warpTokenIdx * params.mNumExperts;
bool validToken = warpTokenIdx < params.mNumTokens;
InputT minScore = InputT{-INFINITY};
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WarpSize>(block);
// then wait on primary grid
if constexpr (KernelParams::UsePdl)
{
cudaGridDependencySynchronize();
}
if (params.mPtrScores != nullptr)
{
// in this case, each warp represents a token
// we then exchange all token max scores, s.t. afterwards, each thread
// represents a token
InputT warpMaxScore[MaxNumTopExperts];
int32_t warpMaxExpertIdx[MaxNumTopExperts];
if (validToken)
{
routingTopKExperts<InputT, MaxNumExperts / WarpSize>(
warp, warpMaxScore, warpMaxExpertIdx, laneIdx, params.mNumExperts, params.mPtrScores + scoreOffset);
if (cute::elect_one_sync())
{
auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})};
TypePacked packedScore{finalScore, static_cast<int16_t>(warpMaxExpertIdx[0])};
smemPackedScoreIdx[warpIdx] = packedScore;
}
}
// make packed scores available to all threads in cluster
__cluster_barrier_arrive();
__cluster_barrier_wait();
}
routingPermutation<KernelParams, OutputT, NumThreads, NumWarps, MaxNumTopExperts,
/*LoadExpertIdxFromGlobal=*/false>(params, smemPackedScoreIdx, warpIdx, clusterBlockRank);
}
#else
__global__ void routingIndicesClusterKernel(KernelParams params)
{
assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures");
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// this kernel is needed in case we have scores as input for the histogram kernel
template <typename KernelParams>
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresKernel(KernelParams params)
{
using OutputT = typename KernelParams::OutputT;
using InputT = typename KernelParams::InputT;
using TypePacked = PackedScoreIdx<OutputT>;
static constexpr int VecSize = MaxNumExperts / WarpSize;
// we assume that #experts is a multiple of 4, so VecSize must be 4.
static_assert(VecSize == 4);
int32_t const laneIdx = cutlass::arch::LaneId();
int32_t const warpIdx = threadIdx.x / WarpSize;
int32_t const globalWarpIdx = blockIdx.x * NumWarpsHist + warpIdx;
int32_t const globalWarpStride = gridDim.x * NumWarpsHist;
InputT minScore = InputT{-INFINITY};
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WarpSize>(block);
// initialize the mPtrExpertCounts
int32_t expertCountsNum = 2 * params.mNumExperts;
int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x;
int32_t globalThreadStride = gridDim.x * NumThreads;
initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// Wait on primary grid and trigger secondary kernel.
if constexpr (KernelParams::UsePdl)
{
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// in this case, each warp represents a token, and we use a grid-stride loop
// over all warps/tokens
for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride)
{
auto scoreOffset = tokenIdx * params.mNumExperts;
int32_t warpMaxExpertIdx[MaxNumTopExperts];
InputT warpMaxScore[MaxNumTopExperts];
routingTopKExperts<InputT, MaxNumExperts / WarpSize>(
warp, warpMaxScore, warpMaxExpertIdx, laneIdx, params.mNumExperts, params.mPtrScores + scoreOffset);
if (cute::elect_one_sync())
{
auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})};
TypePacked packedScore{finalScore, static_cast<int16_t>(warpMaxExpertIdx[0])};
params.mPtrExpertIdx[tokenIdx] = packedScore;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void run(Data const& data, void* stream)
{
TLLM_CHECK_WITH_INFO(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr,
"Routing kernel requires at least one input parameter");
TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr
&& data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr,
"Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers");
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d",
MaxNumTopExperts, data.mTopK);
TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxNumExperts,
"Routing kernel expects #experts %d to be at most max #experts %d", data.mNumExperts, MaxNumExperts);
static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads");
static_assert(MaxNumExperts <= NumThreadsHist, "#experts must be bounded by #threads");
TLLM_CHECK_WITH_INFO(
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
TLLM_CHECK_WITH_INFO(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2);
bool const useSingleWarp = (data.mPtrScores == nullptr && data.mNumTokens <= WarpKernelMaxNumTokens)
|| data.mNumTokens < WarpKernelMaxNumTokens;
bool const useSingleCluster
= data.mNumTokens <= (data.mPtrScores != nullptr ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster);
if (!useSingleCluster)
{
TLLM_CHECK_WITH_INFO(
data.mPtrExpertIdx != nullptr, "When #tokens is large, `mPtrExpertIdx` is a required input.");
TLLM_CHECK_WITH_INFO(
data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input.");
}
if (useSingleWarp)
{
LAUNCH_ROUTING(data,
/*coopLaunch=*/false, routingIndicesWarpKernel, 1, WarpSize,
/*smemSize=*/0, // No dynamic smem
stream);
}
else if (useSingleCluster)
{
LAUNCH_ROUTING(data,
/*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads,
/*smemSize=*/0, // No dynamic smem
stream);
}
else
{
const uint32_t expandedIdxSize = data.mNumTokens * data.mTopK;
const uint32_t histogramEltsPerBlock = 8 * NumThreadsHist;
const uint32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreadsHist;
// Limit grid size (all kernels use a grid-stride loop).
const uint32_t maxNumBlocks = 1024;
int const numBlocksHistogram
= std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks);
int const numBlocksOffsets
= std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks);
if (data.mPtrScores != nullptr)
{
LAUNCH_ROUTING(data,
/*coopLaunch=*/false, routingIndicesHistogramScoresKernel, maxNumBlocks, NumThreadsHist,
/*smemSize=*/0, // No dynamic smem
stream);
}
else
{
// Reset the global histograms.
TLLM_CUDA_CHECK(cudaMemsetAsync(data.mPtrExpertCounts, 0,
static_cast<size_t>(2 * NumThreads) * sizeof(int32_t), (cudaStream_t) stream));
}
LAUNCH_ROUTING(data,
/*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, NumThreadsHist,
/*smemSize=*/0, // No dynamic smem
stream);
LAUNCH_ROUTING(data,
/*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, NumThreadsHist,
/*smemSize=*/0, // No dynamic smem
stream);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace routingLlama4
} // namespace moe::dev::routing

View File

@ -0,0 +1,294 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "RoutingKernel.cuh"
namespace moe::dev::routing
{
namespace routingRenormalize
{
////////////////////////////////////////////////////////////////////////////////////////////////////
static constexpr int NumThreads = 1024;
static constexpr int NumWarps = NumThreads / WarpSize;
static constexpr int MaxNumTopExperts = 8;
static constexpr int MaxNumExperts = 128;
static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads;
static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps;
template <typename DataType, typename InputType, int VecSize, bool DoSoftmaxBeforeTopK>
__forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile<WarpSize> const& warp,
DataType (&score)[VecSize], int32_t (&idx)[VecSize], DataType (&warpTopKScore)[MaxNumTopExperts],
int32_t (&warpTopKExpertIdx)[MaxNumTopExperts], int32_t const laneIdx, int32_t const numExperts, int32_t topK,
InputType const* ptrScores, bool const normTopkProb)
{
DataType minScore = DataType{-INFINITY};
for (int i = 0; i < VecSize; i++)
{
auto expertIdx = i * WarpSize + laneIdx;
auto newScore = expertIdx < numExperts ? static_cast<DataType>(ptrScores[expertIdx]) : minScore;
score[i] = newScore;
idx[i] = expertIdx;
}
if constexpr (DoSoftmaxBeforeTopK)
{
calcSoftmax(warp, score);
}
// Get the top-k scores and their corresponding expert indices
topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, score, idx, minScore, topK);
// Normalize the scores
if constexpr (DoSoftmaxBeforeTopK)
{
float sum = float{1.f};
if (normTopkProb)
{
sum = static_cast<float>(laneIdx < topK ? warpTopKScore[laneIdx] : 0);
sum = cg::reduce(warp, sum, cg::plus<float>());
}
if (laneIdx < topK)
{
warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum;
}
}
else
{
auto softmaxScore = calcSoftmax(warp, laneIdx < topK ? warpTopKScore[laneIdx] : minScore, laneIdx, topK);
if (laneIdx < topK)
{
warpTopKScore[laneIdx] = softmaxScore;
}
}
}
template <typename KernelParams>
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads)
routingIndicesClusterKernel(KernelParams params)
{
// number of tokens/expanded idx is bounded by total number of warps
using OutputT = typename KernelParams::OutputT;
using InputT = typename KernelParams::InputT;
using BaseType = std::conditional_t<KernelParams::DoSoftmaxBeforeTopK, float, InputT>;
using TypePacked = PackedScoreIdx<BaseType>;
static constexpr int VecSize = MaxNumExperts / WarpSize;
// we assume that #experts is a multiple of 4, so VecSize must be 4.
static_assert(VecSize == 4);
__shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * MaxNumTopExperts];
uint32_t const clusterBlockRank = blockIdx.x;
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
int32_t const laneIdx = cutlass::arch::LaneId();
auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx;
auto scoreOffset = warpTokenIdx * params.mNumExperts;
bool validToken = warpTokenIdx < params.mNumTokens;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WarpSize>(block);
// then wait on primary grid
if constexpr (KernelParams::UsePdl)
{
cudaGridDependencySynchronize();
}
if (params.mPtrScores != nullptr)
{
// in this case, each warp represents a token
BaseType score[VecSize];
int32_t idx[VecSize];
BaseType warpTopKScore[MaxNumTopExperts];
int32_t warpTopKExpertIdx[MaxNumTopExperts];
BaseType minScore = BaseType{-INFINITY};
if (validToken)
{
routingTopKExperts<BaseType, InputT, VecSize, KernelParams::DoSoftmaxBeforeTopK>(warp, score, idx,
warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK,
params.mPtrScores + scoreOffset, params.mNormTopkProb);
if (laneIdx < params.mTopK)
{
smemPackedScoreIdx[warpIdx * params.mTopK + laneIdx]
= TypePacked{warpTopKScore[laneIdx], static_cast<int16_t>(warpTopKExpertIdx[laneIdx])};
}
} // end if (validToken)
// make packed scores available to all threads in cluster
__cluster_barrier_arrive();
__cluster_barrier_wait();
}
routingPermutation<KernelParams, BaseType, NumThreads, NumWarps, MaxNumTopExperts,
/*LoadExpertIdxFromGlobal=*/false>(params, smemPackedScoreIdx, warpIdx, clusterBlockRank);
}
#else
__global__ void __launch_bounds__(NumThreads) routingIndicesClusterKernel(KernelParams /* params */)
{
assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures");
}
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
////////////////////////////////////////////////////////////////////////////////////////////////////
// this kernel is needed in case we have scores as input for the histogram kernel
template <typename KernelParams>
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresKernel(KernelParams params)
{
using OutputT = typename KernelParams::OutputT;
using InputT = typename KernelParams::InputT;
using BaseType = std::conditional_t<KernelParams::DoSoftmaxBeforeTopK, float, InputT>;
static constexpr int VecSize = MaxNumExperts / WarpSize;
// we assume that #experts is a multiple of 4, so VecSize must be 4.
static_assert(VecSize == 4);
int32_t const laneIdx = cutlass::arch::LaneId();
int32_t const warpIdx = threadIdx.x / WarpSize;
int32_t const globalWarpIdx = blockIdx.x * NumWarpsHist + warpIdx;
int32_t const globalWarpStride = gridDim.x * NumWarpsHist;
BaseType minScore = BaseType{-INFINITY};
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WarpSize>(block);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// Wait on primary grid.
if constexpr (KernelParams::UsePdl)
{
cudaGridDependencySynchronize();
}
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// initialize the mPtrExpertCounts
int32_t expertCountsNum = 2 * params.mNumExperts;
int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x;
int32_t globalThreadStride = gridDim.x * NumThreads;
initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// Trigger secondary kernel.
if constexpr (KernelParams::UsePdl)
{
cudaTriggerProgrammaticLaunchCompletion();
}
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// in this case, each warp represents a token, and we use a grid-stride loop
// over all warps/tokens
BaseType allScores[VecSize];
int32_t allExpertIdx[VecSize];
BaseType warpTopKScore[MaxNumTopExperts];
int32_t warpTopKExpertIdx[MaxNumTopExperts];
for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride)
{
auto scoreOffset = tokenIdx * params.mNumExperts;
routingTopKExperts<BaseType, InputT, VecSize, KernelParams::DoSoftmaxBeforeTopK>(warp, allScores, allExpertIdx,
warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK,
params.mPtrScores + scoreOffset, params.mNormTopkProb);
if (laneIdx < params.mTopK)
{
PackedScoreIdx<OutputT> packedScore{
static_cast<OutputT>(warpTopKScore[laneIdx]), static_cast<int16_t>(warpTopKExpertIdx[laneIdx])};
params.mPtrExpertIdx[tokenIdx * params.mTopK + laneIdx] = packedScore;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void run(Data const& data, void* stream)
{
TLLM_CHECK_WITH_INFO(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr,
"Routing kernel requires at least one input parameter");
TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr
&& data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr,
"Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers");
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d",
MaxNumTopExperts, data.mTopK);
TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxNumExperts,
"Routing kernel expects #experts %d to be at most max #experts %d", data.mNumExperts, MaxNumExperts);
static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads");
static_assert(MaxNumExperts <= NumThreadsHist, "#experts must be bounded by #threads");
TLLM_CHECK_WITH_INFO(
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
TLLM_CHECK_WITH_INFO(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2);
bool const useSingleCluster
= data.mNumTokens <= (data.mPtrScores != nullptr ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster);
if (!useSingleCluster)
{
TLLM_CHECK_WITH_INFO(
data.mPtrExpertIdx != nullptr, "When #tokens is large, `mPtrExpertIdx` is a required input.");
TLLM_CHECK_WITH_INFO(
data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input.");
}
if (useSingleCluster)
{
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads,
/*smemSize=*/0, // No dynamic smem
stream, data.mDoSoftmaxBeforeTopK, /*forceFloatInput=*/false);
}
else
{
uint32_t const expandedIdxSize = data.mNumTokens * data.mTopK;
uint32_t const histogramEltsPerBlock = 8 * NumThreadsHist;
uint32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreadsHist;
// Limit grid size (all kernels use a grid-stride loop).
uint32_t const maxNumBlocks = 1024;
int const numBlocksHistogram
= std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks);
int const numBlocksOffsets
= std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks);
if (data.mPtrScores != nullptr)
{
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, false, routingIndicesHistogramScoresKernel, maxNumBlocks,
NumThreadsHist,
/*smemSize=*/0, // No dynamic smem
stream, data.mDoSoftmaxBeforeTopK, /*forceFloatInput=*/false);
}
else
{
// Reset the global histograms.
TLLM_CUDA_CHECK(cudaMemsetAsync(data.mPtrExpertCounts, 0,
static_cast<size_t>(2 * NumThreads) * sizeof(int32_t), (cudaStream_t) stream));
}
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, false, routingIndicesHistogramKernel, numBlocksHistogram, NumThreadsHist,
/*smemSize=*/0, // No dynamic smem
stream, data.mDoSoftmaxBeforeTopK, /*forceFloatInput=*/false);
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, false, routingIndicesOffsetsKernel, numBlocksOffsets, NumThreadsHist,
/*smemSize=*/0, // No dynamic smem
stream, data.mDoSoftmaxBeforeTopK, /*forceFloatInput=*/false);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace routingRenormalize
} // namespace moe::dev::routing

View File

@ -77,23 +77,17 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mPtrExpertCounts = expertCountHistogram;
routingData.mPtrPermutedIdxSize = permutedIdxSize;
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
routingData.mPtrPermutedIdxToExpandedIdx = permutedIdxToExpandedIdx;
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
routingData.mPtrNumTokensPerExpert = numTokensPerExpert;
routingData.mPtrExpertWeights = expertWeights;
routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx;
routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit;
routingData.mPtrNumNonExitingCtas = numNonExitingCtas;
routingData.mAllToAllRouteAct = false;
// input:
// routingData.mPtrRoutingWeights = args.mRoutingWeights; // routing weights (don't need if not using gemm)
routingData.mPtrRoutingBias = routingBias;
routingData.mPtrScores = reinterpret_cast<float*>(routingLogits);
// routingData.mPtrIn = args.mInputActs;
routingData.mNumTokens = numTokens;
// routingData.mHiddenDim = args.mHiddenDim;
routingData.mNumExperts = numExperts;
routingData.mNumExpertGroups = nGroup;
routingData.mNumLimitedGroups = topkGroup;
@ -122,9 +116,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mPtrExpertCounts = expertCountHistogram;
routingData.mPtrPermutedIdxSize = permutedIdxSize;
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
// routingData.mPtrPermutedIdxToExpandedIdx = permuted_idx_to_expanded_idx;
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
// routingData.mPtrNumTokensPerExpert = num_tokens_per_expert;
routingData.mPtrExpertWeights = expertWeights;
routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx;

View File

@ -37,11 +37,6 @@ protected:
TensorPtr mPtrRoutingBiasHost;
TensorPtr mPtrRoutingBiasDevice;
// Add this variable because the definition of mPtrExpertIdx is "int32_t*" for Deepseek
//@TODO: remove this variable after refactoring
TensorPtr mPtrDeepseekExpertIdxHost;
TensorPtr mPtrDeepseekExpertIdxDevice;
private:
// private methods
static inline float sigmoid_accurate(float x)
@ -56,7 +51,6 @@ private:
// sigmoid / bias activated scores cannot be negative
static constexpr float invalidScoreFloat = -1.F;
const T invalidScore = T{invalidScoreFloat};
int32_t* expIdxHostPtr = bufferCast<int32_t>(*mPtrDeepseekExpertIdxHost);
float scoreSigmoid[param.numExperts];
for (int it = 0; it < param.numTokens; ++it)
@ -129,7 +123,6 @@ private:
// Convert back to io_dtype and store the topk expert results in hostData.mPtrExpertIdx
for (int ie = 0; ie < param.topK; ++ie)
{
expIdxHostPtr[it * param.topK + ie] = static_cast<int32_t>(finalTopkExperts[ie].idx);
if (param.getExpWeights)
{
bufferCast<T>(*this->mPtrExpertWeightsHost)[it * param.topK + ie]
@ -152,11 +145,6 @@ private:
= mBufferManager->pinned(ITensor::makeShape({param.numExperts}), TRTDataType<T>::value);
this->mPtrRoutingBiasDevice
= mBufferManager->gpu(ITensor::makeShape({param.numExperts}), TRTDataType<T>::value);
this->mPtrDeepseekExpertIdxHost
= mBufferManager->pinned(ITensor::makeShape({param.numTokens * param.topK}), TRTDataType<int32_t>::value);
this->mPtrDeepseekExpertIdxDevice
= mBufferManager->gpu(ITensor::makeShape({param.numTokens * param.topK}), TRTDataType<int32_t>::value);
}
void setupBuffers(RoutingKernelTestParam const& param) override
@ -177,8 +165,6 @@ private:
routingData.mDtypeExpW = btg::Dtype::Bfloat16;
routingData.mPtrScores = bufferCast<float>(*this->mPtrScoresDevice);
routingData.mPtrRoutingBias = bufferCast<T>(*this->mPtrRoutingBiasDevice);
//@todo: remove this line after refactoring
routingData.mPtrExpertIdx = bufferCast<int32_t>(*this->mPtrDeepseekExpertIdxDevice);
routingData.mNumExpertGroups = param.nGroup;
routingData.mNumLimitedGroups = param.topkGroup;
@ -193,62 +179,13 @@ private:
setParams(param, routingData);
moe::dev::routing::routingDeepSeek::run(routingData, mStream->get());
}
void verifyExpertRoutingIndices(RoutingKernelTestParam const& param)
{
// for permuted index, there is non-determinism, thus we check set-equality
// for this, we go over every expert and retrieve the tokens routed to it
// we then get the associated indexes and check set equality
auto const expandedIdxToPermutedIdxHost
= mBufferManager->copyFrom(*this->mPtrExpandedIdxToPermutedIdxDevice, MemoryType::kCPU);
auto const hostExpToPermTest = bufferCast<int32_t>(*expandedIdxToPermutedIdxHost);
auto const permutedIdxToTokenIdxHost
= mBufferManager->copyFrom(*this->mPtrPermutedIdxToTokenIdxDevice, MemoryType::kCPU);
auto const hostPermToTokTest = bufferCast<int32_t>(*permutedIdxToTokenIdxHost);
mStream->synchronize();
int32_t* expIdxToPermHostptr = bufferCast<int32_t>(*this->mPtrExpandedIdxToPermutedIdxHost);
int32_t* expIdxHostPtr = bufferCast<int32_t>(*this->mPtrDeepseekExpertIdxHost);
for (int ie = 0; ie < param.numExperts; ++ie)
{
std::set<int32_t> permutedIdx, permutedIdxTest;
std::set<int32_t> tokenIdx, tokenIdxTest;
auto localExpertIdx = ie - param.localExpertsStartIdx;
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < param.numLocalExperts
&& (localExpertIdx & param.localExpertsStrideLog2) == 0;
for (int it = 0; it < param.numTokens * param.topK; ++it)
{
if (expIdxHostPtr[it] == ie)
{
int const permIdx = isLocalExpert ? expIdxToPermHostptr[it] : int32_t{-1};
permutedIdx.insert(permIdx);
if (isLocalExpert)
{
tokenIdx.insert(it / param.topK);
}
int const permIdxTest = hostExpToPermTest[it];
permutedIdxTest.insert(permIdxTest);
if (isLocalExpert)
{
tokenIdxTest.insert(hostPermToTokTest[permIdxTest]);
}
}
}
EXPECT_EQ(checkSetEqual(ie, permutedIdx, permutedIdxTest, "permuted idx"), true);
EXPECT_EQ(checkSetEqual(ie, tokenIdx, tokenIdxTest, "token idx"), true);
}
}
};
TYPED_TEST_SUITE(RoutingDeepSeekKernelTest, Bf16Types);
TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization)
{
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/10,
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1024, // 10
/*numExperts=*/128, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
@ -279,4 +216,48 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization)
this->runTest(param);
};
// TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization)
// {
// RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/20300,
// /*numExperts=*/128, /*topK=*/8,
// /*expertParallelization=*/1, /*expertParallelizationId=*/0,
// /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
// /*usePdl=*/true, /*getExpWeights=*/true,
// /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
// this->runTest(param);
// };
TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationTop2)
{
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/10,
/*numExperts=*/128, /*topK=*/2,
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParallelizationTop2)
{
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/100,
/*numExperts=*/128, /*topK=*/2,
/*expertParallelization=*/2, /*expertParallelizationId=*/1,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop2)
{
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030,
/*numExperts=*/128, /*topK=*/2,
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
this->runTest(param);
};
} // namespace

View File

@ -51,33 +51,10 @@ private:
int16_t newIdx = static_cast<int16_t>(ie);
score = static_cast<float>(bufferCast<T>(*this->mPtrScoresHost)[it * param.numExperts + ie]);
if (param.doSoftmaxBeforeTopK && score > maxScore)
{
maxScore = score;
}
PackedFloat si{static_cast<float>(score), newIdx};
expWeightsIdx[ie] = si;
}
if (param.doSoftmaxBeforeTopK)
{
// Run softmax before topk
for (int ie = 0; ie < param.numExperts; ++ie)
{
expWeightsIdx[ie].score
= static_cast<float>(std::exp(static_cast<float>(expWeightsIdx[ie].score) - maxScore));
sum += expWeightsIdx[ie].score;
}
for (int ie = 0; ie < param.numExperts; ++ie)
{
float score = static_cast<float>(expWeightsIdx[ie].score);
score /= sum;
expWeightsIdx[ie].score = static_cast<float>(score);
}
}
// Calculate the top-k scores and indices
std::partial_sort_copy(expWeightsIdx, expWeightsIdx + param.numExperts, expIdx, expIdx + param.topK,
[](PackedFloat const& a, PackedFloat const& b)
@ -148,7 +125,7 @@ TYPED_TEST(RoutingLlama4KernelTest, WarpLevelParallelization)
TYPED_TEST(RoutingLlama4KernelTest, ClusterLevelParallelization)
{
RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/100,
RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/10,
/*numExperts=*/128, /*topK=*/1,
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,

View File

@ -165,7 +165,6 @@ private:
routingData.mNormTopkProb = param.routingMethod == RoutingMethodType::RenormalizeNaive;
routingData.mPtrScores = bufferCast<T>(*this->mPtrScoresDevice);
routingData.mPtrExpertIdx = reinterpret_cast<PackedType*>(bufferCast<int8_t>(*this->mPtrExpertIdxDevice));
}
void callTestedFunction(
@ -214,7 +213,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal
TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelization)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/300,
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000,
/*numExperts=*/128, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
@ -223,4 +222,47 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelization)
this->runTest(param);
};
TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationTop4)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/10,
/*numExperts=*/128, /*topK=*/4,
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertParallelizationTop4)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/100,
/*numExperts=*/128, /*topK=*/4,
/*expertParallelization=*/2, /*expertParallelizationId=*/1,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormalizeNaiveTop4)
{
RoutingKernelTestParam param(RoutingMethodType::RenormalizeNaive, /*numTokens=*/10,
/*numExperts=*/128, /*topK=*/4,
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationTop4)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000,
/*numExperts=*/128, /*topK=*/4,
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8);
this->runTest(param);
};
} // end namespace

View File

@ -336,6 +336,8 @@ void RoutingKernelTest<T>::runTest(RoutingKernelTestParam const& param)
{
GTEST_SKIP() << "Skip test due to compute capability requirement.";
}
// Set seed to time-based seed
resetToTimeBasedSeed();
// Allocate buffers
allocateBuffers(param);

View File

@ -17,6 +17,7 @@
#include <gtest/gtest.h>
#include <chrono>
#include <memory> //@todo check the usage of this
#include <random> //@todo check the usage of this
@ -320,6 +321,27 @@ template <typename T>
class RoutingKernelTest : public testing::Test
{
public:
// Add a method to generate time-based seed
static uint32_t generateTimeBasedSeed()
{
std::random_device rd;
uint32_t seed = rd();
TLLM_LOG_DEBUG("Random device seed: %u", seed);
return seed;
}
// Method to set seed after construction
void setSeed(uint32_t seed)
{
mSeed = seed;
}
// Method to reset to time-based seed
void resetToTimeBasedSeed()
{
mSeed = generateTimeBasedSeed();
}
void SetUp() override;
void TearDown() override;
@ -372,7 +394,7 @@ protected:
routingData.mPtrExpandedIdxToPermutedIdx = bufferCast<int32_t>(*mPtrExpandedIdxToPermutedIdxDevice);
routingData.mPtrPermutedIdxToTokenIdx = bufferCast<int32_t>(*mPtrPermutedIdxToTokenIdxDevice);
routingData.mPtrExpertWeights = bufferCast<T>(*mPtrExpertWeightsDevice);
// routingData.mPtrExpertIdx = reinterpret_cast<PackedType*>(bufferCast<int8_t>(*mPtrExpertIdxDevice));
routingData.mPtrExpertIdx = reinterpret_cast<PackedType*>(bufferCast<int8_t>(*mPtrExpertIdxDevice));
// Set grouped gemm launch config buffers
routingData.mPtrCtaIdxXyToBatchIdx = bufferCast<int32_t>(*mPtrCtaIdxXyToBatchIdxDevice);