mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
502 lines
22 KiB
Plaintext
502 lines
22 KiB
Plaintext
/*
|
|
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "RoutingKernel.cuh"
|
|
|
|
namespace moe::dev::routing
|
|
{
|
|
namespace routingLlama4
|
|
{
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static constexpr int NumThreads = 1024;
|
|
static constexpr int NumWarps = NumThreads / WarpSize;
|
|
static constexpr int MaxNumTopExperts = 1;
|
|
static constexpr int MaxNumExperts = 128;
|
|
static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads;
|
|
static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps;
|
|
static constexpr int WarpKernelSmemStride = 33;
|
|
// with further optimization to `routingIndicesWarpKernel`, this limit may
|
|
// increase. For now, it is a good cut-off point for when the block-wise
|
|
// operations are more efficient end-to-end.
|
|
static constexpr int WarpKernelMaxNumTokens = 4;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename DataType, int VecSize>
|
|
__forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile<WarpSize> const& warp,
|
|
DataType (&warpMaxScore)[MaxNumTopExperts], int32_t (&warpMaxExpertIdx)[MaxNumTopExperts], int32_t const laneIdx,
|
|
int32_t const numExperts, DataType const* ptrScores)
|
|
{
|
|
DataType minScore = DataType{-INFINITY};
|
|
DataType maxScore = minScore;
|
|
int32_t maxExpertIdx{-1};
|
|
using DataTypeVec = std::conditional_t<sizeof(DataType) == 2, float2, float4>;
|
|
|
|
// Non-vectorized loading: directly access ptrScores with expertIdx
|
|
for (int i = 0; i < VecSize; ++i)
|
|
{
|
|
auto expertIdx = i * WarpSize + laneIdx;
|
|
auto newScore = expertIdx < numExperts ? ptrScores[expertIdx] : minScore;
|
|
// note: use `>=` s.t. highest index always wins, just like in `reduceTopK`
|
|
if (newScore > maxScore)
|
|
{
|
|
maxScore = newScore;
|
|
maxExpertIdx = expertIdx;
|
|
}
|
|
}
|
|
|
|
topk::reduceTopK(warp, warpMaxScore, warpMaxExpertIdx, maxScore, maxExpertIdx, minScore);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename KernelParams>
|
|
__global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParams params)
|
|
{
|
|
// types used in this kernel
|
|
using OutputT = typename KernelParams::OutputT;
|
|
using InputT = typename KernelParams::InputT;
|
|
using TypePacked = PackedScoreIdx<OutputT>;
|
|
// use the default cub warp-scan, with shfl
|
|
using Scan = cub::WarpScan<int32_t>;
|
|
__shared__ typename Scan::TempStorage tempStorage;
|
|
|
|
// each thread encodes 4 experts in one `int32_t`. The assumption is that
|
|
// we don't have more than 127 tokens, but `WarpKernelMaxNumTokens` must be
|
|
// smaller than that because other approaches will be more efficient for
|
|
// 127 tokens.
|
|
static constexpr int ExpertsPerThread = sizeof(int32_t);
|
|
static_assert(WarpKernelMaxNumTokens <= 127);
|
|
// this is a full table of which token is routed to which expert.
|
|
// the assumption here is that there are no more than 128 experts.
|
|
// we use a stride of 33 instead of 32 to avoid shared memory bank conflicts.
|
|
__shared__ int32_t __attribute((aligned(128)))
|
|
smemExpertTokenCountFull[WarpKernelMaxNumTokens][WarpKernelSmemStride];
|
|
static_assert(WarpKernelSmemStride == WarpSize + 1);
|
|
static_assert(MaxNumExperts / sizeof(int32_t) <= WarpSize);
|
|
|
|
// values needed for the top-1 reduction, if required
|
|
InputT minScore = InputT{-INFINITY};
|
|
auto block = cg::this_thread_block();
|
|
auto warp = cg::tiled_partition<WarpSize>(block);
|
|
|
|
#pragma unroll
|
|
for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens; ++tokenIdx)
|
|
{
|
|
// reset full shared memory field to 0
|
|
smemExpertTokenCountFull[tokenIdx][threadIdx.x] = 0;
|
|
}
|
|
__syncwarp();
|
|
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
// then wait on primary grid
|
|
if constexpr (KernelParams::UsePdl)
|
|
{
|
|
cudaGridDependencySynchronize();
|
|
}
|
|
#endif
|
|
|
|
if (params.mPtrScores != nullptr)
|
|
{
|
|
// if we use `mPtrScores` as input, we need to perform the top-1 reduction
|
|
// for each token, we load the scores then use `reduceTopK` for this.
|
|
// each thread works on 4 experts, so a local reduction is done before
|
|
for (int tokenIdx = 0; tokenIdx < params.mNumTokens; ++tokenIdx)
|
|
{
|
|
auto scoreOffset = tokenIdx * params.mNumExperts;
|
|
int32_t warpMaxExpertIdx[MaxNumTopExperts];
|
|
InputT warpMaxScore[MaxNumTopExperts];
|
|
|
|
// Use routingTopKExperts function instead of inline logic
|
|
routingTopKExperts<InputT, ExpertsPerThread>(
|
|
warp, warpMaxScore, warpMaxExpertIdx, threadIdx.x, params.mNumExperts, params.mPtrScores + scoreOffset);
|
|
|
|
if (cute::elect_one_sync())
|
|
{
|
|
// one thread updates the count linking token to chosen expert
|
|
auto expertTokenCount = 0;
|
|
setBits</* IsZero= */ true>(expertTokenCount, 1, warpMaxExpertIdx[0] % ExpertsPerThread);
|
|
smemExpertTokenCountFull[tokenIdx][warpMaxExpertIdx[0] / ExpertsPerThread] = expertTokenCount;
|
|
// we also compute the final score here and write it out if required
|
|
auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})};
|
|
if (params.mPtrExpertWeights != nullptr)
|
|
{
|
|
params.mPtrExpertWeights[tokenIdx] = finalScore;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// if we do not have `mPtrScores` as input, we expect that `mPtrExpertWeights`
|
|
// contains the top-1 packed score and index already.
|
|
// Each thread represents a token here, and we extract the relevant score
|
|
// The assumption is that the #tokens is limited by warp-size
|
|
static_assert(WarpKernelMaxNumTokens <= WarpSize);
|
|
TypePacked scoreIdx = threadIdx.x < params.mNumTokens ? params.mPtrExpertIdx[threadIdx.x] : TypePacked{};
|
|
int32_t expertTokenCount = 0;
|
|
setBits</* IsZero= */ true>(expertTokenCount, 1, scoreIdx.idx % ExpertsPerThread);
|
|
if (threadIdx.x < params.mNumTokens)
|
|
{
|
|
smemExpertTokenCountFull[threadIdx.x][scoreIdx.idx / ExpertsPerThread] = expertTokenCount;
|
|
}
|
|
// we also compute the final score here and write it out if required
|
|
auto finalScore = OutputT{sigmoid_accurate(float{scoreIdx.score})};
|
|
if (params.mPtrExpertWeights != nullptr && threadIdx.x < params.mNumTokens)
|
|
{
|
|
params.mPtrExpertWeights[threadIdx.x] = finalScore;
|
|
}
|
|
}
|
|
|
|
// make the full table available to all threads
|
|
__syncwarp();
|
|
|
|
// at this point, each thread keeps a count of its 4 assigned experts in
|
|
// `expertCount`, as well as the offsets for all tokens w.r.t. these 4 experts
|
|
// in `expertOffset`.
|
|
int32_t expertCount = 0;
|
|
int32_t expertOffset[WarpKernelMaxNumTokens + 1];
|
|
#pragma unroll
|
|
for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens + 1; ++tokenIdx)
|
|
{
|
|
if (tokenIdx > params.mNumTokens)
|
|
break;
|
|
// simple reduction for `expertCount`, and scan for `expertOffset`
|
|
auto expertTokenCount = tokenIdx < params.mNumTokens ? smemExpertTokenCountFull[tokenIdx][threadIdx.x] : 0;
|
|
expertOffset[tokenIdx] = expertCount;
|
|
expertCount += expertTokenCount;
|
|
}
|
|
|
|
// at this point, we are ready for the scan across all experts to get the
|
|
// thread-wise offsets across experts
|
|
// first, we need to reduce across our 4 experts into `numCta`
|
|
int32_t numCta = 0;
|
|
#pragma unroll
|
|
for (int ii = 0; ii < ExpertsPerThread; ++ii)
|
|
{
|
|
auto count = getBits(expertCount, ii);
|
|
numCta += divUpLog2<int32_t>(count, params.mPaddingLog2);
|
|
}
|
|
// second, we perform the exclusive sum across the warp
|
|
int32_t ctaOffset;
|
|
int32_t numNonExitingCtas;
|
|
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
|
|
|
|
// finally, we perform a scan across our local experts, starting with the
|
|
// warp-wide scan result (`ctaOffset`)
|
|
auto ctaOffsetExp = ctaOffset;
|
|
#pragma unroll
|
|
for (int ii = 0; ii < ExpertsPerThread; ++ii)
|
|
{
|
|
auto count = getBits(expertCount, ii);
|
|
auto finalNumCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
|
|
auto expertIdx = threadIdx.x * ExpertsPerThread + ii;
|
|
// during the scan for expert offsets, we can already write out
|
|
// both `mPtrCtaIdxXyToBatchIdx` and `mPtrCtaIdxXyToMnLimit`
|
|
for (int cta = 0; cta < finalNumCta; ++cta)
|
|
{
|
|
params.mPtrCtaIdxXyToBatchIdx[ctaOffsetExp + cta] = expertIdx;
|
|
params.mPtrCtaIdxXyToMnLimit[ctaOffsetExp + cta]
|
|
= min(mulLog2<int32_t>(ctaOffsetExp + cta + 1, params.mPaddingLog2),
|
|
mulLog2<int32_t>(ctaOffsetExp, params.mPaddingLog2) + count);
|
|
}
|
|
ctaOffsetExp += finalNumCta;
|
|
}
|
|
|
|
// at this point, we can write out padded count from the warp-aggregate
|
|
if (cute::elect_one_sync())
|
|
{
|
|
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
|
|
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
|
|
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
|
|
}
|
|
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
|
|
// we can trigger the next kernel at this point
|
|
if constexpr (KernelParams::UsePdl)
|
|
{
|
|
cudaTriggerProgrammaticLaunchCompletion();
|
|
}
|
|
#endif
|
|
#endif
|
|
|
|
// at this point, all values for offsets are ready, except the final offsets
|
|
// within the padded index (`permutedIdx`)
|
|
// for this, we perform a scan similar to the one directly after the warp-scan:
|
|
// here, we keep the local offset for each of the thread's experts in a field
|
|
// of registers
|
|
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
|
int32_t finalExpertOffset[ExpertsPerThread];
|
|
finalExpertOffset[0] = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
|
|
#pragma unroll
|
|
for (int ii = 1; ii < ExpertsPerThread; ++ii)
|
|
{
|
|
finalExpertOffset[ii]
|
|
= finalExpertOffset[ii - 1] + divUpMulLog2<int32_t>(getBits(expertCount, ii - 1), params.mPaddingLog2);
|
|
}
|
|
|
|
#pragma unroll
|
|
for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens; ++tokenIdx)
|
|
{
|
|
// at this point, we can calculate the final index:
|
|
// we simply loop over all tokens, and all experts assigned to this thread.
|
|
// For each pair, we determine whether that token was routed to that expert
|
|
// based on whether the offset for that token changed.
|
|
// we can then easily compute the final `expertIdx` and `permutedIdx` relative
|
|
// to this token and expert, and write them out.
|
|
if (tokenIdx >= params.mNumTokens)
|
|
break;
|
|
|
|
#pragma unroll
|
|
for (int ii = 0; ii < ExpertsPerThread; ++ii)
|
|
{
|
|
// determine whether the offset for this expert and token changes
|
|
auto localOffsetToken = getBits(expertOffset[tokenIdx], ii);
|
|
auto isTokenRouted = getBits(expertOffset[tokenIdx + 1], ii) > localOffsetToken;
|
|
// the expert index of this expert
|
|
auto expertIdx = threadIdx.x * ExpertsPerThread + ii;
|
|
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
|
|
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
|
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
|
// the permuted index: we add the local offset relative to this expert and token
|
|
// to the global offset from the scan for this expert
|
|
auto permutedIdx = isLocalExpert ? finalExpertOffset[ii] + localOffsetToken : int32_t{-1};
|
|
// write out `mPtrExpandedIdxToPermutedIdx` if required
|
|
if (params.mPtrExpandedIdxToPermutedIdx != nullptr && isTokenRouted)
|
|
{
|
|
params.mPtrExpandedIdxToPermutedIdx[tokenIdx] = permutedIdx;
|
|
}
|
|
// write out `mPtrPermutedIdxToTokenIdx` if required
|
|
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert && isTokenRouted)
|
|
{
|
|
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
template <typename KernelParams>
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads)
|
|
routingIndicesClusterKernel(KernelParams params)
|
|
{
|
|
// number of tokens/expanded idx is bounded by total number of warps
|
|
using OutputT = typename KernelParams::OutputT;
|
|
using InputT = typename KernelParams::InputT;
|
|
using TypePacked = PackedScoreIdx<OutputT>;
|
|
__shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps];
|
|
|
|
uint32_t const clusterBlockRank = blockIdx.x;
|
|
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
|
int32_t const laneIdx = cutlass::arch::LaneId();
|
|
|
|
// TODO(mjoux): expand to more tokens (possibly)
|
|
auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx;
|
|
auto scoreOffset = warpTokenIdx * params.mNumExperts;
|
|
bool validToken = warpTokenIdx < params.mNumTokens;
|
|
InputT minScore = InputT{-INFINITY};
|
|
|
|
auto block = cg::this_thread_block();
|
|
auto warp = cg::tiled_partition<WarpSize>(block);
|
|
|
|
// then wait on primary grid
|
|
if constexpr (KernelParams::UsePdl)
|
|
{
|
|
cudaGridDependencySynchronize();
|
|
}
|
|
|
|
if (params.mPtrScores != nullptr)
|
|
{
|
|
// in this case, each warp represents a token
|
|
// we then exchange all token max scores, s.t. afterwards, each thread
|
|
// represents a token
|
|
InputT warpMaxScore[MaxNumTopExperts];
|
|
int32_t warpMaxExpertIdx[MaxNumTopExperts];
|
|
|
|
if (validToken)
|
|
{
|
|
routingTopKExperts<InputT, MaxNumExperts / WarpSize>(
|
|
warp, warpMaxScore, warpMaxExpertIdx, laneIdx, params.mNumExperts, params.mPtrScores + scoreOffset);
|
|
if (cute::elect_one_sync())
|
|
{
|
|
auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})};
|
|
TypePacked packedScore{finalScore, static_cast<int16_t>(warpMaxExpertIdx[0])};
|
|
smemPackedScoreIdx[warpIdx] = packedScore;
|
|
}
|
|
}
|
|
// make packed scores available to all threads in cluster
|
|
__cluster_barrier_arrive();
|
|
__cluster_barrier_wait();
|
|
}
|
|
|
|
routingPermutation<KernelParams, OutputT, NumThreads, NumWarps, MaxNumTopExperts,
|
|
/*LoadExpertIdxFromGlobal=*/false>(params, smemPackedScoreIdx, warpIdx, clusterBlockRank);
|
|
}
|
|
#else
|
|
__global__ void routingIndicesClusterKernel(KernelParams params)
|
|
{
|
|
assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures");
|
|
}
|
|
#endif
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// this kernel is needed in case we have scores as input for the histogram kernel
|
|
template <typename KernelParams>
|
|
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresKernel(KernelParams params)
|
|
{
|
|
using OutputT = typename KernelParams::OutputT;
|
|
using InputT = typename KernelParams::InputT;
|
|
using TypePacked = PackedScoreIdx<OutputT>;
|
|
static constexpr int VecSize = MaxNumExperts / WarpSize;
|
|
// we assume that #experts is a multiple of 4, so VecSize must be 4.
|
|
static_assert(VecSize == 4);
|
|
|
|
int32_t const laneIdx = cutlass::arch::LaneId();
|
|
int32_t const warpIdx = threadIdx.x / WarpSize;
|
|
int32_t const globalWarpIdx = blockIdx.x * NumWarpsHist + warpIdx;
|
|
int32_t const globalWarpStride = gridDim.x * NumWarpsHist;
|
|
InputT minScore = InputT{-INFINITY};
|
|
auto block = cg::this_thread_block();
|
|
auto warp = cg::tiled_partition<WarpSize>(block);
|
|
|
|
// initialize the mPtrExpertCounts
|
|
int32_t expertCountsNum = 2 * params.mNumExperts;
|
|
int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x;
|
|
int32_t globalThreadStride = gridDim.x * NumThreads;
|
|
initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0);
|
|
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
// Wait on primary grid and trigger secondary kernel.
|
|
if constexpr (KernelParams::UsePdl)
|
|
{
|
|
cudaGridDependencySynchronize();
|
|
cudaTriggerProgrammaticLaunchCompletion();
|
|
}
|
|
#endif
|
|
|
|
// in this case, each warp represents a token, and we use a grid-stride loop
|
|
// over all warps/tokens
|
|
for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride)
|
|
{
|
|
auto scoreOffset = tokenIdx * params.mNumExperts;
|
|
int32_t warpMaxExpertIdx[MaxNumTopExperts];
|
|
InputT warpMaxScore[MaxNumTopExperts];
|
|
|
|
routingTopKExperts<InputT, MaxNumExperts / WarpSize>(
|
|
warp, warpMaxScore, warpMaxExpertIdx, laneIdx, params.mNumExperts, params.mPtrScores + scoreOffset);
|
|
|
|
if (cute::elect_one_sync())
|
|
{
|
|
auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})};
|
|
TypePacked packedScore{finalScore, static_cast<int16_t>(warpMaxExpertIdx[0])};
|
|
params.mPtrExpertIdx[tokenIdx] = packedScore;
|
|
}
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void run(Data const& data, void* stream)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr,
|
|
"Routing kernel requires at least one input parameter");
|
|
TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr
|
|
&& data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr,
|
|
"Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers");
|
|
TLLM_CHECK_WITH_INFO(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d",
|
|
MaxNumTopExperts, data.mTopK);
|
|
TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxNumExperts,
|
|
"Routing kernel expects #experts %d to be at most max #experts %d", data.mNumExperts, MaxNumExperts);
|
|
static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads");
|
|
static_assert(MaxNumExperts <= NumThreadsHist, "#experts must be bounded by #threads");
|
|
TLLM_CHECK_WITH_INFO(
|
|
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
|
|
TLLM_CHECK_WITH_INFO(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2);
|
|
|
|
bool const useSingleWarp = (data.mPtrScores == nullptr && data.mNumTokens <= WarpKernelMaxNumTokens)
|
|
|| data.mNumTokens < WarpKernelMaxNumTokens;
|
|
bool const useSingleCluster
|
|
= data.mNumTokens <= (data.mPtrScores != nullptr ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster);
|
|
if (!useSingleCluster)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(
|
|
data.mPtrExpertIdx != nullptr, "When #tokens is large, `mPtrExpertIdx` is a required input.");
|
|
TLLM_CHECK_WITH_INFO(
|
|
data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input.");
|
|
}
|
|
|
|
if (useSingleWarp)
|
|
{
|
|
LAUNCH_ROUTING(data,
|
|
/*coopLaunch=*/false, routingIndicesWarpKernel, 1, WarpSize,
|
|
/*smemSize=*/0, // No dynamic smem
|
|
stream);
|
|
}
|
|
else if (useSingleCluster)
|
|
{
|
|
LAUNCH_ROUTING(data,
|
|
/*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads,
|
|
/*smemSize=*/0, // No dynamic smem
|
|
stream);
|
|
}
|
|
else
|
|
{
|
|
const uint32_t expandedIdxSize = data.mNumTokens * data.mTopK;
|
|
|
|
const uint32_t histogramEltsPerBlock = 8 * NumThreadsHist;
|
|
const uint32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreadsHist;
|
|
|
|
// Limit grid size (all kernels use a grid-stride loop).
|
|
const uint32_t maxNumBlocks = 1024;
|
|
|
|
int const numBlocksHistogram
|
|
= std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks);
|
|
int const numBlocksOffsets
|
|
= std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks);
|
|
|
|
if (data.mPtrScores != nullptr)
|
|
{
|
|
LAUNCH_ROUTING(data,
|
|
/*coopLaunch=*/false, routingIndicesHistogramScoresKernel, maxNumBlocks, NumThreadsHist,
|
|
/*smemSize=*/0, // No dynamic smem
|
|
stream);
|
|
}
|
|
else
|
|
{
|
|
// Reset the global histograms.
|
|
TLLM_CUDA_CHECK(cudaMemsetAsync(data.mPtrExpertCounts, 0,
|
|
static_cast<size_t>(2 * NumThreads) * sizeof(int32_t), (cudaStream_t) stream));
|
|
}
|
|
LAUNCH_ROUTING(data,
|
|
/*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, NumThreadsHist,
|
|
/*smemSize=*/0, // No dynamic smem
|
|
stream);
|
|
LAUNCH_ROUTING(data,
|
|
/*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, NumThreadsHist,
|
|
/*smemSize=*/0, // No dynamic smem
|
|
stream);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace routingLlama4
|
|
} // namespace moe::dev::routing
|