/* * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "DevKernel.h" #include "RoutingKernel.h" #include "RoutingKernelTopK.cuh" #include #include #include #include #include #include //////////////////////////////////////////////////////////////////////////////////////////////////// namespace moe::dev { //////////////////////////////////////////////////////////////////////////////////////////////////// namespace routing { namespace cg = cooperative_groups; //////////////////////////////////////////////////////////////////////////////////////////////////// static constexpr int WarpSize = 32; static constexpr int NumBlocksPerCluster = 8; // Performance tuning knob. static constexpr int NumEltsPerOffsetTilePerThread = 8; //////////////////////////////////////////////////////////////////////////////////////////////////// static __device__ inline float sigmoid_accurate(float x) { return 0.5f * tanhf(0.5f * x) + 0.5f; } //////////////////////////////////////////////////////////////////////////////////////////////////// template __host__ __device__ constexpr T mulLog2(T a, T bLog2) { return a << bLog2; } //////////////////////////////////////////////////////////////////////////////////////////////////// template __host__ __device__ constexpr T divUpLog2(T a, T bLog2) { return ((a + (1 << bLog2) - 1) >> bLog2); } //////////////////////////////////////////////////////////////////////////////////////////////////// template __host__ __device__ constexpr T divUpMulLog2(T a, T bLog2) { return mulLog2(divUpLog2(a, bLog2), bLog2); } //////////////////////////////////////////////////////////////////////////////////////////////////// template __host__ __device__ constexpr T mulTileN(T a, T tileN) { return a * tileN; } //////////////////////////////////////////////////////////////////////////////////////////////////// template __host__ __device__ constexpr T divUpTileN(T a, T tileN) { return (a + tileN - 1) / tileN; } //////////////////////////////////////////////////////////////////////////////////////////////////// template __host__ __device__ constexpr T divUpMulTileN(T a, T tileN) { return divUpTileN(a, tileN) * tileN; } //////////////////////////////////////////////////////////////////////////////////////////////////// __host__ __device__ constexpr int32_t getBits(int32_t value, int idx) { int mask = idx == 0 ? 0x000000FF : idx == 1 ? 0x0000FF00 : idx == 2 ? 0x00FF0000 : 0xFF000000; return (value & mask) >> (idx * 8); } //////////////////////////////////////////////////////////////////////////////////////////////////// template __host__ __device__ constexpr void setBits(int32_t& value, int32_t newBits, int idx) { if constexpr (!IsZero) { int mask = idx == 0 ? 0xFFFFFF00 : idx == 1 ? 0xFFFF00FF : idx == 2 ? 0xFF00FFFF : 0x00FFFFFF; value &= mask; } value |= (newBits << (idx * 8)); } //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ void initArr(int startIdx, int numElts, int stride, DataType* arr, DataType value) { if (arr != nullptr) { for (int i = startIdx; i < numElts; i += stride) { arr[i] = value; } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ void calcSoftmax(cg::thread_block_tile const& warp, DataType (&scores)[VecSize]) { // Compute in float to support half/bfloat16 inputs safely. float maxScore = -INFINITY; float sumScore = 0.f; // Get the max score for each token #pragma unroll for (int i = 0; i < VecSize; ++i) { float si = static_cast(scores[i]); maxScore = si >= maxScore ? si : maxScore; } maxScore = cg::reduce(warp, maxScore, cg::greater()); // Get the summation of scores for each token #pragma unroll for (int i = 0; i < VecSize; ++i) { float si = static_cast(scores[i]); float e = expf(si - maxScore); scores[i] = static_cast(e); sumScore += e; } sumScore = cg::reduce(warp, sumScore, cg::plus()); // Normalize the scores #pragma unroll for (int i = 0; i < VecSize; ++i) { float si = static_cast(scores[i]) / sumScore; scores[i] = static_cast(si); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ DataType calcSoftmax( cg::thread_block_tile const& warp, DataType score, int32_t laneIdx, int32_t NumTopExperts) { DataType maxScore = DataType{-INFINITY}; if (laneIdx < NumTopExperts) { maxScore = score >= maxScore ? score : maxScore; } maxScore = cg::reduce(warp, maxScore, cg::greater()); float sumScore = float{0.f}; float newScore; // Get the summation of scores for each token if (laneIdx < NumTopExperts) { newScore = static_cast(score) - static_cast(maxScore); newScore = static_cast(exp(newScore)); sumScore += newScore; } sumScore = cg::reduce(warp, sumScore, cg::plus()); if (laneIdx < NumTopExperts) { score = static_cast(newScore / sumScore); } return score; } //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ void routingPermutation(KernelParams params, PackedScoreIdx* smemPackedScoreIdx, int32_t const warpIdx, uint32_t const clusterBlockRank) { using OutputT = typename KernelParams::OutputT; using TypePacked = PackedScoreIdx; static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; // Number of threads in the cluster. static constexpr int NumThreadsPerCluster = NumThreads * NumBlocksPerCluster; // same as max num tokens static constexpr int MaxExpandedIdxPerThread = (MaxNumTokensSingleCluster * MaxNumTopExperts + NumThreadsPerCluster - 1) / NumThreadsPerCluster; // Needed for the exclusive sum of token offsets. // Note: the scan might include more bins than needed, with bin counts of 0 to pad using Scan = cub::BlockScan; __shared__ typename Scan::TempStorage tempStorage; uint32_t const clusterThreadIdx = NumThreads * clusterBlockRank + threadIdx.x; auto expandedIdxSize = params.mNumTokens * params.mTopK; // number of experts is bounded by number of threads __shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads]; __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads]; // pre-fill the counts with 0 if (threadIdx.x < params.mNumExperts) { smemExpertCount[threadIdx.x] = 0; } __syncthreads(); // each thread keeps some number of "expanded indexes" assigned to it // note that expanded indexes simply represent tokens here. // for each of these, we keep the associated expert and offset within expert in registers int32_t expertIndexes[MaxExpandedIdxPerThread]; int32_t expertOffsets[MaxExpandedIdxPerThread]; auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a // time, and branch between a fast path without bound checks and a slow path with bound checks. // TODO(mjoux): potentially add this back for perf tuning // int constexpr IterStride = 4; // static_assert(MaxExpandedIdxPerThread % IterStride == 0); // Define a lambda to avoid code duplication in both branches. auto loopBody = [&](int ii, int expandedIdx) { TypePacked scoreIdx; if constexpr (LoadExpertIdxFromGlobal) { if (params.mPtrTopKIds != nullptr) { scoreIdx = TypePacked{static_cast(params.mPtrTopKWeights[expandedIdx]), static_cast(params.mPtrTopKIds[expandedIdx])}; } else { scoreIdx = TypePacked{static_cast(params.mPtrTopKPacked[expandedIdx].score), static_cast(params.mPtrTopKPacked[expandedIdx].idx)}; } } else { TypePacked const* remoteSmem = cg::cluster_group::map_shared_rank(smemPackedScoreIdx, expandedIdx / (NumWarps * params.mTopK)); scoreIdx = remoteSmem[expandedIdx % (NumWarps * params.mTopK)]; } expertIndexes[ii] = scoreIdx.idx; // check whether this expert is local to our GPU at all and ignore if not auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + scoreIdx.idx, 1) : 0; if (params.mPtrTopKWeights != nullptr && params.mPtrTopKIds == nullptr) { params.mPtrTopKWeights[expandedIdx] = OutputT{scoreIdx.score}; } }; int constexpr IterStride = 4; #pragma unroll for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) { // Whether it's safe to do multiple iterations without bound checks. bool const takeFastPath = (ii0 + IterStride) * NumThreadsPerCluster <= expandedIdxSize; if (takeFastPath) { #pragma unroll for (int32_t jj = 0; jj < IterStride; jj++) { int const ii = ii0 + jj; auto expandedIdx = static_cast(clusterThreadIdx) + ii * NumThreadsPerCluster; loopBody(ii, expandedIdx); } } else { bool doBreak = false; #pragma unroll for (int32_t jj = 0; jj < IterStride; jj++) { int const ii = ii0 + jj; auto expandedIdx = static_cast(clusterThreadIdx) + ii * NumThreadsPerCluster; if (expandedIdx >= expandedIdxSize) { doBreak = true; break; } loopBody(ii, expandedIdx); } if (doBreak) { break; } } } // Make local histogram (token counts per expert) available to all threads in the cluster. __cluster_barrier_arrive(); __cluster_barrier_wait(); // // Each thread now represents one expert // // Total number of tokens for this expert. int32_t count = 0; // Per-expert offset for this block. int32_t blockExpertOffset = 0; if (threadIdx.x < params.mNumExperts) { // Get the histogram bin from each rank for this expert. int32_t expertCounts[NumBlocksPerCluster]; #pragma unroll for (int rank = 0; rank < NumBlocksPerCluster; rank++) { int32_t const* remoteSmem = cg::cluster_group::map_shared_rank(smemExpertCount, rank); expertCounts[rank] = rank * NumWarps < params.mNumTokens ? remoteSmem[threadIdx.x] : 0; } // Compute an exclusive prefix sum of the block-local count. #pragma unroll for (int rank = 0; rank < NumBlocksPerCluster; rank++) { if (rank == clusterBlockRank) { blockExpertOffset = count; } count += expertCounts[rank]; } } // Arrive: we do not access distributed shared memory after this point. __cluster_barrier_arrive(); // Compute the runtime config for projections // Whether or not an expert is local is taken into account when smemExpertCount is computed // so we do not need to take it into account here. int32_t numCta; if constexpr (KernelParams::isPow2) { numCta = divUpLog2(count, params.mPaddingLog2); } else { numCta = divUpTileN(count, params.mTileTokensDim); } int32_t ctaOffset; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); if (threadIdx.x < params.mNumExperts) { // Strided loop to share this work between blocks. for (int32_t cta = clusterBlockRank; cta < numCta; cta += NumBlocksPerCluster) { const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; int32_t mnLimit1; int32_t mnLimit2; if constexpr (KernelParams::isPow2) { mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; } else { mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; } params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } // get the padded offset associated with this expert int32_t offset; if constexpr (KernelParams::isPow2) { offset = mulLog2(ctaOffset, params.mPaddingLog2); } else { offset = mulTileN(ctaOffset, params.mTileTokensDim); } // write expert offsets to shared smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; } // write out padded count if (clusterBlockRank == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync()) { int32_t permutedIdxSize; if constexpr (KernelParams::isPow2) { permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); } else { permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } // make expert offsets available to all threads __syncthreads(); // Wait: we cannot exit while other blocks may be accessing the current block's shared memory. // Note: I observed a perf benefit to doing this before the final loop so the compiler can // implement break with EXIT. __cluster_barrier_wait(); // trigger the secondary kernel when using PDL // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens // TODO: this is not sufficient to ensure visibility in the next kernel! #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } #endif // each thread has the same "expanded indexes" assigned to it as above // at this point, we know the final offsets of experts and the offsets within // experts, which allows writing the final index values #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii) { auto expandedIdx = static_cast(clusterThreadIdx) + ii * NumThreadsPerCluster; if (expandedIdx >= expandedIdxSize) { break; } auto expertIdx = expertIndexes[ii]; // check whether this expert is local to our GPU at all auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; auto tokenIdx = expandedIdx / params.mTopK; auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1}; if (params.mPtrExpandedIdxToPermutedIdx != nullptr) { params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; } if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) { params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; } } } //////////////////////////////////////////////////////////////////////////////////////////////////// // Two-step approach (if number of tokens exceed limits of what cluster / cooperative launch // variants can handle): in order to minimize the amount of data to exchange through global memory, // we will compute the local histograms in smem twice: the first kernel will get us the total number // of tokens per expert. The second kernel will use the smem and L2 atomics to get corresponding // element and tile offsets. // // Note: the histogram calculation could also be fused with routingMainKernel, but this might be // inefficient if we have one CTA per token doing a single global atomic. template __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesHistogramKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; // number of experts is bounded by number of threads __shared__ int32_t __attribute((aligned(128))) smemExpertCount[KernelParams::MaxNumExperts]; // For unrolling. uint32_t constexpr NumEltsPerThread = 8; // Pre-fill the counts with 0 if (threadIdx.x < params.mNumExperts) { smemExpertCount[threadIdx.x] = 0; } __syncthreads(); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid and trigger secondary kernel. if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); cudaTriggerProgrammaticLaunchCompletion(); } #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) uint32_t const expandedIdxSize = params.mNumTokens * params.mTopK; uint32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; uint32_t const gridBlockOffset = blockIdx.x * KernelParams::MaxNumExperts; uint32_t const gridStride = gridDim.x * KernelParams::MaxNumExperts; // Define a lambda to avoid code duplication in branches. auto loopBody = [&](int expandedIdx) { PackedScoreIdx scoreIdx; int idx; if (params.mPtrTopKIds != nullptr) { idx = params.mPtrTopKIds[expandedIdx]; } else { // If params.mPtrTopKIds != nullptr, we don't need to store the weights if (params.mPtrTopKWeights != nullptr) { scoreIdx = params.mPtrTopKPacked[expandedIdx]; idx = scoreIdx.idx; params.mPtrTopKWeights[expandedIdx] = static_cast(scoreIdx.score); } } // check whether this expert is local to our GPU at all and ignore if not auto localExpertIdx = idx - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; if (isLocalExpert) { atomicAdd(&smemExpertCount[idx], 1); } }; // Grid-stride loop. for (uint32_t expandedIdx0 = gridBlockOffset * NumEltsPerThread; expandedIdx0 < expandedIdxSize; expandedIdx0 += gridStride * NumEltsPerThread) { // Fast path if bound checks aren't necessary if (expandedIdx0 + NumEltsPerThread * KernelParams::MaxNumExperts <= expandedIdxSize) { #pragma unroll for (uint32_t ii = 0; ii < NumEltsPerThread; ii++) { uint32_t expandedIdx = expandedIdx0 + ii * KernelParams::MaxNumExperts + threadIdx.x; loopBody(expandedIdx); } } else { for (uint32_t expandedIdx = expandedIdx0 + threadIdx.x; expandedIdx < expandedIdxSize; expandedIdx += KernelParams::MaxNumExperts) { loopBody(expandedIdx); } } } __syncthreads(); // // Each thread now represents one expert // // Reduce histograms with atomics. if (threadIdx.x < params.mNumExperts) { int32_t const localExpertCount = smemExpertCount[threadIdx.x]; atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesOffsetsKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; // number of experts is bounded by number of threads __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[KernelParams::MaxNumExperts]; __shared__ int32_t __attribute((aligned(128))) smemExpertCount[KernelParams::MaxNumExperts]; __shared__ int32_t __attribute((aligned(128))) smemExpertTileOffset[KernelParams::MaxNumExperts]; // needed for the exclusive sum of token offsets using Scan = cub::BlockScan; __shared__ typename Scan::TempStorage tempStorage; static constexpr int MaxExpandedIdxPerThread = NumEltsPerOffsetTilePerThread; static constexpr int MaxExpandedIdxPerBlock = KernelParams::MaxNumExperts * MaxExpandedIdxPerThread; int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); uint32_t const expandedIdxSize = params.mNumTokens * params.mTopK; uint32_t const numTiles = (expandedIdxSize + MaxExpandedIdxPerBlock - 1) / (MaxExpandedIdxPerBlock); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid. if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); } #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // The expert offsets are common to all tiles of all blocks. // Load the histogram, scan it and write offsets to shared memory. // Note: the scan is redundant in all CTAs. Would it make sense to use an intermediate kernel for // the scan, with PDL? // // Each thread represents one expert. // // Get total count for this expert. int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0; // Compute the runtime config for projections // Whether or not an expert is local is taken into account when the histogram is computed // so we do not need to take it into account here. int32_t numCta; if constexpr (KernelParams::isPow2) { numCta = divUpLog2(count, params.mPaddingLog2); } else { numCta = divUpTileN(count, params.mTileTokensDim); } int32_t ctaOffset; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); if (threadIdx.x < params.mNumExperts) { // Get the padded offset associated with this expert int32_t offset; if constexpr (KernelParams::isPow2) { offset = mulLog2(ctaOffset, params.mPaddingLog2); } else { offset = mulTileN(ctaOffset, params.mTileTokensDim); } // Write expert offsets to shared smemExpertOffset[threadIdx.x] = offset; } // Sync to make expert offsets available to all threads. __syncthreads(); // The first block writes out padded count if (blockIdx.x == 0 && warpIdx == KernelParams::MaxNumExperts / WarpSize - 1 && cute::elect_one_sync()) { int32_t permutedIdxSize; if constexpr (KernelParams::isPow2) { permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); } else { permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); } params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } if (threadIdx.x < params.mNumExperts) { // Strided loop to share this work between blocks. for (int32_t cta = blockIdx.x; cta < numCta; cta += gridDim.x) { const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; int32_t mnLimit1; int32_t mnLimit2; if constexpr (KernelParams::isPow2) { mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; } else { mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; } params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } } // // Now loop on indices and compute offsets. // // Grid-stride loop on 1D "tiles" of input indices. for (uint32_t tileIdx = blockIdx.x; tileIdx < numTiles; tileIdx += gridDim.x) { if (tileIdx > 0) { // Sync for safe reuse of smem buffers. __syncthreads(); } // Pre-fill the counts with 0 if (threadIdx.x < params.mNumExperts) { smemExpertCount[threadIdx.x] = 0; } __syncthreads(); // each thread keeps has some number of "expanded indexes" assigned to it // for each of these, we keep the associated expert and offset within expert in registers int32_t expertIndexes[MaxExpandedIdxPerThread]; int32_t expertOffsets[MaxExpandedIdxPerThread]; auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; // Define a lambda to avoid code duplication in branches. auto loopBody = [&](int ii, int expandedIdx) { expertIndexes[ii] = params.mPtrTopKIds ? params.mPtrTopKIds[expandedIdx] : params.mPtrTopKPacked[expandedIdx].idx; // check whether this expert is local to our GPU at all and ignore if not auto localExpertIdx = expertIndexes[ii] - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIndexes[ii], 1) : 0; }; // For all tiles but the last, all indices are in bounds. if (tileIdx < numTiles - 1) { #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) { auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; loopBody(ii, expandedIdx); } } else { // For the last tile, we need to exit the loop when out of bounds. // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a // time, and branch between a fast path without bound checks and a slow path with bound checks int constexpr IterStride = 4; static_assert(MaxExpandedIdxPerThread % IterStride == 0); #pragma unroll for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) { // Whether it's safe to do multiple iterations without bound checks. bool const takeFastPath = tileIdx * MaxExpandedIdxPerBlock + (ii0 + IterStride) * KernelParams::MaxNumExperts <= expandedIdxSize; if (takeFastPath) { #pragma unroll for (int32_t jj = 0; jj < IterStride; jj++) { int const ii = ii0 + jj; auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; loopBody(ii, expandedIdx); } } else { bool doBreak = false; #pragma unroll for (int32_t jj = 0; jj < IterStride; jj++) { int const ii = ii0 + jj; auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; if (expandedIdx >= expandedIdxSize) { doBreak = true; break; } loopBody(ii, expandedIdx); } if (doBreak) { break; } } } } // Make local histogram (token counts per expert) available to all threads in the block. __syncthreads(); // // Each thread now represents one expert // if (threadIdx.x < params.mNumExperts) { // Add the local bin count to the common bin count and get a per-CTA offset. We use the second // half of the histogram buffer for this histogram, because the first half already holds the // reduced histogram from the previous kernel. int32_t const localExpertCount = smemExpertCount[threadIdx.x]; int32_t const tileExpertOffset = atomicAdd(¶ms.mPtrExpertCounts[params.mNumExperts + threadIdx.x], localExpertCount); // Make per-expert tile offsets available to all threads in the block. smemExpertTileOffset[threadIdx.x] = tileExpertOffset + smemExpertOffset[threadIdx.x]; } __syncthreads(); // Add tile offset and element offset and write to global memory. auto storeLoopBody = [&](int ii, int expandedIdx) { int32_t expertIdx = expertIndexes[ii]; // check whether this expert is local to our GPU at all auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; auto tokenIdx = expandedIdx / params.mTopK; auto permutedIdx = isLocalExpert ? (expertOffsets[ii] + smemExpertTileOffset[expertIdx]) : int32_t{-1}; if (params.mPtrExpandedIdxToPermutedIdx != nullptr) { params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; } if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) { params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; } }; // Bound checks only in last tile. if (tileIdx < numTiles - 1) { #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) { auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; storeLoopBody(ii, expandedIdx); } } else { #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) { auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; if (expandedIdx >= expandedIdxSize) { break; } storeLoopBody(ii, expandedIdx); } } } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Trigger secondary kernel. // Note: this does not guarantee the visibility of prior writes unless the consumer executes a // dependency sync. if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) } //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingInitExpertCounts(KernelParams params) { // initialize the mPtrExpertCounts int32_t expertCountsNum = 2 * params.mNumExperts; int32_t globalThreadIdx = blockIdx.x * KernelParams::MaxNumExperts + threadIdx.x; int32_t globalThreadStride = gridDim.x * KernelParams::MaxNumExperts; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid. if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); } #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid. if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) } } // namespace routing } // namespace moe::dev