mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
3798 lines
153 KiB
Plaintext
3798 lines
153 KiB
Plaintext
/*
|
||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||
*
|
||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
* you may not use this file except in compliance with the License.
|
||
* You may obtain a copy of the License at
|
||
*
|
||
* http://www.apache.org/licenses/LICENSE-2.0
|
||
*
|
||
* Unless required by applicable law or agreed to in writing, software
|
||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
* See the License for the specific language governing permissions and
|
||
* limitations under the License.
|
||
*/
|
||
|
||
#include "DevKernel.h"
|
||
#include "RoutingKernel.h"
|
||
|
||
#include <cooperative_groups.h>
|
||
#include <cooperative_groups/reduce.h>
|
||
#include <cub/cub.cuh>
|
||
|
||
#include <cute/arch/cluster_sm90.hpp>
|
||
#include <cutlass/arch/arch.h>
|
||
|
||
#include <type_traits>
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
namespace moe::dev
|
||
{
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
namespace routing
|
||
{
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
namespace tg = batchedGemm::trtllm::gen;
|
||
namespace cg = cooperative_groups;
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
static constexpr int NumThreads = 256;
|
||
static constexpr int NumBlocksPerCluster = 8;
|
||
static constexpr int NumThreadsGemm = 128;
|
||
static constexpr int WarpSize = 32;
|
||
static constexpr int NumWarps = NumThreads / WarpSize;
|
||
static constexpr int NumTopGroups = 4;
|
||
static constexpr int NumExpertsPerGroup = 32;
|
||
static constexpr int NumTopGroupScores = 2;
|
||
static constexpr int NumTopExperts = 8;
|
||
|
||
// Performance tuning knob.
|
||
static constexpr int NumEltsPerOffsetTilePerThread = 8;
|
||
|
||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL))
|
||
#define TLLM_GEN_ENABLE_FAST_REDUX
|
||
#endif
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename TypeExpW_>
|
||
struct TopKRedType
|
||
{
|
||
using TypeExpW = TypeExpW_;
|
||
static_assert(std::is_same_v<TypeExpW, float> || std::is_same_v<TypeExpW, cutlass::bfloat16_t>,
|
||
"Top K reduction only implemented for float and Bf16");
|
||
using TypeCmp = std::conditional_t<sizeof(TypeExpW) >= 4, double, float>;
|
||
static constexpr int64_t Mask64 = 0x000000000000FFFF;
|
||
static constexpr int32_t Mask32 = 0x0000FFFF;
|
||
|
||
TypeCmp compVal;
|
||
|
||
static __host__ __device__ inline TypeCmp makeCmpVal(TypeExpW val, int32_t idx = 0)
|
||
{
|
||
auto cmpVal = TypeCmp{val};
|
||
TypeCmp cmpValWithIdx;
|
||
if constexpr (sizeof(TypeExpW) >= 4)
|
||
{
|
||
auto cmpValIdx64 = reinterpret_cast<int64_t&>(cmpVal) | (Mask64& int64_t{idx});
|
||
cmpValWithIdx = reinterpret_cast<TypeCmp&>(cmpValIdx64);
|
||
}
|
||
else
|
||
{
|
||
auto cmpValIdx32 = reinterpret_cast<int32_t&>(cmpVal) | (Mask32 & idx);
|
||
cmpValWithIdx = reinterpret_cast<TypeCmp&>(cmpValIdx32);
|
||
}
|
||
return cmpValWithIdx;
|
||
}
|
||
|
||
static __host__ __device__ inline void unpack(TypeExpW& val, int32_t& idx, TypeCmp cmp)
|
||
{
|
||
if constexpr (sizeof(TypeExpW) >= 4)
|
||
{
|
||
idx = static_cast<int32_t>(reinterpret_cast<int64_t&>(cmp) & Mask64);
|
||
auto val64 = reinterpret_cast<int64_t&>(cmp) & ~Mask64;
|
||
val = static_cast<float>(reinterpret_cast<double&>(val64));
|
||
}
|
||
else
|
||
{
|
||
idx = reinterpret_cast<int32_t&>(cmp) & Mask32;
|
||
auto val32 = reinterpret_cast<int32_t&>(cmp) >> 16;
|
||
val = TypeExpW::bitcast(reinterpret_cast<uint16_t&>(val32));
|
||
}
|
||
}
|
||
|
||
__host__ __device__ TopKRedType() = default;
|
||
|
||
__host__ __device__ TopKRedType(TypeExpW val, int32_t idx)
|
||
: compVal(makeCmpVal(val, idx))
|
||
{
|
||
}
|
||
|
||
__host__ __device__ operator TypeCmp() const noexcept
|
||
{
|
||
return compVal;
|
||
}
|
||
|
||
__device__ inline TypeCmp reduce(cg::thread_block_tile<WarpSize> const& warp)
|
||
{
|
||
#if defined(TLLM_GEN_ENABLE_FAST_REDUX)
|
||
static constexpr bool UseCg = false;
|
||
#else
|
||
static constexpr bool UseCg = true;
|
||
#endif
|
||
if constexpr (UseCg || sizeof(TypeExpW) >= 4)
|
||
{
|
||
return cg::reduce(warp, compVal, cg::greater<TypeCmp>{});
|
||
}
|
||
else
|
||
{
|
||
float result;
|
||
asm("redux.sync.max.f32 %0, %1, 0xffffffff;\n" : "=f"(result) : "f"(compVal));
|
||
return result;
|
||
}
|
||
}
|
||
};
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
static __device__ inline float tanh_fast(float x)
|
||
{
|
||
float res;
|
||
asm volatile("{ tanh.approx.f32 %0, %1; }\n" : "=f"(res) : "f"(x));
|
||
return res;
|
||
}
|
||
|
||
static __device__ inline float sigmoid_fast(float x)
|
||
{
|
||
return 0.5f * tanh_fast(0.5f * x) + 0.5f;
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
static __device__ inline float sigmoid_accurate(float x)
|
||
{
|
||
return 0.5f * tanhf(0.5f * x) + 0.5f;
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <int K_, bool Enable_>
|
||
struct TopKIdx
|
||
{
|
||
// by default, empty
|
||
};
|
||
|
||
template <int K_>
|
||
struct TopKIdx<K_, true>
|
||
{
|
||
static constexpr int K = K_;
|
||
int32_t val[K];
|
||
};
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <int K, typename Type>
|
||
__device__ void reduceTopK(cg::thread_block_tile<WarpSize> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
|
||
Type value, int32_t idx, Type minValue)
|
||
{
|
||
static_assert(K > 0, "Top K must have K > 0");
|
||
static_assert(K < WarpSize, "Top K must have K < WarpSize");
|
||
using RedType = TopKRedType<Type>;
|
||
RedType topK{value, idx};
|
||
typename RedType::TypeCmp packedMax{};
|
||
#pragma unroll
|
||
for (int kk = 0; kk < K; ++kk)
|
||
{
|
||
topK = kk > 0 && packedMax == topK.compVal ? RedType{minValue, idx} : topK;
|
||
// get the next largest value
|
||
packedMax = topK.reduce(warp);
|
||
RedType::unpack(out[kk], outIdx[kk], packedMax);
|
||
}
|
||
};
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
#define TOPK_SWAP(I, J) \
|
||
{ \
|
||
auto pairMin = min(topK[I].compVal, topK[J].compVal); \
|
||
auto pairMax = max(topK[I].compVal, topK[J].compVal); \
|
||
topK[I].compVal = pairMax; \
|
||
topK[J].compVal = pairMin; \
|
||
}
|
||
|
||
template <int K, typename Type, int N, bool IsSorted = false>
|
||
__device__ void reduceTopK(cg::thread_block_tile<WarpSize> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
|
||
Type (&value)[N], int32_t (&idx)[N], Type minValue)
|
||
{
|
||
static_assert(K > 0, "Top K must have K > 0");
|
||
static_assert(K < WarpSize, "Top K must have K < WarpSize");
|
||
static_assert(N > 0, "Top K must have N > 1");
|
||
static_assert(N <= K, "Top K must have N < K");
|
||
using RedType = TopKRedType<Type>;
|
||
RedType topK[N];
|
||
#pragma unroll
|
||
for (int nn = 0; nn < N; ++nn)
|
||
topK[nn] = RedType{value[nn], idx[nn]};
|
||
if constexpr (!IsSorted)
|
||
{
|
||
TOPK_SWAP(0, 2);
|
||
TOPK_SWAP(1, 3);
|
||
|
||
TOPK_SWAP(0, 1);
|
||
TOPK_SWAP(2, 3);
|
||
|
||
TOPK_SWAP(1, 2);
|
||
}
|
||
typename RedType::TypeCmp packedMax{};
|
||
#pragma unroll
|
||
for (int kk = 0; kk < K; ++kk)
|
||
{
|
||
bool update = kk > 0 && packedMax == topK[0].compVal;
|
||
#pragma unroll
|
||
for (int nn = 0; nn < N; ++nn)
|
||
{
|
||
topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn];
|
||
}
|
||
// get the next largest value
|
||
packedMax = topK[0].reduce(warp);
|
||
RedType::unpack(out[kk], outIdx[kk], packedMax);
|
||
}
|
||
};
|
||
|
||
#undef TOPK_SWAP
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename KernelParams>
|
||
__global__ void routingKernelGemm(KernelParams params)
|
||
{
|
||
// naive Gemm, to be replaced by performant kernel
|
||
using Type = typename KernelParams::Type;
|
||
using TypeExpW = typename KernelParams::TypeExpW;
|
||
// each thread has space for the dot product of each expert here
|
||
extern __shared__ char __attribute((aligned(128))) smemBase[];
|
||
auto* smemDotPartial = reinterpret_cast<float*>(smemBase);
|
||
static constexpr int SmemStride = NumThreadsGemm + 1;
|
||
|
||
auto tokenOff = int64_t{blockIdx.x} * int64_t{params.mHiddenDim};
|
||
|
||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||
// immediately trigger the secondary kernel when using PDL
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaTriggerProgrammaticLaunchCompletion();
|
||
}
|
||
#endif
|
||
|
||
// dot product for all experts
|
||
// entire block must go into this loop
|
||
for (int32_t dd = threadIdx.x; dd < params.mHiddenDim; dd += NumThreadsGemm)
|
||
{
|
||
Type act = params.mPtrIn[tokenOff + dd];
|
||
|
||
for (int32_t expertIdx = 0; expertIdx < params.mNumExperts; ++expertIdx)
|
||
{
|
||
auto weightOff = int64_t{expertIdx} * int64_t{params.mHiddenDim};
|
||
TypeExpW weight = params.mPtrRoutingWeights[weightOff + dd];
|
||
auto val = float{act} * float{weight};
|
||
if (dd == threadIdx.x)
|
||
{
|
||
smemDotPartial[expertIdx * SmemStride + threadIdx.x] = val;
|
||
}
|
||
else
|
||
{
|
||
smemDotPartial[expertIdx * SmemStride + threadIdx.x] += val;
|
||
}
|
||
}
|
||
}
|
||
// make all partial dot products available to all threads
|
||
__syncthreads();
|
||
|
||
// finalize dot product and write to output
|
||
for (int32_t expertIdx = threadIdx.x; expertIdx < params.mNumExperts; expertIdx += NumThreadsGemm)
|
||
{
|
||
float dot = 0.F;
|
||
for (int32_t ii = 0; ii < NumThreadsGemm; ++ii)
|
||
{
|
||
dot += smemDotPartial[expertIdx * SmemStride + ii];
|
||
}
|
||
params.mPtrScores[int64_t{blockIdx.x} * int64_t{params.mNumExperts} + expertIdx] = dot;
|
||
}
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename T>
|
||
__host__ __device__ constexpr T mulLog2(T a, T bLog2)
|
||
{
|
||
return a << bLog2;
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename T>
|
||
__host__ __device__ constexpr T divUpLog2(T a, T bLog2)
|
||
{
|
||
return ((a + (1 << bLog2) - 1) >> bLog2);
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename T>
|
||
__host__ __device__ constexpr T divUpMulLog2(T a, T bLog2)
|
||
{
|
||
return mulLog2<T>(divUpLog2<T>(a, bLog2), bLog2);
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename KernelParams>
|
||
__global__ void routingMainKernel(KernelParams params)
|
||
{
|
||
// declare types required for reductions
|
||
using TypeExpW = typename KernelParams::TypeExpW;
|
||
|
||
// declare shared memory structure
|
||
// number of experts is bounded by number of threads
|
||
__shared__ float __attribute((aligned(128))) smemScoreSigmoid[NumThreads];
|
||
__shared__ float __attribute((aligned(128))) smemScoreBias[NumExpertsPerGroup * NumThreads];
|
||
// number of expert groups is bounded by number of warps
|
||
__shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps];
|
||
|
||
// needed for warp reduce
|
||
auto block = cg::this_thread_block();
|
||
auto warp = cg::tiled_partition<WarpSize>(block);
|
||
// for the final reduction of weight norm, only some lanes need to participate
|
||
int32_t laneIdx = threadIdx.x % WarpSize;
|
||
int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
||
// warps outside the range of expert groups do not participate
|
||
if (warpIdx >= params.mNumExpertGroups)
|
||
{
|
||
return;
|
||
}
|
||
|
||
// note that for invalid scores, we simply use a negative value:
|
||
// they work well even with the compacted format used in topK, and
|
||
// sigmoid / bias activated scores cannot be negative
|
||
static constexpr float invalidScoreFloat = -1.F;
|
||
const TypeExpW invalidScore = TypeExpW{invalidScoreFloat};
|
||
|
||
// load bias already; each warp represents one expert group
|
||
auto threadExpert = warpIdx * params.mNumExpertsPerGroup + laneIdx;
|
||
auto expertSelected = laneIdx < params.mNumExpertsPerGroup;
|
||
auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert;
|
||
auto biasVal = expertSelected ? params.mPtrRoutingBias[threadExpert] : invalidScore;
|
||
|
||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||
// trigger the secondary kernel when using PDL, then wait on primary
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaTriggerProgrammaticLaunchCompletion();
|
||
cudaGridDependencySynchronize();
|
||
}
|
||
#endif
|
||
|
||
// get our assigned thread score; each warp represents one expert group
|
||
float score = expertSelected ? params.mPtrScores[scoreIdx] : invalidScoreFloat;
|
||
// get the sigmoid score
|
||
// note that for invalid values, we simply use a negative value:
|
||
// sigmoig scores are always strictly positive
|
||
auto scoreSigmoid = sigmoid_accurate(score);
|
||
// write the sigmoid score to shared for later use
|
||
if (expertSelected)
|
||
{
|
||
smemScoreSigmoid[threadExpert] = scoreSigmoid;
|
||
}
|
||
// get the score with bias
|
||
// note that with invalid values, because sigmoid is < 1 and bias is -1,
|
||
// we must get a negative value, which is smaller than any valid value
|
||
// TODO: verify bf16 scoreBias accuracy before changing it back to bf16
|
||
// auto scoreBias = TypeExpW{scoreSigmoid + float{biasVal}}; // TypeExpW is bf16
|
||
auto scoreBias = float{scoreSigmoid + float{biasVal}};
|
||
if (expertSelected)
|
||
{
|
||
smemScoreBias[threadExpert] = scoreBias;
|
||
}
|
||
|
||
// registers for top group score reduction
|
||
float topExpGroupScores[NumTopGroupScores];
|
||
[[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores];
|
||
reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert,
|
||
/* minValue */ invalidScoreFloat);
|
||
|
||
// get the final group score and write it to shared
|
||
if (cute::elect_one_sync())
|
||
{
|
||
auto groupScore = topExpGroupScores[0] + topExpGroupScores[1];
|
||
smemGroupScores[warpIdx] = groupScore;
|
||
}
|
||
|
||
// make group scores available to all warps
|
||
__syncthreads();
|
||
|
||
float topGroups[NumTopGroups]; // params.mNumLimitedGroups
|
||
int32_t topGroupIdx[NumTopGroups];
|
||
float expertScoreGroup[NumTopGroups];
|
||
int32_t expertIdxGroup[NumTopGroups];
|
||
float topScores[NumTopExperts]; // params.mTopK
|
||
int32_t topExperts[NumTopExperts];
|
||
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||
if (warpIdx == 0)
|
||
{
|
||
// a single warp performs the selection of top groups, and goes on to select the final experts
|
||
float groupScore = laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : float{};
|
||
|
||
reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx,
|
||
/* minValue */ invalidScoreFloat);
|
||
|
||
// final expert selection: get relevant indexes and scores from shared
|
||
|
||
#pragma unroll
|
||
for (int ii = 0; ii < NumTopGroups; ++ii)
|
||
{ // params.mNumLimitedGroups
|
||
auto groupIdx = topGroupIdx[ii];
|
||
expertIdxGroup[ii] = groupIdx * params.mNumExpertsPerGroup + laneIdx;
|
||
expertScoreGroup[ii] = expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat;
|
||
}
|
||
|
||
reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup,
|
||
/* minValue */ invalidScoreFloat);
|
||
|
||
// determine our lane's expert index and write to output
|
||
int32_t expertIdx = 0;
|
||
#pragma unroll
|
||
for (int ii = 0; ii < NumTopExperts; ++ii)
|
||
{ // params.mTopK
|
||
expertIdx = laneIdx == ii ? topExperts[ii] : expertIdx;
|
||
}
|
||
// determine whether our expert is local to this GPU
|
||
auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
|
||
// write expert idx out already
|
||
auto idxTopK = blockIdx.x * NumTopExperts + laneIdx; // params.mTopK
|
||
if (laneIdx < NumTopExperts && params.mPtrExpertIdx != nullptr)
|
||
{ // params.mTopK
|
||
params.mPtrExpertIdx[idxTopK] = expertIdx;
|
||
}
|
||
float scoreNorm = laneIdx < NumTopExperts ? smemScoreSigmoid[expertIdx] : 0.F;
|
||
auto redNorm = cg::reduce(warp, scoreNorm, cg::plus<float>{});
|
||
auto finalScore = TypeExpW{scoreNorm * params.mRouteScale / redNorm};
|
||
if (laneIdx < NumTopExperts && params.mPtrExpertWeights != nullptr)
|
||
{ // params.mTopK
|
||
params.mPtrExpertWeights[idxTopK] = finalScore;
|
||
}
|
||
if (laneIdx < NumTopExperts && params.mPtrExpertWeightsFull != nullptr && isLocalExpert)
|
||
{ // params.mTopK
|
||
auto idxWeightsFull = localExpertIdx * gridDim.x + blockIdx.x;
|
||
params.mPtrExpertWeightsFull[idxWeightsFull] = finalScore;
|
||
}
|
||
}
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename KernelParams>
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads)
|
||
routingIndicesClusterKernel(KernelParams params)
|
||
{
|
||
// number of experts is bounded by number of threads
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads];
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads];
|
||
// needed for the exclusive sum of token offsets
|
||
using Scan = cub::BlockScan<int32_t, NumThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
||
__shared__ typename Scan::TempStorage tempStorage;
|
||
// Number of threads in the cluster.
|
||
static constexpr int NumThreadsPerCluster = NumThreads * NumBlocksPerCluster;
|
||
// If the number of tokens is bounded by 16384, then the total number of indexes
|
||
// is bounded by 16384 * TopK.
|
||
// TODO: if we only use this kernel up to 1024 tokens, we could use 1024 here.
|
||
static constexpr int MaxExpandedIdxPerThread
|
||
= (16384 * NumTopExperts + NumThreadsPerCluster - 1) / NumThreadsPerCluster;
|
||
|
||
// Initialize cluster.
|
||
uint32_t const clusterBlockRank = blockIdx.x;
|
||
uint32_t const clusterThreadIdx = NumThreads * clusterBlockRank + threadIdx.x;
|
||
|
||
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
||
|
||
auto expandedIdxSize = params.mNumTokens * NumTopExperts;
|
||
|
||
// pre-fill the counts with 0
|
||
smemExpertCount[threadIdx.x] = 0;
|
||
__syncthreads();
|
||
|
||
// then wait on primary grid
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaGridDependencySynchronize();
|
||
}
|
||
|
||
// each thread keeps has some number of "expanded indexes" assigned to it
|
||
// for each of these, we keep the associated expert and offset within expert in registers
|
||
int32_t expertIndexes[MaxExpandedIdxPerThread];
|
||
int32_t expertOffsets[MaxExpandedIdxPerThread];
|
||
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||
// In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
|
||
// time, and branch between a fast path without bound checks and a slow path with bound checks.
|
||
int constexpr IterStride = 4;
|
||
static_assert(MaxExpandedIdxPerThread % IterStride == 0);
|
||
|
||
// Define a lambda to avoid code duplication in both branches.
|
||
auto loopBody = [&](int ii, int expandedIdx)
|
||
{
|
||
int32_t expertIdx = params.mPtrExpertIdx[expandedIdx];
|
||
expertIndexes[ii] = expertIdx;
|
||
// check whether this expert is local to our GPU at all and ignore if not
|
||
auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0;
|
||
};
|
||
|
||
#pragma unroll
|
||
for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride)
|
||
{
|
||
// Whether it's safe to do multiple iterations without bound checks.
|
||
bool const takeFastPath = (ii0 + IterStride) * NumThreadsPerCluster <= expandedIdxSize;
|
||
if (takeFastPath)
|
||
{
|
||
#pragma unroll
|
||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||
{
|
||
int const ii = ii0 + jj;
|
||
auto expandedIdx = static_cast<int32_t>(clusterThreadIdx) + ii * NumThreadsPerCluster;
|
||
loopBody(ii, expandedIdx);
|
||
}
|
||
}
|
||
else
|
||
{
|
||
bool doBreak = false;
|
||
#pragma unroll
|
||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||
{
|
||
int const ii = ii0 + jj;
|
||
auto expandedIdx = static_cast<int32_t>(clusterThreadIdx) + ii * NumThreadsPerCluster;
|
||
if (expandedIdx >= expandedIdxSize)
|
||
{
|
||
doBreak = true;
|
||
break;
|
||
}
|
||
loopBody(ii, expandedIdx);
|
||
}
|
||
if (doBreak)
|
||
{
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
// Make local histogram (token counts per expert) available to all threads in the cluster.
|
||
cg::cluster_group::sync();
|
||
|
||
//
|
||
// Each thread now represents one expert
|
||
//
|
||
|
||
// Get the histogram bin from each rank for this expert.
|
||
int32_t expertCounts[NumBlocksPerCluster];
|
||
#pragma unroll
|
||
for (int rank = 0; rank < NumBlocksPerCluster; rank++)
|
||
{
|
||
int32_t const* remoteSmem = cg::cluster_group::map_shared_rank(smemExpertCount, rank);
|
||
expertCounts[rank] = remoteSmem[threadIdx.x];
|
||
}
|
||
|
||
// Compute an exclusive prefix sum of the block-local count.
|
||
// Each block only needs the count up to its rank, and the total count.
|
||
int32_t count = 0;
|
||
int32_t blockExpertOffset = 0;
|
||
#pragma unroll
|
||
for (int rank = 0; rank < NumBlocksPerCluster; rank++)
|
||
{
|
||
if (rank == clusterBlockRank)
|
||
{
|
||
blockExpertOffset = count;
|
||
}
|
||
count += expertCounts[rank];
|
||
}
|
||
|
||
// Arrive: we do not access distributed shared memory after this point.
|
||
__cluster_barrier_arrive();
|
||
|
||
// Compute the runtime config for projections
|
||
// Whether or not an expert is local is taken into account when smemExpertCount is computed
|
||
// so we do not need to take it into account here.
|
||
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
|
||
int32_t ctaOffset;
|
||
int32_t numNonExitingCtas;
|
||
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
|
||
|
||
// Strided loop to share this work between blocks.
|
||
int32_t tokensPerTile = params.mAllToAllRouteAct ? params.mNumTokens : count;
|
||
for (int32_t cta = clusterBlockRank; cta < numCta; cta += NumBlocksPerCluster)
|
||
{
|
||
const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
|
||
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
|
||
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
|
||
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + tokensPerTile);
|
||
}
|
||
|
||
// get the padded offset associated with this expert
|
||
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
|
||
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
|
||
|
||
// write out padded count
|
||
if (clusterBlockRank == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync())
|
||
{
|
||
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
|
||
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
|
||
}
|
||
|
||
// write expert offsets to shared
|
||
smemExpertOffset[threadIdx.x] = offset + blockExpertOffset;
|
||
|
||
// make expert offsets available to all threads
|
||
__syncthreads();
|
||
|
||
// Wait: we cannot exit while other blocks may be accessing the current block's shared memory.
|
||
// Note: I observed a perf benefit to doing this before the final loop so the compiler can
|
||
// implement break with EXIT.
|
||
__cluster_barrier_wait();
|
||
|
||
// trigger the secondary kernel when using PDL
|
||
// We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx,
|
||
// mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens
|
||
// TODO: this is not sufficient to ensure visibility in the next kernel!
|
||
|
||
// TODO: disable PDL for now to avoid race condition in FC1
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
// cudaTriggerProgrammaticLaunchCompletion();
|
||
}
|
||
|
||
// each thread has the same "expanded indexes" assigned to it as above
|
||
// at this point, we know the final offsets of experts and the offsets within
|
||
// experts, which allows writing the final index values
|
||
#pragma unroll
|
||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii)
|
||
{
|
||
auto expandedIdx = static_cast<int32_t>(clusterThreadIdx) + ii * NumThreadsPerCluster;
|
||
if (expandedIdx >= expandedIdxSize)
|
||
{
|
||
break;
|
||
}
|
||
auto expertIdx = expertIndexes[ii];
|
||
// check whether this expert is local to our GPU at all
|
||
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
auto tokenIdx = expandedIdx / NumTopExperts;
|
||
auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1};
|
||
if (params.mPtrExpandedIdxToPermutedIdx != nullptr)
|
||
{
|
||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||
}
|
||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
|
||
{
|
||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||
}
|
||
}
|
||
}
|
||
#else
|
||
__global__ void routingIndicesClusterKernel(KernelParams params)
|
||
{
|
||
assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures");
|
||
}
|
||
#endif
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename KernelParams>
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
__global__ void __launch_bounds__(NumThreads) routingIndicesCoopKernel(KernelParams params)
|
||
{
|
||
// number of experts is bounded by number of threads
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads];
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads];
|
||
// needed for the exclusive sum of token offsets
|
||
using Scan = cub::BlockScan<int32_t, NumThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
||
__shared__ typename Scan::TempStorage tempStorage;
|
||
// 64 elements -> 128+ registers. Above that we may start to see spilling to local memory.
|
||
static constexpr int MaxExpandedIdxPerThread = 64;
|
||
|
||
// Initialize grid.
|
||
cg::grid_group grid = cg::this_grid();
|
||
// Note: the following is more efficient than grid.block_index() because we don't use y and z.
|
||
uint32_t const gridBlockIdx = blockIdx.x;
|
||
uint32_t const gridThreadIdx = NumThreads * gridBlockIdx + threadIdx.x;
|
||
uint32_t const numBlocks = gridDim.x;
|
||
uint32_t const numThreadsPerGrid = numBlocks * NumThreads;
|
||
|
||
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
||
|
||
auto expandedIdxSize = params.mNumTokens * NumTopExperts;
|
||
|
||
// pre-fill the counts with 0
|
||
smemExpertCount[threadIdx.x] = 0;
|
||
__syncthreads();
|
||
|
||
// then wait on primary grid
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaGridDependencySynchronize();
|
||
}
|
||
|
||
// each thread keeps has some number of "expanded indexes" assigned to it
|
||
// for each of these, we keep the associated expert and offset within expert in registers
|
||
int32_t expertIndexes[MaxExpandedIdxPerThread];
|
||
int32_t expertOffsets[MaxExpandedIdxPerThread];
|
||
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||
// In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
|
||
// time, and branch between a fast path without bound checks and a slow path with bound checks.
|
||
int constexpr IterStride = 4;
|
||
static_assert(MaxExpandedIdxPerThread % IterStride == 0);
|
||
|
||
// Define a lambda to avoid code duplication in both branches.
|
||
auto loopBody = [&](int ii, int expandedIdx)
|
||
{
|
||
int32_t expertIdx = params.mPtrExpertIdx[expandedIdx];
|
||
expertIndexes[ii] = expertIdx;
|
||
// check whether this expert is local to our GPU at all and ignore if not
|
||
auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0;
|
||
};
|
||
|
||
#pragma unroll
|
||
for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride)
|
||
{
|
||
// Whether it's safe to do multiple iterations without bound checks.
|
||
bool const takeFastPath = (ii0 + IterStride) * numThreadsPerGrid <= expandedIdxSize;
|
||
if (takeFastPath)
|
||
{
|
||
#pragma unroll
|
||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||
{
|
||
int const ii = ii0 + jj;
|
||
auto expandedIdx = static_cast<int32_t>(gridThreadIdx) + ii * numThreadsPerGrid;
|
||
loopBody(ii, expandedIdx);
|
||
}
|
||
}
|
||
else
|
||
{
|
||
bool doBreak = false;
|
||
#pragma unroll
|
||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||
{
|
||
int const ii = ii0 + jj;
|
||
auto expandedIdx = static_cast<int32_t>(gridThreadIdx) + ii * numThreadsPerGrid;
|
||
if (expandedIdx >= expandedIdxSize)
|
||
{
|
||
doBreak = true;
|
||
break;
|
||
}
|
||
loopBody(ii, expandedIdx);
|
||
}
|
||
if (doBreak)
|
||
{
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
// Make histogram (token counts per expert) available to all threads in the block.
|
||
__syncthreads();
|
||
|
||
//
|
||
// Each thread now represents one expert
|
||
//
|
||
|
||
// Add the local bin count to the common bin count and get a per-CTA offset.
|
||
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
|
||
int32_t const blockExpertOffset = atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount);
|
||
|
||
// Sync to wait for completion of the histogram reduction.
|
||
grid.sync();
|
||
|
||
// Get total count for this expert.
|
||
int32_t count = params.mPtrExpertCounts[threadIdx.x];
|
||
|
||
// Note: the scan is redundant in all CTAs, but doing it in only 1 CTA would be worse for latency.
|
||
|
||
// Compute the runtime config for projections
|
||
// Whether or not an expert is local is taken into account when smemExpertCount is computed
|
||
// so we do not need to take it into account here.
|
||
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
|
||
int32_t ctaOffset;
|
||
int32_t numNonExitingCtas;
|
||
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
|
||
|
||
// Strided loop to share this work between blocks.
|
||
int32_t tokensPerTile = params.mAllToAllRouteAct ? params.mNumTokens : count;
|
||
for (int32_t cta = gridBlockIdx; cta < numCta; cta += numBlocks)
|
||
{
|
||
const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
|
||
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
|
||
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
|
||
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + tokensPerTile);
|
||
}
|
||
|
||
// get the padded offset associated with this expert
|
||
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
|
||
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
|
||
|
||
// write out padded count
|
||
if (gridBlockIdx == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync())
|
||
{
|
||
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
|
||
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
|
||
}
|
||
|
||
// write expert offsets to shared
|
||
smemExpertOffset[threadIdx.x] = offset + blockExpertOffset;
|
||
|
||
// make expert offsets available to all threads
|
||
__syncthreads();
|
||
|
||
// trigger the secondary kernel when using PDL
|
||
// We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx,
|
||
// mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens
|
||
// TODO: this is not sufficient to ensure visibility in the next kernel!
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaTriggerProgrammaticLaunchCompletion();
|
||
}
|
||
|
||
// each thread has the same "expanded indexes" assigned to it as above
|
||
// at this point, we know the final offsets of experts and the offsets within
|
||
// experts, which allows writing the final index values
|
||
#pragma unroll
|
||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii)
|
||
{
|
||
auto expandedIdx = static_cast<int32_t>(gridThreadIdx) + ii * numThreadsPerGrid;
|
||
if (expandedIdx >= expandedIdxSize)
|
||
{
|
||
break;
|
||
}
|
||
auto expertIdx = expertIndexes[ii];
|
||
// check whether this expert is local to our GPU at all
|
||
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
auto tokenIdx = expandedIdx / NumTopExperts;
|
||
auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1};
|
||
if (params.mPtrExpandedIdxToPermutedIdx != nullptr)
|
||
{
|
||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||
}
|
||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
|
||
{
|
||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||
}
|
||
}
|
||
}
|
||
#else
|
||
__global__ void routingIndicesCoopKernel(KernelParams params)
|
||
{
|
||
assert(false && "routingIndicesCoopKernel is only supported on SM90+ architectures");
|
||
}
|
||
#endif
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
// Two-step approach (if number of tokens exceed limits of what cluster / cooperative launch
|
||
// variants can handle): in order to minimize the amount of data to exchange through global memory,
|
||
// we will compute the local histograms in smem twice: the first kernel will get us the total number
|
||
// of tokens per expert. The second kernel will use the smem and L2 atomics to get corresponding
|
||
// element and tile offsets.
|
||
//
|
||
// Note: the histogram calculation could also be fused with routingMainKernel, but this might be
|
||
// inefficient if we have one CTA per token doing a single global atomic.
|
||
|
||
template <typename KernelParams>
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
__global__ void __launch_bounds__(NumThreads) routingIndicesHistogramKernel(KernelParams params)
|
||
{
|
||
// number of experts is bounded by number of threads
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads];
|
||
|
||
// For unrolling.
|
||
uint32_t constexpr NumEltsPerThread = 8;
|
||
|
||
// Pre-fill the counts with 0
|
||
smemExpertCount[threadIdx.x] = 0;
|
||
__syncthreads();
|
||
|
||
// Wait on primary grid and trigger secondary kernel.
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaGridDependencySynchronize();
|
||
cudaTriggerProgrammaticLaunchCompletion();
|
||
}
|
||
|
||
uint32_t const expandedIdxSize = params.mNumTokens * NumTopExperts;
|
||
uint32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||
|
||
uint32_t const gridBlockOffset = blockIdx.x * NumThreads;
|
||
uint32_t const gridStride = gridDim.x * NumThreads;
|
||
|
||
// Define a lambda to avoid code duplication in branches.
|
||
auto loopBody = [&](int expandedIdx)
|
||
{
|
||
int32_t expertIdx = params.mPtrExpertIdx[expandedIdx];
|
||
// check whether this expert is local to our GPU at all and ignore if not
|
||
auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
if (isLocalExpert)
|
||
{
|
||
atomicAdd(&smemExpertCount[expertIdx], 1);
|
||
}
|
||
};
|
||
|
||
// Grid-stride loop.
|
||
for (uint32_t expandedIdx0 = gridBlockOffset * NumEltsPerThread; expandedIdx0 < expandedIdxSize;
|
||
expandedIdx0 += gridStride * NumEltsPerThread)
|
||
{
|
||
// Fast path if bound checks aren't necessary
|
||
if (expandedIdx0 + NumEltsPerThread * NumThreads <= expandedIdxSize)
|
||
{
|
||
#pragma unroll
|
||
for (uint32_t ii = 0; ii < NumEltsPerThread; ii++)
|
||
{
|
||
uint32_t expandedIdx = expandedIdx0 + ii * NumThreads + threadIdx.x;
|
||
loopBody(expandedIdx);
|
||
}
|
||
}
|
||
else
|
||
{
|
||
for (uint32_t expandedIdx = expandedIdx0 + threadIdx.x; expandedIdx < expandedIdxSize;
|
||
expandedIdx += NumThreads)
|
||
{
|
||
loopBody(expandedIdx);
|
||
}
|
||
}
|
||
}
|
||
__syncthreads();
|
||
|
||
//
|
||
// Each thread now represents one expert
|
||
//
|
||
|
||
// Reduce histograms with atomics.
|
||
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
|
||
atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount);
|
||
}
|
||
#else
|
||
__global__ void routingIndicesHistogramKernel(KernelParams params)
|
||
{
|
||
assert(false && "routingIndicesHistogramKernel is only supported on SM90+ architectures");
|
||
}
|
||
#endif
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename KernelParams>
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
__global__ void __launch_bounds__(NumThreads) routingIndicesOffsetsKernel(KernelParams params)
|
||
{
|
||
// number of experts is bounded by number of threads
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads];
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads];
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertTileOffset[NumThreads];
|
||
// needed for the exclusive sum of token offsets
|
||
using Scan = cub::BlockScan<int32_t, NumThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
||
__shared__ typename Scan::TempStorage tempStorage;
|
||
static constexpr int MaxExpandedIdxPerThread = NumEltsPerOffsetTilePerThread;
|
||
static constexpr int MaxExpandedIdxPerBlock = NumThreads * MaxExpandedIdxPerThread;
|
||
|
||
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
||
|
||
uint32_t const expandedIdxSize = params.mNumTokens * NumTopExperts;
|
||
uint32_t const numTiles = (expandedIdxSize + MaxExpandedIdxPerBlock - 1) / (MaxExpandedIdxPerBlock);
|
||
|
||
// Wait on primary grid.
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaGridDependencySynchronize();
|
||
}
|
||
|
||
// The expert offsets are common to all tiles of all blocks.
|
||
// Load the histogram, scan it and write offsets to shared memory.
|
||
// Note: the scan is redundant in all CTAs. Would it make sense to use an intermediate kernel for
|
||
// the scan, with PDL?
|
||
|
||
// Each thread represents one expert. Get total count for this expert.
|
||
int32_t count = params.mPtrExpertCounts[threadIdx.x];
|
||
|
||
// Compute the runtime config for projections
|
||
// Whether or not an expert is local is taken into account when the histogram is computed
|
||
// so we do not need to take it into account here.
|
||
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
|
||
int32_t ctaOffset;
|
||
int32_t numNonExitingCtas;
|
||
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
|
||
|
||
// Get the padded offset associated with this expert
|
||
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
|
||
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
|
||
|
||
// Write expert offsets to shared
|
||
smemExpertOffset[threadIdx.x] = offset;
|
||
// Sync to make expert offsets available to all threads.
|
||
__syncthreads();
|
||
|
||
// The first block writes out padded count
|
||
if (blockIdx.x == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync())
|
||
{
|
||
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
|
||
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
|
||
}
|
||
|
||
// Strided loop to share this work between blocks.
|
||
int32_t tokensPerTile = params.mAllToAllRouteAct ? params.mNumTokens : count;
|
||
for (int32_t cta = blockIdx.x; cta < numCta; cta += gridDim.x)
|
||
{
|
||
const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
|
||
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
|
||
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
|
||
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + tokensPerTile);
|
||
}
|
||
|
||
//
|
||
// Now loop on indices and compute offsets.
|
||
//
|
||
|
||
// Grid-stride loop on 1D "tiles" of input indices.
|
||
for (uint32_t tileIdx = blockIdx.x; tileIdx < numTiles; tileIdx += gridDim.x)
|
||
{
|
||
if (tileIdx > 0)
|
||
{
|
||
// Sync for safe reuse of smem buffers.
|
||
__syncthreads();
|
||
}
|
||
|
||
// Pre-fill the counts with 0
|
||
smemExpertCount[threadIdx.x] = 0;
|
||
__syncthreads();
|
||
|
||
// each thread keeps has some number of "expanded indexes" assigned to it
|
||
// for each of these, we keep the associated expert and offset within expert in registers
|
||
int32_t expertIndexes[MaxExpandedIdxPerThread];
|
||
int32_t expertOffsets[MaxExpandedIdxPerThread];
|
||
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||
|
||
// Define a lambda to avoid code duplication in branches.
|
||
auto loopBody = [&](int ii, int expandedIdx)
|
||
{
|
||
int32_t expertIdx = params.mPtrExpertIdx[expandedIdx];
|
||
expertIndexes[ii] = expertIdx;
|
||
// check whether this expert is local to our GPU at all and ignore if not
|
||
auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0;
|
||
};
|
||
|
||
// For all tiles but the last, all indices are in bounds.
|
||
if (tileIdx < numTiles - 1)
|
||
{
|
||
#pragma unroll
|
||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
|
||
{
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreads + threadIdx.x;
|
||
loopBody(ii, expandedIdx);
|
||
}
|
||
}
|
||
else
|
||
{
|
||
// For the last tile, we need to exit the loop when out of bounds.
|
||
// In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
|
||
// time, and branch between a fast path without bound checks and a slow path with bound checks
|
||
int constexpr IterStride = 4;
|
||
static_assert(MaxExpandedIdxPerThread % IterStride == 0);
|
||
|
||
#pragma unroll
|
||
for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride)
|
||
{
|
||
// Whether it's safe to do multiple iterations without bound checks.
|
||
bool const takeFastPath
|
||
= tileIdx * MaxExpandedIdxPerBlock + (ii0 + IterStride) * NumThreads <= expandedIdxSize;
|
||
if (takeFastPath)
|
||
{
|
||
#pragma unroll
|
||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||
{
|
||
int const ii = ii0 + jj;
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreads + threadIdx.x;
|
||
loopBody(ii, expandedIdx);
|
||
}
|
||
}
|
||
else
|
||
{
|
||
bool doBreak = false;
|
||
#pragma unroll
|
||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||
{
|
||
int const ii = ii0 + jj;
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreads + threadIdx.x;
|
||
if (expandedIdx >= expandedIdxSize)
|
||
{
|
||
doBreak = true;
|
||
break;
|
||
}
|
||
loopBody(ii, expandedIdx);
|
||
}
|
||
if (doBreak)
|
||
{
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Make local histogram (token counts per expert) available to all threads in the block.
|
||
__syncthreads();
|
||
|
||
// Each thread now represents one expert
|
||
|
||
// Add the local bin count to the common bin count and get a per-CTA offset. We use the second
|
||
// half of the histogram buffer for this histogram, because the first half already holds the
|
||
// reduced histogram from the previous kernel.
|
||
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
|
||
int32_t const tileExpertOffset
|
||
= atomicAdd(¶ms.mPtrExpertCounts[NumThreads + threadIdx.x], localExpertCount);
|
||
|
||
// Make per-expert tile offsets available to all threads in the block.
|
||
smemExpertTileOffset[threadIdx.x] = tileExpertOffset + smemExpertOffset[threadIdx.x];
|
||
__syncthreads();
|
||
|
||
// Add tile offset and element offset and write to global memory.
|
||
auto storeLoopBody = [&](int ii, int expandedIdx)
|
||
{
|
||
int32_t expertIdx = expertIndexes[ii];
|
||
// check whether this expert is local to our GPU at all
|
||
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
auto tokenIdx = expandedIdx / NumTopExperts;
|
||
auto permutedIdx = isLocalExpert ? (expertOffsets[ii] + smemExpertTileOffset[expertIdx]) : int32_t{-1};
|
||
if (params.mPtrExpandedIdxToPermutedIdx != nullptr)
|
||
{
|
||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||
}
|
||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
|
||
{
|
||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||
}
|
||
};
|
||
// Bound checks only in last tile.
|
||
if (tileIdx < numTiles - 1)
|
||
{
|
||
#pragma unroll
|
||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
|
||
{
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreads + threadIdx.x;
|
||
storeLoopBody(ii, expandedIdx);
|
||
}
|
||
}
|
||
else
|
||
{
|
||
#pragma unroll
|
||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
|
||
{
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreads + threadIdx.x;
|
||
if (expandedIdx >= expandedIdxSize)
|
||
{
|
||
break;
|
||
}
|
||
storeLoopBody(ii, expandedIdx);
|
||
}
|
||
}
|
||
}
|
||
|
||
// Trigger secondary kernel.
|
||
// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
|
||
// dependency sync.
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaTriggerProgrammaticLaunchCompletion();
|
||
}
|
||
}
|
||
#else
|
||
__global__ void routingIndicesOffsetsKernel(KernelParams params)
|
||
{
|
||
assert(false && "routingIndicesOffsetsKernel is only supported on SM90+ architectures");
|
||
}
|
||
#endif
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
void run(Data const& data, void* stream)
|
||
{
|
||
TLLM_CHECK_WITH_INFO(data.mPtrExpertIdx != nullptr || data.mPtrPermutedIdxSize != nullptr
|
||
|| data.mPtrExpertWeightsFull != nullptr || data.mPtrExpertWeights != nullptr,
|
||
"Routing kernel requires at least one output parameter");
|
||
if (data.mPtrExpandedIdxToPermutedIdx != nullptr || data.mPtrPermutedIdxToTokenIdx != nullptr)
|
||
TLLM_CHECK_WITH_INFO(data.mPtrExpertIdx != nullptr && data.mPtrPermutedIdxSize,
|
||
"If permuted index is required, `mPtrExpertIdx` is also required");
|
||
TLLM_CHECK_WITH_INFO(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet");
|
||
TLLM_CHECK_WITH_INFO(
|
||
data.mNumLimitedGroups == NumTopGroups, "Routing kernel expects %d groups (for now)", NumTopGroups);
|
||
TLLM_CHECK_WITH_INFO(
|
||
data.mTopK == NumTopExperts, "Routing kernel expects %d topK experts (for now)", NumTopExperts);
|
||
TLLM_CHECK_WITH_INFO(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", data.mTopK);
|
||
TLLM_CHECK_WITH_INFO(data.mTopK * data.mNumLimitedGroups <= WarpSize,
|
||
"Routing kernel expects top K * top groups <= warp size (for now), got %d * %d", data.mTopK,
|
||
data.mNumLimitedGroups);
|
||
TLLM_CHECK_WITH_INFO(data.mNumExperts >= NumTopExperts, "Routing kernel expects %d to be at most #experts %d",
|
||
NumTopExperts, data.mNumExperts);
|
||
TLLM_CHECK_WITH_INFO(data.mNumExperts <= NumThreads, "Routing kernel expects #experts %d <= #threads %d",
|
||
data.mNumExperts, NumThreads);
|
||
TLLM_CHECK_WITH_INFO(data.mNumExpertGroups <= NumWarps, "Routing kernel expects #experts groups %d <= #warps %d",
|
||
data.mNumExpertGroups, NumWarps);
|
||
TLLM_CHECK_WITH_INFO(data.mNumExperts % data.mNumExpertGroups == 0,
|
||
"Routing kernel expects #experts %d to be a multiple of #expert groups %d", data.mNumExperts,
|
||
data.mNumExpertGroups);
|
||
TLLM_CHECK_WITH_INFO(data.mNumExperts / data.mNumExpertGroups <= WarpSize,
|
||
"Routing kernel expects #experts per group <= warp size, got %d", data.mNumExperts / data.mNumExpertGroups);
|
||
TLLM_CHECK_WITH_INFO(
|
||
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
|
||
TLLM_CHECK_WITH_INFO(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2);
|
||
TLLM_CHECK_WITH_INFO(data.mNumExperts / data.mNumExpertGroups <= NumExpertsPerGroup,
|
||
"Routing kernel expects number of experts per group <= %d, got %d", NumExpertsPerGroup,
|
||
data.mNumExperts / data.mNumExpertGroups);
|
||
int const numBlocks = data.mNumTokens;
|
||
|
||
if (data.mPtrExpertWeightsFull != nullptr)
|
||
{
|
||
auto localExpertExtent = data.mNumLocalExperts << data.mLocalExpertsStrideLog2;
|
||
// note: we set a value of 0 here, s.t. even if the routing happens,
|
||
// it will be ignored / not given any weight
|
||
TLLM_CUDA_CHECK(cudaMemsetAsync(
|
||
data.mPtrExpertWeightsFull, 0, localExpertExtent * data.mNumTokens * sizeof(float), (cudaStream_t) stream));
|
||
}
|
||
|
||
/* disable memset(-1) for permuted_idx_to_token_idx for performance
|
||
if (data.mPtrPermutedIdxToTokenIdx != nullptr)
|
||
{
|
||
// need to set all values to -1 before running the kernel
|
||
auto maxPermutedSize
|
||
= data.mNumTokens * data.mTopK + (data.mNumExperts << data.mPaddingLog2) - data.mNumExperts;
|
||
// note that a value of -1 per byte works for any size of signed integer
|
||
// to set each full value to the logical value -1
|
||
TLLM_CUDA_CHECK(cudaMemsetAsync(data.mPtrPermutedIdxToTokenIdx, -1,
|
||
static_cast<size_t>(maxPermutedSize) * sizeof(int32_t), (cudaStream_t) stream));
|
||
}
|
||
*/
|
||
|
||
bool const useSingleCluster = data.mNumTokens <= 1024;
|
||
if (!useSingleCluster)
|
||
{
|
||
// Reset the global histograms (not used in single-cluster code path).
|
||
// Cover both for the cooperative and two-kernel code paths.
|
||
TLLM_CUDA_CHECK(cudaMemsetAsync(
|
||
data.mPtrExpertCounts, 0, static_cast<size_t>(2 * NumThreads) * sizeof(int32_t), (cudaStream_t) stream));
|
||
}
|
||
|
||
// Number of blocks we can use in the cooperative kernel
|
||
// The number of blocks must be:
|
||
// >= ⌈(numTokens * NumTopExperts) / (MaxExpandedIdxPerThread * NumThreads)⌉
|
||
// <= numSms, assuming an occupancy of 1 block/SM
|
||
//
|
||
// If too small for the given numTokens, fall back to the less performant two-step method.
|
||
//
|
||
// The upper bound is a strict requirement. The number of blocks should be determined by querying
|
||
// the device properties, or conservatively low.
|
||
// /!\ The following number is not portable!! (but works on H100 and B200)
|
||
int const numBlocksCoop = 128;
|
||
|
||
// Maximum number of tokens supported by the kernel using a cooperative launch.
|
||
int const maxTokensCoop = (numBlocksCoop * NumThreads * 64) / NumTopExperts;
|
||
LAUNCH_EXPW_ONLY(data,
|
||
/*coopLaunch=*/false, routingMainKernel, numBlocks, NumThreads,
|
||
/*smemSize=*/0, // No dynamic smem
|
||
stream);
|
||
|
||
if (data.mPtrPermutedIdxSize != nullptr)
|
||
{
|
||
if (useSingleCluster)
|
||
{
|
||
LAUNCH_EXPW_ONLY(data,
|
||
/*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads,
|
||
/*smemSize=*/0, // No dynamic smem
|
||
stream);
|
||
}
|
||
else if (data.mNumTokens <= maxTokensCoop)
|
||
{
|
||
LAUNCH_EXPW_ONLY(data,
|
||
/*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, NumThreads,
|
||
/*smemSize=*/0, // No dynamic smem
|
||
stream);
|
||
}
|
||
else
|
||
{
|
||
const uint32_t expandedIdxSize = data.mNumTokens * NumTopExperts;
|
||
|
||
const uint32_t histogramEltsPerBlock = 8 * NumThreads;
|
||
const uint32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreads;
|
||
|
||
// Limit grid size (both kernels use a grid-stride loop).
|
||
const uint32_t maxNumBlocks = 1024;
|
||
|
||
int const numBlocksHistogram
|
||
= std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks);
|
||
int const numBlocksOffsets
|
||
= std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks);
|
||
|
||
LAUNCH_EXPW_ONLY(data,
|
||
/*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, NumThreads,
|
||
/*smemSize=*/0, // No dynamic smem
|
||
stream);
|
||
LAUNCH_EXPW_ONLY(data,
|
||
/*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, NumThreads,
|
||
/*smemSize=*/0, // No dynamic smem
|
||
stream);
|
||
}
|
||
}
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
} // namespace routing
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
namespace routingLlama4
|
||
{
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
namespace tg = batchedGemm::trtllm::gen;
|
||
namespace cg = cooperative_groups;
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
static constexpr int NumThreads = 1024;
|
||
static constexpr int NumThreadsHist = 256;
|
||
static constexpr int NumBlocksPerCluster = 8;
|
||
static constexpr int WarpSize = 32;
|
||
static constexpr int NumWarps = NumThreads / WarpSize;
|
||
static constexpr int NumWarpsHist = NumThreadsHist / WarpSize;
|
||
static constexpr int NumTopExperts = 1;
|
||
static constexpr int MaxNumExperts = 128;
|
||
static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads;
|
||
static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps;
|
||
static constexpr int WarpKernelSmemStride = 33;
|
||
// with further optimization to `routingIndicesWarpKernel`, this limit may
|
||
// increase. For now, it is a good cut-off point for when the block-wise
|
||
// operations are more efficient end-to-end.
|
||
static constexpr int WarpKernelMaxNumTokens = 4;
|
||
|
||
// Performance tuning knob.
|
||
static constexpr int NumEltsPerOffsetTilePerThread = 8;
|
||
|
||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL))
|
||
#define TLLM_GEN_ENABLE_FAST_REDUX
|
||
#endif
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename TypeExpW_>
|
||
struct TopKRedType
|
||
{
|
||
using TypeExpW = TypeExpW_;
|
||
static_assert(std::is_same_v<TypeExpW, float> || std::is_same_v<TypeExpW, cutlass::bfloat16_t>,
|
||
"Top K reduction only implemented for float and Bf16");
|
||
using TypeCmp = std::conditional_t<sizeof(TypeExpW) >= 4, double, float>;
|
||
static constexpr int64_t Mask64 = 0x000000000000FFFF;
|
||
static constexpr int32_t Mask32 = 0x0000FFFF;
|
||
|
||
TypeCmp compVal;
|
||
|
||
static __host__ __device__ inline TypeCmp makeCmpVal(TypeExpW val, int32_t idx = 0)
|
||
{
|
||
auto cmpVal = TypeCmp{val};
|
||
TypeCmp cmpValWithIdx;
|
||
if constexpr (sizeof(TypeExpW) >= 4)
|
||
{
|
||
auto cmpValIdx64 = reinterpret_cast<int64_t&>(cmpVal) | (Mask64& int64_t{idx});
|
||
cmpValWithIdx = reinterpret_cast<TypeCmp&>(cmpValIdx64);
|
||
}
|
||
else
|
||
{
|
||
auto cmpValIdx32 = reinterpret_cast<int32_t&>(cmpVal) | (Mask32 & idx);
|
||
cmpValWithIdx = reinterpret_cast<TypeCmp&>(cmpValIdx32);
|
||
}
|
||
return cmpValWithIdx;
|
||
}
|
||
|
||
static __host__ __device__ inline void unpack(TypeExpW& val, int32_t& idx, TypeCmp cmp)
|
||
{
|
||
if constexpr (sizeof(TypeExpW) >= 4)
|
||
{
|
||
idx = static_cast<int32_t>(reinterpret_cast<int64_t&>(cmp) & Mask64);
|
||
auto val64 = reinterpret_cast<int64_t&>(cmp) & ~Mask64;
|
||
val = static_cast<float>(reinterpret_cast<double&>(val64));
|
||
}
|
||
else
|
||
{
|
||
idx = reinterpret_cast<int32_t&>(cmp) & Mask32;
|
||
auto val32 = reinterpret_cast<int32_t&>(cmp) >> 16;
|
||
val = TypeExpW::bitcast(reinterpret_cast<uint16_t&>(val32));
|
||
}
|
||
}
|
||
|
||
__host__ __device__ TopKRedType() = default;
|
||
|
||
__host__ __device__ TopKRedType(TypeExpW val, int32_t idx)
|
||
: compVal(makeCmpVal(val, idx))
|
||
{
|
||
}
|
||
|
||
__host__ __device__ operator TypeCmp() const noexcept
|
||
{
|
||
return compVal;
|
||
}
|
||
|
||
__device__ inline TypeCmp reduce(cg::thread_block_tile<WarpSize> const& warp)
|
||
{
|
||
#if defined(TLLM_GEN_ENABLE_FAST_REDUX)
|
||
static constexpr bool UseCg = false;
|
||
#else
|
||
static constexpr bool UseCg = true;
|
||
#endif
|
||
if constexpr (UseCg || sizeof(TypeExpW) >= 4)
|
||
{
|
||
return cg::reduce(warp, compVal, cg::greater<TypeCmp>{});
|
||
}
|
||
else
|
||
{
|
||
float result;
|
||
asm("redux.sync.max.f32 %0, %1, 0xffffffff;\n" : "=f"(result) : "f"(compVal));
|
||
return result;
|
||
}
|
||
}
|
||
};
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
static __device__ inline float sigmoid_accurate(float x)
|
||
{
|
||
return 0.5f * tanhf(0.5f * x) + 0.5f;
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <int K_, bool Enable_>
|
||
struct TopKIdx
|
||
{
|
||
// by default, empty
|
||
};
|
||
|
||
template <int K_>
|
||
struct TopKIdx<K_, true>
|
||
{
|
||
static constexpr int K = K_;
|
||
int32_t val[K];
|
||
};
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <int K, typename Type>
|
||
__device__ void reduceTopK(cg::thread_block_tile<WarpSize> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
|
||
Type value, int32_t idx, Type minValue)
|
||
{
|
||
static_assert(K > 0, "Top K must have K > 0");
|
||
static_assert(K < WarpSize, "Top K must have K < WarpSize");
|
||
using RedType = TopKRedType<Type>;
|
||
RedType topK{value, idx};
|
||
typename RedType::TypeCmp packedMax{};
|
||
#pragma unroll
|
||
for (int kk = 0; kk < K; ++kk)
|
||
{
|
||
topK = kk > 0 && packedMax == topK.compVal ? RedType{minValue, idx} : topK;
|
||
// get the next largest value
|
||
packedMax = topK.reduce(warp);
|
||
RedType::unpack(out[kk], outIdx[kk], packedMax);
|
||
}
|
||
};
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
#define TOPK_SWAP(I, J) \
|
||
{ \
|
||
auto pairMin = min(topK[I].compVal, topK[J].compVal); \
|
||
auto pairMax = max(topK[I].compVal, topK[J].compVal); \
|
||
topK[I].compVal = pairMax; \
|
||
topK[J].compVal = pairMin; \
|
||
}
|
||
|
||
template <int K, typename Type, int N, bool IsSorted = false>
|
||
__device__ void reduceTopK(cg::thread_block_tile<WarpSize> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
|
||
Type (&value)[N], int32_t (&idx)[N], Type minValue)
|
||
{
|
||
static_assert(K > 0, "Top K must have K > 0");
|
||
static_assert(K < WarpSize, "Top K must have K < WarpSize");
|
||
static_assert(N > 0, "Top K must have N > 1");
|
||
static_assert(N <= K, "Top K must have N < K");
|
||
using RedType = TopKRedType<Type>;
|
||
RedType topK[N];
|
||
#pragma unroll
|
||
for (int nn = 0; nn < N; ++nn)
|
||
topK[nn] = RedType{value[nn], idx[nn]};
|
||
if constexpr (!IsSorted)
|
||
{
|
||
TOPK_SWAP(0, 2);
|
||
TOPK_SWAP(1, 3);
|
||
|
||
TOPK_SWAP(0, 1);
|
||
TOPK_SWAP(2, 3);
|
||
|
||
TOPK_SWAP(1, 2);
|
||
}
|
||
typename RedType::TypeCmp packedMax{};
|
||
#pragma unroll
|
||
for (int kk = 0; kk < K; ++kk)
|
||
{
|
||
bool update = kk > 0 && packedMax == topK[0].compVal;
|
||
#pragma unroll
|
||
for (int nn = 0; nn < N; ++nn)
|
||
{
|
||
topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn];
|
||
}
|
||
// get the next largest value
|
||
packedMax = topK[0].reduce(warp);
|
||
RedType::unpack(out[kk], outIdx[kk], packedMax);
|
||
}
|
||
};
|
||
|
||
#undef TOPK_SWAP
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename T>
|
||
__host__ __device__ constexpr T mulLog2(T a, T bLog2)
|
||
{
|
||
return a << bLog2;
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename T>
|
||
__host__ __device__ constexpr T divUpLog2(T a, T bLog2)
|
||
{
|
||
return ((a + (1 << bLog2) - 1) >> bLog2);
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename T>
|
||
__host__ __device__ constexpr T divUpMulLog2(T a, T bLog2)
|
||
{
|
||
return mulLog2<T>(divUpLog2<T>(a, bLog2), bLog2);
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
__host__ __device__ constexpr int32_t getBits(int32_t value, int idx)
|
||
{
|
||
int mask = idx == 0 ? 0x000000FF : idx == 1 ? 0x0000FF00 : idx == 2 ? 0x00FF0000 : 0xFF000000;
|
||
return (value & mask) >> (idx * 8);
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <bool IsZero = false>
|
||
__host__ __device__ constexpr void setBits(int32_t& value, int32_t newBits, int idx)
|
||
{
|
||
if constexpr (!IsZero)
|
||
{
|
||
int mask = idx == 0 ? 0xFFFFFF00 : idx == 1 ? 0xFFFF00FF : idx == 2 ? 0xFF00FFFF : 0x00FFFFFF;
|
||
value &= mask;
|
||
}
|
||
value |= (newBits << (idx * 8));
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename KernelParams>
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
__global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParams params)
|
||
{
|
||
// types used in this kernel
|
||
using TypeExpW = typename KernelParams::TypeExpW;
|
||
using TypePacked = PackedScoreIdx<TypeExpW>;
|
||
// use the default cub warp-scan, with shfl
|
||
using Scan = cub::WarpScan<int32_t>;
|
||
__shared__ typename Scan::TempStorage tempStorage;
|
||
|
||
// each thread encodes 4 experts in one `int32_t`. The assumption is that
|
||
// we don't have more than 127 tokens, but `WarpKernelMaxNumTokens` must be
|
||
// smaller than that because other approaches will be more efficient for
|
||
// 127 tokens.
|
||
static constexpr int ExpertsPerThread = sizeof(int32_t);
|
||
static_assert(WarpKernelMaxNumTokens <= 127);
|
||
// this is a full table of which token is routed to which expert.
|
||
// the assumption here is that there are no more than 128 experts.
|
||
// we use a stride of 33 instead of 32 to avoid shared memory bank conflicts.
|
||
__shared__ int32_t __attribute((aligned(128)))
|
||
smemExpertTokenCountFull[WarpKernelMaxNumTokens][WarpKernelSmemStride];
|
||
static_assert(WarpKernelSmemStride == WarpSize + 1);
|
||
static_assert(MaxNumExperts / sizeof(int32_t) <= WarpSize);
|
||
|
||
// values needed for the top-1 reduction, if required
|
||
TypeExpW minScore = TypeExpW{-INFINITY};
|
||
auto block = cg::this_thread_block();
|
||
auto warp = cg::tiled_partition<WarpSize>(block);
|
||
|
||
#pragma unroll
|
||
for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens; ++tokenIdx)
|
||
{
|
||
// reset full shared memory field to 0
|
||
smemExpertTokenCountFull[tokenIdx][threadIdx.x] = 0;
|
||
}
|
||
__syncwarp();
|
||
|
||
// then wait on primary grid
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaGridDependencySynchronize();
|
||
}
|
||
|
||
if (params.mPtrScores != nullptr)
|
||
{
|
||
// if we use `mPtrScores` as input, we need to perform the top-1 reduction
|
||
// for each token, we load the scores then use `reduceTopK` for this.
|
||
// each thread works on 4 experts, so a local reduction is done before
|
||
for (int tokenIdx = 0; tokenIdx < params.mNumTokens; ++tokenIdx)
|
||
{
|
||
auto scoreOffset = tokenIdx * params.mNumExperts;
|
||
// local reduction to get the best score for our 4 experts
|
||
TypeExpW maxScore = minScore;
|
||
int32_t maxExpertIdx{-1};
|
||
#pragma unroll
|
||
for (int ii = 0; ii < ExpertsPerThread; ++ii)
|
||
{
|
||
auto expertIdx = ii * WarpSize + threadIdx.x;
|
||
auto newScore = expertIdx < params.mNumExperts ? params.mPtrScores[scoreOffset + expertIdx] : minScore;
|
||
// note: use `>=` s.t. highest index always wins, just like in `reduceTopK`
|
||
maxExpertIdx = newScore >= maxScore ? expertIdx : maxExpertIdx;
|
||
maxScore = newScore >= maxScore ? newScore : maxScore;
|
||
}
|
||
int32_t warpMaxExpertIdx[NumTopExperts];
|
||
TypeExpW warpMaxScore[NumTopExperts];
|
||
// warp-wide reduction to get the best score for all experts
|
||
reduceTopK(warp, warpMaxScore, warpMaxExpertIdx, maxScore, maxExpertIdx, minScore);
|
||
if (cute::elect_one_sync())
|
||
{
|
||
// one thread updates the count linking token to chosen expert
|
||
auto expertTokenCount = 0;
|
||
setBits</* IsZero= */ true>(expertTokenCount, 1, warpMaxExpertIdx[0] % ExpertsPerThread);
|
||
smemExpertTokenCountFull[tokenIdx][warpMaxExpertIdx[0] / ExpertsPerThread] = expertTokenCount;
|
||
// we also compute the final score here and write it out if required
|
||
auto finalScore = TypeExpW{sigmoid_accurate(float{warpMaxScore[0]})};
|
||
if (params.mPtrExpertWeights != nullptr)
|
||
{
|
||
params.mPtrExpertWeights[tokenIdx] = finalScore;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
else
|
||
{
|
||
// if we do not have `mPtrScores` as input, we expect that `mPtrExpertWeights`
|
||
// contains the top-1 packed score and index already.
|
||
// Each thread represents a token here, and we extract the relevant score
|
||
// The assumption is that the #tokens is limited by warp-size
|
||
static_assert(WarpKernelMaxNumTokens <= WarpSize);
|
||
TypePacked scoreIdx = threadIdx.x < params.mNumTokens ? params.mPtrExpertIdx[threadIdx.x] : TypePacked{};
|
||
int32_t expertTokenCount = 0;
|
||
setBits</* IsZero= */ true>(expertTokenCount, 1, scoreIdx.idx % ExpertsPerThread);
|
||
if (threadIdx.x < params.mNumTokens)
|
||
{
|
||
smemExpertTokenCountFull[threadIdx.x][scoreIdx.idx / ExpertsPerThread] = expertTokenCount;
|
||
}
|
||
// we also compute the final score here and write it out if required
|
||
auto finalScore = TypeExpW{sigmoid_accurate(float{scoreIdx.score})};
|
||
if (params.mPtrExpertWeights != nullptr && threadIdx.x < params.mNumTokens)
|
||
{
|
||
params.mPtrExpertWeights[threadIdx.x] = finalScore;
|
||
}
|
||
}
|
||
|
||
// make the full table available to all threads
|
||
__syncwarp();
|
||
|
||
// at this point, each thread keeps a count of its 4 assigned experts in
|
||
// `expertCount`, as well as the offsets for all tokens w.r.t. these 4 experts
|
||
// in `expertOffset`.
|
||
int32_t expertCount = 0;
|
||
int32_t expertOffset[WarpKernelMaxNumTokens + 1];
|
||
#pragma unroll
|
||
for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens + 1; ++tokenIdx)
|
||
{
|
||
if (tokenIdx > params.mNumTokens)
|
||
break;
|
||
// simple reduction for `expertCount`, and scan for `expertOffset`
|
||
auto expertTokenCount = tokenIdx < params.mNumTokens ? smemExpertTokenCountFull[tokenIdx][threadIdx.x] : 0;
|
||
expertOffset[tokenIdx] = expertCount;
|
||
expertCount += expertTokenCount;
|
||
}
|
||
|
||
// at this point, we are ready for the scan across all experts to get the
|
||
// thread-wise offsets across experts
|
||
// first, we need to reduce across our 4 experts into `numCta`
|
||
int32_t numCta = 0;
|
||
#pragma unroll
|
||
for (int ii = 0; ii < ExpertsPerThread; ++ii)
|
||
{
|
||
auto count = getBits(expertCount, ii);
|
||
numCta += divUpLog2<int32_t>(count, params.mPaddingLog2);
|
||
}
|
||
// second, we perform the exclusive sum across the warp
|
||
int32_t ctaOffset;
|
||
int32_t numNonExitingCtas;
|
||
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
|
||
|
||
// finally, we perform a scan across our local experts, starting with the
|
||
// warp-wide scan result (`ctaOffset`)
|
||
auto ctaOffsetExp = ctaOffset;
|
||
#pragma unroll
|
||
for (int ii = 0; ii < ExpertsPerThread; ++ii)
|
||
{
|
||
auto count = getBits(expertCount, ii);
|
||
auto finalNumCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
|
||
auto expertIdx = threadIdx.x * ExpertsPerThread + ii;
|
||
// during the scan for expert offsets, we can already write out
|
||
// both `mPtrCtaIdxXyToBatchIdx` and `mPtrCtaIdxXyToMnLimit`
|
||
for (int cta = 0; cta < finalNumCta; ++cta)
|
||
{
|
||
params.mPtrCtaIdxXyToBatchIdx[ctaOffsetExp + cta] = expertIdx;
|
||
params.mPtrCtaIdxXyToMnLimit[ctaOffsetExp + cta]
|
||
= min(mulLog2<int32_t>(ctaOffsetExp + cta + 1, params.mPaddingLog2),
|
||
mulLog2<int32_t>(ctaOffsetExp, params.mPaddingLog2) + count);
|
||
}
|
||
ctaOffsetExp += finalNumCta;
|
||
}
|
||
|
||
// at this point, we can write out padded count from the warp-aggregate
|
||
if (cute::elect_one_sync())
|
||
{
|
||
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
|
||
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
|
||
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
|
||
}
|
||
|
||
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
|
||
// we can trigger the next kernel at this point
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaTriggerProgrammaticLaunchCompletion();
|
||
}
|
||
#endif
|
||
|
||
// at this point, all values for offsets are ready, except the final offsets
|
||
// within the padded index (`permutedIdx`)
|
||
// for this, we perform a scan similar to the one directly after the warp-scan:
|
||
// here, we keep the local offset for each of the thread's experts in a field
|
||
// of registers
|
||
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||
int32_t finalExpertOffset[ExpertsPerThread];
|
||
finalExpertOffset[0] = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
|
||
#pragma unroll
|
||
for (int ii = 1; ii < ExpertsPerThread; ++ii)
|
||
{
|
||
finalExpertOffset[ii]
|
||
= finalExpertOffset[ii - 1] + divUpMulLog2<int32_t>(getBits(expertCount, ii - 1), params.mPaddingLog2);
|
||
}
|
||
|
||
#pragma unroll
|
||
for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens; ++tokenIdx)
|
||
{
|
||
// at this point, we can calculate the final index:
|
||
// we simply loop over all tokens, and all experts assigned to this thread.
|
||
// For each pair, we determine whether that token was routed to that expert
|
||
// based on whether the offset for that token changed.
|
||
// we can then easily compute the final `expertIdx` and `permutedIdx` relative
|
||
// to this token and expert, and write them out.
|
||
if (tokenIdx >= params.mNumTokens)
|
||
break;
|
||
|
||
#pragma unroll
|
||
for (int ii = 0; ii < ExpertsPerThread; ++ii)
|
||
{
|
||
// determine whether the offset for this expert and token changes
|
||
auto localOffsetToken = getBits(expertOffset[tokenIdx], ii);
|
||
auto isTokenRouted = getBits(expertOffset[tokenIdx + 1], ii) > localOffsetToken;
|
||
// the expert index of this expert
|
||
auto expertIdx = threadIdx.x * ExpertsPerThread + ii;
|
||
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
// the permuted index: we add the local offset relative to this expert and token
|
||
// to the global offset from the scan for this expert
|
||
auto permutedIdx = isLocalExpert ? finalExpertOffset[ii] + localOffsetToken : int32_t{-1};
|
||
// write out `mPtrExpandedIdxToPermutedIdx` if required
|
||
if (params.mPtrExpandedIdxToPermutedIdx != nullptr && isTokenRouted)
|
||
{
|
||
params.mPtrExpandedIdxToPermutedIdx[tokenIdx] = permutedIdx;
|
||
}
|
||
// write out `mPtrPermutedIdxToTokenIdx` if required
|
||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert && isTokenRouted)
|
||
{
|
||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
#else
|
||
__global__ void routingIndicesWarpKernel(KernelParams params)
|
||
{
|
||
assert(false && "routingIndicesWarpKernel is only supported on SM90+ architectures");
|
||
}
|
||
#endif
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename KernelParams>
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads)
|
||
routingIndicesClusterKernel(KernelParams params)
|
||
{
|
||
// number of experts is bounded by number of threads
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads];
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads];
|
||
// number of tokens/expanded idx is bounded by total number of warps
|
||
using TypeExpW = typename KernelParams::TypeExpW;
|
||
using TypePacked = PackedScoreIdx<TypeExpW>;
|
||
__shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps];
|
||
// Needed for the exclusive sum of token offsets.
|
||
// Note: the scan might include more bins than needed, with bin counts of 0 to pad
|
||
using Scan = cub::BlockScan<int32_t, NumThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
||
__shared__ typename Scan::TempStorage tempStorage;
|
||
// Number of threads in the cluster.
|
||
static constexpr int NumThreadsPerCluster = NumThreads * NumBlocksPerCluster;
|
||
// same as max num tokens
|
||
static constexpr int MaxExpandedIdxPerThread
|
||
= (MaxNumTokensSingleCluster * NumTopExperts + NumThreadsPerCluster - 1) / NumThreadsPerCluster;
|
||
|
||
uint32_t const clusterBlockRank = blockIdx.x;
|
||
uint32_t const clusterThreadIdx = NumThreads * clusterBlockRank + threadIdx.x;
|
||
|
||
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
||
int32_t const laneIdx = cutlass::arch::LaneId();
|
||
|
||
auto expandedIdxSize = params.mNumTokens * NumTopExperts;
|
||
// TODO(mjoux): expand to more tokens (possibly)
|
||
auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx;
|
||
auto scoreOffset = warpTokenIdx * params.mNumExperts;
|
||
bool validToken = warpTokenIdx < params.mNumTokens;
|
||
TypeExpW minScore = TypeExpW{-INFINITY};
|
||
|
||
auto block = cg::this_thread_block();
|
||
auto warp = cg::tiled_partition<WarpSize>(block);
|
||
|
||
// pre-fill the counts with 0
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
smemExpertCount[threadIdx.x] = 0;
|
||
}
|
||
__syncthreads();
|
||
|
||
// then wait on primary grid
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaGridDependencySynchronize();
|
||
}
|
||
|
||
if (params.mPtrScores != nullptr)
|
||
{
|
||
TypeExpW maxScore = minScore;
|
||
int32_t maxExpertIdx{-1};
|
||
// in this case, each warp represents a token
|
||
// we then exchange all token max scores, s.t. afterwards, each thread
|
||
// represents a token
|
||
if (validToken)
|
||
{
|
||
#pragma unroll
|
||
for (int i = 0; i < MaxNumExperts / WarpSize; ++i)
|
||
{
|
||
auto expertIdx = i * WarpSize + laneIdx;
|
||
auto newScore = expertIdx < params.mNumExperts ? params.mPtrScores[scoreOffset + expertIdx] : minScore;
|
||
// note: use `>=` s.t. highest index always wins, just like in `reduceTopK`
|
||
maxExpertIdx = newScore >= maxScore ? expertIdx : maxExpertIdx;
|
||
maxScore = newScore >= maxScore ? newScore : maxScore;
|
||
}
|
||
int32_t warpMaxExpertIdx[NumTopExperts];
|
||
TypeExpW warpMaxScore[NumTopExperts];
|
||
reduceTopK(warp, warpMaxScore, warpMaxExpertIdx, maxScore, maxExpertIdx, minScore);
|
||
if (cute::elect_one_sync())
|
||
{
|
||
TypePacked packedScore{warpMaxScore[0], static_cast<int16_t>(warpMaxExpertIdx[0])};
|
||
smemPackedScoreIdx[warpIdx] = packedScore;
|
||
}
|
||
}
|
||
// make packed scores available to all threads in cluster
|
||
__cluster_barrier_arrive();
|
||
__cluster_barrier_wait();
|
||
}
|
||
|
||
// each thread keeps some number of "expanded indexes" assigned to it
|
||
// note that expanded indexes simply represent tokens here.
|
||
// for each of these, we keep the associated expert and offset within expert in registers
|
||
int32_t expertIndexes[MaxExpandedIdxPerThread];
|
||
int32_t expertOffsets[MaxExpandedIdxPerThread];
|
||
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||
// In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
|
||
// time, and branch between a fast path without bound checks and a slow path with bound checks.
|
||
// TODO(mjoux): potentially add this back for perf tuning
|
||
// int constexpr IterStride = 4;
|
||
// static_assert(MaxExpandedIdxPerThread % IterStride == 0);
|
||
|
||
// Define a lambda to avoid code duplication in both branches.
|
||
auto loopBody = [&](int ii, int expandedIdx)
|
||
{
|
||
TypePacked scoreIdx;
|
||
if (params.mPtrScores != nullptr)
|
||
{
|
||
TypePacked const* remoteSmem
|
||
= cg::cluster_group::map_shared_rank(smemPackedScoreIdx, expandedIdx / NumWarps);
|
||
scoreIdx = remoteSmem[expandedIdx % NumWarps];
|
||
}
|
||
else
|
||
{
|
||
scoreIdx = params.mPtrExpertIdx[expandedIdx];
|
||
}
|
||
expertIndexes[ii] = scoreIdx.idx;
|
||
// check whether this expert is local to our GPU at all and ignore if not
|
||
auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + scoreIdx.idx, 1) : 0;
|
||
auto finalScore = TypeExpW{sigmoid_accurate(float{scoreIdx.score})};
|
||
if (params.mPtrExpertWeights != nullptr)
|
||
{
|
||
params.mPtrExpertWeights[expandedIdx] = finalScore;
|
||
}
|
||
};
|
||
|
||
if (clusterThreadIdx < expandedIdxSize)
|
||
{
|
||
loopBody(0, clusterThreadIdx);
|
||
}
|
||
|
||
// Make local histogram (token counts per expert) available to all threads in the cluster.
|
||
__cluster_barrier_arrive();
|
||
__cluster_barrier_wait();
|
||
|
||
//
|
||
// Each thread now represents one expert
|
||
//
|
||
|
||
// Total number of tokens for this expert.
|
||
int32_t count = 0;
|
||
// Per-expert offset for this block.
|
||
int32_t blockExpertOffset = 0;
|
||
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
// Get the histogram bin from each rank for this expert.
|
||
int32_t expertCounts[NumBlocksPerCluster];
|
||
#pragma unroll
|
||
for (int rank = 0; rank < NumBlocksPerCluster; rank++)
|
||
{
|
||
int32_t const* remoteSmem = cg::cluster_group::map_shared_rank(smemExpertCount, rank);
|
||
expertCounts[rank] = rank * NumWarps < params.mNumTokens ? remoteSmem[threadIdx.x] : 0;
|
||
}
|
||
|
||
// Compute an exclusive prefix sum of the block-local count.
|
||
#pragma unroll
|
||
for (int rank = 0; rank < NumBlocksPerCluster; rank++)
|
||
{
|
||
if (rank == clusterBlockRank)
|
||
{
|
||
blockExpertOffset = count;
|
||
}
|
||
count += expertCounts[rank];
|
||
}
|
||
}
|
||
|
||
// Arrive: we do not access distributed shared memory after this point.
|
||
__cluster_barrier_arrive();
|
||
|
||
// Compute the runtime config for projections
|
||
// Weather or not an expert is local is taken into account when smemExpertCount is computed
|
||
// so we do not need to take it into account here.
|
||
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
|
||
int32_t ctaOffset;
|
||
int32_t numNonExitingCtas;
|
||
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
|
||
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
// Strided loop to share this work between blocks.
|
||
for (int32_t cta = clusterBlockRank; cta < numCta; cta += NumBlocksPerCluster)
|
||
{
|
||
const int32_t localExpertIdx
|
||
= (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
|
||
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
|
||
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta]
|
||
= min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
|
||
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count);
|
||
}
|
||
|
||
// get the padded offset associated with this expert
|
||
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
|
||
|
||
// write expert offsets to shared
|
||
smemExpertOffset[threadIdx.x] = offset + blockExpertOffset;
|
||
}
|
||
|
||
// write out padded count
|
||
if (clusterBlockRank == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync())
|
||
{
|
||
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
|
||
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
|
||
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
|
||
}
|
||
|
||
// make expert offsets available to all threads
|
||
__syncthreads();
|
||
|
||
// Wait: we cannot exit while other blocks may be accessing the current block's shared memory.
|
||
// Note: I observed a perf benefit to doing this before the final loop so the compiler can
|
||
// implement break with EXIT.
|
||
__cluster_barrier_wait();
|
||
|
||
// trigger the secondary kernel when using PDL
|
||
// We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx,
|
||
// mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens
|
||
// TODO: this is not sufficient to ensure visibility in the next kernel!
|
||
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaTriggerProgrammaticLaunchCompletion();
|
||
}
|
||
#endif
|
||
|
||
// each thread has the same "expanded indexes" assigned to it as above
|
||
// at this point, we know the final offsets of experts and the offsets within
|
||
// experts, which allows writing the final index values
|
||
|
||
#pragma unroll
|
||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii)
|
||
{
|
||
auto expandedIdx = static_cast<int32_t>(clusterThreadIdx) + ii * NumThreadsPerCluster;
|
||
if (expandedIdx >= expandedIdxSize)
|
||
{
|
||
break;
|
||
}
|
||
auto expertIdx = expertIndexes[ii];
|
||
// check whether this expert is local to our GPU at all
|
||
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
auto tokenIdx = expandedIdx / NumTopExperts;
|
||
auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1};
|
||
if (params.mPtrExpandedIdxToPermutedIdx != nullptr)
|
||
{
|
||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||
}
|
||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
|
||
{
|
||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||
}
|
||
}
|
||
}
|
||
#else
|
||
__global__ void routingIndicesClusterKernel(KernelParams params)
|
||
{
|
||
assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures");
|
||
}
|
||
#endif
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
// this kernel is needed in case we have scores as input for the histogram kernel
|
||
template <typename KernelParams>
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresKernel(KernelParams params)
|
||
{
|
||
using TypeExpW = typename KernelParams::TypeExpW;
|
||
using TypeExpWVec = std::conditional_t<sizeof(TypeExpW) == 2, float2, float4>;
|
||
using TypePacked = PackedScoreIdx<TypeExpW>;
|
||
static constexpr int VecSize = MaxNumExperts / WarpSize;
|
||
// we assume that #experts is a multiple of 4, so VecSize must be 4.
|
||
static_assert(VecSize == 4);
|
||
|
||
int32_t const laneIdx = cutlass::arch::LaneId();
|
||
int32_t const warpIdx = threadIdx.x / WarpSize;
|
||
int32_t const globalWarpIdx = blockIdx.x * NumWarpsHist + warpIdx;
|
||
int32_t const globalWarpStride = gridDim.x * NumWarpsHist;
|
||
TypeExpW minScore = TypeExpW{-INFINITY};
|
||
auto block = cg::this_thread_block();
|
||
auto warp = cg::tiled_partition<WarpSize>(block);
|
||
|
||
// Wait on primary grid and trigger secondary kernel.
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaGridDependencySynchronize();
|
||
cudaTriggerProgrammaticLaunchCompletion();
|
||
}
|
||
|
||
// in this case, each warp represents a token, and we use a grid-stride loop
|
||
// over all warps/tokens
|
||
for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride)
|
||
{
|
||
TypeExpW maxScore = minScore;
|
||
int32_t maxExpertIdx{-1};
|
||
auto scoreOffset = (tokenIdx * params.mNumExperts) / VecSize + laneIdx;
|
||
|
||
TypeExpW allScores[VecSize];
|
||
auto* ptrAllScores = reinterpret_cast<TypeExpWVec const*>(params.mPtrScores);
|
||
*reinterpret_cast<TypeExpWVec*>(allScores) = ptrAllScores[scoreOffset];
|
||
|
||
#pragma unroll
|
||
for (int i = 0; i < VecSize; ++i)
|
||
{
|
||
auto expertIdx = laneIdx * VecSize + i;
|
||
auto newScore = expertIdx < params.mNumExperts ? allScores[i] : minScore;
|
||
// note: use `>=` s.t. highest index always wins, just like in `reduceTopK`
|
||
maxExpertIdx = newScore >= maxScore ? expertIdx : maxExpertIdx;
|
||
maxScore = newScore >= maxScore ? newScore : maxScore;
|
||
}
|
||
int32_t warpMaxExpertIdx[NumTopExperts];
|
||
TypeExpW warpMaxScore[NumTopExperts];
|
||
reduceTopK(warp, warpMaxScore, warpMaxExpertIdx, maxScore, maxExpertIdx, minScore);
|
||
if (cute::elect_one_sync())
|
||
{
|
||
TypePacked packedScore{warpMaxScore[0], static_cast<int16_t>(warpMaxExpertIdx[0])};
|
||
params.mPtrExpertIdx[tokenIdx] = packedScore;
|
||
}
|
||
}
|
||
}
|
||
#else
|
||
__global__ void routingIndicesHistogramScoresKernel(KernelParams params)
|
||
{
|
||
assert(false && "routingIndicesHistogramScoresKernel is only supported on SM90+ architectures");
|
||
}
|
||
#endif
|
||
|
||
// Two-step approach (if number of tokens exceed limits of what cluster / cooperative launch
|
||
// variants can handle): in order to minimize the amount of data to exchange through global memory,
|
||
// we will compute the local histograms in smem twice: the first kernel will get us the total number
|
||
// of tokens per expert. The second kernel will use the smem and L2 atomics to get corresponding
|
||
// element and tile offsets.
|
||
//
|
||
// Note: the histogram calculation could also be fused with routingMainKernel, but this might be
|
||
// inefficient if we have one CTA per token doing a single global atomic.
|
||
template <typename KernelParams>
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramKernel(KernelParams params)
|
||
{
|
||
using TypeExpW = typename KernelParams::TypeExpW;
|
||
using TypePacked = PackedScoreIdx<TypeExpW>;
|
||
// number of experts is bounded by number of threads
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreadsHist];
|
||
|
||
// For unrolling.
|
||
uint32_t constexpr NumEltsPerThread = 8;
|
||
|
||
// Pre-fill the counts with 0
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
smemExpertCount[threadIdx.x] = 0;
|
||
}
|
||
__syncthreads();
|
||
|
||
// Wait on primary grid and trigger secondary kernel.
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaGridDependencySynchronize();
|
||
cudaTriggerProgrammaticLaunchCompletion();
|
||
}
|
||
|
||
uint32_t const expandedIdxSize = params.mNumTokens * NumTopExperts;
|
||
uint32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||
|
||
uint32_t const gridBlockOffset = blockIdx.x * NumThreadsHist;
|
||
uint32_t const gridStride = gridDim.x * NumThreadsHist;
|
||
|
||
// Define a lambda to avoid code duplication in branches.
|
||
auto loopBody = [&](int expandedIdx)
|
||
{
|
||
TypePacked scoreIdx = params.mPtrExpertIdx[expandedIdx];
|
||
// check whether this expert is local to our GPU at all and ignore if not
|
||
auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
if (isLocalExpert)
|
||
{
|
||
atomicAdd(&smemExpertCount[scoreIdx.idx], 1);
|
||
}
|
||
auto finalScore = TypeExpW{sigmoid_accurate(float{scoreIdx.score})};
|
||
if (params.mPtrExpertWeights != nullptr)
|
||
{
|
||
params.mPtrExpertWeights[expandedIdx] = finalScore;
|
||
}
|
||
};
|
||
|
||
// Grid-stride loop.
|
||
for (uint32_t expandedIdx0 = gridBlockOffset * NumEltsPerThread; expandedIdx0 < expandedIdxSize;
|
||
expandedIdx0 += gridStride * NumEltsPerThread)
|
||
{
|
||
// Fast path if bound checks aren't necessary
|
||
if (expandedIdx0 + NumEltsPerThread * NumThreadsHist <= expandedIdxSize)
|
||
{
|
||
#pragma unroll
|
||
for (uint32_t ii = 0; ii < NumEltsPerThread; ii++)
|
||
{
|
||
uint32_t expandedIdx = expandedIdx0 + ii * NumThreadsHist + threadIdx.x;
|
||
loopBody(expandedIdx);
|
||
}
|
||
}
|
||
else
|
||
{
|
||
for (uint32_t expandedIdx = expandedIdx0 + threadIdx.x; expandedIdx < expandedIdxSize;
|
||
expandedIdx += NumThreadsHist)
|
||
{
|
||
loopBody(expandedIdx);
|
||
}
|
||
}
|
||
}
|
||
__syncthreads();
|
||
|
||
//
|
||
// Each thread now represents one expert
|
||
//
|
||
|
||
// Reduce histograms with atomics.
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
|
||
atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount);
|
||
}
|
||
}
|
||
#else
|
||
__global__ void routingIndicesHistogramKernel(KernelParams params)
|
||
{
|
||
assert(false && "routingIndicesHistogramKernel is only supported on SM90+ architectures");
|
||
}
|
||
#endif
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename KernelParams>
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(KernelParams params)
|
||
{
|
||
using TypeExpW = typename KernelParams::TypeExpW;
|
||
using TypePacked = PackedScoreIdx<TypeExpW>;
|
||
// number of experts is bounded by number of threads
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreadsHist];
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreadsHist];
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertTileOffset[NumThreadsHist];
|
||
// needed for the exclusive sum of token offsets
|
||
using Scan = cub::BlockScan<int32_t, NumThreadsHist, cub::BLOCK_SCAN_WARP_SCANS>;
|
||
__shared__ typename Scan::TempStorage tempStorage;
|
||
static constexpr int MaxExpandedIdxPerThread = NumEltsPerOffsetTilePerThread;
|
||
static constexpr int MaxExpandedIdxPerBlock = NumThreadsHist * MaxExpandedIdxPerThread;
|
||
|
||
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
||
|
||
uint32_t const expandedIdxSize = params.mNumTokens * NumTopExperts;
|
||
uint32_t const numTiles = (expandedIdxSize + MaxExpandedIdxPerBlock - 1) / (MaxExpandedIdxPerBlock);
|
||
|
||
// Wait on primary grid.
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaGridDependencySynchronize();
|
||
}
|
||
|
||
// The expert offsets are common to all tiles of all blocks.
|
||
// Load the histogram, scan it and write offsets to shared memory.
|
||
// Note: the scan is redundant in all CTAs. Would it make sense to use an intermediate kernel for
|
||
// the scan, with PDL?
|
||
|
||
//
|
||
// Each thread represents one expert.
|
||
//
|
||
|
||
// Get total count for this expert.
|
||
int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0;
|
||
|
||
// Compute the runtime config for projections
|
||
// Weather or not an expert is local is taken into account when the histogram is computed
|
||
// so we do not need to take it into account here.
|
||
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
|
||
int32_t ctaOffset;
|
||
int32_t numNonExitingCtas;
|
||
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
|
||
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
// Get the padded offset associated with this expert
|
||
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
|
||
|
||
// Write expert offsets to shared
|
||
smemExpertOffset[threadIdx.x] = offset;
|
||
}
|
||
|
||
// Sync to make expert offsets available to all threads.
|
||
__syncthreads();
|
||
|
||
// The first block writes out padded count
|
||
if (blockIdx.x == 0 && warpIdx == NumWarpsHist - 1 && cute::elect_one_sync())
|
||
{
|
||
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
|
||
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
|
||
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
|
||
}
|
||
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
// Strided loop to share this work between blocks.
|
||
for (int32_t cta = blockIdx.x; cta < numCta; cta += gridDim.x)
|
||
{
|
||
const int32_t localExpertIdx
|
||
= (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
|
||
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
|
||
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta]
|
||
= min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
|
||
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count);
|
||
}
|
||
}
|
||
|
||
//
|
||
// Now loop on indices and compute offsets.
|
||
//
|
||
|
||
// Grid-stride loop on 1D "tiles" of input indices.
|
||
for (uint32_t tileIdx = blockIdx.x; tileIdx < numTiles; tileIdx += gridDim.x)
|
||
{
|
||
if (tileIdx > 0)
|
||
{
|
||
// Sync for safe reuse of smem buffers.
|
||
__syncthreads();
|
||
}
|
||
|
||
// Pre-fill the counts with 0
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
smemExpertCount[threadIdx.x] = 0;
|
||
}
|
||
__syncthreads();
|
||
|
||
// each thread keeps has some number of "expanded indexes" assigned to it
|
||
// for each of these, we keep the associated expert and offset within expert in registers
|
||
int32_t expertIndexes[MaxExpandedIdxPerThread];
|
||
int32_t expertOffsets[MaxExpandedIdxPerThread];
|
||
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||
|
||
// Define a lambda to avoid code duplication in branches.
|
||
auto loopBody = [&](int ii, int expandedIdx)
|
||
{
|
||
TypePacked scoreIdx = params.mPtrExpertIdx[expandedIdx];
|
||
expertIndexes[ii] = scoreIdx.idx;
|
||
// check whether this expert is local to our GPU at all and ignore if not
|
||
auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + scoreIdx.idx, 1) : 0;
|
||
};
|
||
|
||
// For all tiles but the last, all indices are in bounds.
|
||
if (tileIdx < numTiles - 1)
|
||
{
|
||
#pragma unroll
|
||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
|
||
{
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||
loopBody(ii, expandedIdx);
|
||
}
|
||
}
|
||
else
|
||
{
|
||
// For the last tile, we need to exit the loop when out of bounds.
|
||
// In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
|
||
// time, and branch between a fast path without bound checks and a slow path with bound checks
|
||
int constexpr IterStride = 4;
|
||
static_assert(MaxExpandedIdxPerThread % IterStride == 0);
|
||
|
||
#pragma unroll
|
||
for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride)
|
||
{
|
||
// Whether it's safe to do multiple iterations without bound checks.
|
||
bool const takeFastPath
|
||
= tileIdx * MaxExpandedIdxPerBlock + (ii0 + IterStride) * NumThreadsHist <= expandedIdxSize;
|
||
if (takeFastPath)
|
||
{
|
||
#pragma unroll
|
||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||
{
|
||
int const ii = ii0 + jj;
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||
loopBody(ii, expandedIdx);
|
||
}
|
||
}
|
||
else
|
||
{
|
||
bool doBreak = false;
|
||
#pragma unroll
|
||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||
{
|
||
int const ii = ii0 + jj;
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||
if (expandedIdx >= expandedIdxSize)
|
||
{
|
||
doBreak = true;
|
||
break;
|
||
}
|
||
loopBody(ii, expandedIdx);
|
||
}
|
||
if (doBreak)
|
||
{
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Make local histogram (token counts per expert) available to all threads in the block.
|
||
__syncthreads();
|
||
|
||
//
|
||
// Each thread now represents one expert
|
||
//
|
||
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
// Add the local bin count to the common bin count and get a per-CTA offset. We use the second
|
||
// half of the histogram buffer for this histogram, because the first half already holds the
|
||
// reduced histogram from the previous kernel.
|
||
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
|
||
int32_t const tileExpertOffset
|
||
= atomicAdd(¶ms.mPtrExpertCounts[params.mNumExperts + threadIdx.x], localExpertCount);
|
||
|
||
// Make per-expert tile offsets available to all threads in the block.
|
||
smemExpertTileOffset[threadIdx.x] = tileExpertOffset + smemExpertOffset[threadIdx.x];
|
||
}
|
||
__syncthreads();
|
||
|
||
// Add tile offset and element offset and write to global memory.
|
||
auto storeLoopBody = [&](int ii, int expandedIdx)
|
||
{
|
||
int32_t expertIdx = expertIndexes[ii];
|
||
// check whether this expert is local to our GPU at all
|
||
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
auto tokenIdx = expandedIdx / NumTopExperts;
|
||
auto permutedIdx = isLocalExpert ? (expertOffsets[ii] + smemExpertTileOffset[expertIdx]) : int32_t{-1};
|
||
if (params.mPtrExpandedIdxToPermutedIdx != nullptr)
|
||
{
|
||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||
}
|
||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
|
||
{
|
||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||
}
|
||
};
|
||
// Bound checks only in last tile.
|
||
if (tileIdx < numTiles - 1)
|
||
{
|
||
#pragma unroll
|
||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
|
||
{
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||
storeLoopBody(ii, expandedIdx);
|
||
}
|
||
}
|
||
else
|
||
{
|
||
#pragma unroll
|
||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
|
||
{
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||
if (expandedIdx >= expandedIdxSize)
|
||
{
|
||
break;
|
||
}
|
||
storeLoopBody(ii, expandedIdx);
|
||
}
|
||
}
|
||
}
|
||
|
||
// Trigger secondary kernel.
|
||
// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
|
||
// dependency sync.
|
||
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaTriggerProgrammaticLaunchCompletion();
|
||
}
|
||
#endif
|
||
}
|
||
#else
|
||
__global__ void routingIndicesOffsetsKernel(KernelParams params)
|
||
{
|
||
assert(false && "routingIndicesOffsetsKernel is only supported on SM90+ architectures");
|
||
}
|
||
#endif
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
void run(Data const& data, void* stream)
|
||
{
|
||
TLLM_CHECK_WITH_INFO(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr,
|
||
"Routing kernel requires at least one input parameter");
|
||
TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr
|
||
&& data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr,
|
||
"Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers");
|
||
TLLM_CHECK_WITH_INFO(
|
||
data.mTopK == NumTopExperts, "Routing kernel expects %d topK experts (for now)", NumTopExperts);
|
||
TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxNumExperts,
|
||
"Routing kernel expects #experts %d to be at most max #experts %d", data.mNumExperts, MaxNumExperts);
|
||
static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads");
|
||
static_assert(MaxNumExperts <= NumThreadsHist, "#experts must be bounded by #threads");
|
||
TLLM_CHECK_WITH_INFO(
|
||
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
|
||
TLLM_CHECK_WITH_INFO(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2);
|
||
|
||
if (data.mPtrPermutedIdxToTokenIdx != nullptr)
|
||
{
|
||
// need to set all values to -1 before running the kernel
|
||
auto maxPermutedSize
|
||
= data.mNumTokens * data.mTopK + (data.mNumExperts << data.mPaddingLog2) - data.mNumExperts;
|
||
// note that a value of -1 per byte works for any size of signed integer
|
||
// to set each full value to the logical value -1
|
||
TLLM_CUDA_CHECK(cudaMemsetAsync(data.mPtrPermutedIdxToTokenIdx, -1,
|
||
static_cast<size_t>(maxPermutedSize) * sizeof(int32_t), (cudaStream_t) stream));
|
||
}
|
||
|
||
bool const useSingleWarp = (data.mPtrScores == nullptr && data.mNumTokens <= WarpKernelMaxNumTokens)
|
||
|| data.mNumTokens < WarpKernelMaxNumTokens;
|
||
bool const useSingleCluster
|
||
= data.mNumTokens <= (data.mPtrScores != nullptr ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster);
|
||
if (!useSingleCluster)
|
||
{
|
||
TLLM_CHECK_WITH_INFO(
|
||
data.mPtrExpertIdx != nullptr, "When #tokens is large, `mPtrExpertIdx` is a required input.");
|
||
TLLM_CHECK_WITH_INFO(
|
||
data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input.");
|
||
// Reset the global histograms (not used in single-cluster code path).
|
||
TLLM_CUDA_CHECK(cudaMemsetAsync(data.mPtrExpertCounts, 0,
|
||
static_cast<size_t>(2 * data.mNumExperts) * sizeof(int32_t), (cudaStream_t) stream));
|
||
}
|
||
|
||
if (useSingleWarp)
|
||
{
|
||
LAUNCH_EXPW_ONLY(data,
|
||
/*coopLaunch=*/false, routingIndicesWarpKernel, 1, WarpSize,
|
||
/*smemSize=*/0, // No dynamic smem
|
||
stream);
|
||
}
|
||
else if (useSingleCluster)
|
||
{
|
||
LAUNCH_EXPW_ONLY(data,
|
||
/*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads,
|
||
/*smemSize=*/0, // No dynamic smem
|
||
stream);
|
||
}
|
||
else
|
||
{
|
||
const uint32_t expandedIdxSize = data.mNumTokens * NumTopExperts;
|
||
|
||
const uint32_t histogramEltsPerBlock = 8 * NumThreadsHist;
|
||
const uint32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreadsHist;
|
||
|
||
// Limit grid size (all kernels use a grid-stride loop).
|
||
const uint32_t maxNumBlocks = 1024;
|
||
|
||
int const numBlocksHistogram
|
||
= std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks);
|
||
int const numBlocksOffsets
|
||
= std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks);
|
||
|
||
if (data.mPtrScores != nullptr)
|
||
{
|
||
LAUNCH_EXPW_ONLY(data,
|
||
/*coopLaunch=*/false, routingIndicesHistogramScoresKernel, maxNumBlocks, NumThreadsHist,
|
||
/*smemSize=*/0, // No dynamic smem
|
||
stream);
|
||
}
|
||
LAUNCH_EXPW_ONLY(data,
|
||
/*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, NumThreadsHist,
|
||
/*smemSize=*/0, // No dynamic smem
|
||
stream);
|
||
LAUNCH_EXPW_ONLY(data,
|
||
/*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, NumThreadsHist,
|
||
/*smemSize=*/0, // No dynamic smem
|
||
stream);
|
||
}
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
} // namespace routingLlama4
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
namespace routingQwen3
|
||
{
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
namespace cg = cooperative_groups;
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
static constexpr int NumThreads = 1024;
|
||
static constexpr int NumThreadsHist = 256;
|
||
static constexpr int NumBlocksPerCluster = 8;
|
||
static constexpr int WarpSize = 32;
|
||
static constexpr int NumWarps = NumThreads / WarpSize;
|
||
static constexpr int NumWarpsHist = NumThreadsHist / WarpSize;
|
||
static constexpr int NumTopExperts = 8;
|
||
static constexpr int MaxNumExperts = 128;
|
||
static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads;
|
||
static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps;
|
||
|
||
// Performance tuning knob.
|
||
static constexpr int NumEltsPerOffsetTilePerThread = 8;
|
||
|
||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL))
|
||
#define TLLM_GEN_ENABLE_FAST_REDUX
|
||
#endif
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename TypeExpW_>
|
||
struct TopKRedType
|
||
{
|
||
using TypeExpW = TypeExpW_;
|
||
static_assert(std::is_same_v<TypeExpW, float> || std::is_same_v<TypeExpW, cutlass::bfloat16_t>,
|
||
"Top K reduction only implemented for float and Bf16");
|
||
using TypeCmp = std::conditional_t<sizeof(TypeExpW) >= 4, double, float>;
|
||
static constexpr int64_t Mask64 = 0x000000000000FFFF;
|
||
static constexpr int32_t Mask32 = 0x0000FFFF;
|
||
|
||
TypeCmp compVal;
|
||
|
||
static __host__ __device__ inline TypeCmp makeCmpVal(TypeExpW val, int32_t idx = 0)
|
||
{
|
||
auto cmpVal = TypeCmp{val};
|
||
TypeCmp cmpValWithIdx;
|
||
if constexpr (sizeof(TypeExpW) >= 4)
|
||
{
|
||
auto cmpValIdx64 = reinterpret_cast<int64_t&>(cmpVal) | (Mask64& int64_t{idx});
|
||
cmpValWithIdx = reinterpret_cast<TypeCmp&>(cmpValIdx64);
|
||
}
|
||
else
|
||
{
|
||
auto cmpValIdx32 = reinterpret_cast<int32_t&>(cmpVal) | (Mask32 & idx);
|
||
cmpValWithIdx = reinterpret_cast<TypeCmp&>(cmpValIdx32);
|
||
}
|
||
return cmpValWithIdx;
|
||
}
|
||
|
||
static __host__ __device__ inline void unpack(TypeExpW& val, int32_t& idx, TypeCmp cmp)
|
||
{
|
||
if constexpr (sizeof(TypeExpW) >= 4)
|
||
{
|
||
idx = static_cast<int32_t>(reinterpret_cast<int64_t&>(cmp) & Mask64);
|
||
auto val64 = reinterpret_cast<int64_t&>(cmp) & ~Mask64;
|
||
val = static_cast<float>(reinterpret_cast<double&>(val64));
|
||
}
|
||
else
|
||
{
|
||
idx = reinterpret_cast<int32_t&>(cmp) & Mask32;
|
||
auto val32 = reinterpret_cast<int32_t&>(cmp) >> 16;
|
||
val = TypeExpW::bitcast(reinterpret_cast<uint16_t&>(val32));
|
||
}
|
||
}
|
||
|
||
__host__ __device__ TopKRedType() = default;
|
||
|
||
__host__ __device__ TopKRedType(TypeExpW val, int32_t idx)
|
||
: compVal(makeCmpVal(val, idx))
|
||
{
|
||
}
|
||
|
||
__host__ __device__ operator TypeCmp() const noexcept
|
||
{
|
||
return compVal;
|
||
}
|
||
|
||
__device__ inline TypeCmp reduce(cg::thread_block_tile<WarpSize> const& warp)
|
||
{
|
||
#if defined(TLLM_GEN_ENABLE_FAST_REDUX)
|
||
static constexpr bool UseCg = false;
|
||
#else
|
||
static constexpr bool UseCg = true;
|
||
#endif
|
||
if constexpr (UseCg || sizeof(TypeExpW) >= 4)
|
||
{
|
||
return cg::reduce(warp, compVal, cg::greater<TypeCmp>{});
|
||
}
|
||
else
|
||
{
|
||
float result;
|
||
asm("redux.sync.max.f32 %0, %1, 0xffffffff;\n" : "=f"(result) : "f"(compVal));
|
||
return result;
|
||
}
|
||
}
|
||
};
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <int K_, bool Enable_>
|
||
struct TopKIdx
|
||
{
|
||
// by default, empty
|
||
};
|
||
|
||
template <int K_>
|
||
struct TopKIdx<K_, true>
|
||
{
|
||
static constexpr int K = K_;
|
||
int32_t val[K];
|
||
};
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <int K, typename Type>
|
||
__device__ void reduceTopK(cg::thread_block_tile<WarpSize> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
|
||
Type value, int32_t idx, Type minValue)
|
||
{
|
||
static_assert(K > 0, "Top K must have K > 0");
|
||
static_assert(K < WarpSize, "Top K must have K < WarpSize");
|
||
using RedType = TopKRedType<Type>;
|
||
RedType topK{value, idx};
|
||
typename RedType::TypeCmp packedMax{};
|
||
#pragma unroll
|
||
for (int kk = 0; kk < K; ++kk)
|
||
{
|
||
topK = kk > 0 && packedMax == topK.compVal ? RedType{minValue, idx} : topK;
|
||
// get the next largest value
|
||
packedMax = topK.reduce(warp);
|
||
RedType::unpack(out[kk], outIdx[kk], packedMax);
|
||
}
|
||
};
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
#define TOPK_SWAP(I, J) \
|
||
{ \
|
||
auto pairMin = min(topK[I].compVal, topK[J].compVal); \
|
||
auto pairMax = max(topK[I].compVal, topK[J].compVal); \
|
||
topK[I].compVal = pairMax; \
|
||
topK[J].compVal = pairMin; \
|
||
}
|
||
|
||
template <int K, typename Type, int N, bool IsSorted = false>
|
||
__device__ void reduceTopK(cg::thread_block_tile<WarpSize> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
|
||
Type (&value)[N], int32_t (&idx)[N], Type minValue)
|
||
{
|
||
static_assert(K > 0, "Top K must have K > 0");
|
||
static_assert(K < WarpSize, "Top K must have K < WarpSize");
|
||
static_assert(N > 0, "Top K must have N > 1");
|
||
// static_assert(N <= K, "Top K must have N < K");
|
||
using RedType = TopKRedType<Type>;
|
||
RedType topK[N];
|
||
#pragma unroll
|
||
for (int nn = 0; nn < N; ++nn)
|
||
{
|
||
topK[nn] = RedType{value[nn], idx[nn]};
|
||
}
|
||
|
||
if constexpr (!IsSorted)
|
||
{
|
||
TOPK_SWAP(0, 2);
|
||
TOPK_SWAP(1, 3);
|
||
|
||
TOPK_SWAP(0, 1);
|
||
TOPK_SWAP(2, 3);
|
||
|
||
TOPK_SWAP(1, 2);
|
||
}
|
||
typename RedType::TypeCmp packedMax{};
|
||
#pragma unroll
|
||
for (int kk = 0; kk < K; ++kk)
|
||
{
|
||
bool update = kk > 0 && packedMax == topK[0].compVal;
|
||
#pragma unroll
|
||
for (int nn = 0; nn < N; ++nn)
|
||
{
|
||
topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn];
|
||
}
|
||
// get the next largest value
|
||
packedMax = topK[0].reduce(warp);
|
||
RedType::unpack(out[kk], outIdx[kk], packedMax);
|
||
}
|
||
};
|
||
|
||
#undef TOPK_SWAP
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename T>
|
||
__host__ __device__ constexpr T mulLog2(T a, T bLog2)
|
||
{
|
||
return a << bLog2;
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename T>
|
||
__host__ __device__ constexpr T divUpLog2(T a, T bLog2)
|
||
{
|
||
return ((a + (1 << bLog2) - 1) >> bLog2);
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename T>
|
||
__host__ __device__ constexpr T divUpMulLog2(T a, T bLog2)
|
||
{
|
||
return mulLog2<T>(divUpLog2<T>(a, bLog2), bLog2);
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
__host__ __device__ constexpr int32_t getBits(int32_t value, int idx)
|
||
{
|
||
int mask = idx == 0 ? 0x000000FF : idx == 1 ? 0x0000FF00 : idx == 2 ? 0x00FF0000 : 0xFF000000;
|
||
return (value & mask) >> (idx * 8);
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <bool IsZero = false>
|
||
__host__ __device__ constexpr void setBits(int32_t& value, int32_t newBits, int idx)
|
||
{
|
||
if constexpr (!IsZero)
|
||
{
|
||
int mask = idx == 0 ? 0xFFFFFF00 : idx == 1 ? 0xFFFF00FF : idx == 2 ? 0xFF00FFFF : 0x00FFFFFF;
|
||
value &= mask;
|
||
}
|
||
value |= (newBits << (idx * 8));
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename TypeExpW, int VecSize>
|
||
__device__ void calcSoftmax(cg::thread_block_tile<WarpSize> const& warp, TypeExpW (&scores)[VecSize])
|
||
{
|
||
TypeExpW maxScore = TypeExpW{-INFINITY};
|
||
TypeExpW sumScore = TypeExpW{0.f};
|
||
|
||
// Get the max score for each token
|
||
for (int i = 0; i < VecSize; ++i)
|
||
{
|
||
maxScore = scores[i] >= maxScore ? scores[i] : maxScore;
|
||
}
|
||
maxScore = cg::reduce(warp, maxScore, cg::greater<TypeExpW>());
|
||
|
||
// Get the summation of scores for each token
|
||
#pragma unroll
|
||
for (int i = 0; i < VecSize; ++i)
|
||
{
|
||
scores[i] = static_cast<TypeExpW>(exp(scores[i] - maxScore));
|
||
sumScore += scores[i];
|
||
}
|
||
sumScore = cg::reduce(warp, sumScore, cg::plus<TypeExpW>());
|
||
|
||
// Normalize the scores
|
||
#pragma unroll
|
||
for (int i = 0; i < VecSize; ++i)
|
||
{
|
||
scores[i] = static_cast<TypeExpW>(scores[i] / sumScore);
|
||
}
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename TypeExpW>
|
||
__device__ TypeExpW calcSoftmax(
|
||
cg::thread_block_tile<WarpSize> const& warp, TypeExpW score, int32_t laneIdx, int32_t NumTopExperts)
|
||
{
|
||
TypeExpW maxScore = TypeExpW{-INFINITY};
|
||
if (laneIdx < NumTopExperts)
|
||
{
|
||
maxScore = score >= maxScore ? score : maxScore;
|
||
}
|
||
maxScore = cg::reduce(warp, maxScore, cg::greater<TypeExpW>());
|
||
|
||
float sumScore = float{0.f};
|
||
float newScore;
|
||
// Get the summation of scores for each token
|
||
if (laneIdx < NumTopExperts)
|
||
{
|
||
newScore = static_cast<float>(score) - static_cast<float>(maxScore);
|
||
newScore = static_cast<float>(exp(newScore));
|
||
sumScore += newScore;
|
||
}
|
||
sumScore = cg::reduce(warp, sumScore, cg::plus<float>());
|
||
|
||
if (laneIdx < NumTopExperts)
|
||
{
|
||
score = static_cast<TypeExpW>(newScore / sumScore);
|
||
}
|
||
|
||
return score;
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename KernelParams, bool DoSoftmaxBeforeTopK = false>
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads)
|
||
routingIndicesClusterKernel(KernelParams params)
|
||
{
|
||
// number of experts is bounded by number of threads
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads];
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads];
|
||
// number of tokens/expanded idx is bounded by total number of warps
|
||
using TypeExpW = typename KernelParams::TypeExpW;
|
||
|
||
using BaseType = std::conditional_t<DoSoftmaxBeforeTopK, float, TypeExpW>;
|
||
using TypePacked = PackedScoreIdx<BaseType>;
|
||
|
||
__shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * NumTopExperts];
|
||
// Needed for the exclusive sum of token offsets.
|
||
// Note: the scan might include more bins than needed, with bin counts of 0 to pad
|
||
using Scan = cub::BlockScan<int32_t, NumThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
||
__shared__ typename Scan::TempStorage tempStorage;
|
||
// Number of threads in the cluster.
|
||
static constexpr int NumThreadsPerCluster = NumThreads * NumBlocksPerCluster;
|
||
// same as max num tokens*num top experts
|
||
static constexpr int MaxExpandedIdxPerThread
|
||
= (MaxNumTokensSingleCluster * NumTopExperts + NumThreadsPerCluster - 1) / NumThreadsPerCluster;
|
||
|
||
uint32_t const clusterBlockRank = blockIdx.x;
|
||
uint32_t const clusterThreadIdx = NumThreads * clusterBlockRank + threadIdx.x;
|
||
|
||
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
||
int32_t const laneIdx = cutlass::arch::LaneId();
|
||
|
||
auto expandedIdxSize = params.mNumTokens * NumTopExperts;
|
||
auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx;
|
||
auto scoreOffset = warpTokenIdx * params.mNumExperts;
|
||
bool validToken = warpTokenIdx < params.mNumTokens;
|
||
|
||
auto block = cg::this_thread_block();
|
||
auto warp = cg::tiled_partition<WarpSize>(block);
|
||
|
||
// pre-fill the counts with 0
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
smemExpertCount[threadIdx.x] = 0;
|
||
}
|
||
__syncthreads();
|
||
|
||
// then wait on primary grid
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaGridDependencySynchronize();
|
||
}
|
||
|
||
// initialize the mPtrPermutedIdxToTokenIdx
|
||
if (params.mPtrPermutedIdxToTokenIdx != nullptr)
|
||
{
|
||
int32_t permIdxToTokenIdxNum
|
||
= (params.mNumTokens * NumTopExperts + (params.mNumExperts << params.mPaddingLog2) - params.mNumExperts);
|
||
for (int32_t i = clusterThreadIdx; i < permIdxToTokenIdxNum; i += NumThreadsPerCluster)
|
||
{
|
||
params.mPtrPermutedIdxToTokenIdx[i] = -1;
|
||
}
|
||
// A cluster synchronization is performed prior to setting mPtrPermutedIdxToTokenIdx at the end of the kernel.
|
||
// Don't need to use __threadfence() here.
|
||
}
|
||
|
||
if (params.mPtrScores != nullptr)
|
||
{
|
||
// in this case, each warp represents a token
|
||
BaseType score[MaxNumExperts / WarpSize];
|
||
int32_t idx[MaxNumExperts / WarpSize];
|
||
|
||
BaseType warpTopKScore[NumTopExperts];
|
||
int32_t warpTopKExpertIdx[NumTopExperts];
|
||
|
||
BaseType minScore = BaseType{-INFINITY};
|
||
if (validToken)
|
||
{
|
||
for (int i = 0; i < MaxNumExperts / WarpSize; i++)
|
||
{
|
||
auto expertIdx = i * WarpSize + laneIdx;
|
||
auto newScore = expertIdx < params.mNumExperts
|
||
? static_cast<BaseType>(params.mPtrScores[scoreOffset + expertIdx])
|
||
: minScore;
|
||
score[i] = newScore;
|
||
idx[i] = expertIdx;
|
||
}
|
||
|
||
if constexpr (DoSoftmaxBeforeTopK)
|
||
{
|
||
calcSoftmax(warp, score);
|
||
}
|
||
|
||
// Get the top-k scores and their corresponding expert indices
|
||
reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, score, idx, minScore);
|
||
|
||
// Normalize the scores
|
||
if constexpr (DoSoftmaxBeforeTopK)
|
||
{
|
||
float sum = float{1.f};
|
||
if (params.mNormTopkProb)
|
||
{
|
||
sum = static_cast<float>(laneIdx < NumTopExperts ? warpTopKScore[laneIdx] : 0);
|
||
sum = cg::reduce(warp, sum, cg::plus<float>());
|
||
}
|
||
if (laneIdx < NumTopExperts)
|
||
{
|
||
warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum;
|
||
smemPackedScoreIdx[warpIdx * NumTopExperts + laneIdx]
|
||
= TypePacked{warpTopKScore[laneIdx], static_cast<int16_t>(warpTopKExpertIdx[laneIdx])};
|
||
}
|
||
}
|
||
else
|
||
{
|
||
auto score = calcSoftmax(
|
||
warp, laneIdx < NumTopExperts ? warpTopKScore[laneIdx] : minScore, laneIdx, NumTopExperts);
|
||
if (laneIdx < NumTopExperts)
|
||
{
|
||
warpTopKScore[laneIdx] = score;
|
||
smemPackedScoreIdx[warpIdx * NumTopExperts + laneIdx]
|
||
= TypePacked{warpTopKScore[laneIdx], static_cast<int16_t>(warpTopKExpertIdx[laneIdx])};
|
||
}
|
||
}
|
||
} // end if (validToken)
|
||
|
||
// make packed scores available to all threads in cluster
|
||
__cluster_barrier_arrive();
|
||
__cluster_barrier_wait();
|
||
}
|
||
|
||
// each thread keeps some number of "expanded indexes" assigned to it
|
||
// note that expanded indexes simply represent tokens here.
|
||
// for each of these, we keep the associated expert and offset within expert in registers
|
||
int32_t expertIndexes[MaxExpandedIdxPerThread];
|
||
int32_t expertOffsets[MaxExpandedIdxPerThread];
|
||
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||
// In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
|
||
// time, and branch between a fast path without bound checks and a slow path with bound checks.
|
||
// TODO(mjoux): potentially add this back for perf tuning
|
||
// int constexpr IterStride = 4;
|
||
// static_assert(MaxExpandedIdxPerThread % IterStride == 0);
|
||
|
||
// Define a lambda to avoid code duplication in both branches.
|
||
auto loopBody = [&](int ii, int expandedIdx)
|
||
{
|
||
TypePacked scoreIdx;
|
||
if (params.mPtrScores != nullptr)
|
||
{
|
||
TypePacked const* remoteSmem
|
||
= cg::cluster_group::map_shared_rank(smemPackedScoreIdx, expandedIdx / (NumWarps * NumTopExperts));
|
||
scoreIdx = remoteSmem[expandedIdx % (NumWarps * NumTopExperts)];
|
||
}
|
||
else
|
||
{
|
||
scoreIdx = TypePacked{static_cast<BaseType>(params.mPtrExpertIdx[expandedIdx].score),
|
||
static_cast<int16_t>(params.mPtrExpertIdx[expandedIdx].idx)};
|
||
}
|
||
expertIndexes[ii] = scoreIdx.idx;
|
||
// check whether this expert is local to our GPU at all and ignore if not
|
||
auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + scoreIdx.idx, 1) : 0;
|
||
if (params.mPtrExpertWeights != nullptr)
|
||
{
|
||
params.mPtrExpertWeights[expandedIdx] = static_cast<TypeExpW>(scoreIdx.score);
|
||
}
|
||
};
|
||
|
||
if (clusterThreadIdx < expandedIdxSize)
|
||
{
|
||
loopBody(0, clusterThreadIdx);
|
||
}
|
||
|
||
// Make local histogram (token counts per expert) available to all threads in the cluster.
|
||
__cluster_barrier_arrive();
|
||
__cluster_barrier_wait();
|
||
|
||
//
|
||
// Each thread now represents one expert
|
||
//
|
||
|
||
// Total number of tokens for this expert.
|
||
int32_t count = 0;
|
||
// Per-expert offset for this block.
|
||
int32_t blockExpertOffset = 0;
|
||
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
// Get the histogram bin from each rank for this expert.
|
||
int32_t expertCounts[NumBlocksPerCluster];
|
||
#pragma unroll
|
||
for (int rank = 0; rank < NumBlocksPerCluster; rank++)
|
||
{
|
||
int32_t const* remoteSmem = cg::cluster_group::map_shared_rank(smemExpertCount, rank);
|
||
expertCounts[rank] = rank * NumWarps < params.mNumTokens ? remoteSmem[threadIdx.x] : 0;
|
||
}
|
||
|
||
// Compute an exclusive prefix sum of the block-local count.
|
||
#pragma unroll
|
||
for (int rank = 0; rank < NumBlocksPerCluster; rank++)
|
||
{
|
||
if (rank == clusterBlockRank)
|
||
{
|
||
blockExpertOffset = count;
|
||
}
|
||
count += expertCounts[rank];
|
||
}
|
||
}
|
||
|
||
// Arrive: we do not access distributed shared memory after this point.
|
||
__cluster_barrier_arrive();
|
||
|
||
// Compute the runtime config for projections
|
||
// Whether or not an expert is local is taken into account when smemExpertCount is computed
|
||
// so we do not need to take it into account here.
|
||
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
|
||
int32_t ctaOffset;
|
||
int32_t numNonExitingCtas;
|
||
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
|
||
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
// Strided loop to share this work between blocks.
|
||
for (int32_t cta = clusterBlockRank; cta < numCta; cta += NumBlocksPerCluster)
|
||
{
|
||
const int32_t localExpertIdx
|
||
= (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
|
||
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
|
||
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta]
|
||
= min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
|
||
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count);
|
||
}
|
||
|
||
// get the padded offset associated with this expert
|
||
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
|
||
|
||
// write expert offsets to shared
|
||
smemExpertOffset[threadIdx.x] = offset + blockExpertOffset;
|
||
}
|
||
|
||
// write out padded count
|
||
if (clusterBlockRank == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync())
|
||
{
|
||
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
|
||
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
|
||
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
|
||
}
|
||
|
||
// make expert offsets available to all threads
|
||
__syncthreads();
|
||
|
||
// Wait: we cannot exit while other blocks may be accessing the current block's shared memory.
|
||
// Note: I observed a perf benefit to doing this before the final loop so the compiler can
|
||
// implement break with EXIT.
|
||
__cluster_barrier_wait();
|
||
|
||
// trigger the secondary kernel when using PDL
|
||
// We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx,
|
||
// mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens
|
||
// TODO: this is not sufficient to ensure visibility in the next kernel!
|
||
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
cudaTriggerProgrammaticLaunchCompletion();
|
||
}
|
||
#endif
|
||
|
||
// each thread has the same "expanded indexes" assigned to it as above
|
||
// at this point, we know the final offsets of experts and the offsets within
|
||
// experts, which allows writing the final index values
|
||
|
||
#pragma unroll
|
||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii)
|
||
{
|
||
auto expandedIdx = static_cast<int32_t>(clusterThreadIdx) + ii * NumThreadsPerCluster;
|
||
if (expandedIdx >= expandedIdxSize)
|
||
{
|
||
break;
|
||
}
|
||
auto expertIdx = expertIndexes[ii];
|
||
// check whether this expert is local to our GPU at all
|
||
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
auto tokenIdx = expandedIdx / NumTopExperts;
|
||
auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1};
|
||
if (params.mPtrExpandedIdxToPermutedIdx != nullptr)
|
||
{
|
||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||
}
|
||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
|
||
{
|
||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||
}
|
||
}
|
||
}
|
||
#else
|
||
__global__ void __launch_bounds__(NumThreads) routingIndicesClusterKernel(KernelParams /* params */)
|
||
{
|
||
assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures");
|
||
}
|
||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
// this kernel is needed in case we have scores as input for the histogram kernel
|
||
template <typename KernelParams, bool DoSoftmaxBeforeTopK = true>
|
||
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresKernel(KernelParams params)
|
||
{
|
||
using TypeExpW = typename KernelParams::TypeExpW;
|
||
|
||
using BaseType = std::conditional_t<DoSoftmaxBeforeTopK, float, TypeExpW>;
|
||
|
||
static constexpr int VecSize = MaxNumExperts / WarpSize;
|
||
// we assume that #experts is a multiple of 4, so VecSize must be 4.
|
||
static_assert(VecSize == 4);
|
||
|
||
int32_t const laneIdx = cutlass::arch::LaneId();
|
||
int32_t const warpIdx = threadIdx.x / WarpSize;
|
||
int32_t const globalWarpIdx = blockIdx.x * NumWarpsHist + warpIdx;
|
||
int32_t const globalWarpStride = gridDim.x * NumWarpsHist;
|
||
BaseType minScore = BaseType{-INFINITY};
|
||
auto block = cg::this_thread_block();
|
||
auto warp = cg::tiled_partition<WarpSize>(block);
|
||
|
||
// Wait on primary grid.
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
cudaGridDependencySynchronize();
|
||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
}
|
||
|
||
// initialize the mPtrPermutedIdxToTokenIdx
|
||
int32_t globalThreadIdx = globalWarpIdx * WarpSize + laneIdx;
|
||
int32_t globalThreadStride = globalWarpStride * WarpSize;
|
||
if (params.mPtrPermutedIdxToTokenIdx != nullptr)
|
||
{
|
||
int32_t permIdxToTokenIdxNum
|
||
= (params.mNumTokens * NumTopExperts + (params.mNumExperts << params.mPaddingLog2) - params.mNumExperts);
|
||
for (int32_t i = globalThreadIdx; i < permIdxToTokenIdxNum; i += globalThreadStride)
|
||
{
|
||
params.mPtrPermutedIdxToTokenIdx[i] = -1;
|
||
}
|
||
}
|
||
|
||
// initialize the mPtrExpertCounts
|
||
if (params.mPtrExpertCounts != nullptr)
|
||
{
|
||
int32_t expertCountsNum = 2 * params.mNumExperts;
|
||
for (int32_t i = globalThreadIdx; i < expertCountsNum; i += globalThreadStride)
|
||
{
|
||
params.mPtrExpertCounts[i] = 0;
|
||
}
|
||
}
|
||
|
||
// Trigger secondary kernel.
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
cudaTriggerProgrammaticLaunchCompletion();
|
||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
}
|
||
|
||
// in this case, each warp represents a token, and we use a grid-stride loop
|
||
// over all warps/tokens
|
||
for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride)
|
||
{
|
||
auto scoreOffset = tokenIdx * params.mNumExperts;
|
||
BaseType allScores[VecSize];
|
||
int32_t allExpertIdx[VecSize];
|
||
BaseType warpTopKScore[NumTopExperts];
|
||
int32_t warpTopKExpertIdx[NumTopExperts];
|
||
|
||
//@TODO:optimize this part with vectorized loading
|
||
|
||
#pragma unroll
|
||
for (int i = 0; i < VecSize; ++i)
|
||
{
|
||
auto expertIdx = i * WarpSize + laneIdx;
|
||
auto newScore = expertIdx < params.mNumExperts
|
||
? static_cast<BaseType>(params.mPtrScores[scoreOffset + expertIdx])
|
||
: minScore;
|
||
allScores[i] = newScore;
|
||
allExpertIdx[i] = expertIdx;
|
||
}
|
||
|
||
if constexpr (DoSoftmaxBeforeTopK)
|
||
{
|
||
calcSoftmax(warp, allScores);
|
||
}
|
||
|
||
// Get the top-k scores and their corresponding expert indices
|
||
reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, allScores, allExpertIdx, minScore);
|
||
__syncwarp(); //@TODO: check the synchronization
|
||
|
||
// Normalize the scores
|
||
if constexpr (DoSoftmaxBeforeTopK)
|
||
{
|
||
float sum = float{1.f};
|
||
if (params.mNormTopkProb)
|
||
{
|
||
sum = static_cast<float>(laneIdx < NumTopExperts ? warpTopKScore[laneIdx] : 0);
|
||
sum = cg::reduce(warp, sum, cg::plus<float>());
|
||
}
|
||
if (laneIdx < NumTopExperts)
|
||
{
|
||
warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum;
|
||
}
|
||
}
|
||
else
|
||
{
|
||
auto score = laneIdx < NumTopExperts ? warpTopKScore[laneIdx] : minScore;
|
||
score = calcSoftmax(warp, score, laneIdx, NumTopExperts);
|
||
if (laneIdx < NumTopExperts)
|
||
{
|
||
warpTopKScore[laneIdx] = score;
|
||
}
|
||
}
|
||
for (int i = laneIdx; i < NumTopExperts; i += WarpSize)
|
||
{
|
||
PackedScoreIdx<TypeExpW> packedScore{
|
||
static_cast<TypeExpW>(warpTopKScore[i]), static_cast<int16_t>(warpTopKExpertIdx[i])};
|
||
params.mPtrExpertIdx[tokenIdx * NumTopExperts + i] = packedScore;
|
||
}
|
||
}
|
||
}
|
||
|
||
// Two-step approach (if number of tokens exceed limits of what cluster / cooperative launch
|
||
// variants can handle): in order to minimize the amount of data to exchange through global memory,
|
||
// we will compute the local histograms in smem twice: the first kernel will get us the total number
|
||
// of tokens per expert. The second kernel will use the smem and L2 atomics to get corresponding
|
||
// element and tile offsets.
|
||
//
|
||
// Note: the histogram calculation could also be fused with routingMainKernel, but this might be
|
||
// inefficient if we have one CTA per token doing a single global atomic.
|
||
template <typename KernelParams>
|
||
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramKernel(KernelParams params)
|
||
{
|
||
using TypeExpW = typename KernelParams::TypeExpW;
|
||
|
||
using TypePacked = PackedScoreIdx<float>;
|
||
// number of experts is bounded by number of threads
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreadsHist];
|
||
|
||
// For unrolling.
|
||
uint32_t constexpr NumEltsPerThread = 8;
|
||
|
||
// Pre-fill the counts with 0
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
smemExpertCount[threadIdx.x] = 0;
|
||
}
|
||
__syncthreads();
|
||
|
||
// Wait on primary grid and trigger secondary kernel.
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
cudaGridDependencySynchronize();
|
||
cudaTriggerProgrammaticLaunchCompletion();
|
||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
}
|
||
|
||
uint32_t const expandedIdxSize = params.mNumTokens * NumTopExperts;
|
||
uint32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||
|
||
uint32_t const gridBlockOffset = blockIdx.x * NumThreadsHist;
|
||
uint32_t const gridStride = gridDim.x * NumThreadsHist;
|
||
|
||
// Define a lambda to avoid code duplication in branches.
|
||
auto loopBody = [&](int expandedIdx)
|
||
{
|
||
PackedScoreIdx<TypeExpW> scoreIdx = params.mPtrExpertIdx[expandedIdx];
|
||
// check whether this expert is local to our GPU at all and ignore if not
|
||
auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
if (isLocalExpert)
|
||
{
|
||
atomicAdd(&smemExpertCount[scoreIdx.idx], 1);
|
||
}
|
||
|
||
if (params.mPtrExpertWeights != nullptr)
|
||
{
|
||
params.mPtrExpertWeights[expandedIdx] = static_cast<TypeExpW>(scoreIdx.score);
|
||
}
|
||
};
|
||
|
||
// Grid-stride loop.
|
||
for (uint32_t expandedIdx0 = gridBlockOffset * NumEltsPerThread; expandedIdx0 < expandedIdxSize;
|
||
expandedIdx0 += gridStride * NumEltsPerThread)
|
||
{
|
||
// Fast path if bound checks aren't necessary
|
||
if (expandedIdx0 + NumEltsPerThread * NumThreadsHist <= expandedIdxSize)
|
||
{
|
||
#pragma unroll
|
||
for (uint32_t ii = 0; ii < NumEltsPerThread; ii++)
|
||
{
|
||
uint32_t expandedIdx = expandedIdx0 + ii * NumThreadsHist + threadIdx.x;
|
||
loopBody(expandedIdx);
|
||
}
|
||
}
|
||
else
|
||
{
|
||
for (uint32_t expandedIdx = expandedIdx0 + threadIdx.x; expandedIdx < expandedIdxSize;
|
||
expandedIdx += NumThreadsHist)
|
||
{
|
||
loopBody(expandedIdx);
|
||
}
|
||
}
|
||
}
|
||
__syncthreads();
|
||
|
||
//
|
||
// Each thread now represents one expert
|
||
//
|
||
|
||
// Reduce histograms with atomics.
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
|
||
atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount);
|
||
}
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
template <typename KernelParams>
|
||
__global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(KernelParams params)
|
||
{
|
||
using TypeExpW = typename KernelParams::TypeExpW;
|
||
using TypePacked = PackedScoreIdx<TypeExpW>;
|
||
|
||
// number of experts is bounded by number of threads
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreadsHist];
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreadsHist];
|
||
__shared__ int32_t __attribute((aligned(128))) smemExpertTileOffset[NumThreadsHist];
|
||
// needed for the exclusive sum of token offsets
|
||
using Scan = cub::BlockScan<int32_t, NumThreadsHist, cub::BLOCK_SCAN_WARP_SCANS>;
|
||
__shared__ typename Scan::TempStorage tempStorage;
|
||
static constexpr int MaxExpandedIdxPerThread = NumEltsPerOffsetTilePerThread;
|
||
static constexpr int MaxExpandedIdxPerBlock = NumThreadsHist * MaxExpandedIdxPerThread;
|
||
|
||
int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0);
|
||
|
||
uint32_t const expandedIdxSize = params.mNumTokens * NumTopExperts;
|
||
uint32_t const numTiles = (expandedIdxSize + MaxExpandedIdxPerBlock - 1) / (MaxExpandedIdxPerBlock);
|
||
|
||
// Wait on primary grid.
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
cudaGridDependencySynchronize();
|
||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
}
|
||
|
||
// The expert offsets are common to all tiles of all blocks.
|
||
// Load the histogram, scan it and write offsets to shared memory.
|
||
// Note: the scan is redundant in all CTAs. Would it make sense to use an intermediate kernel for
|
||
// the scan, with PDL?
|
||
|
||
//
|
||
// Each thread represents one expert.
|
||
//
|
||
|
||
// Get total count for this expert.
|
||
int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0;
|
||
|
||
// Compute the runtime config for projections
|
||
// Whether or not an expert is local is taken into account when the histogram is computed
|
||
// so we do not need to take it into account here.
|
||
const int32_t numCta = divUpLog2<int32_t>(count, params.mPaddingLog2);
|
||
int32_t ctaOffset;
|
||
int32_t numNonExitingCtas;
|
||
Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas);
|
||
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
// Get the padded offset associated with this expert
|
||
const int32_t offset = mulLog2<int32_t>(ctaOffset, params.mPaddingLog2);
|
||
|
||
// Write expert offsets to shared
|
||
smemExpertOffset[threadIdx.x] = offset;
|
||
}
|
||
|
||
// Sync to make expert offsets available to all threads.
|
||
__syncthreads();
|
||
|
||
// The first block writes out padded count
|
||
if (blockIdx.x == 0 && warpIdx == NumWarpsHist - 1 && cute::elect_one_sync())
|
||
{
|
||
const int32_t permutedIdxSize = mulLog2<int32_t>(numNonExitingCtas, params.mPaddingLog2);
|
||
params.mPtrPermutedIdxSize[0] = permutedIdxSize;
|
||
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
|
||
}
|
||
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
// Strided loop to share this work between blocks.
|
||
for (int32_t cta = blockIdx.x; cta < numCta; cta += gridDim.x)
|
||
{
|
||
const int32_t localExpertIdx
|
||
= (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2;
|
||
params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx;
|
||
params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta]
|
||
= min(mulLog2<int32_t>(ctaOffset + cta + 1, params.mPaddingLog2),
|
||
mulLog2<int32_t>(ctaOffset, params.mPaddingLog2) + count);
|
||
}
|
||
}
|
||
|
||
//
|
||
// Now loop on indices and compute offsets.
|
||
//
|
||
|
||
// Grid-stride loop on 1D "tiles" of input indices.
|
||
for (uint32_t tileIdx = blockIdx.x; tileIdx < numTiles; tileIdx += gridDim.x)
|
||
{
|
||
if (tileIdx > 0)
|
||
{
|
||
// Sync for safe reuse of smem buffers.
|
||
__syncthreads();
|
||
}
|
||
|
||
// Pre-fill the counts with 0
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
smemExpertCount[threadIdx.x] = 0;
|
||
}
|
||
__syncthreads();
|
||
|
||
// each thread keeps has some number of "expanded indexes" assigned to it
|
||
// for each of these, we keep the associated expert and offset within expert in registers
|
||
int32_t expertIndexes[MaxExpandedIdxPerThread];
|
||
int32_t expertOffsets[MaxExpandedIdxPerThread];
|
||
auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2;
|
||
|
||
// Define a lambda to avoid code duplication in branches.
|
||
auto loopBody = [&](int ii, int expandedIdx)
|
||
{
|
||
PackedScoreIdx<TypeExpW> scoreIdx = params.mPtrExpertIdx[expandedIdx];
|
||
expertIndexes[ii] = scoreIdx.idx;
|
||
// check whether this expert is local to our GPU at all and ignore if not
|
||
auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + scoreIdx.idx, 1) : 0;
|
||
};
|
||
|
||
// For all tiles but the last, all indices are in bounds.
|
||
if (tileIdx < numTiles - 1)
|
||
{
|
||
#pragma unroll
|
||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
|
||
{
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||
loopBody(ii, expandedIdx);
|
||
}
|
||
}
|
||
else
|
||
{
|
||
// For the last tile, we need to exit the loop when out of bounds.
|
||
// In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a
|
||
// time, and branch between a fast path without bound checks and a slow path with bound checks
|
||
int constexpr IterStride = 4;
|
||
static_assert(MaxExpandedIdxPerThread % IterStride == 0);
|
||
|
||
#pragma unroll
|
||
for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride)
|
||
{
|
||
// Whether it's safe to do multiple iterations without bound checks.
|
||
bool const takeFastPath
|
||
= tileIdx * MaxExpandedIdxPerBlock + (ii0 + IterStride) * NumThreadsHist <= expandedIdxSize;
|
||
if (takeFastPath)
|
||
{
|
||
#pragma unroll
|
||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||
{
|
||
int const ii = ii0 + jj;
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||
loopBody(ii, expandedIdx);
|
||
}
|
||
}
|
||
else
|
||
{
|
||
bool doBreak = false;
|
||
#pragma unroll
|
||
for (int32_t jj = 0; jj < IterStride; jj++)
|
||
{
|
||
int const ii = ii0 + jj;
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||
if (expandedIdx >= expandedIdxSize)
|
||
{
|
||
doBreak = true;
|
||
break;
|
||
}
|
||
loopBody(ii, expandedIdx);
|
||
}
|
||
if (doBreak)
|
||
{
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Make local histogram (token counts per expert) available to all threads in the block.
|
||
__syncthreads();
|
||
|
||
//
|
||
// Each thread now represents one expert
|
||
//
|
||
|
||
if (threadIdx.x < params.mNumExperts)
|
||
{
|
||
// Add the local bin count to the common bin count and get a per-CTA offset. We use the second
|
||
// half of the histogram buffer for this histogram, because the first half already holds the
|
||
// reduced histogram from the previous kernel.
|
||
int32_t const localExpertCount = smemExpertCount[threadIdx.x];
|
||
int32_t const tileExpertOffset
|
||
= atomicAdd(¶ms.mPtrExpertCounts[params.mNumExperts + threadIdx.x], localExpertCount);
|
||
|
||
// Make per-expert tile offsets available to all threads in the block.
|
||
smemExpertTileOffset[threadIdx.x] = tileExpertOffset + smemExpertOffset[threadIdx.x];
|
||
}
|
||
__syncthreads();
|
||
|
||
// Add tile offset and element offset and write to global memory.
|
||
auto storeLoopBody = [&](int ii, int expandedIdx)
|
||
{
|
||
int32_t expertIdx = expertIndexes[ii];
|
||
// check whether this expert is local to our GPU at all
|
||
auto localExpertIdx = static_cast<int32_t>(expertIdx) - params.mLocalExpertsStartIdx;
|
||
auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent
|
||
&& (localExpertIdx & params.mLocalExpertsStrideLog2) == 0;
|
||
auto tokenIdx = expandedIdx / NumTopExperts;
|
||
auto permutedIdx = isLocalExpert ? (expertOffsets[ii] + smemExpertTileOffset[expertIdx]) : int32_t{-1};
|
||
if (params.mPtrExpandedIdxToPermutedIdx != nullptr)
|
||
{
|
||
params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx;
|
||
}
|
||
if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert)
|
||
{
|
||
params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx;
|
||
}
|
||
};
|
||
// Bound checks only in last tile.
|
||
if (tileIdx < numTiles - 1)
|
||
{
|
||
#pragma unroll
|
||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
|
||
{
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||
storeLoopBody(ii, expandedIdx);
|
||
}
|
||
}
|
||
else
|
||
{
|
||
#pragma unroll
|
||
for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1)
|
||
{
|
||
auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsHist + threadIdx.x;
|
||
if (expandedIdx >= expandedIdxSize)
|
||
{
|
||
break;
|
||
}
|
||
storeLoopBody(ii, expandedIdx);
|
||
}
|
||
}
|
||
}
|
||
|
||
// Trigger secondary kernel.
|
||
// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
|
||
// dependency sync.
|
||
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
|
||
if constexpr (KernelParams::UsePdl)
|
||
{
|
||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
cudaTriggerProgrammaticLaunchCompletion();
|
||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||
}
|
||
#endif
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
void run(Data const& data, void* stream)
|
||
{
|
||
TLLM_CHECK_WITH_INFO(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr,
|
||
"Routing kernel requires at least one input parameter");
|
||
TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr
|
||
&& data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr,
|
||
"Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers");
|
||
TLLM_CHECK_WITH_INFO(
|
||
data.mTopK == NumTopExperts, "Routing kernel expects %d topK experts (for now)", NumTopExperts);
|
||
TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxNumExperts,
|
||
"Routing kernel expects #experts %d to be at most max #experts %d", data.mNumExperts, MaxNumExperts);
|
||
static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads");
|
||
static_assert(MaxNumExperts <= NumThreadsHist, "#experts must be bounded by #threads");
|
||
TLLM_CHECK_WITH_INFO(
|
||
data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);
|
||
TLLM_CHECK_WITH_INFO(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got %d", data.mPaddingLog2);
|
||
|
||
bool const useSingleCluster
|
||
= data.mNumTokens <= (data.mPtrScores != nullptr ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster);
|
||
if (!useSingleCluster)
|
||
{
|
||
TLLM_CHECK_WITH_INFO(
|
||
data.mPtrExpertIdx != nullptr, "When #tokens is large, `mPtrExpertIdx` is a required input.");
|
||
TLLM_CHECK_WITH_INFO(
|
||
data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input.");
|
||
}
|
||
|
||
if (useSingleCluster)
|
||
{
|
||
LAUNCH_EXPW_QWEN3(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads,
|
||
/*smemSize=*/0, // No dynamic smem
|
||
stream);
|
||
}
|
||
else
|
||
{
|
||
uint32_t const expandedIdxSize = data.mNumTokens * NumTopExperts;
|
||
|
||
uint32_t const histogramEltsPerBlock = 8 * NumThreadsHist;
|
||
uint32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreadsHist;
|
||
|
||
// Limit grid size (all kernels use a grid-stride loop).
|
||
uint32_t const maxNumBlocks = 1024;
|
||
|
||
int const numBlocksHistogram
|
||
= std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks);
|
||
int const numBlocksOffsets
|
||
= std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks);
|
||
|
||
if (data.mPtrScores != nullptr)
|
||
{
|
||
LAUNCH_EXPW_QWEN3(data, false, routingIndicesHistogramScoresKernel, maxNumBlocks, NumThreadsHist,
|
||
/*smemSize=*/0, // No dynamic smem
|
||
stream);
|
||
}
|
||
LAUNCH_EXPW_ONLY(data, false, routingIndicesHistogramKernel, numBlocksHistogram, NumThreadsHist,
|
||
/*smemSize=*/0, // No dynamic smem
|
||
stream);
|
||
LAUNCH_EXPW_ONLY(data, false, routingIndicesOffsetsKernel, numBlocksOffsets, NumThreadsHist,
|
||
/*smemSize=*/0, // No dynamic smem
|
||
stream);
|
||
}
|
||
}
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
} // namespace routingQwen3
|
||
|
||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
||
} // namespace moe::dev
|