mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Refactor the rest routing part for the routing kernels in the MoE TRT-LLM backend (#5771)
Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
This commit is contained in:
parent
37293e4dfd
commit
c5fb692a7d
@ -31,7 +31,7 @@ namespace moe::dev
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define LAUCNCH_ESC(...) __VA_ARGS__
|
||||
#define LAUNCH_ESC(...) __VA_ARGS__
|
||||
|
||||
#define LAUNCH_PDL(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \
|
||||
cudaLaunchConfig_t config{}; \
|
||||
@ -64,53 +64,6 @@ namespace moe::dev
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernelTyped, params)); \
|
||||
}
|
||||
|
||||
#define LAUNCH_PDL_QWEN3(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \
|
||||
cudaLaunchConfig_t config{}; \
|
||||
config.gridDim = numBlocks; \
|
||||
config.blockDim = numThreads; \
|
||||
config.dynamicSmemBytes = smemSize; \
|
||||
config.stream = (cudaStream_t) stream; \
|
||||
\
|
||||
cudaLaunchAttribute attributes[2] = {}; \
|
||||
attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \
|
||||
attributes[0].val.programmaticStreamSerializationAllowed = int(data.mUsePdl); \
|
||||
attributes[1].id = cudaLaunchAttributeCooperative; \
|
||||
attributes[1].val.cooperative = int(coopLaunch); \
|
||||
config.attrs = attributes; \
|
||||
config.numAttrs = 2; \
|
||||
if (data.mUsePdl && data.mDoSoftmaxBeforeTopK) \
|
||||
{ \
|
||||
auto params = KernelParams<types, /*mUsePdl=*/true>::setKernelParams(data); \
|
||||
auto kernelTyped = kernel<KernelParams<types, /*mUsePdl=*/true>, /*mDoSoftmaxBeforeTopK=*/true>; \
|
||||
if (smemSize > 48 * 1024) \
|
||||
TLLM_CUDA_CHECK(cudaFuncSetAttribute(kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernelTyped, params)); \
|
||||
} \
|
||||
else if (data.mUsePdl && !data.mDoSoftmaxBeforeTopK) \
|
||||
{ \
|
||||
auto params = KernelParams<types, /*mUsePdl=*/true>::setKernelParams(data); \
|
||||
auto kernelTyped = kernel<KernelParams<types, /*mUsePdl=*/true>, /*mDoSoftmaxBeforeTopK=*/false>; \
|
||||
if (smemSize > 48 * 1024) \
|
||||
TLLM_CUDA_CHECK(cudaFuncSetAttribute(kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernelTyped, params)); \
|
||||
} \
|
||||
else if (!data.mUsePdl && data.mDoSoftmaxBeforeTopK) \
|
||||
{ \
|
||||
auto params = KernelParams<types, /*mUsePdl=*/false>::setKernelParams(data); \
|
||||
auto kernelTyped = kernel<KernelParams<types, /*mUsePdl=*/false>, /*mDoSoftmaxBeforeTopK=*/true>; \
|
||||
if (smemSize > 48 * 1024) \
|
||||
TLLM_CUDA_CHECK(cudaFuncSetAttribute(kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernelTyped, params)); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
auto params = KernelParams<types, /*mUsePdl=*/false>::setKernelParams(data); \
|
||||
auto kernelTyped = kernel<KernelParams<types, /*mUsePdl=*/false>, /*mDoSoftmaxBeforeTopK=*/false>; \
|
||||
if (smemSize > 48 * 1024) \
|
||||
TLLM_CUDA_CHECK(cudaFuncSetAttribute(kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernelTyped, params)); \
|
||||
}
|
||||
|
||||
#define LAUNCH(data, kernel, numBlocks, numThreads, smemSize, stream) \
|
||||
if (data.mDtypeElt == tg::Dtype::Fp16) \
|
||||
{ \
|
||||
@ -132,31 +85,31 @@ namespace moe::dev
|
||||
#define LAUNCH_EXPW(data, kernel, numBlocks, numThreads, smemSize, stream) \
|
||||
if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Fp32) \
|
||||
{ \
|
||||
LAUNCH_PDL(data, false, LAUCNCH_ESC(cutlass::half_t, float), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, float), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Fp32) \
|
||||
{ \
|
||||
LAUNCH_PDL( \
|
||||
data, false, LAUCNCH_ESC(cutlass::float_e4m3_t, float), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
data, false, LAUNCH_ESC(cutlass::float_e4m3_t, float), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Fp32) \
|
||||
{ \
|
||||
LAUNCH_PDL( \
|
||||
data, false, LAUCNCH_ESC(cutlass::bfloat16_t, float), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
data, false, LAUNCH_ESC(cutlass::bfloat16_t, float), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeElt == tg::Dtype::Fp16 && data.mDtypeExpW == tg::Dtype::Bfloat16) \
|
||||
{ \
|
||||
LAUNCH_PDL(data, false, LAUCNCH_ESC(cutlass::half_t, cutlass::bfloat16_t), kernel, numBlocks, numThreads, \
|
||||
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::half_t, cutlass::bfloat16_t), kernel, numBlocks, numThreads, \
|
||||
smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeElt == tg::Dtype::E4m3 && data.mDtypeExpW == tg::Dtype::Bfloat16) \
|
||||
{ \
|
||||
LAUNCH_PDL(data, false, LAUCNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t), kernel, numBlocks, \
|
||||
numThreads, smemSize, stream); \
|
||||
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::float_e4m3_t, cutlass::bfloat16_t), kernel, numBlocks, numThreads, \
|
||||
smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeElt == tg::Dtype::Bfloat16 && data.mDtypeExpW == tg::Dtype::Bfloat16) \
|
||||
{ \
|
||||
LAUNCH_PDL(data, false, LAUCNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t), kernel, numBlocks, numThreads, \
|
||||
LAUNCH_PDL(data, false, LAUNCH_ESC(cutlass::bfloat16_t, cutlass::bfloat16_t), kernel, numBlocks, numThreads, \
|
||||
smemSize, stream); \
|
||||
} \
|
||||
else \
|
||||
@ -164,68 +117,51 @@ namespace moe::dev
|
||||
TLLM_LOG_ERROR("Unsupported pair"); \
|
||||
}
|
||||
|
||||
#define LAUNCH_EXPW_QWEN3(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \
|
||||
#define LAUNCH_ROUTING(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \
|
||||
if (data.mDtypeExpW == tg::Dtype::Fp32) \
|
||||
{ \
|
||||
LAUNCH_PDL_QWEN3(data, coopLaunch, LAUCNCH_ESC(void, float), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \
|
||||
{ \
|
||||
LAUNCH_PDL_QWEN3( \
|
||||
data, coopLaunch, LAUCNCH_ESC(void, __nv_bfloat16), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TLLM_LOG_ERROR("Unsupported dtypeExpW: "); \
|
||||
}
|
||||
|
||||
#define LAUNCH_EXPW_ONLY_QWEN3(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \
|
||||
if (data.mDtypeExpW == tg::Dtype::Fp32) \
|
||||
{ \
|
||||
LAUNCH_PDL(data, coopLaunch, LAUCNCH_ESC(void, float), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \
|
||||
{ \
|
||||
LAUNCH_PDL( \
|
||||
data, coopLaunch, LAUCNCH_ESC(void, __nv_bfloat16), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TLLM_LOG_ERROR("Unsupported dtypeExpW: "); \
|
||||
}
|
||||
|
||||
#define LAUNCH_EXPW_ONLY(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \
|
||||
if (data.mDtypeExpW == tg::Dtype::Fp32) \
|
||||
{ \
|
||||
LAUNCH_PDL(data, coopLaunch, float, kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \
|
||||
{ \
|
||||
LAUNCH_PDL(data, coopLaunch, __nv_bfloat16, kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16), kernel, numBlocks, numThreads, \
|
||||
smemSize, stream); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
TLLM_LOG_ERROR("Unsupported dtypeExpW"); \
|
||||
}
|
||||
|
||||
#define LAUNCH_EXPW_ONLY_GROUPS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \
|
||||
if (data.mDtypeExpW == tg::Dtype::Fp32 && data.mNumExpertGroups > 1) \
|
||||
#define LAUNCH_ROUTING_WITH_EXTRA_FLAG( \
|
||||
data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag, forceFloatInput) \
|
||||
if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) \
|
||||
{ \
|
||||
LAUNCH_PDL(data, coopLaunch, LAUCNCH_ESC(float, true), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, float, true), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeExpW == tg::Dtype::Fp32) \
|
||||
{ \
|
||||
LAUNCH_PDL(data, coopLaunch, LAUCNCH_ESC(float, false), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && data.mNumExpertGroups > 1) \
|
||||
{ \
|
||||
LAUNCH_PDL( \
|
||||
data, coopLaunch, LAUCNCH_ESC(__nv_bfloat16, true), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
data, coopLaunch, LAUNCH_ESC(float, float, false), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) \
|
||||
{ \
|
||||
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, true), kernel, numBlocks, numThreads, smemSize, \
|
||||
stream); \
|
||||
} \
|
||||
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) \
|
||||
{ \
|
||||
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, true), kernel, numBlocks, numThreads, \
|
||||
smemSize, stream); \
|
||||
} \
|
||||
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) \
|
||||
{ \
|
||||
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, false), kernel, numBlocks, numThreads, smemSize, \
|
||||
stream); \
|
||||
} \
|
||||
else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \
|
||||
{ \
|
||||
LAUNCH_PDL( \
|
||||
data, coopLaunch, LAUCNCH_ESC(__nv_bfloat16, false), kernel, numBlocks, numThreads, smemSize, stream); \
|
||||
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, false), kernel, numBlocks, numThreads, \
|
||||
smemSize, stream); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
|
||||
@ -0,0 +1,573 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "RoutingKernel.cuh"
|
||||
|
||||
namespace moe::dev::routing
|
||||
{
|
||||
|
||||
namespace routingDeepSeek
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static constexpr int NumThreads = 256;
|
||||
static constexpr int NumWarps = NumThreads / WarpSize;
|
||||
static constexpr int NumTopGroupScores = 2;
|
||||
static constexpr int MaxNumTopExperts = 8;
|
||||
static constexpr int MaxNumTopGroups = 4;
|
||||
|
||||
template <typename KernelParams>
|
||||
__global__ void routingMainKernel(KernelParams params)
|
||||
{
|
||||
// declare types
|
||||
using OutputT = typename KernelParams::OutputT;
|
||||
using InputT = typename KernelParams::InputT;
|
||||
|
||||
// declare shared memory structure
|
||||
// number of experts is bounded by number of threads
|
||||
__shared__ float __attribute((aligned(128))) smemScoreSigmoid[NumThreads];
|
||||
__shared__ float __attribute((aligned(128))) smemScoreBias[NumThreads];
|
||||
// number of expert groups is bounded by number of warps
|
||||
__shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps];
|
||||
|
||||
// needed for warp reduce
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WarpSize>(block);
|
||||
// for the final reduction of weight norm, only some lanes need to participate
|
||||
int32_t laneIdx = threadIdx.x % WarpSize;
|
||||
int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
||||
// warps outside the range of expert groups do not participate
|
||||
if constexpr (KernelParams::UseGroups)
|
||||
{
|
||||
if (warpIdx >= params.mNumExpertGroups)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// note that for invalid scores, we simply use a negative value:
|
||||
// they work well even with the compacted format used in topK, and
|
||||
// sigmoid / bias activated scores cannot be negative
|
||||
static constexpr float invalidScoreFloat = -1.F;
|
||||
const OutputT invalidScore = OutputT{invalidScoreFloat};
|
||||
|
||||
// load bias already; each warp represents one expert group
|
||||
auto threadExpert = threadIdx.x;
|
||||
bool expertSelected = threadExpert < params.mNumExperts;
|
||||
if constexpr (KernelParams::UseGroups)
|
||||
{
|
||||
threadExpert = warpIdx * params.mNumExpertsPerGroup + laneIdx;
|
||||
expertSelected = laneIdx < params.mNumExpertsPerGroup;
|
||||
}
|
||||
auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert;
|
||||
auto biasVal = expertSelected ? params.mPtrRoutingBias[threadExpert] : invalidScore;
|
||||
|
||||
// initialize the mPtrExpertCounts
|
||||
if (params.mPtrExpertCounts)
|
||||
{
|
||||
int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x;
|
||||
int32_t globalThreadStride = gridDim.x * NumThreads;
|
||||
int32_t expertCountsNum = 2 * params.mNumExperts;
|
||||
initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0);
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
// trigger the secondary kernel when using PDL, then wait on primary
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
#endif
|
||||
|
||||
// get our assigned thread score; each warp represents one expert group
|
||||
float score = expertSelected ? static_cast<float>(params.mPtrScores[scoreIdx]) : invalidScoreFloat;
|
||||
// get the sigmoid score
|
||||
// note that for invalid values, we simply use a negative value:
|
||||
// sigmoig scores are always strictly positive
|
||||
auto scoreSigmoid = sigmoid_accurate(score);
|
||||
// write the sigmoid score to shared for later use
|
||||
if (expertSelected)
|
||||
{
|
||||
smemScoreSigmoid[threadExpert] = scoreSigmoid;
|
||||
}
|
||||
// get the score with bias
|
||||
// note that with invalid values, because sigmoid is < 1 and bias is -1,
|
||||
// we must get a negative value, which is smaller than any valid value
|
||||
auto scoreBias = float{scoreSigmoid + float{biasVal}};
|
||||
|
||||
if (expertSelected)
|
||||
{
|
||||
smemScoreBias[threadExpert] = scoreBias;
|
||||
}
|
||||
|
||||
// registers for top group score reduction
|
||||
float topExpGroupScores[NumTopGroupScores];
|
||||
[[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores];
|
||||
float topGroups[MaxNumTopGroups]; // bound of params.mNumLimitedGroups
|
||||
int32_t topGroupIdx[MaxNumTopGroups];
|
||||
float expertScoreGroup[MaxNumTopGroups];
|
||||
int32_t expertIdxGroup[MaxNumTopGroups];
|
||||
float topScores[MaxNumTopExperts]; // bound of params.mTopK
|
||||
int32_t topExperts[MaxNumTopExperts];
|
||||
|
||||
if constexpr (KernelParams::UseGroups)
|
||||
{
|
||||
topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert,
|
||||
/* minValue */ invalidScoreFloat);
|
||||
|
||||
// get the final group score and write it to shared
|
||||
if (cute::elect_one_sync())
|
||||
{
|
||||
auto groupScore = topExpGroupScores[0] + topExpGroupScores[1];
|
||||
smemGroupScores[warpIdx] = groupScore;
|
||||
}
|
||||
}
|
||||
|
||||
// make group scores available to all warps
|
||||
__syncthreads();
|
||||
|
||||
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||||
if (warpIdx == 0)
|
||||
{
|
||||
// a single warp performs the selection of top groups, and goes on to select the final experts
|
||||
if constexpr (KernelParams::UseGroups)
|
||||
{
|
||||
float groupScore = laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : invalidScoreFloat;
|
||||
|
||||
topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx,
|
||||
/* minValue */ invalidScoreFloat);
|
||||
|
||||
// final expert selection: get relevant indexes and scores from shared
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < MaxNumTopGroups; ++ii)
|
||||
{ // bound of params.mNumLimitedGroups
|
||||
auto groupIdx = topGroupIdx[ii];
|
||||
expertIdxGroup[ii] = groupIdx * params.mNumExpertsPerGroup + laneIdx;
|
||||
// note: expertSelected implies laneIdx < params.mNumExpertsPerGroup.
|
||||
// we have params.mNumExpertsPerGroup == params.mNumExperts / params.mNumExpertGroups,
|
||||
// thus groupIdx <= params.mNumExpertGroups - 1 =>
|
||||
// groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - params.mNumExpertsPerGroup
|
||||
// => expertIdxGroup[ii] < params.mNumExperts <= NumThreads,
|
||||
// so the access is safe here
|
||||
expertScoreGroup[ii] = groupIdx < params.mNumExpertGroups && expertSelected
|
||||
? smemScoreBias[expertIdxGroup[ii]]
|
||||
: invalidScoreFloat;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// without groups, each thread just takes `MaxNumTopGroups` experts
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < MaxNumTopGroups; ++ii)
|
||||
{
|
||||
auto expertIdx = ii * WarpSize + laneIdx;
|
||||
expertIdxGroup[ii] = expertIdx;
|
||||
expertScoreGroup[ii] = expertIdx < params.mNumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat;
|
||||
}
|
||||
}
|
||||
|
||||
topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup,
|
||||
/* minValue */ invalidScoreFloat, params.mTopK);
|
||||
|
||||
// determine our lane's expert index and write to output
|
||||
int32_t expertIdx = 0;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < params.mTopK; ++ii)
|
||||
{ // bound of params.mTopK
|
||||
expertIdx = laneIdx == ii ? topExperts[ii] : expertIdx;
|
||||
}
|
||||
// determine whether our expert is local to this GPU
|
||||
auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx;
|
||||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||||
|
||||
float scoreNorm = laneIdx < params.mTopK ? smemScoreSigmoid[expertIdx] : 0.F;
|
||||
auto redNorm = cg::reduce(warp, scoreNorm, cg::plus<float>{});
|
||||
auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm};
|
||||
|
||||
// write expert idx out already
|
||||
auto idxTopK = blockIdx.x * params.mTopK + laneIdx;
|
||||
if (laneIdx < params.mTopK && params.mPtrExpertIdx != nullptr)
|
||||
{
|
||||
PackedScoreIdx<OutputT> packedScore{static_cast<OutputT>(finalScore), static_cast<int16_t>(expertIdx)};
|
||||
params.mPtrExpertIdx[idxTopK] = packedScore;
|
||||
}
|
||||
|
||||
if (laneIdx < params.mTopK && params.mPtrExpertWeights != nullptr)
|
||||
{
|
||||
params.mPtrExpertWeights[idxTopK] = finalScore;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename KernelParams>
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads)
|
||||
routingIndicesClusterKernel(KernelParams params)
|
||||
{
|
||||
using OutputT = typename KernelParams::OutputT;
|
||||
|
||||
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
||||
int32_t const clusterBlockRank = blockIdx.x;
|
||||
|
||||
//@todo: try to move it into routingPermutation
|
||||
// then wait on primary grid
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
|
||||
routingPermutation<KernelParams, OutputT, NumThreads, NumWarps, MaxNumTopExperts, /*LoadExpertIdxFromGlobal=*/true>(
|
||||
params, nullptr, warpIdx, clusterBlockRank);
|
||||
}
|
||||
#else
|
||||
__global__ void routingIndicesClusterKernel(KernelParams params)
|
||||
{
|
||||
assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures");
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename KernelParams>
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
__global__ void __launch_bounds__(NumThreads) routingIndicesCoopKernel(KernelParams params)
|
||||
{
|
||||
// number of experts is bounded by number of threads
|
||||
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads];
|
||||
__shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads];
|
||||
// needed for the exclusive sum of token offsets
|
||||
using Scan = cub::BlockScan<int32_t, NumThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||
__shared__ typename Scan::TempStorage tempStorage;
|
||||
// 64 elements -> 128+ registers. Above that we may start to see spilling to local memory.
|
||||
static constexpr int MaxExpandedIdxPerThread = 64;
|
||||
|
||||
// Initialize grid.
|
||||
cg::grid_group grid = cg::this_grid();
|
||||
// Note: the following is more efficient than grid.block_index() because we don't use y and z.
|
||||
int32_t const gridBlockIdx = blockIdx.x;
|
||||
int32_t const gridThreadIdx = NumThreads * gridBlockIdx + threadIdx.x;
|
||||
int32_t const numBlocks = gridDim.x;
|
||||
int32_t const numThreadsPerGrid = numBlocks * NumThreads;
|
||||
|
||||
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
||||
|
||||
auto expandedIdxSize = params.mNumTokens * params.mTopK;
|
||||
|
||||
// pre-fill the counts with 0
|
||||
smemExpertCount[threadIdx.x] = 0;
|
||||
__syncthreads();
|
||||
|
||||
// then wait on primary grid
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
|
||||
// each thread keeps has some number of "expanded indexes" assigned to it
|
||||
// for each of these, we keep the associated expert and offset within expert in registers
|
||||
int32_t expertIndexes[MaxExpandedIdxPerThread];
|
||||
int32_t expertOffsets[MaxExpandedIdxPerThread];
|
||||
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||||
// In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
|
||||
// time, and branch between a fast path without bound checks and a slow path with bound checks.
|
||||
int constexpr IterStride = 4;
|
||||
static_assert(MaxExpandedIdxPerThread % IterStride == 0);
|
||||
|
||||
// Define a lambda to avoid code duplication in both branches.
|
||||
auto loopBody = [&](int ii, int expandedIdx)
|
||||
{
|
||||
int32_t expertIdx = params.mPtrExpertIdx[expandedIdx].idx;
|
||||
expertIndexes[ii] = expertIdx;
|
||||
// check whether this expert is local to our GPU at all and ignore if not
|
||||
auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx;
|
||||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||||
expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0;
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride)
|
||||
{
|
||||
// Whether it's safe to do multiple iterations without bound checks.
|
||||
bool const takeFastPath = (ii0 + IterStride) * numThreadsPerGrid <= expandedIdxSize;
|
||||
if (takeFastPath)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||||
{
|
||||
int const ii = ii0 + jj;
|
||||
auto expandedIdx = static_cast<int32_t>(gridThreadIdx) + ii * numThreadsPerGrid;
|
||||
loopBody(ii, expandedIdx);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
bool doBreak = false;
|
||||
#pragma unroll
|
||||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||||
{
|
||||
int const ii = ii0 + jj;
|
||||
auto expandedIdx = static_cast<int32_t>(gridThreadIdx) + ii * numThreadsPerGrid;
|
||||
if (expandedIdx >= expandedIdxSize)
|
||||
{
|
||||
doBreak = true;
|
||||
break;
|
||||
}
|
||||
loopBody(ii, expandedIdx);
|
||||
}
|
||||
if (doBreak)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Make histogram (token counts per expert) available to all threads in the block.
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Each thread now represents one expert
|
||||
//
|
||||
|
||||
// Add the local bin count to the common bin count and get a per-CTA offset.
|
||||
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
|
||||
|
||||
int32_t blockExpertOffset = 0;
|
||||
if (threadIdx.x < params.mNumExperts)
|
||||
{
|
||||
blockExpertOffset = atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount);
|
||||
}
|
||||
|
||||
// Sync to wait for completion of the histogram reduction.
|
||||
grid.sync();
|
||||
|
||||
// Get total count for this expert.
|
||||
int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0;
|
||||
|
||||
// Note: the scan is redundant in all CTAs, but doing it in only 1 CTA would be worse for latency.
|
||||
|
||||
// Compute the runtime config for projections
|
||||
// Whether or not an expert is local is taken into account when smemExpertCount is computed
|
||||
// so we do not need to take it into account here.
|
||||
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
|
||||
int32_t ctaOffset;
|
||||
int32_t numNonExitingCtas;
|
||||
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
|
||||
|
||||
for (int32_t cta = gridBlockIdx; cta < numCta; cta += numBlocks)
|
||||
{
|
||||
const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
|
||||
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
|
||||
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
|
||||
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count);
|
||||
}
|
||||
|
||||
// get the padded offset associated with this expert
|
||||
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
|
||||
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
|
||||
|
||||
// write out padded count
|
||||
if (gridBlockIdx == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync())
|
||||
{
|
||||
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
|
||||
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
|
||||
}
|
||||
|
||||
// write expert offsets to shared
|
||||
smemExpertOffset[threadIdx.x] = offset + blockExpertOffset;
|
||||
|
||||
// make expert offsets available to all threads
|
||||
__syncthreads();
|
||||
|
||||
// trigger the secondary kernel when using PDL
|
||||
// We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx,
|
||||
// mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens
|
||||
// TODO: this is not sufficient to ensure visibility in the next kernel!
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
|
||||
// each thread has the same "expanded indexes" assigned to it as above
|
||||
// at this point, we know the final offsets of experts and the offsets within
|
||||
// experts, which allows writing the final index values
|
||||
#pragma unroll
|
||||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii)
|
||||
{
|
||||
auto expandedIdx = static_cast<int32_t>(gridThreadIdx) + ii * numThreadsPerGrid;
|
||||
if (expandedIdx >= expandedIdxSize)
|
||||
{
|
||||
break;
|
||||
}
|
||||
auto expertIdx = expertIndexes[ii];
|
||||
// check whether this expert is local to our GPU at all
|
||||
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
|
||||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||||
auto tokenIdx = expandedIdx / params.mTopK;
|
||||
auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1};
|
||||
if (params.mPtrExpandedIdxToPermutedIdx != nullptr)
|
||||
{
|
||||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||||
}
|
||||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
|
||||
{
|
||||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
__global__ void routingIndicesCoopKernel(KernelParams params)
|
||||
{
|
||||
assert(false && "routingIndicesCoopKernel is only supported on SM90+ architectures");
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run(Data& data, void* stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
data.mPtrExpertIdx != nullptr || data.mPtrPermutedIdxSize != nullptr || data.mPtrExpertWeights != nullptr,
|
||||
"Routing kernel requires at least one output parameter");
|
||||
if (data.mPtrExpandedIdxToPermutedIdx != nullptr || data.mPtrPermutedIdxToTokenIdx != nullptr)
|
||||
TLLM_CHECK_WITH_INFO(data.mPtrExpertIdx != nullptr && data.mPtrPermutedIdxSize,
|
||||
"If permuted index is required, `mPtrExpertIdx` is also required");
|
||||
TLLM_CHECK_WITH_INFO(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet");
|
||||
TLLM_CHECK_WITH_INFO(data.mNumLimitedGroups <= MaxNumTopGroups, "Routing kernel expects <= %d top groups, got %d",
|
||||
MaxNumTopGroups, data.mNumLimitedGroups);
|
||||
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d",
|
||||
MaxNumTopExperts, data.mTopK);
|
||||
TLLM_CHECK_WITH_INFO(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", data.mTopK);
|
||||
TLLM_CHECK_WITH_INFO(data.mTopK * data.mNumLimitedGroups <= WarpSize,
|
||||
"Routing kernel expects top K * top groups <= warp size (for now), got %d * %d", data.mTopK,
|
||||
data.mNumLimitedGroups);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExperts >= MaxNumTopExperts, "Routing kernel expects %d to be at most #experts %d",
|
||||
MaxNumTopExperts, data.mNumExperts);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExperts <= NumThreads, "Routing kernel expects #experts %d <= #threads %d",
|
||||
data.mNumExperts, NumThreads);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExpertGroups >= data.mNumLimitedGroups,
|
||||
"Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups,
|
||||
data.mNumExpertGroups);
|
||||
if (data.mNumExpertGroups > 1)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExpertGroups <= NumWarps,
|
||||
"Routing kernel expects #experts groups %d to be <= #warps %d", data.mNumExpertGroups, NumWarps);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExperts % data.mNumExpertGroups == 0,
|
||||
"Routing kernel expects #experts %d to be a multiple of #expert groups %d", data.mNumExperts,
|
||||
data.mNumExpertGroups);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExperts / data.mNumExpertGroups <= WarpSize,
|
||||
"Routing kernel expects #experts per group <= warp size, got %d", data.mNumExperts / data.mNumExpertGroups);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExperts <= WarpSize * MaxNumTopGroups,
|
||||
"Routing kernel expects #experts %d <= WarpSize * MaxNumTopGroups %d", data.mNumExperts,
|
||||
WarpSize * MaxNumTopGroups);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
data.mTopK <= NumWarps, "Routing kernel expects top K %d to be <= #warps %d", data.mTopK, NumWarps);
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
|
||||
TLLM_CHECK_WITH_INFO(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2);
|
||||
int const numBlocks = data.mNumTokens;
|
||||
|
||||
bool const useSingleCluster = data.mNumTokens <= 1024;
|
||||
if (!useSingleCluster)
|
||||
{
|
||||
// Reset the global histograms (not used in single-cluster code path).
|
||||
// Cover both for the cooperative and two-kernel code paths.
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input.");
|
||||
}
|
||||
else
|
||||
{
|
||||
data.mPtrExpertCounts = nullptr; // Set it to nullptr for single-cluster code path, as it won't be used
|
||||
}
|
||||
|
||||
// Number of blocks we can use in the cooperative kernel
|
||||
// The number of blocks must be:
|
||||
// >= ⌈(numTokens * topK) / (MaxExpandedIdxPerThread * NumThreads)⌉
|
||||
// <= numSms, assuming an occupancy of 1 block/SM
|
||||
//
|
||||
// If too small for the given numTokens, fall back to the less performant two-step method.
|
||||
//
|
||||
// The upper bound is a strict requirement. The number of blocks should be determined by querying
|
||||
// the device properties, or conservatively low.
|
||||
// /!\ The following number is not portable!! (but works on H100 and B200)
|
||||
int const numBlocksCoop = 128;
|
||||
|
||||
// Maximum number of tokens supported by the kernel using a cooperative launch.
|
||||
int const maxTokensCoop = (numBlocksCoop * NumThreads * 64) / data.mTopK;
|
||||
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
|
||||
/*coopLaunch=*/false, routingMainKernel, numBlocks, NumThreads,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
|
||||
|
||||
if (data.mPtrPermutedIdxSize != nullptr)
|
||||
{
|
||||
if (useSingleCluster)
|
||||
{
|
||||
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
|
||||
/*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
|
||||
}
|
||||
else if (data.mNumTokens <= maxTokensCoop)
|
||||
{
|
||||
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
|
||||
/*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, NumThreads,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
|
||||
}
|
||||
else
|
||||
{
|
||||
const int32_t expandedIdxSize = data.mNumTokens * data.mTopK;
|
||||
|
||||
const int32_t histogramEltsPerBlock = 8 * NumThreads;
|
||||
const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreads;
|
||||
|
||||
// Limit grid size (both kernels use a grid-stride loop).
|
||||
const int32_t maxNumBlocks = 1024;
|
||||
|
||||
int const numBlocksHistogram
|
||||
= std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks);
|
||||
int const numBlocksOffsets
|
||||
= std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks);
|
||||
|
||||
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
|
||||
/*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, NumThreads,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
|
||||
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data,
|
||||
/*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, NumThreads,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace routingDeepSeek
|
||||
} // namespace moe::dev::routing
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,779 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "DevKernel.h"
|
||||
#include "RoutingKernel.h"
|
||||
#include "RoutingKernelTopK.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cutlass/arch/arch.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "tensorrt_llm/kernels/archCondition.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace moe::dev
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
namespace routing
|
||||
{
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static constexpr int WarpSize = 32;
|
||||
static constexpr int NumBlocksPerCluster = 8;
|
||||
// Performance tuning knob.
|
||||
static constexpr int NumEltsPerOffsetTilePerThread = 8;
|
||||
|
||||
static constexpr int NumThreadsHist = 256;
|
||||
static constexpr int NumWarpsHist = NumThreadsHist / WarpSize;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static constexpr bool TLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10>;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static __device__ inline float sigmoid_accurate(float x)
|
||||
{
|
||||
return 0.5f * tanhf(0.5f * x) + 0.5f;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T mulLog2(T a, T bLog2)
|
||||
{
|
||||
return a << bLog2;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T divUpLog2(T a, T bLog2)
|
||||
{
|
||||
return ((a + (1 << bLog2) - 1) >> bLog2);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T divUpMulLog2(T a, T bLog2)
|
||||
{
|
||||
return mulLog2<T>(divUpLog2<T>(a, bLog2), bLog2);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
__host__ __device__ constexpr int32_t getBits(int32_t value, int idx)
|
||||
{
|
||||
int mask = idx == 0 ? 0x000000FF : idx == 1 ? 0x0000FF00 : idx == 2 ? 0x00FF0000 : 0xFF000000;
|
||||
return (value & mask) >> (idx * 8);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool IsZero = false>
|
||||
__host__ __device__ constexpr void setBits(int32_t& value, int32_t newBits, int idx)
|
||||
{
|
||||
if constexpr (!IsZero)
|
||||
{
|
||||
int mask = idx == 0 ? 0xFFFFFF00 : idx == 1 ? 0xFFFF00FF : idx == 2 ? 0xFF00FFFF : 0x00FFFFFF;
|
||||
value &= mask;
|
||||
}
|
||||
value |= (newBits << (idx * 8));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename DataType>
|
||||
__device__ void initArr(int startIdx, int numElts, int stride, DataType* arr, DataType value)
|
||||
{
|
||||
if (arr != nullptr)
|
||||
{
|
||||
for (int i = startIdx; i < numElts; i += stride)
|
||||
{
|
||||
arr[i] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename DataType, int VecSize>
|
||||
__device__ void calcSoftmax(cg::thread_block_tile<WarpSize> const& warp, DataType (&scores)[VecSize])
|
||||
{
|
||||
DataType maxScore = DataType{-INFINITY};
|
||||
DataType sumScore = DataType{0.f};
|
||||
|
||||
// Get the max score for each token
|
||||
for (int i = 0; i < VecSize; ++i)
|
||||
{
|
||||
maxScore = scores[i] >= maxScore ? scores[i] : maxScore;
|
||||
}
|
||||
maxScore = cg::reduce(warp, maxScore, cg::greater<DataType>());
|
||||
|
||||
// Get the summation of scores for each token
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; ++i)
|
||||
{
|
||||
scores[i] = static_cast<DataType>(exp(scores[i] - maxScore));
|
||||
sumScore += scores[i];
|
||||
}
|
||||
sumScore = cg::reduce(warp, sumScore, cg::plus<DataType>());
|
||||
|
||||
// Normalize the scores
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; ++i)
|
||||
{
|
||||
scores[i] = static_cast<DataType>(scores[i] / sumScore);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename DataType>
|
||||
__device__ DataType calcSoftmax(
|
||||
cg::thread_block_tile<WarpSize> const& warp, DataType score, int32_t laneIdx, int32_t NumTopExperts)
|
||||
{
|
||||
DataType maxScore = DataType{-INFINITY};
|
||||
if (laneIdx < NumTopExperts)
|
||||
{
|
||||
maxScore = score >= maxScore ? score : maxScore;
|
||||
}
|
||||
maxScore = cg::reduce(warp, maxScore, cg::greater<DataType>());
|
||||
|
||||
float sumScore = float{0.f};
|
||||
float newScore;
|
||||
// Get the summation of scores for each token
|
||||
if (laneIdx < NumTopExperts)
|
||||
{
|
||||
newScore = static_cast<float>(score) - static_cast<float>(maxScore);
|
||||
newScore = static_cast<float>(exp(newScore));
|
||||
sumScore += newScore;
|
||||
}
|
||||
sumScore = cg::reduce(warp, sumScore, cg::plus<float>());
|
||||
|
||||
if (laneIdx < NumTopExperts)
|
||||
{
|
||||
score = static_cast<DataType>(newScore / sumScore);
|
||||
}
|
||||
|
||||
return score;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename KernelParams, typename BaseType, int NumThreads, int NumWarps, int MaxNumTopExperts,
|
||||
bool LoadExpertIdxFromGlobal = false>
|
||||
__device__ void routingPermutation(KernelParams params, PackedScoreIdx<BaseType>* smemPackedScoreIdx,
|
||||
int32_t const warpIdx, uint32_t const clusterBlockRank)
|
||||
{
|
||||
|
||||
using OutputT = typename KernelParams::OutputT;
|
||||
using TypePacked = PackedScoreIdx<BaseType>;
|
||||
|
||||
static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads;
|
||||
// Number of threads in the cluster.
|
||||
static constexpr int NumThreadsPerCluster = NumThreads * NumBlocksPerCluster;
|
||||
// same as max num tokens
|
||||
static constexpr int MaxExpandedIdxPerThread
|
||||
= (MaxNumTokensSingleCluster * MaxNumTopExperts + NumThreadsPerCluster - 1) / NumThreadsPerCluster;
|
||||
|
||||
// Needed for the exclusive sum of token offsets.
|
||||
// Note: the scan might include more bins than needed, with bin counts of 0 to pad
|
||||
using Scan = cub::BlockScan<int32_t, NumThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||
__shared__ typename Scan::TempStorage tempStorage;
|
||||
|
||||
uint32_t const clusterThreadIdx = NumThreads * clusterBlockRank + threadIdx.x;
|
||||
auto expandedIdxSize = params.mNumTokens * params.mTopK;
|
||||
|
||||
// number of experts is bounded by number of threads
|
||||
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads];
|
||||
__shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads];
|
||||
|
||||
// pre-fill the counts with 0
|
||||
if (threadIdx.x < params.mNumExperts)
|
||||
{
|
||||
smemExpertCount[threadIdx.x] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// each thread keeps some number of "expanded indexes" assigned to it
|
||||
// note that expanded indexes simply represent tokens here.
|
||||
// for each of these, we keep the associated expert and offset within expert in registers
|
||||
int32_t expertIndexes[MaxExpandedIdxPerThread];
|
||||
int32_t expertOffsets[MaxExpandedIdxPerThread];
|
||||
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||||
|
||||
// In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
|
||||
// time, and branch between a fast path without bound checks and a slow path with bound checks.
|
||||
// TODO(mjoux): potentially add this back for perf tuning
|
||||
// int constexpr IterStride = 4;
|
||||
// static_assert(MaxExpandedIdxPerThread % IterStride == 0);
|
||||
|
||||
// Define a lambda to avoid code duplication in both branches.
|
||||
auto loopBody = [&](int ii, int expandedIdx)
|
||||
{
|
||||
TypePacked scoreIdx;
|
||||
if constexpr (LoadExpertIdxFromGlobal)
|
||||
{
|
||||
scoreIdx = TypePacked{static_cast<BaseType>(params.mPtrExpertIdx[expandedIdx].score),
|
||||
static_cast<int16_t>(params.mPtrExpertIdx[expandedIdx].idx)};
|
||||
}
|
||||
else
|
||||
{
|
||||
TypePacked const* remoteSmem
|
||||
= cg::cluster_group::map_shared_rank(smemPackedScoreIdx, expandedIdx / (NumWarps * params.mTopK));
|
||||
scoreIdx = remoteSmem[expandedIdx % (NumWarps * params.mTopK)];
|
||||
}
|
||||
|
||||
expertIndexes[ii] = scoreIdx.idx;
|
||||
// check whether this expert is local to our GPU at all and ignore if not
|
||||
auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx;
|
||||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||||
expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + scoreIdx.idx, 1) : 0;
|
||||
if (params.mPtrExpertWeights != nullptr)
|
||||
{
|
||||
params.mPtrExpertWeights[expandedIdx] = OutputT{scoreIdx.score};
|
||||
}
|
||||
};
|
||||
|
||||
int constexpr IterStride = 4;
|
||||
#pragma unroll
|
||||
for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride)
|
||||
{
|
||||
// Whether it's safe to do multiple iterations without bound checks.
|
||||
bool const takeFastPath = (ii0 + IterStride) * NumThreadsPerCluster <= expandedIdxSize;
|
||||
if (takeFastPath)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||||
{
|
||||
int const ii = ii0 + jj;
|
||||
auto expandedIdx = static_cast<int32_t>(clusterThreadIdx) + ii * NumThreadsPerCluster;
|
||||
loopBody(ii, expandedIdx);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
bool doBreak = false;
|
||||
#pragma unroll
|
||||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||||
{
|
||||
int const ii = ii0 + jj;
|
||||
auto expandedIdx = static_cast<int32_t>(clusterThreadIdx) + ii * NumThreadsPerCluster;
|
||||
if (expandedIdx >= expandedIdxSize)
|
||||
{
|
||||
doBreak = true;
|
||||
break;
|
||||
}
|
||||
loopBody(ii, expandedIdx);
|
||||
}
|
||||
if (doBreak)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Make local histogram (token counts per expert) available to all threads in the cluster.
|
||||
__cluster_barrier_arrive();
|
||||
__cluster_barrier_wait();
|
||||
|
||||
//
|
||||
// Each thread now represents one expert
|
||||
//
|
||||
|
||||
// Total number of tokens for this expert.
|
||||
int32_t count = 0;
|
||||
// Per-expert offset for this block.
|
||||
int32_t blockExpertOffset = 0;
|
||||
|
||||
if (threadIdx.x < params.mNumExperts)
|
||||
{
|
||||
// Get the histogram bin from each rank for this expert.
|
||||
int32_t expertCounts[NumBlocksPerCluster];
|
||||
#pragma unroll
|
||||
for (int rank = 0; rank < NumBlocksPerCluster; rank++)
|
||||
{
|
||||
int32_t const* remoteSmem = cg::cluster_group::map_shared_rank(smemExpertCount, rank);
|
||||
expertCounts[rank] = rank * NumWarps < params.mNumTokens ? remoteSmem[threadIdx.x] : 0;
|
||||
}
|
||||
|
||||
// Compute an exclusive prefix sum of the block-local count.
|
||||
#pragma unroll
|
||||
for (int rank = 0; rank < NumBlocksPerCluster; rank++)
|
||||
{
|
||||
if (rank == clusterBlockRank)
|
||||
{
|
||||
blockExpertOffset = count;
|
||||
}
|
||||
count += expertCounts[rank];
|
||||
}
|
||||
}
|
||||
|
||||
// Arrive: we do not access distributed shared memory after this point.
|
||||
__cluster_barrier_arrive();
|
||||
|
||||
// Compute the runtime config for projections
|
||||
// Whether or not an expert is local is taken into account when smemExpertCount is computed
|
||||
// so we do not need to take it into account here.
|
||||
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
|
||||
int32_t ctaOffset;
|
||||
int32_t numNonExitingCtas;
|
||||
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
|
||||
|
||||
if (threadIdx.x < params.mNumExperts)
|
||||
{
|
||||
// Strided loop to share this work between blocks.
|
||||
for (int32_t cta = clusterBlockRank; cta < numCta; cta += NumBlocksPerCluster)
|
||||
{
|
||||
const int32_t localExpertIdx
|
||||
= (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
|
||||
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
|
||||
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta]
|
||||
= min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
|
||||
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count);
|
||||
}
|
||||
|
||||
// get the padded offset associated with this expert
|
||||
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
|
||||
|
||||
// write expert offsets to shared
|
||||
smemExpertOffset[threadIdx.x] = offset + blockExpertOffset;
|
||||
}
|
||||
|
||||
// write out padded count
|
||||
if (clusterBlockRank == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync())
|
||||
{
|
||||
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
|
||||
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
|
||||
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
|
||||
}
|
||||
|
||||
// make expert offsets available to all threads
|
||||
__syncthreads();
|
||||
|
||||
// Wait: we cannot exit while other blocks may be accessing the current block's shared memory.
|
||||
// Note: I observed a perf benefit to doing this before the final loop so the compiler can
|
||||
// implement break with EXIT.
|
||||
__cluster_barrier_wait();
|
||||
|
||||
// trigger the secondary kernel when using PDL
|
||||
// We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx,
|
||||
// mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens
|
||||
// TODO: this is not sufficient to ensure visibility in the next kernel!
|
||||
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
#endif
|
||||
|
||||
// each thread has the same "expanded indexes" assigned to it as above
|
||||
// at this point, we know the final offsets of experts and the offsets within
|
||||
// experts, which allows writing the final index values
|
||||
|
||||
#pragma unroll
|
||||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii)
|
||||
{
|
||||
auto expandedIdx = static_cast<int32_t>(clusterThreadIdx) + ii * NumThreadsPerCluster;
|
||||
if (expandedIdx >= expandedIdxSize)
|
||||
{
|
||||
break;
|
||||
}
|
||||
auto expertIdx = expertIndexes[ii];
|
||||
// check whether this expert is local to our GPU at all
|
||||
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
|
||||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||||
auto tokenIdx = expandedIdx / params.mTopK;
|
||||
auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1};
|
||||
if (params.mPtrExpandedIdxToPermutedIdx != nullptr)
|
||||
{
|
||||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||||
}
|
||||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
|
||||
{
|
||||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Two-step approach (if number of tokens exceed limits of what cluster / cooperative launch
|
||||
// variants can handle): in order to minimize the amount of data to exchange through global memory,
|
||||
// we will compute the local histograms in smem twice: the first kernel will get us the total number
|
||||
// of tokens per expert. The second kernel will use the smem and L2 atomics to get corresponding
|
||||
// element and tile offsets.
|
||||
//
|
||||
// Note: the histogram calculation could also be fused with routingMainKernel, but this might be
|
||||
// inefficient if we have one CTA per token doing a single global atomic.
|
||||
template <typename KernelParams>
|
||||
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramKernel(KernelParams params)
|
||||
{
|
||||
using OutputT = typename KernelParams::OutputT;
|
||||
// number of experts is bounded by number of threads
|
||||
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreadsHist];
|
||||
|
||||
// For unrolling.
|
||||
uint32_t constexpr NumEltsPerThread = 8;
|
||||
|
||||
// Pre-fill the counts with 0
|
||||
if (threadIdx.x < params.mNumExperts)
|
||||
{
|
||||
smemExpertCount[threadIdx.x] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// Wait on primary grid and trigger secondary kernel.
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
|
||||
uint32_t const expandedIdxSize = params.mNumTokens * params.mTopK;
|
||||
uint32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||||
|
||||
uint32_t const gridBlockOffset = blockIdx.x * NumThreadsHist;
|
||||
uint32_t const gridStride = gridDim.x * NumThreadsHist;
|
||||
|
||||
// Define a lambda to avoid code duplication in branches.
|
||||
auto loopBody = [&](int expandedIdx)
|
||||
{
|
||||
PackedScoreIdx<OutputT> scoreIdx = params.mPtrExpertIdx[expandedIdx];
|
||||
// check whether this expert is local to our GPU at all and ignore if not
|
||||
auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx;
|
||||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||||
if (isLocalExpert)
|
||||
{
|
||||
atomicAdd(&smemExpertCount[scoreIdx.idx], 1);
|
||||
}
|
||||
|
||||
if (params.mPtrExpertWeights != nullptr)
|
||||
{
|
||||
params.mPtrExpertWeights[expandedIdx] = static_cast<OutputT>(scoreIdx.score);
|
||||
}
|
||||
};
|
||||
|
||||
// Grid-stride loop.
|
||||
for (uint32_t expandedIdx0 = gridBlockOffset * NumEltsPerThread; expandedIdx0 < expandedIdxSize;
|
||||
expandedIdx0 += gridStride * NumEltsPerThread)
|
||||
{
|
||||
// Fast path if bound checks aren't necessary
|
||||
if (expandedIdx0 + NumEltsPerThread * NumThreadsHist <= expandedIdxSize)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t ii = 0; ii < NumEltsPerThread; ii++)
|
||||
{
|
||||
uint32_t expandedIdx = expandedIdx0 + ii * NumThreadsHist + threadIdx.x;
|
||||
loopBody(expandedIdx);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (uint32_t expandedIdx = expandedIdx0 + threadIdx.x; expandedIdx < expandedIdxSize;
|
||||
expandedIdx += NumThreadsHist)
|
||||
{
|
||||
loopBody(expandedIdx);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Each thread now represents one expert
|
||||
//
|
||||
|
||||
// Reduce histograms with atomics.
|
||||
if (threadIdx.x < params.mNumExperts)
|
||||
{
|
||||
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
|
||||
atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename KernelParams>
|
||||
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(KernelParams params)
|
||||
{
|
||||
using OutputT = typename KernelParams::OutputT;
|
||||
|
||||
// number of experts is bounded by number of threads
|
||||
__shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreadsHist];
|
||||
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreadsHist];
|
||||
__shared__ int32_t __attribute((aligned(128))) smemExpertTileOffset[NumThreadsHist];
|
||||
// needed for the exclusive sum of token offsets
|
||||
using Scan = cub::BlockScan<int32_t, NumThreadsHist, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||
__shared__ typename Scan::TempStorage tempStorage;
|
||||
static constexpr int MaxExpandedIdxPerThread = NumEltsPerOffsetTilePerThread;
|
||||
static constexpr int MaxExpandedIdxPerBlock = NumThreadsHist * MaxExpandedIdxPerThread;
|
||||
|
||||
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
||||
|
||||
uint32_t const expandedIdxSize = params.mNumTokens * params.mTopK;
|
||||
uint32_t const numTiles = (expandedIdxSize + MaxExpandedIdxPerBlock - 1) / (MaxExpandedIdxPerBlock);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// Wait on primary grid.
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
|
||||
// The expert offsets are common to all tiles of all blocks.
|
||||
// Load the histogram, scan it and write offsets to shared memory.
|
||||
// Note: the scan is redundant in all CTAs. Would it make sense to use an intermediate kernel for
|
||||
// the scan, with PDL?
|
||||
|
||||
//
|
||||
// Each thread represents one expert.
|
||||
//
|
||||
|
||||
// Get total count for this expert.
|
||||
int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0;
|
||||
|
||||
// Compute the runtime config for projections
|
||||
// Whether or not an expert is local is taken into account when the histogram is computed
|
||||
// so we do not need to take it into account here.
|
||||
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
|
||||
int32_t ctaOffset;
|
||||
int32_t numNonExitingCtas;
|
||||
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
|
||||
|
||||
if (threadIdx.x < params.mNumExperts)
|
||||
{
|
||||
// Get the padded offset associated with this expert
|
||||
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
|
||||
|
||||
// Write expert offsets to shared
|
||||
smemExpertOffset[threadIdx.x] = offset;
|
||||
}
|
||||
|
||||
// Sync to make expert offsets available to all threads.
|
||||
__syncthreads();
|
||||
|
||||
// The first block writes out padded count
|
||||
if (blockIdx.x == 0 && warpIdx == NumWarpsHist - 1 && cute::elect_one_sync())
|
||||
{
|
||||
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
|
||||
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
|
||||
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
|
||||
}
|
||||
|
||||
if (threadIdx.x < params.mNumExperts)
|
||||
{
|
||||
// Strided loop to share this work between blocks.
|
||||
for (int32_t cta = blockIdx.x; cta < numCta; cta += gridDim.x)
|
||||
{
|
||||
const int32_t localExpertIdx
|
||||
= (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
|
||||
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
|
||||
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta]
|
||||
= min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
|
||||
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Now loop on indices and compute offsets.
|
||||
//
|
||||
|
||||
// Grid-stride loop on 1D "tiles" of input indices.
|
||||
for (uint32_t tileIdx = blockIdx.x; tileIdx < numTiles; tileIdx += gridDim.x)
|
||||
{
|
||||
if (tileIdx > 0)
|
||||
{
|
||||
// Sync for safe reuse of smem buffers.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Pre-fill the counts with 0
|
||||
if (threadIdx.x < params.mNumExperts)
|
||||
{
|
||||
smemExpertCount[threadIdx.x] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// each thread keeps has some number of "expanded indexes" assigned to it
|
||||
// for each of these, we keep the associated expert and offset within expert in registers
|
||||
int32_t expertIndexes[MaxExpandedIdxPerThread];
|
||||
int32_t expertOffsets[MaxExpandedIdxPerThread];
|
||||
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||||
|
||||
// Define a lambda to avoid code duplication in branches.
|
||||
auto loopBody = [&](int ii, int expandedIdx)
|
||||
{
|
||||
PackedScoreIdx<OutputT> scoreIdx = params.mPtrExpertIdx[expandedIdx];
|
||||
expertIndexes[ii] = scoreIdx.idx;
|
||||
// check whether this expert is local to our GPU at all and ignore if not
|
||||
auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx;
|
||||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||||
expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + scoreIdx.idx, 1) : 0;
|
||||
};
|
||||
|
||||
// For all tiles but the last, all indices are in bounds.
|
||||
if (tileIdx < numTiles - 1)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
|
||||
{
|
||||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||||
loopBody(ii, expandedIdx);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// For the last tile, we need to exit the loop when out of bounds.
|
||||
// In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
|
||||
// time, and branch between a fast path without bound checks and a slow path with bound checks
|
||||
int constexpr IterStride = 4;
|
||||
static_assert(MaxExpandedIdxPerThread % IterStride == 0);
|
||||
|
||||
#pragma unroll
|
||||
for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride)
|
||||
{
|
||||
// Whether it's safe to do multiple iterations without bound checks.
|
||||
bool const takeFastPath
|
||||
= tileIdx * MaxExpandedIdxPerBlock + (ii0 + IterStride) * NumThreadsHist <= expandedIdxSize;
|
||||
if (takeFastPath)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||||
{
|
||||
int const ii = ii0 + jj;
|
||||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||||
loopBody(ii, expandedIdx);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
bool doBreak = false;
|
||||
#pragma unroll
|
||||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||||
{
|
||||
int const ii = ii0 + jj;
|
||||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||||
if (expandedIdx >= expandedIdxSize)
|
||||
{
|
||||
doBreak = true;
|
||||
break;
|
||||
}
|
||||
loopBody(ii, expandedIdx);
|
||||
}
|
||||
if (doBreak)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Make local histogram (token counts per expert) available to all threads in the block.
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// Each thread now represents one expert
|
||||
//
|
||||
|
||||
if (threadIdx.x < params.mNumExperts)
|
||||
{
|
||||
// Add the local bin count to the common bin count and get a per-CTA offset. We use the second
|
||||
// half of the histogram buffer for this histogram, because the first half already holds the
|
||||
// reduced histogram from the previous kernel.
|
||||
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
|
||||
int32_t const tileExpertOffset
|
||||
= atomicAdd(¶ms.mPtrExpertCounts[params.mNumExperts + threadIdx.x], localExpertCount);
|
||||
|
||||
// Make per-expert tile offsets available to all threads in the block.
|
||||
smemExpertTileOffset[threadIdx.x] = tileExpertOffset + smemExpertOffset[threadIdx.x];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Add tile offset and element offset and write to global memory.
|
||||
auto storeLoopBody = [&](int ii, int expandedIdx)
|
||||
{
|
||||
int32_t expertIdx = expertIndexes[ii];
|
||||
// check whether this expert is local to our GPU at all
|
||||
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
|
||||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||||
auto tokenIdx = expandedIdx / params.mTopK;
|
||||
auto permutedIdx = isLocalExpert ? (expertOffsets[ii] + smemExpertTileOffset[expertIdx]) : int32_t{-1};
|
||||
if (params.mPtrExpandedIdxToPermutedIdx != nullptr)
|
||||
{
|
||||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||||
}
|
||||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
|
||||
{
|
||||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||||
}
|
||||
};
|
||||
// Bound checks only in last tile.
|
||||
if (tileIdx < numTiles - 1)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
|
||||
{
|
||||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||||
storeLoopBody(ii, expandedIdx);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll
|
||||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
|
||||
{
|
||||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||||
if (expandedIdx >= expandedIdxSize)
|
||||
{
|
||||
break;
|
||||
}
|
||||
storeLoopBody(ii, expandedIdx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// Trigger secondary kernel.
|
||||
// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
|
||||
// dependency sync.
|
||||
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
#endif
|
||||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
}
|
||||
|
||||
} // namespace routing
|
||||
} // namespace moe::dev
|
||||
@ -31,33 +31,21 @@ namespace routing
|
||||
|
||||
namespace tg = batchedGemm::trtllm::gen;
|
||||
|
||||
template <typename TypeExpW>
|
||||
template <typename DataType>
|
||||
struct PackedScoreIdx
|
||||
{
|
||||
TypeExpW score;
|
||||
int16_t idx; // @TODO: Might use int8_t as the number of experts is 128
|
||||
DataType score;
|
||||
int16_t idx;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace routingDeepSeek
|
||||
struct DataBase
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Data
|
||||
{
|
||||
tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16};
|
||||
bool mUsePdl{false};
|
||||
|
||||
// note: at least one of the optional outputs below must be provided
|
||||
// note: if one of the indexes using "PermutedIdx" is provided,
|
||||
// then `mPtrExpertIdx` and `mPtrPermutedIdxSize` must be provided
|
||||
// optional: if `nullptr`, it is not filled
|
||||
// dim: [mNumTokens, mTopK]
|
||||
int32_t* mPtrExpertIdx{nullptr};
|
||||
// optional: only used as an intermediate buffer when the number of tokens is large.
|
||||
// dim: [2*NumThreads] = [512]
|
||||
// dim: max([2*NumThreads] = [512], mNumExperts*2)
|
||||
int32_t* mPtrExpertCounts{nullptr};
|
||||
// optional: if `nullptr`, it is not filled
|
||||
// dim: [1]
|
||||
@ -67,13 +55,25 @@ struct Data
|
||||
int32_t* mPtrExpandedIdxToPermutedIdx{nullptr};
|
||||
// optional: if `nullptr`, it is not filled
|
||||
// dim: [mNumTokens * mTopK + (mNumExperts << mPaddingLog2) - mNumExperts]
|
||||
// Note: this array (mPtrPermutedIdxToTokenIdx) is uninitialized
|
||||
// Any out-of-bounds values are undefined.
|
||||
int32_t* mPtrPermutedIdxToTokenIdx{nullptr};
|
||||
// optional: if `nullptr`, it is not filled
|
||||
// dim: [mNumLocalExperts * (2 ^ mLocalExpertsStrideLog2), mNumTokens]
|
||||
void* mPtrExpertWeightsFull{nullptr};
|
||||
// optional: if `nullptr`, it is not filled
|
||||
// dim: [mNumTokens, mTopK]
|
||||
void* mPtrExpertWeights{nullptr};
|
||||
// optional: if `nullptr`, scores are used directly as input.
|
||||
// If it is given, it must represent a packed value s.t. the most significant
|
||||
// 16/32 bits represent the score without sigmoid activation and
|
||||
// the least significant 16 bits represent the index of the chosen expert (unsigned).
|
||||
// note: this is required if the number of tokens is large.
|
||||
// dim: [mNumTokens, mTopK]
|
||||
void* mPtrExpertIdx{nullptr};
|
||||
// optional: if `nullptr`, `mPtrExpertIdx` must be provided.
|
||||
// If it is given, it represents the scores without sigmoid activation for
|
||||
// each token and expert.
|
||||
// note: if it is provided, we always re-compute the top1 scores
|
||||
// dim: [mNumTokens, mNumExperts]
|
||||
void const* mPtrScores{nullptr};
|
||||
|
||||
//
|
||||
// Grouped Gemm Launch Config Buffers
|
||||
@ -81,110 +81,137 @@ struct Data
|
||||
int32_t* mPtrCtaIdxXyToBatchIdx{nullptr};
|
||||
int32_t* mPtrCtaIdxXyToMnLimit{nullptr};
|
||||
int32_t* mPtrNumNonExitingCtas{nullptr};
|
||||
// mPtrPermutedIdxSize is ptrTotalNumPaddedTokens
|
||||
bool mAllToAllRouteAct{false};
|
||||
|
||||
void const* mPtrRoutingWeights;
|
||||
void const* mPtrRoutingBias;
|
||||
void const* mPtrIn;
|
||||
float* mPtrScores;
|
||||
|
||||
//
|
||||
// Metadata
|
||||
//
|
||||
int32_t mNumTokens;
|
||||
int32_t mHiddenDim;
|
||||
int32_t mNumExperts;
|
||||
int32_t mNumExpertGroups;
|
||||
int32_t mNumLimitedGroups;
|
||||
int32_t mTopK;
|
||||
int32_t mPaddingLog2;
|
||||
|
||||
/// For expert parallelization
|
||||
int32_t mLocalExpertsStartIdx;
|
||||
int32_t mLocalExpertsStrideLog2;
|
||||
int32_t mNumLocalExperts;
|
||||
float mRouteScale;
|
||||
bool mUseRoutingSoftmax;
|
||||
|
||||
int32_t* mPtrNumTokensPerExpert{nullptr};
|
||||
int32_t* mPtrPermutedIdxToExpandedIdx{nullptr};
|
||||
};
|
||||
|
||||
template <typename TypeExpW_, bool UseGroups_, bool UsePdl_>
|
||||
struct KernelParams
|
||||
template <typename InputT_, typename OutputT_, bool UsePdl_>
|
||||
struct KernelParamsBase
|
||||
{
|
||||
using TypeExpW = TypeExpW_;
|
||||
static constexpr bool UseGroups = UseGroups_;
|
||||
using InputT = InputT_;
|
||||
using OutputT = OutputT_;
|
||||
static constexpr bool UsePdl = UsePdl_;
|
||||
|
||||
int32_t* mPtrExpertIdx;
|
||||
int32_t* mPtrExpertCounts;
|
||||
int32_t* mPtrPermutedIdxSize;
|
||||
int32_t* mPtrExpandedIdxToPermutedIdx;
|
||||
int32_t* mPtrPermutedIdxToTokenIdx;
|
||||
int32_t* mPtrPermutedIdxToExpandedIdx;
|
||||
int32_t* mPtrNumTokensPerExpert;
|
||||
// Public pointer members
|
||||
int32_t* mPtrExpertCounts = nullptr;
|
||||
int32_t* mPtrPermutedIdxSize = nullptr;
|
||||
int32_t* mPtrExpandedIdxToPermutedIdx = nullptr;
|
||||
int32_t* mPtrPermutedIdxToTokenIdx = nullptr;
|
||||
int32_t* mPtrCtaIdxXyToBatchIdx = nullptr;
|
||||
int32_t* mPtrCtaIdxXyToMnLimit = nullptr;
|
||||
int32_t* mPtrNumNonExitingCtas = nullptr;
|
||||
OutputT* mPtrExpertWeights = nullptr;
|
||||
InputT const* mPtrScores = nullptr;
|
||||
|
||||
int32_t* mPtrCtaIdxXyToBatchIdx;
|
||||
int32_t* mPtrCtaIdxXyToMnLimit;
|
||||
int32_t* mPtrNumNonExitingCtas;
|
||||
// Public scalar members
|
||||
int32_t mNumTokens = 0;
|
||||
int32_t mNumExperts = 0;
|
||||
|
||||
TypeExpW* mPtrExpertWeightsFull;
|
||||
TypeExpW* mPtrExpertWeights;
|
||||
TypeExpW const* mPtrRoutingWeights;
|
||||
TypeExpW const* mPtrRoutingBias;
|
||||
float const* mPtrScores;
|
||||
int32_t mPaddingLog2 = 0;
|
||||
int32_t mLocalExpertsStartIdx = 0;
|
||||
int32_t mLocalExpertsStrideLog2 = 0;
|
||||
int32_t mNumLocalExperts = 0;
|
||||
|
||||
int32_t mHiddenDim;
|
||||
int32_t mNumExperts;
|
||||
// Public initialization function - make it a template to accept different Data types
|
||||
template <typename DataType>
|
||||
void setBaseParams(DataType const& data)
|
||||
{
|
||||
mPtrExpertCounts = data.mPtrExpertCounts;
|
||||
mPtrPermutedIdxSize = data.mPtrPermutedIdxSize;
|
||||
mPtrExpandedIdxToPermutedIdx = data.mPtrExpandedIdxToPermutedIdx;
|
||||
mPtrPermutedIdxToTokenIdx = data.mPtrPermutedIdxToTokenIdx;
|
||||
mPtrCtaIdxXyToBatchIdx = data.mPtrCtaIdxXyToBatchIdx;
|
||||
mPtrCtaIdxXyToMnLimit = data.mPtrCtaIdxXyToMnLimit;
|
||||
mPtrNumNonExitingCtas = data.mPtrNumNonExitingCtas;
|
||||
mPtrExpertWeights = static_cast<OutputT*>(data.mPtrExpertWeights);
|
||||
mPtrScores = (InputT const*) data.mPtrScores;
|
||||
|
||||
mNumTokens = data.mNumTokens;
|
||||
mNumExperts = data.mNumExperts;
|
||||
|
||||
mPaddingLog2 = data.mPaddingLog2;
|
||||
mLocalExpertsStartIdx = data.mLocalExpertsStartIdx;
|
||||
mLocalExpertsStrideLog2 = data.mLocalExpertsStrideLog2;
|
||||
mNumLocalExperts = data.mNumLocalExperts;
|
||||
}
|
||||
};
|
||||
|
||||
namespace routingDeepSeek
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct Data : public DataBase
|
||||
{
|
||||
tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16};
|
||||
|
||||
//
|
||||
// Grouped Gemm Launch Config Buffers
|
||||
//
|
||||
void const* mPtrRoutingBias;
|
||||
|
||||
int32_t mHiddenDim; // not used
|
||||
int32_t mNumExpertGroups;
|
||||
int32_t mNumExpertsPerGroup;
|
||||
int32_t mNumLimitedGroups;
|
||||
trtllm::dev::IntFastDiv mTopK;
|
||||
int32_t mPaddingLog2;
|
||||
int32_t mLocalExpertsStartIdx;
|
||||
int32_t mLocalExpertsStrideLog2;
|
||||
int32_t mNumLocalExperts;
|
||||
int32_t mNumTokens;
|
||||
|
||||
float mRouteScale;
|
||||
bool mAllToAllRouteAct;
|
||||
bool mUseRoutingSoftmax;
|
||||
};
|
||||
|
||||
template <typename InputT_, typename OutputT_, bool UseGroups_, bool UsePdl_>
|
||||
struct KernelParams : public KernelParamsBase<InputT_, OutputT_, UsePdl_>
|
||||
{
|
||||
using InputT = InputT_;
|
||||
using OutputT = OutputT_;
|
||||
|
||||
static constexpr bool UseGroups = UseGroups_;
|
||||
|
||||
PackedScoreIdx<OutputT>* mPtrExpertIdx = nullptr;
|
||||
|
||||
// OutputT* mPtrExpertWeightsFull = nullptr;
|
||||
// Note: this variable(mPtrExpertWeightsFull) might need to be added back for the low-latency kernels for MoE in
|
||||
// tllm-gen in the future
|
||||
|
||||
OutputT const* mPtrRoutingBias = nullptr;
|
||||
|
||||
int32_t mNumExpertGroups = 0;
|
||||
int32_t mNumExpertsPerGroup = 0;
|
||||
int32_t mNumLimitedGroups = 0;
|
||||
|
||||
trtllm::dev::IntFastDiv mTopK;
|
||||
float mRouteScale = 0.f;
|
||||
|
||||
static KernelParams setKernelParams(Data const& data)
|
||||
{
|
||||
KernelParams params;
|
||||
params.setBaseParams(data);
|
||||
|
||||
params.mPtrExpertIdx = data.mPtrExpertIdx;
|
||||
params.mPtrExpertCounts = data.mPtrExpertCounts;
|
||||
params.mPtrPermutedIdxSize = data.mPtrPermutedIdxSize;
|
||||
params.mPtrExpandedIdxToPermutedIdx = data.mPtrExpandedIdxToPermutedIdx;
|
||||
params.mPtrPermutedIdxToTokenIdx = data.mPtrPermutedIdxToTokenIdx;
|
||||
params.mPtrPermutedIdxToExpandedIdx = data.mPtrPermutedIdxToExpandedIdx;
|
||||
params.mPtrNumTokensPerExpert = data.mPtrNumTokensPerExpert;
|
||||
params.mPtrExpertIdx = (PackedScoreIdx<OutputT>*) data.mPtrExpertIdx;
|
||||
|
||||
params.mPtrCtaIdxXyToBatchIdx = data.mPtrCtaIdxXyToBatchIdx;
|
||||
params.mPtrCtaIdxXyToMnLimit = data.mPtrCtaIdxXyToMnLimit;
|
||||
params.mPtrNumNonExitingCtas = data.mPtrNumNonExitingCtas;
|
||||
// params.mPtrExpertWeightsFull = static_cast<OutputT*>(data.mPtrExpertWeightsFull);
|
||||
params.mPtrRoutingBias = static_cast<OutputT const*>(data.mPtrRoutingBias);
|
||||
|
||||
params.mPtrExpertWeightsFull = (TypeExpW*) data.mPtrExpertWeightsFull;
|
||||
params.mPtrExpertWeights = (TypeExpW*) data.mPtrExpertWeights;
|
||||
params.mPtrRoutingWeights = (TypeExpW*) data.mPtrRoutingWeights;
|
||||
params.mPtrRoutingBias = (TypeExpW*) data.mPtrRoutingBias;
|
||||
params.mPtrScores = data.mPtrScores;
|
||||
|
||||
params.mHiddenDim = data.mHiddenDim;
|
||||
params.mNumExperts = data.mNumExperts;
|
||||
params.mNumExpertGroups = data.mNumExpertGroups;
|
||||
params.mNumExpertsPerGroup = data.mNumExperts / data.mNumExpertGroups;
|
||||
params.mNumLimitedGroups = data.mNumLimitedGroups;
|
||||
params.mTopK = trtllm::dev::IntFastDiv(data.mTopK);
|
||||
params.mPaddingLog2 = data.mPaddingLog2;
|
||||
params.mLocalExpertsStartIdx = data.mLocalExpertsStartIdx;
|
||||
params.mLocalExpertsStrideLog2 = data.mLocalExpertsStrideLog2;
|
||||
params.mNumLocalExperts = data.mNumLocalExperts;
|
||||
params.mNumTokens = data.mNumTokens;
|
||||
params.mRouteScale = data.mRouteScale;
|
||||
params.mAllToAllRouteAct = data.mAllToAllRouteAct;
|
||||
|
||||
return params;
|
||||
}
|
||||
};
|
||||
|
||||
void run(Data const& data, void* stream);
|
||||
void run(Data& data, void* stream);
|
||||
|
||||
} // namespace routingDeepSeek
|
||||
|
||||
@ -195,102 +222,28 @@ namespace routingLlama4
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Data
|
||||
struct Data : public DataBase
|
||||
{
|
||||
tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16};
|
||||
bool mUsePdl{false};
|
||||
|
||||
// optional: if `nullptr`, `mPtrExpertIdx` must be provided.
|
||||
// If it is given, it represents the scores without sigmoid activation for
|
||||
// each token and expert.
|
||||
// note: if it is provided, we always re-compute the top1 scores
|
||||
// dim: [mNumTokens, mNumExperts]
|
||||
void const* mPtrScores{nullptr};
|
||||
// optional: if `nullptr`, scores are used directly as input.
|
||||
// If it is given, it must represent a packed value s.t. the most significant
|
||||
// 16/32 bits represent the score without sigmoid activation and
|
||||
// the least significant 16 bits represent the index of the chosen expert (unsigned).
|
||||
// note: this is required if the number of tokens is large.
|
||||
// dim: [mNumTokens, mTopK]
|
||||
void* mPtrExpertIdx{nullptr};
|
||||
|
||||
// note: at least one of the optional outputs below must be provided
|
||||
// optional: only used as an intermediate buffer when the number of tokens is large.
|
||||
// dim: [2, mNumExperts]
|
||||
int32_t* mPtrExpertCounts{nullptr};
|
||||
// dim: [1]
|
||||
int32_t* mPtrPermutedIdxSize{nullptr};
|
||||
// optional: if `nullptr`, it is not filled
|
||||
// dim: [mNumTokens * mTopK]
|
||||
int32_t* mPtrExpandedIdxToPermutedIdx{nullptr};
|
||||
// optional: if `nullptr`, it is not filled
|
||||
// dim: [mNumTokens * mTopK + (mNumExperts << mPaddingLog2) - mNumExperts]
|
||||
int32_t* mPtrPermutedIdxToTokenIdx{nullptr};
|
||||
// optional: if `nullptr`, it is not filled
|
||||
// dim: [mNumTokens, mTopK]
|
||||
void* mPtrExpertWeights{nullptr};
|
||||
//
|
||||
// Grouped Gemm Launch Config Buffers
|
||||
//
|
||||
int32_t* mPtrCtaIdxXyToBatchIdx{nullptr};
|
||||
int32_t* mPtrCtaIdxXyToMnLimit{nullptr};
|
||||
int32_t* mPtrNumNonExitingCtas{nullptr};
|
||||
|
||||
int32_t mNumTokens;
|
||||
int32_t mNumExperts;
|
||||
int32_t mTopK;
|
||||
int32_t mPaddingLog2;
|
||||
int32_t mLocalExpertsStartIdx;
|
||||
int32_t mLocalExpertsStrideLog2;
|
||||
int32_t mNumLocalExperts;
|
||||
};
|
||||
|
||||
template <typename TypeExpW_, bool UsePdl_>
|
||||
struct KernelParams
|
||||
template <typename InputT_, typename OutputT_, bool UsePdl_>
|
||||
struct KernelParams : public KernelParamsBase<InputT_, OutputT_, UsePdl_>
|
||||
{
|
||||
using TypeExpW = TypeExpW_;
|
||||
static constexpr bool UsePdl = UsePdl_;
|
||||
using InputT = InputT_;
|
||||
using OutputT = OutputT_;
|
||||
|
||||
PackedScoreIdx<TypeExpW>* mPtrExpertIdx;
|
||||
TypeExpW const* mPtrScores;
|
||||
int32_t* mPtrExpertCounts;
|
||||
int32_t* mPtrPermutedIdxSize;
|
||||
int32_t* mPtrExpandedIdxToPermutedIdx;
|
||||
int32_t* mPtrPermutedIdxToTokenIdx;
|
||||
int32_t* mPtrCtaIdxXyToBatchIdx;
|
||||
int32_t* mPtrCtaIdxXyToMnLimit;
|
||||
int32_t* mPtrNumNonExitingCtas;
|
||||
TypeExpW* mPtrExpertWeights;
|
||||
PackedScoreIdx<OutputT>* mPtrExpertIdx = nullptr;
|
||||
|
||||
int32_t mNumTokens;
|
||||
int32_t mNumExperts;
|
||||
int32_t mPaddingLog2;
|
||||
int32_t mLocalExpertsStartIdx;
|
||||
int32_t mLocalExpertsStrideLog2;
|
||||
int32_t mNumLocalExperts;
|
||||
int32_t mTopK;
|
||||
|
||||
static KernelParams setKernelParams(Data const& data)
|
||||
{
|
||||
KernelParams params;
|
||||
params.setBaseParams(data);
|
||||
|
||||
params.mPtrExpertIdx = (PackedScoreIdx<TypeExpW>*) data.mPtrExpertIdx;
|
||||
params.mPtrScores = (TypeExpW const*) data.mPtrScores;
|
||||
params.mPtrExpertCounts = data.mPtrExpertCounts;
|
||||
params.mPtrPermutedIdxSize = data.mPtrPermutedIdxSize;
|
||||
params.mPtrExpandedIdxToPermutedIdx = data.mPtrExpandedIdxToPermutedIdx;
|
||||
params.mPtrPermutedIdxToTokenIdx = data.mPtrPermutedIdxToTokenIdx;
|
||||
params.mPtrCtaIdxXyToBatchIdx = data.mPtrCtaIdxXyToBatchIdx;
|
||||
params.mPtrCtaIdxXyToMnLimit = data.mPtrCtaIdxXyToMnLimit;
|
||||
params.mPtrNumNonExitingCtas = data.mPtrNumNonExitingCtas;
|
||||
params.mPtrExpertWeights = (TypeExpW*) data.mPtrExpertWeights;
|
||||
|
||||
params.mNumTokens = data.mNumTokens;
|
||||
params.mNumExperts = data.mNumExperts;
|
||||
params.mPaddingLog2 = data.mPaddingLog2;
|
||||
params.mLocalExpertsStartIdx = data.mLocalExpertsStartIdx;
|
||||
params.mLocalExpertsStrideLog2 = data.mLocalExpertsStrideLog2;
|
||||
params.mNumLocalExperts = data.mNumLocalExperts;
|
||||
|
||||
params.mPtrExpertIdx = (PackedScoreIdx<OutputT>*) data.mPtrExpertIdx;
|
||||
params.mTopK = data.mTopK;
|
||||
return params;
|
||||
}
|
||||
};
|
||||
@ -306,105 +259,37 @@ namespace routingRenormalize
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Data
|
||||
struct Data : public DataBase
|
||||
{
|
||||
tg::Dtype mDtypeExpW{tg::Dtype::Fp32};
|
||||
tg::Dtype mDtypeElt{tg::Dtype::Bfloat16};
|
||||
bool mUsePdl{false};
|
||||
|
||||
bool mDoSoftmaxBeforeTopK{false};
|
||||
bool mNormTopkProb{true}; // Default value is true for Qwen3 model
|
||||
// optional: if `nullptr`, `mPtrExpertIdx` must be provided.
|
||||
// If it is given, it represents the scores without sigmoid activation for
|
||||
// each token and expert.
|
||||
// note: if it is provided, we always re-compute the top1 scores
|
||||
// dim: [mNumTokens, mNumExperts]
|
||||
void const* mPtrScores{nullptr};
|
||||
// optional: if `nullptr`, scores are used directly as input.
|
||||
// If it is given, it must represent a packed value s.t. the most significant
|
||||
// 16/32 bits represent the score without sigmoid activation and
|
||||
// the least significant 16 bits represent the index of the chosen expert (unsigned).
|
||||
// note: this is required if the number of tokens is large.
|
||||
// dim: [mNumTokens, mTopK]
|
||||
void* mPtrExpertIdx{nullptr};
|
||||
|
||||
// note: at least one of the optional outputs below must be provided
|
||||
// optional: only used as an intermediate buffer when the number of tokens is large.
|
||||
// dim: [2, mNumExperts]
|
||||
int32_t* mPtrExpertCounts{nullptr};
|
||||
// dim: [1]
|
||||
int32_t* mPtrPermutedIdxSize{nullptr};
|
||||
// optional: if `nullptr`, it is not filled
|
||||
// dim: [mNumTokens * mTopK]
|
||||
int32_t* mPtrExpandedIdxToPermutedIdx{nullptr};
|
||||
// optional: if `nullptr`, it is not filled
|
||||
// dim: [mNumTokens * mTopK + (mNumExperts << mPaddingLog2) - mNumExperts]
|
||||
int32_t* mPtrPermutedIdxToTokenIdx{nullptr};
|
||||
// optional: if `nullptr`, it is not filled
|
||||
// dim: [mNumTokens, mTopK]
|
||||
void* mPtrExpertWeights{nullptr};
|
||||
//
|
||||
// Grouped Gemm Launch Config Buffers
|
||||
//
|
||||
int32_t* mPtrCtaIdxXyToBatchIdx{nullptr};
|
||||
int32_t* mPtrCtaIdxXyToMnLimit{nullptr};
|
||||
int32_t* mPtrNumNonExitingCtas{nullptr};
|
||||
|
||||
int32_t mNumTokens;
|
||||
int32_t mNumExperts;
|
||||
int32_t mTopK;
|
||||
int32_t mPaddingLog2;
|
||||
int32_t mLocalExpertsStartIdx;
|
||||
int32_t mLocalExpertsStrideLog2;
|
||||
int32_t mNumLocalExperts;
|
||||
};
|
||||
|
||||
template <typename Type_, typename TypeExpW_, bool UsePdl_>
|
||||
struct KernelParams
|
||||
template <typename InputT_, typename OutputT_, bool DoSoftmaxBeforeTopK_, bool UsePdl_>
|
||||
struct KernelParams : public KernelParamsBase<InputT_, OutputT_, UsePdl_>
|
||||
{
|
||||
using Type = Type_;
|
||||
using TypeExpW = TypeExpW_;
|
||||
static constexpr bool UsePdl = UsePdl_;
|
||||
bool mNormTopkProb = true;
|
||||
PackedScoreIdx<TypeExpW>* mPtrExpertIdx;
|
||||
TypeExpW const* mPtrScores;
|
||||
int32_t* mPtrExpertCounts;
|
||||
int32_t* mPtrPermutedIdxSize;
|
||||
int32_t* mPtrExpandedIdxToPermutedIdx;
|
||||
int32_t* mPtrPermutedIdxToTokenIdx;
|
||||
int32_t* mPtrCtaIdxXyToBatchIdx;
|
||||
int32_t* mPtrCtaIdxXyToMnLimit;
|
||||
int32_t* mPtrNumNonExitingCtas;
|
||||
TypeExpW* mPtrExpertWeights;
|
||||
using InputT = InputT_;
|
||||
using OutputT = OutputT_;
|
||||
|
||||
int32_t mNumTokens;
|
||||
int32_t mNumExperts;
|
||||
int32_t mPaddingLog2;
|
||||
int32_t mLocalExpertsStartIdx;
|
||||
int32_t mLocalExpertsStrideLog2;
|
||||
int32_t mNumLocalExperts;
|
||||
static constexpr bool DoSoftmaxBeforeTopK = DoSoftmaxBeforeTopK_;
|
||||
|
||||
PackedScoreIdx<OutputT>* mPtrExpertIdx = nullptr;
|
||||
|
||||
int32_t mTopK = 0;
|
||||
|
||||
bool mNormTopkProb = true;
|
||||
|
||||
static KernelParams setKernelParams(Data const& data)
|
||||
{
|
||||
KernelParams params;
|
||||
params.setBaseParams(data);
|
||||
|
||||
params.mPtrExpertIdx = (PackedScoreIdx<OutputT>*) data.mPtrExpertIdx;
|
||||
params.mNormTopkProb = data.mNormTopkProb;
|
||||
params.mPtrExpertIdx = (PackedScoreIdx<TypeExpW>*) data.mPtrExpertIdx;
|
||||
params.mPtrScores = (TypeExpW const*) data.mPtrScores;
|
||||
params.mPtrExpertCounts = data.mPtrExpertCounts;
|
||||
params.mPtrPermutedIdxSize = data.mPtrPermutedIdxSize;
|
||||
params.mPtrExpandedIdxToPermutedIdx = data.mPtrExpandedIdxToPermutedIdx;
|
||||
params.mPtrPermutedIdxToTokenIdx = data.mPtrPermutedIdxToTokenIdx;
|
||||
params.mPtrCtaIdxXyToBatchIdx = data.mPtrCtaIdxXyToBatchIdx;
|
||||
params.mPtrCtaIdxXyToMnLimit = data.mPtrCtaIdxXyToMnLimit;
|
||||
params.mPtrNumNonExitingCtas = data.mPtrNumNonExitingCtas;
|
||||
params.mPtrExpertWeights = (TypeExpW*) data.mPtrExpertWeights;
|
||||
|
||||
params.mNumTokens = data.mNumTokens;
|
||||
params.mNumExperts = data.mNumExperts;
|
||||
params.mPaddingLog2 = data.mPaddingLog2;
|
||||
params.mLocalExpertsStartIdx = data.mLocalExpertsStartIdx;
|
||||
params.mLocalExpertsStrideLog2 = data.mLocalExpertsStrideLog2;
|
||||
params.mNumLocalExperts = data.mNumLocalExperts;
|
||||
|
||||
params.mTopK = data.mTopK;
|
||||
return params;
|
||||
}
|
||||
};
|
||||
|
||||
@ -162,7 +162,7 @@ struct Sort<4, RedType>
|
||||
|
||||
template <int K, typename Type>
|
||||
__forceinline__ __device__ void reduceTopK(cg::thread_block_tile<WarpSize> const& warp, Type (&out)[K],
|
||||
int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue)
|
||||
int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue, int actualK = K)
|
||||
{
|
||||
static_assert(K > 0, "Top K must have K > 0");
|
||||
static_assert(K < WarpSize, "Top K must have K < WarpSize");
|
||||
@ -170,7 +170,7 @@ __forceinline__ __device__ void reduceTopK(cg::thread_block_tile<WarpSize> const
|
||||
RedType topK{value, idx};
|
||||
typename RedType::TypeCmp packedMax{};
|
||||
#pragma unroll
|
||||
for (int kk = 0; kk < K; ++kk)
|
||||
for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct
|
||||
{
|
||||
topK = kk > 0 && packedMax == topK.compVal ? RedType{minValue, idx} : topK;
|
||||
// get the next largest value
|
||||
@ -181,7 +181,7 @@ __forceinline__ __device__ void reduceTopK(cg::thread_block_tile<WarpSize> const
|
||||
|
||||
template <int K, typename Type, int N>
|
||||
__forceinline__ __device__ void reduceTopK(cg::thread_block_tile<WarpSize> const& warp, Type (&out)[K],
|
||||
int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue)
|
||||
int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K)
|
||||
{
|
||||
static_assert(K > 0, "Top K must have K > 0");
|
||||
static_assert(K < WarpSize, "Top K must have K < WarpSize");
|
||||
@ -199,7 +199,7 @@ __forceinline__ __device__ void reduceTopK(cg::thread_block_tile<WarpSize> const
|
||||
|
||||
typename RedType::TypeCmp packedMax{};
|
||||
#pragma unroll
|
||||
for (int kk = 0; kk < K; ++kk)
|
||||
for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct
|
||||
{
|
||||
bool update = kk > 0 && packedMax == topK[0].compVal;
|
||||
#pragma unroll
|
||||
|
||||
@ -0,0 +1,501 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "RoutingKernel.cuh"
|
||||
|
||||
namespace moe::dev::routing
|
||||
{
|
||||
namespace routingLlama4
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static constexpr int NumThreads = 1024;
|
||||
static constexpr int NumWarps = NumThreads / WarpSize;
|
||||
static constexpr int MaxNumTopExperts = 1;
|
||||
static constexpr int MaxNumExperts = 128;
|
||||
static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads;
|
||||
static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps;
|
||||
static constexpr int WarpKernelSmemStride = 33;
|
||||
// with further optimization to `routingIndicesWarpKernel`, this limit may
|
||||
// increase. For now, it is a good cut-off point for when the block-wise
|
||||
// operations are more efficient end-to-end.
|
||||
static constexpr int WarpKernelMaxNumTokens = 4;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename DataType, int VecSize>
|
||||
__forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile<WarpSize> const& warp,
|
||||
DataType (&warpMaxScore)[MaxNumTopExperts], int32_t (&warpMaxExpertIdx)[MaxNumTopExperts], int32_t const laneIdx,
|
||||
int32_t const numExperts, DataType const* ptrScores)
|
||||
{
|
||||
DataType minScore = DataType{-INFINITY};
|
||||
DataType maxScore = minScore;
|
||||
int32_t maxExpertIdx{-1};
|
||||
using DataTypeVec = std::conditional_t<sizeof(DataType) == 2, float2, float4>;
|
||||
|
||||
// Non-vectorized loading: directly access ptrScores with expertIdx
|
||||
for (int i = 0; i < VecSize; ++i)
|
||||
{
|
||||
auto expertIdx = i * WarpSize + laneIdx;
|
||||
auto newScore = expertIdx < numExperts ? ptrScores[expertIdx] : minScore;
|
||||
// note: use `>=` s.t. highest index always wins, just like in `reduceTopK`
|
||||
if (newScore > maxScore)
|
||||
{
|
||||
maxScore = newScore;
|
||||
maxExpertIdx = expertIdx;
|
||||
}
|
||||
}
|
||||
|
||||
topk::reduceTopK(warp, warpMaxScore, warpMaxExpertIdx, maxScore, maxExpertIdx, minScore);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename KernelParams>
|
||||
__global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParams params)
|
||||
{
|
||||
// types used in this kernel
|
||||
using OutputT = typename KernelParams::OutputT;
|
||||
using InputT = typename KernelParams::InputT;
|
||||
using TypePacked = PackedScoreIdx<OutputT>;
|
||||
// use the default cub warp-scan, with shfl
|
||||
using Scan = cub::WarpScan<int32_t>;
|
||||
__shared__ typename Scan::TempStorage tempStorage;
|
||||
|
||||
// each thread encodes 4 experts in one `int32_t`. The assumption is that
|
||||
// we don't have more than 127 tokens, but `WarpKernelMaxNumTokens` must be
|
||||
// smaller than that because other approaches will be more efficient for
|
||||
// 127 tokens.
|
||||
static constexpr int ExpertsPerThread = sizeof(int32_t);
|
||||
static_assert(WarpKernelMaxNumTokens <= 127);
|
||||
// this is a full table of which token is routed to which expert.
|
||||
// the assumption here is that there are no more than 128 experts.
|
||||
// we use a stride of 33 instead of 32 to avoid shared memory bank conflicts.
|
||||
__shared__ int32_t __attribute((aligned(128)))
|
||||
smemExpertTokenCountFull[WarpKernelMaxNumTokens][WarpKernelSmemStride];
|
||||
static_assert(WarpKernelSmemStride == WarpSize + 1);
|
||||
static_assert(MaxNumExperts / sizeof(int32_t) <= WarpSize);
|
||||
|
||||
// values needed for the top-1 reduction, if required
|
||||
InputT minScore = InputT{-INFINITY};
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WarpSize>(block);
|
||||
|
||||
#pragma unroll
|
||||
for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens; ++tokenIdx)
|
||||
{
|
||||
// reset full shared memory field to 0
|
||||
smemExpertTokenCountFull[tokenIdx][threadIdx.x] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// then wait on primary grid
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
#endif
|
||||
|
||||
if (params.mPtrScores != nullptr)
|
||||
{
|
||||
// if we use `mPtrScores` as input, we need to perform the top-1 reduction
|
||||
// for each token, we load the scores then use `reduceTopK` for this.
|
||||
// each thread works on 4 experts, so a local reduction is done before
|
||||
for (int tokenIdx = 0; tokenIdx < params.mNumTokens; ++tokenIdx)
|
||||
{
|
||||
auto scoreOffset = tokenIdx * params.mNumExperts;
|
||||
int32_t warpMaxExpertIdx[MaxNumTopExperts];
|
||||
InputT warpMaxScore[MaxNumTopExperts];
|
||||
|
||||
// Use routingTopKExperts function instead of inline logic
|
||||
routingTopKExperts<InputT, ExpertsPerThread>(
|
||||
warp, warpMaxScore, warpMaxExpertIdx, threadIdx.x, params.mNumExperts, params.mPtrScores + scoreOffset);
|
||||
|
||||
if (cute::elect_one_sync())
|
||||
{
|
||||
// one thread updates the count linking token to chosen expert
|
||||
auto expertTokenCount = 0;
|
||||
setBits</* IsZero= */ true>(expertTokenCount, 1, warpMaxExpertIdx[0] % ExpertsPerThread);
|
||||
smemExpertTokenCountFull[tokenIdx][warpMaxExpertIdx[0] / ExpertsPerThread] = expertTokenCount;
|
||||
// we also compute the final score here and write it out if required
|
||||
auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})};
|
||||
if (params.mPtrExpertWeights != nullptr)
|
||||
{
|
||||
params.mPtrExpertWeights[tokenIdx] = finalScore;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// if we do not have `mPtrScores` as input, we expect that `mPtrExpertWeights`
|
||||
// contains the top-1 packed score and index already.
|
||||
// Each thread represents a token here, and we extract the relevant score
|
||||
// The assumption is that the #tokens is limited by warp-size
|
||||
static_assert(WarpKernelMaxNumTokens <= WarpSize);
|
||||
TypePacked scoreIdx = threadIdx.x < params.mNumTokens ? params.mPtrExpertIdx[threadIdx.x] : TypePacked{};
|
||||
int32_t expertTokenCount = 0;
|
||||
setBits</* IsZero= */ true>(expertTokenCount, 1, scoreIdx.idx % ExpertsPerThread);
|
||||
if (threadIdx.x < params.mNumTokens)
|
||||
{
|
||||
smemExpertTokenCountFull[threadIdx.x][scoreIdx.idx / ExpertsPerThread] = expertTokenCount;
|
||||
}
|
||||
// we also compute the final score here and write it out if required
|
||||
auto finalScore = OutputT{sigmoid_accurate(float{scoreIdx.score})};
|
||||
if (params.mPtrExpertWeights != nullptr && threadIdx.x < params.mNumTokens)
|
||||
{
|
||||
params.mPtrExpertWeights[threadIdx.x] = finalScore;
|
||||
}
|
||||
}
|
||||
|
||||
// make the full table available to all threads
|
||||
__syncwarp();
|
||||
|
||||
// at this point, each thread keeps a count of its 4 assigned experts in
|
||||
// `expertCount`, as well as the offsets for all tokens w.r.t. these 4 experts
|
||||
// in `expertOffset`.
|
||||
int32_t expertCount = 0;
|
||||
int32_t expertOffset[WarpKernelMaxNumTokens + 1];
|
||||
#pragma unroll
|
||||
for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens + 1; ++tokenIdx)
|
||||
{
|
||||
if (tokenIdx > params.mNumTokens)
|
||||
break;
|
||||
// simple reduction for `expertCount`, and scan for `expertOffset`
|
||||
auto expertTokenCount = tokenIdx < params.mNumTokens ? smemExpertTokenCountFull[tokenIdx][threadIdx.x] : 0;
|
||||
expertOffset[tokenIdx] = expertCount;
|
||||
expertCount += expertTokenCount;
|
||||
}
|
||||
|
||||
// at this point, we are ready for the scan across all experts to get the
|
||||
// thread-wise offsets across experts
|
||||
// first, we need to reduce across our 4 experts into `numCta`
|
||||
int32_t numCta = 0;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ExpertsPerThread; ++ii)
|
||||
{
|
||||
auto count = getBits(expertCount, ii);
|
||||
numCta += divUpLog2<int32_t>(count, params.mPaddingLog2);
|
||||
}
|
||||
// second, we perform the exclusive sum across the warp
|
||||
int32_t ctaOffset;
|
||||
int32_t numNonExitingCtas;
|
||||
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
|
||||
|
||||
// finally, we perform a scan across our local experts, starting with the
|
||||
// warp-wide scan result (`ctaOffset`)
|
||||
auto ctaOffsetExp = ctaOffset;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ExpertsPerThread; ++ii)
|
||||
{
|
||||
auto count = getBits(expertCount, ii);
|
||||
auto finalNumCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
|
||||
auto expertIdx = threadIdx.x * ExpertsPerThread + ii;
|
||||
// during the scan for expert offsets, we can already write out
|
||||
// both `mPtrCtaIdxXyToBatchIdx` and `mPtrCtaIdxXyToMnLimit`
|
||||
for (int cta = 0; cta < finalNumCta; ++cta)
|
||||
{
|
||||
params.mPtrCtaIdxXyToBatchIdx[ctaOffsetExp + cta] = expertIdx;
|
||||
params.mPtrCtaIdxXyToMnLimit[ctaOffsetExp + cta]
|
||||
= min(mulLog2<int32_t>(ctaOffsetExp + cta + 1, params.mPaddingLog2),
|
||||
mulLog2<int32_t>(ctaOffsetExp, params.mPaddingLog2) + count);
|
||||
}
|
||||
ctaOffsetExp += finalNumCta;
|
||||
}
|
||||
|
||||
// at this point, we can write out padded count from the warp-aggregate
|
||||
if (cute::elect_one_sync())
|
||||
{
|
||||
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
|
||||
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
|
||||
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
|
||||
// we can trigger the next kernel at this point
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// at this point, all values for offsets are ready, except the final offsets
|
||||
// within the padded index (`permutedIdx`)
|
||||
// for this, we perform a scan similar to the one directly after the warp-scan:
|
||||
// here, we keep the local offset for each of the thread's experts in a field
|
||||
// of registers
|
||||
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||||
int32_t finalExpertOffset[ExpertsPerThread];
|
||||
finalExpertOffset[0] = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < ExpertsPerThread; ++ii)
|
||||
{
|
||||
finalExpertOffset[ii]
|
||||
= finalExpertOffset[ii - 1] + divUpMulLog2<int32_t>(getBits(expertCount, ii - 1), params.mPaddingLog2);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens; ++tokenIdx)
|
||||
{
|
||||
// at this point, we can calculate the final index:
|
||||
// we simply loop over all tokens, and all experts assigned to this thread.
|
||||
// For each pair, we determine whether that token was routed to that expert
|
||||
// based on whether the offset for that token changed.
|
||||
// we can then easily compute the final `expertIdx` and `permutedIdx` relative
|
||||
// to this token and expert, and write them out.
|
||||
if (tokenIdx >= params.mNumTokens)
|
||||
break;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ExpertsPerThread; ++ii)
|
||||
{
|
||||
// determine whether the offset for this expert and token changes
|
||||
auto localOffsetToken = getBits(expertOffset[tokenIdx], ii);
|
||||
auto isTokenRouted = getBits(expertOffset[tokenIdx + 1], ii) > localOffsetToken;
|
||||
// the expert index of this expert
|
||||
auto expertIdx = threadIdx.x * ExpertsPerThread + ii;
|
||||
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
|
||||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||||
// the permuted index: we add the local offset relative to this expert and token
|
||||
// to the global offset from the scan for this expert
|
||||
auto permutedIdx = isLocalExpert ? finalExpertOffset[ii] + localOffsetToken : int32_t{-1};
|
||||
// write out `mPtrExpandedIdxToPermutedIdx` if required
|
||||
if (params.mPtrExpandedIdxToPermutedIdx != nullptr && isTokenRouted)
|
||||
{
|
||||
params.mPtrExpandedIdxToPermutedIdx[tokenIdx] = permutedIdx;
|
||||
}
|
||||
// write out `mPtrPermutedIdxToTokenIdx` if required
|
||||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert && isTokenRouted)
|
||||
{
|
||||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename KernelParams>
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads)
|
||||
routingIndicesClusterKernel(KernelParams params)
|
||||
{
|
||||
// number of tokens/expanded idx is bounded by total number of warps
|
||||
using OutputT = typename KernelParams::OutputT;
|
||||
using InputT = typename KernelParams::InputT;
|
||||
using TypePacked = PackedScoreIdx<OutputT>;
|
||||
__shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps];
|
||||
|
||||
uint32_t const clusterBlockRank = blockIdx.x;
|
||||
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
||||
int32_t const laneIdx = cutlass::arch::LaneId();
|
||||
|
||||
// TODO(mjoux): expand to more tokens (possibly)
|
||||
auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx;
|
||||
auto scoreOffset = warpTokenIdx * params.mNumExperts;
|
||||
bool validToken = warpTokenIdx < params.mNumTokens;
|
||||
InputT minScore = InputT{-INFINITY};
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WarpSize>(block);
|
||||
|
||||
// then wait on primary grid
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
|
||||
if (params.mPtrScores != nullptr)
|
||||
{
|
||||
// in this case, each warp represents a token
|
||||
// we then exchange all token max scores, s.t. afterwards, each thread
|
||||
// represents a token
|
||||
InputT warpMaxScore[MaxNumTopExperts];
|
||||
int32_t warpMaxExpertIdx[MaxNumTopExperts];
|
||||
|
||||
if (validToken)
|
||||
{
|
||||
routingTopKExperts<InputT, MaxNumExperts / WarpSize>(
|
||||
warp, warpMaxScore, warpMaxExpertIdx, laneIdx, params.mNumExperts, params.mPtrScores + scoreOffset);
|
||||
if (cute::elect_one_sync())
|
||||
{
|
||||
auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})};
|
||||
TypePacked packedScore{finalScore, static_cast<int16_t>(warpMaxExpertIdx[0])};
|
||||
smemPackedScoreIdx[warpIdx] = packedScore;
|
||||
}
|
||||
}
|
||||
// make packed scores available to all threads in cluster
|
||||
__cluster_barrier_arrive();
|
||||
__cluster_barrier_wait();
|
||||
}
|
||||
|
||||
routingPermutation<KernelParams, OutputT, NumThreads, NumWarps, MaxNumTopExperts,
|
||||
/*LoadExpertIdxFromGlobal=*/false>(params, smemPackedScoreIdx, warpIdx, clusterBlockRank);
|
||||
}
|
||||
#else
|
||||
__global__ void routingIndicesClusterKernel(KernelParams params)
|
||||
{
|
||||
assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures");
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// this kernel is needed in case we have scores as input for the histogram kernel
|
||||
template <typename KernelParams>
|
||||
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresKernel(KernelParams params)
|
||||
{
|
||||
using OutputT = typename KernelParams::OutputT;
|
||||
using InputT = typename KernelParams::InputT;
|
||||
using TypePacked = PackedScoreIdx<OutputT>;
|
||||
static constexpr int VecSize = MaxNumExperts / WarpSize;
|
||||
// we assume that #experts is a multiple of 4, so VecSize must be 4.
|
||||
static_assert(VecSize == 4);
|
||||
|
||||
int32_t const laneIdx = cutlass::arch::LaneId();
|
||||
int32_t const warpIdx = threadIdx.x / WarpSize;
|
||||
int32_t const globalWarpIdx = blockIdx.x * NumWarpsHist + warpIdx;
|
||||
int32_t const globalWarpStride = gridDim.x * NumWarpsHist;
|
||||
InputT minScore = InputT{-INFINITY};
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WarpSize>(block);
|
||||
|
||||
// initialize the mPtrExpertCounts
|
||||
int32_t expertCountsNum = 2 * params.mNumExperts;
|
||||
int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x;
|
||||
int32_t globalThreadStride = gridDim.x * NumThreads;
|
||||
initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// Wait on primary grid and trigger secondary kernel.
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
#endif
|
||||
|
||||
// in this case, each warp represents a token, and we use a grid-stride loop
|
||||
// over all warps/tokens
|
||||
for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride)
|
||||
{
|
||||
auto scoreOffset = tokenIdx * params.mNumExperts;
|
||||
int32_t warpMaxExpertIdx[MaxNumTopExperts];
|
||||
InputT warpMaxScore[MaxNumTopExperts];
|
||||
|
||||
routingTopKExperts<InputT, MaxNumExperts / WarpSize>(
|
||||
warp, warpMaxScore, warpMaxExpertIdx, laneIdx, params.mNumExperts, params.mPtrScores + scoreOffset);
|
||||
|
||||
if (cute::elect_one_sync())
|
||||
{
|
||||
auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})};
|
||||
TypePacked packedScore{finalScore, static_cast<int16_t>(warpMaxExpertIdx[0])};
|
||||
params.mPtrExpertIdx[tokenIdx] = packedScore;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run(Data const& data, void* stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr,
|
||||
"Routing kernel requires at least one input parameter");
|
||||
TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr
|
||||
&& data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr,
|
||||
"Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers");
|
||||
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d",
|
||||
MaxNumTopExperts, data.mTopK);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxNumExperts,
|
||||
"Routing kernel expects #experts %d to be at most max #experts %d", data.mNumExperts, MaxNumExperts);
|
||||
static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads");
|
||||
static_assert(MaxNumExperts <= NumThreadsHist, "#experts must be bounded by #threads");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
|
||||
TLLM_CHECK_WITH_INFO(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2);
|
||||
|
||||
bool const useSingleWarp = (data.mPtrScores == nullptr && data.mNumTokens <= WarpKernelMaxNumTokens)
|
||||
|| data.mNumTokens < WarpKernelMaxNumTokens;
|
||||
bool const useSingleCluster
|
||||
= data.mNumTokens <= (data.mPtrScores != nullptr ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster);
|
||||
if (!useSingleCluster)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
data.mPtrExpertIdx != nullptr, "When #tokens is large, `mPtrExpertIdx` is a required input.");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input.");
|
||||
}
|
||||
|
||||
if (useSingleWarp)
|
||||
{
|
||||
LAUNCH_ROUTING(data,
|
||||
/*coopLaunch=*/false, routingIndicesWarpKernel, 1, WarpSize,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
stream);
|
||||
}
|
||||
else if (useSingleCluster)
|
||||
{
|
||||
LAUNCH_ROUTING(data,
|
||||
/*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
const uint32_t expandedIdxSize = data.mNumTokens * data.mTopK;
|
||||
|
||||
const uint32_t histogramEltsPerBlock = 8 * NumThreadsHist;
|
||||
const uint32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreadsHist;
|
||||
|
||||
// Limit grid size (all kernels use a grid-stride loop).
|
||||
const uint32_t maxNumBlocks = 1024;
|
||||
|
||||
int const numBlocksHistogram
|
||||
= std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks);
|
||||
int const numBlocksOffsets
|
||||
= std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks);
|
||||
|
||||
if (data.mPtrScores != nullptr)
|
||||
{
|
||||
LAUNCH_ROUTING(data,
|
||||
/*coopLaunch=*/false, routingIndicesHistogramScoresKernel, maxNumBlocks, NumThreadsHist,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Reset the global histograms.
|
||||
TLLM_CUDA_CHECK(cudaMemsetAsync(data.mPtrExpertCounts, 0,
|
||||
static_cast<size_t>(2 * NumThreads) * sizeof(int32_t), (cudaStream_t) stream));
|
||||
}
|
||||
LAUNCH_ROUTING(data,
|
||||
/*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, NumThreadsHist,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
stream);
|
||||
LAUNCH_ROUTING(data,
|
||||
/*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, NumThreadsHist,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace routingLlama4
|
||||
} // namespace moe::dev::routing
|
||||
@ -0,0 +1,294 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "RoutingKernel.cuh"
|
||||
|
||||
namespace moe::dev::routing
|
||||
{
|
||||
namespace routingRenormalize
|
||||
{
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static constexpr int NumThreads = 1024;
|
||||
static constexpr int NumWarps = NumThreads / WarpSize;
|
||||
static constexpr int MaxNumTopExperts = 8;
|
||||
static constexpr int MaxNumExperts = 128;
|
||||
static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads;
|
||||
static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps;
|
||||
|
||||
template <typename DataType, typename InputType, int VecSize, bool DoSoftmaxBeforeTopK>
|
||||
__forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile<WarpSize> const& warp,
|
||||
DataType (&score)[VecSize], int32_t (&idx)[VecSize], DataType (&warpTopKScore)[MaxNumTopExperts],
|
||||
int32_t (&warpTopKExpertIdx)[MaxNumTopExperts], int32_t const laneIdx, int32_t const numExperts, int32_t topK,
|
||||
InputType const* ptrScores, bool const normTopkProb)
|
||||
{
|
||||
DataType minScore = DataType{-INFINITY};
|
||||
|
||||
for (int i = 0; i < VecSize; i++)
|
||||
{
|
||||
auto expertIdx = i * WarpSize + laneIdx;
|
||||
auto newScore = expertIdx < numExperts ? static_cast<DataType>(ptrScores[expertIdx]) : minScore;
|
||||
score[i] = newScore;
|
||||
idx[i] = expertIdx;
|
||||
}
|
||||
if constexpr (DoSoftmaxBeforeTopK)
|
||||
{
|
||||
calcSoftmax(warp, score);
|
||||
}
|
||||
|
||||
// Get the top-k scores and their corresponding expert indices
|
||||
topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, score, idx, minScore, topK);
|
||||
|
||||
// Normalize the scores
|
||||
if constexpr (DoSoftmaxBeforeTopK)
|
||||
{
|
||||
float sum = float{1.f};
|
||||
if (normTopkProb)
|
||||
{
|
||||
sum = static_cast<float>(laneIdx < topK ? warpTopKScore[laneIdx] : 0);
|
||||
sum = cg::reduce(warp, sum, cg::plus<float>());
|
||||
}
|
||||
if (laneIdx < topK)
|
||||
{
|
||||
warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto softmaxScore = calcSoftmax(warp, laneIdx < topK ? warpTopKScore[laneIdx] : minScore, laneIdx, topK);
|
||||
if (laneIdx < topK)
|
||||
{
|
||||
warpTopKScore[laneIdx] = softmaxScore;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KernelParams>
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads)
|
||||
routingIndicesClusterKernel(KernelParams params)
|
||||
{
|
||||
// number of tokens/expanded idx is bounded by total number of warps
|
||||
using OutputT = typename KernelParams::OutputT;
|
||||
using InputT = typename KernelParams::InputT;
|
||||
|
||||
using BaseType = std::conditional_t<KernelParams::DoSoftmaxBeforeTopK, float, InputT>;
|
||||
using TypePacked = PackedScoreIdx<BaseType>;
|
||||
|
||||
static constexpr int VecSize = MaxNumExperts / WarpSize;
|
||||
// we assume that #experts is a multiple of 4, so VecSize must be 4.
|
||||
static_assert(VecSize == 4);
|
||||
|
||||
__shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * MaxNumTopExperts];
|
||||
|
||||
uint32_t const clusterBlockRank = blockIdx.x;
|
||||
|
||||
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
||||
int32_t const laneIdx = cutlass::arch::LaneId();
|
||||
|
||||
auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx;
|
||||
auto scoreOffset = warpTokenIdx * params.mNumExperts;
|
||||
bool validToken = warpTokenIdx < params.mNumTokens;
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WarpSize>(block);
|
||||
|
||||
// then wait on primary grid
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
|
||||
if (params.mPtrScores != nullptr)
|
||||
{
|
||||
// in this case, each warp represents a token
|
||||
BaseType score[VecSize];
|
||||
int32_t idx[VecSize];
|
||||
|
||||
BaseType warpTopKScore[MaxNumTopExperts];
|
||||
int32_t warpTopKExpertIdx[MaxNumTopExperts];
|
||||
|
||||
BaseType minScore = BaseType{-INFINITY};
|
||||
if (validToken)
|
||||
{
|
||||
routingTopKExperts<BaseType, InputT, VecSize, KernelParams::DoSoftmaxBeforeTopK>(warp, score, idx,
|
||||
warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK,
|
||||
params.mPtrScores + scoreOffset, params.mNormTopkProb);
|
||||
|
||||
if (laneIdx < params.mTopK)
|
||||
{
|
||||
smemPackedScoreIdx[warpIdx * params.mTopK + laneIdx]
|
||||
= TypePacked{warpTopKScore[laneIdx], static_cast<int16_t>(warpTopKExpertIdx[laneIdx])};
|
||||
}
|
||||
} // end if (validToken)
|
||||
|
||||
// make packed scores available to all threads in cluster
|
||||
__cluster_barrier_arrive();
|
||||
__cluster_barrier_wait();
|
||||
}
|
||||
|
||||
routingPermutation<KernelParams, BaseType, NumThreads, NumWarps, MaxNumTopExperts,
|
||||
/*LoadExpertIdxFromGlobal=*/false>(params, smemPackedScoreIdx, warpIdx, clusterBlockRank);
|
||||
}
|
||||
#else
|
||||
__global__ void __launch_bounds__(NumThreads) routingIndicesClusterKernel(KernelParams /* params */)
|
||||
{
|
||||
assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures");
|
||||
}
|
||||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// this kernel is needed in case we have scores as input for the histogram kernel
|
||||
template <typename KernelParams>
|
||||
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresKernel(KernelParams params)
|
||||
{
|
||||
using OutputT = typename KernelParams::OutputT;
|
||||
using InputT = typename KernelParams::InputT;
|
||||
using BaseType = std::conditional_t<KernelParams::DoSoftmaxBeforeTopK, float, InputT>;
|
||||
|
||||
static constexpr int VecSize = MaxNumExperts / WarpSize;
|
||||
// we assume that #experts is a multiple of 4, so VecSize must be 4.
|
||||
static_assert(VecSize == 4);
|
||||
|
||||
int32_t const laneIdx = cutlass::arch::LaneId();
|
||||
int32_t const warpIdx = threadIdx.x / WarpSize;
|
||||
int32_t const globalWarpIdx = blockIdx.x * NumWarpsHist + warpIdx;
|
||||
int32_t const globalWarpStride = gridDim.x * NumWarpsHist;
|
||||
BaseType minScore = BaseType{-INFINITY};
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WarpSize>(block);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// Wait on primary grid.
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaGridDependencySynchronize();
|
||||
}
|
||||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
|
||||
// initialize the mPtrExpertCounts
|
||||
int32_t expertCountsNum = 2 * params.mNumExperts;
|
||||
int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x;
|
||||
int32_t globalThreadStride = gridDim.x * NumThreads;
|
||||
initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// Trigger secondary kernel.
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
|
||||
// in this case, each warp represents a token, and we use a grid-stride loop
|
||||
// over all warps/tokens
|
||||
BaseType allScores[VecSize];
|
||||
int32_t allExpertIdx[VecSize];
|
||||
BaseType warpTopKScore[MaxNumTopExperts];
|
||||
int32_t warpTopKExpertIdx[MaxNumTopExperts];
|
||||
for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride)
|
||||
{
|
||||
auto scoreOffset = tokenIdx * params.mNumExperts;
|
||||
|
||||
routingTopKExperts<BaseType, InputT, VecSize, KernelParams::DoSoftmaxBeforeTopK>(warp, allScores, allExpertIdx,
|
||||
warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK,
|
||||
params.mPtrScores + scoreOffset, params.mNormTopkProb);
|
||||
|
||||
if (laneIdx < params.mTopK)
|
||||
{
|
||||
PackedScoreIdx<OutputT> packedScore{
|
||||
static_cast<OutputT>(warpTopKScore[laneIdx]), static_cast<int16_t>(warpTopKExpertIdx[laneIdx])};
|
||||
params.mPtrExpertIdx[tokenIdx * params.mTopK + laneIdx] = packedScore;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run(Data const& data, void* stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr,
|
||||
"Routing kernel requires at least one input parameter");
|
||||
TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr
|
||||
&& data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr,
|
||||
"Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers");
|
||||
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d",
|
||||
MaxNumTopExperts, data.mTopK);
|
||||
TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxNumExperts,
|
||||
"Routing kernel expects #experts %d to be at most max #experts %d", data.mNumExperts, MaxNumExperts);
|
||||
static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads");
|
||||
static_assert(MaxNumExperts <= NumThreadsHist, "#experts must be bounded by #threads");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
|
||||
TLLM_CHECK_WITH_INFO(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2);
|
||||
|
||||
bool const useSingleCluster
|
||||
= data.mNumTokens <= (data.mPtrScores != nullptr ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster);
|
||||
|
||||
if (!useSingleCluster)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
data.mPtrExpertIdx != nullptr, "When #tokens is large, `mPtrExpertIdx` is a required input.");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input.");
|
||||
}
|
||||
|
||||
if (useSingleCluster)
|
||||
{
|
||||
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
stream, data.mDoSoftmaxBeforeTopK, /*forceFloatInput=*/false);
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t const expandedIdxSize = data.mNumTokens * data.mTopK;
|
||||
|
||||
uint32_t const histogramEltsPerBlock = 8 * NumThreadsHist;
|
||||
uint32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreadsHist;
|
||||
|
||||
// Limit grid size (all kernels use a grid-stride loop).
|
||||
uint32_t const maxNumBlocks = 1024;
|
||||
|
||||
int const numBlocksHistogram
|
||||
= std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks);
|
||||
int const numBlocksOffsets
|
||||
= std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks);
|
||||
|
||||
if (data.mPtrScores != nullptr)
|
||||
{
|
||||
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, false, routingIndicesHistogramScoresKernel, maxNumBlocks,
|
||||
NumThreadsHist,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
stream, data.mDoSoftmaxBeforeTopK, /*forceFloatInput=*/false);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Reset the global histograms.
|
||||
TLLM_CUDA_CHECK(cudaMemsetAsync(data.mPtrExpertCounts, 0,
|
||||
static_cast<size_t>(2 * NumThreads) * sizeof(int32_t), (cudaStream_t) stream));
|
||||
}
|
||||
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, false, routingIndicesHistogramKernel, numBlocksHistogram, NumThreadsHist,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
stream, data.mDoSoftmaxBeforeTopK, /*forceFloatInput=*/false);
|
||||
LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, false, routingIndicesOffsetsKernel, numBlocksOffsets, NumThreadsHist,
|
||||
/*smemSize=*/0, // No dynamic smem
|
||||
stream, data.mDoSoftmaxBeforeTopK, /*forceFloatInput=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace routingRenormalize
|
||||
} // namespace moe::dev::routing
|
||||
@ -77,23 +77,17 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
|
||||
routingData.mPtrExpertCounts = expertCountHistogram;
|
||||
routingData.mPtrPermutedIdxSize = permutedIdxSize;
|
||||
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
|
||||
routingData.mPtrPermutedIdxToExpandedIdx = permutedIdxToExpandedIdx;
|
||||
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
|
||||
routingData.mPtrNumTokensPerExpert = numTokensPerExpert;
|
||||
routingData.mPtrExpertWeights = expertWeights;
|
||||
|
||||
routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx;
|
||||
routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit;
|
||||
routingData.mPtrNumNonExitingCtas = numNonExitingCtas;
|
||||
routingData.mAllToAllRouteAct = false;
|
||||
|
||||
// input:
|
||||
// routingData.mPtrRoutingWeights = args.mRoutingWeights; // routing weights (don't need if not using gemm)
|
||||
routingData.mPtrRoutingBias = routingBias;
|
||||
routingData.mPtrScores = reinterpret_cast<float*>(routingLogits);
|
||||
// routingData.mPtrIn = args.mInputActs;
|
||||
routingData.mNumTokens = numTokens;
|
||||
// routingData.mHiddenDim = args.mHiddenDim;
|
||||
routingData.mNumExperts = numExperts;
|
||||
routingData.mNumExpertGroups = nGroup;
|
||||
routingData.mNumLimitedGroups = topkGroup;
|
||||
@ -122,9 +116,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
|
||||
routingData.mPtrExpertCounts = expertCountHistogram;
|
||||
routingData.mPtrPermutedIdxSize = permutedIdxSize;
|
||||
routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx;
|
||||
// routingData.mPtrPermutedIdxToExpandedIdx = permuted_idx_to_expanded_idx;
|
||||
routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx;
|
||||
// routingData.mPtrNumTokensPerExpert = num_tokens_per_expert;
|
||||
routingData.mPtrExpertWeights = expertWeights;
|
||||
|
||||
routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx;
|
||||
|
||||
@ -37,11 +37,6 @@ protected:
|
||||
TensorPtr mPtrRoutingBiasHost;
|
||||
TensorPtr mPtrRoutingBiasDevice;
|
||||
|
||||
// Add this variable because the definition of mPtrExpertIdx is "int32_t*" for Deepseek
|
||||
//@TODO: remove this variable after refactoring
|
||||
TensorPtr mPtrDeepseekExpertIdxHost;
|
||||
TensorPtr mPtrDeepseekExpertIdxDevice;
|
||||
|
||||
private:
|
||||
// private methods
|
||||
static inline float sigmoid_accurate(float x)
|
||||
@ -56,7 +51,6 @@ private:
|
||||
// sigmoid / bias activated scores cannot be negative
|
||||
static constexpr float invalidScoreFloat = -1.F;
|
||||
const T invalidScore = T{invalidScoreFloat};
|
||||
int32_t* expIdxHostPtr = bufferCast<int32_t>(*mPtrDeepseekExpertIdxHost);
|
||||
|
||||
float scoreSigmoid[param.numExperts];
|
||||
for (int it = 0; it < param.numTokens; ++it)
|
||||
@ -129,7 +123,6 @@ private:
|
||||
// Convert back to io_dtype and store the topk expert results in hostData.mPtrExpertIdx
|
||||
for (int ie = 0; ie < param.topK; ++ie)
|
||||
{
|
||||
expIdxHostPtr[it * param.topK + ie] = static_cast<int32_t>(finalTopkExperts[ie].idx);
|
||||
if (param.getExpWeights)
|
||||
{
|
||||
bufferCast<T>(*this->mPtrExpertWeightsHost)[it * param.topK + ie]
|
||||
@ -152,11 +145,6 @@ private:
|
||||
= mBufferManager->pinned(ITensor::makeShape({param.numExperts}), TRTDataType<T>::value);
|
||||
this->mPtrRoutingBiasDevice
|
||||
= mBufferManager->gpu(ITensor::makeShape({param.numExperts}), TRTDataType<T>::value);
|
||||
|
||||
this->mPtrDeepseekExpertIdxHost
|
||||
= mBufferManager->pinned(ITensor::makeShape({param.numTokens * param.topK}), TRTDataType<int32_t>::value);
|
||||
this->mPtrDeepseekExpertIdxDevice
|
||||
= mBufferManager->gpu(ITensor::makeShape({param.numTokens * param.topK}), TRTDataType<int32_t>::value);
|
||||
}
|
||||
|
||||
void setupBuffers(RoutingKernelTestParam const& param) override
|
||||
@ -177,8 +165,6 @@ private:
|
||||
routingData.mDtypeExpW = btg::Dtype::Bfloat16;
|
||||
routingData.mPtrScores = bufferCast<float>(*this->mPtrScoresDevice);
|
||||
routingData.mPtrRoutingBias = bufferCast<T>(*this->mPtrRoutingBiasDevice);
|
||||
//@todo: remove this line after refactoring
|
||||
routingData.mPtrExpertIdx = bufferCast<int32_t>(*this->mPtrDeepseekExpertIdxDevice);
|
||||
|
||||
routingData.mNumExpertGroups = param.nGroup;
|
||||
routingData.mNumLimitedGroups = param.topkGroup;
|
||||
@ -193,62 +179,13 @@ private:
|
||||
setParams(param, routingData);
|
||||
moe::dev::routing::routingDeepSeek::run(routingData, mStream->get());
|
||||
}
|
||||
|
||||
void verifyExpertRoutingIndices(RoutingKernelTestParam const& param)
|
||||
{
|
||||
// for permuted index, there is non-determinism, thus we check set-equality
|
||||
// for this, we go over every expert and retrieve the tokens routed to it
|
||||
// we then get the associated indexes and check set equality
|
||||
auto const expandedIdxToPermutedIdxHost
|
||||
= mBufferManager->copyFrom(*this->mPtrExpandedIdxToPermutedIdxDevice, MemoryType::kCPU);
|
||||
auto const hostExpToPermTest = bufferCast<int32_t>(*expandedIdxToPermutedIdxHost);
|
||||
|
||||
auto const permutedIdxToTokenIdxHost
|
||||
= mBufferManager->copyFrom(*this->mPtrPermutedIdxToTokenIdxDevice, MemoryType::kCPU);
|
||||
auto const hostPermToTokTest = bufferCast<int32_t>(*permutedIdxToTokenIdxHost);
|
||||
mStream->synchronize();
|
||||
|
||||
int32_t* expIdxToPermHostptr = bufferCast<int32_t>(*this->mPtrExpandedIdxToPermutedIdxHost);
|
||||
int32_t* expIdxHostPtr = bufferCast<int32_t>(*this->mPtrDeepseekExpertIdxHost);
|
||||
|
||||
for (int ie = 0; ie < param.numExperts; ++ie)
|
||||
{
|
||||
std::set<int32_t> permutedIdx, permutedIdxTest;
|
||||
std::set<int32_t> tokenIdx, tokenIdxTest;
|
||||
auto localExpertIdx = ie - param.localExpertsStartIdx;
|
||||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < param.numLocalExperts
|
||||
&& (localExpertIdx & param.localExpertsStrideLog2) == 0;
|
||||
|
||||
for (int it = 0; it < param.numTokens * param.topK; ++it)
|
||||
{
|
||||
if (expIdxHostPtr[it] == ie)
|
||||
{
|
||||
int const permIdx = isLocalExpert ? expIdxToPermHostptr[it] : int32_t{-1};
|
||||
permutedIdx.insert(permIdx);
|
||||
if (isLocalExpert)
|
||||
{
|
||||
tokenIdx.insert(it / param.topK);
|
||||
}
|
||||
|
||||
int const permIdxTest = hostExpToPermTest[it];
|
||||
permutedIdxTest.insert(permIdxTest);
|
||||
if (isLocalExpert)
|
||||
{
|
||||
tokenIdxTest.insert(hostPermToTokTest[permIdxTest]);
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(checkSetEqual(ie, permutedIdx, permutedIdxTest, "permuted idx"), true);
|
||||
EXPECT_EQ(checkSetEqual(ie, tokenIdx, tokenIdxTest, "token idx"), true);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(RoutingDeepSeekKernelTest, Bf16Types);
|
||||
|
||||
TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/10,
|
||||
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1024, // 10
|
||||
/*numExperts=*/128, /*topK=*/8,
|
||||
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
|
||||
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
|
||||
@ -279,4 +216,48 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization)
|
||||
this->runTest(param);
|
||||
};
|
||||
|
||||
// TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization)
|
||||
// {
|
||||
// RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/20300,
|
||||
// /*numExperts=*/128, /*topK=*/8,
|
||||
// /*expertParallelization=*/1, /*expertParallelizationId=*/0,
|
||||
// /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
|
||||
// /*usePdl=*/true, /*getExpWeights=*/true,
|
||||
// /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
|
||||
// this->runTest(param);
|
||||
// };
|
||||
|
||||
TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationTop2)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/10,
|
||||
/*numExperts=*/128, /*topK=*/2,
|
||||
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
|
||||
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
|
||||
/*usePdl=*/true, /*getExpWeights=*/true,
|
||||
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
|
||||
this->runTest(param);
|
||||
};
|
||||
|
||||
TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParallelizationTop2)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/100,
|
||||
/*numExperts=*/128, /*topK=*/2,
|
||||
/*expertParallelization=*/2, /*expertParallelizationId=*/1,
|
||||
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
|
||||
/*usePdl=*/true, /*getExpWeights=*/true,
|
||||
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
|
||||
this->runTest(param);
|
||||
};
|
||||
|
||||
TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop2)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030,
|
||||
/*numExperts=*/128, /*topK=*/2,
|
||||
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
|
||||
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
|
||||
/*usePdl=*/true, /*getExpWeights=*/true,
|
||||
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
|
||||
this->runTest(param);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -51,33 +51,10 @@ private:
|
||||
int16_t newIdx = static_cast<int16_t>(ie);
|
||||
score = static_cast<float>(bufferCast<T>(*this->mPtrScoresHost)[it * param.numExperts + ie]);
|
||||
|
||||
if (param.doSoftmaxBeforeTopK && score > maxScore)
|
||||
{
|
||||
maxScore = score;
|
||||
}
|
||||
|
||||
PackedFloat si{static_cast<float>(score), newIdx};
|
||||
expWeightsIdx[ie] = si;
|
||||
}
|
||||
|
||||
if (param.doSoftmaxBeforeTopK)
|
||||
{
|
||||
// Run softmax before topk
|
||||
for (int ie = 0; ie < param.numExperts; ++ie)
|
||||
{
|
||||
expWeightsIdx[ie].score
|
||||
= static_cast<float>(std::exp(static_cast<float>(expWeightsIdx[ie].score) - maxScore));
|
||||
sum += expWeightsIdx[ie].score;
|
||||
}
|
||||
|
||||
for (int ie = 0; ie < param.numExperts; ++ie)
|
||||
{
|
||||
float score = static_cast<float>(expWeightsIdx[ie].score);
|
||||
score /= sum;
|
||||
expWeightsIdx[ie].score = static_cast<float>(score);
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate the top-k scores and indices
|
||||
std::partial_sort_copy(expWeightsIdx, expWeightsIdx + param.numExperts, expIdx, expIdx + param.topK,
|
||||
[](PackedFloat const& a, PackedFloat const& b)
|
||||
@ -148,7 +125,7 @@ TYPED_TEST(RoutingLlama4KernelTest, WarpLevelParallelization)
|
||||
|
||||
TYPED_TEST(RoutingLlama4KernelTest, ClusterLevelParallelization)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/100,
|
||||
RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/10,
|
||||
/*numExperts=*/128, /*topK=*/1,
|
||||
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
|
||||
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
|
||||
|
||||
@ -165,7 +165,6 @@ private:
|
||||
routingData.mNormTopkProb = param.routingMethod == RoutingMethodType::RenormalizeNaive;
|
||||
|
||||
routingData.mPtrScores = bufferCast<T>(*this->mPtrScoresDevice);
|
||||
routingData.mPtrExpertIdx = reinterpret_cast<PackedType*>(bufferCast<int8_t>(*this->mPtrExpertIdxDevice));
|
||||
}
|
||||
|
||||
void callTestedFunction(
|
||||
@ -214,7 +213,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal
|
||||
|
||||
TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelization)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/300,
|
||||
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000,
|
||||
/*numExperts=*/128, /*topK=*/8,
|
||||
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
|
||||
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
|
||||
@ -223,4 +222,47 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelization)
|
||||
this->runTest(param);
|
||||
};
|
||||
|
||||
TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationTop4)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/10,
|
||||
/*numExperts=*/128, /*topK=*/4,
|
||||
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
|
||||
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
|
||||
/*usePdl=*/true, /*getExpWeights=*/true,
|
||||
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
|
||||
this->runTest(param);
|
||||
};
|
||||
|
||||
TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertParallelizationTop4)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/100,
|
||||
/*numExperts=*/128, /*topK=*/4,
|
||||
/*expertParallelization=*/2, /*expertParallelizationId=*/1,
|
||||
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
|
||||
/*usePdl=*/true, /*getExpWeights=*/true,
|
||||
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
|
||||
this->runTest(param);
|
||||
};
|
||||
|
||||
TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormalizeNaiveTop4)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::RenormalizeNaive, /*numTokens=*/10,
|
||||
/*numExperts=*/128, /*topK=*/4,
|
||||
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
|
||||
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
|
||||
/*usePdl=*/true, /*getExpWeights=*/true,
|
||||
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
|
||||
this->runTest(param);
|
||||
};
|
||||
|
||||
TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationTop4)
|
||||
{
|
||||
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000,
|
||||
/*numExperts=*/128, /*topK=*/4,
|
||||
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
|
||||
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
|
||||
/*usePdl=*/true, /*getExpWeights=*/true,
|
||||
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8);
|
||||
this->runTest(param);
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
@ -336,6 +336,8 @@ void RoutingKernelTest<T>::runTest(RoutingKernelTestParam const& param)
|
||||
{
|
||||
GTEST_SKIP() << "Skip test due to compute capability requirement.";
|
||||
}
|
||||
// Set seed to time-based seed
|
||||
resetToTimeBasedSeed();
|
||||
|
||||
// Allocate buffers
|
||||
allocateBuffers(param);
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <memory> //@todo check the usage of this
|
||||
#include <random> //@todo check the usage of this
|
||||
|
||||
@ -320,6 +321,27 @@ template <typename T>
|
||||
class RoutingKernelTest : public testing::Test
|
||||
{
|
||||
public:
|
||||
// Add a method to generate time-based seed
|
||||
static uint32_t generateTimeBasedSeed()
|
||||
{
|
||||
std::random_device rd;
|
||||
uint32_t seed = rd();
|
||||
TLLM_LOG_DEBUG("Random device seed: %u", seed);
|
||||
return seed;
|
||||
}
|
||||
|
||||
// Method to set seed after construction
|
||||
void setSeed(uint32_t seed)
|
||||
{
|
||||
mSeed = seed;
|
||||
}
|
||||
|
||||
// Method to reset to time-based seed
|
||||
void resetToTimeBasedSeed()
|
||||
{
|
||||
mSeed = generateTimeBasedSeed();
|
||||
}
|
||||
|
||||
void SetUp() override;
|
||||
void TearDown() override;
|
||||
|
||||
@ -372,7 +394,7 @@ protected:
|
||||
routingData.mPtrExpandedIdxToPermutedIdx = bufferCast<int32_t>(*mPtrExpandedIdxToPermutedIdxDevice);
|
||||
routingData.mPtrPermutedIdxToTokenIdx = bufferCast<int32_t>(*mPtrPermutedIdxToTokenIdxDevice);
|
||||
routingData.mPtrExpertWeights = bufferCast<T>(*mPtrExpertWeightsDevice);
|
||||
// routingData.mPtrExpertIdx = reinterpret_cast<PackedType*>(bufferCast<int8_t>(*mPtrExpertIdxDevice));
|
||||
routingData.mPtrExpertIdx = reinterpret_cast<PackedType*>(bufferCast<int8_t>(*mPtrExpertIdxDevice));
|
||||
|
||||
// Set grouped gemm launch config buffers
|
||||
routingData.mPtrCtaIdxXyToBatchIdx = bufferCast<int32_t>(*mPtrCtaIdxXyToBatchIdxDevice);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user