/* * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "moeTopKFuncs.cuh" #include "tensorrt_llm/common/cudaTypeUtils.cuh" #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/kernels/archCondition.h" #include "tensorrt_llm/kernels/customMoeRoutingKernels.h" #include // For INT_MAX #include #include #include #include // For numeric_limits #include namespace cg = cooperative_groups; using namespace tensorrt_llm::common; namespace tensorrt_llm::kernels { static constexpr int BLOCK_SIZE = 1024; static constexpr int WARP_SIZE = 32; static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ DataType calcSoftmax( cg::thread_block_tile const& warp, DataType score, int32_t laneIdx, int32_t NumTopExperts) { float maxScore = -INFINITY; if (laneIdx < NumTopExperts) { maxScore = float(score) >= maxScore ? float(score) : maxScore; } maxScore = cg::reduce(warp, maxScore, cg::greater()); float sumScore = 0.f; float newScore; // Get the summation of scores for each token if (laneIdx < NumTopExperts) { newScore = static_cast(score) - static_cast(maxScore); newScore = static_cast(exp(newScore)); sumScore += newScore; } sumScore = cg::reduce(warp, sumScore, cg::plus()); if (laneIdx < NumTopExperts) { score = static_cast(newScore / sumScore); } return score; } template __device__ void calcSoftmax(cg::thread_block_tile const& warp, DataType (&scores)[VecSize]) { // Compute in float to support half/bfloat16 inputs safely. float maxScore = -INFINITY; float sumScore = 0.f; // Get the max score for each token #pragma unroll for (int i = 0; i < VecSize; ++i) { float si = static_cast(scores[i]); maxScore = si >= maxScore ? si : maxScore; } maxScore = cg::reduce(warp, maxScore, cg::greater()); // Get the summation of scores for each token #pragma unroll for (int i = 0; i < VecSize; ++i) { float si = static_cast(scores[i]); float e = expf(si - maxScore); scores[i] = static_cast(e); sumScore += e; } sumScore = cg::reduce(warp, sumScore, cg::plus()); // Normalize the scores #pragma unroll for (int i = 0; i < VecSize; ++i) { float si = static_cast(scores[i]) / sumScore; scores[i] = static_cast(si); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void customMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int32_t const numTokens, int32_t const numExperts, int32_t const topK) { using BaseType = std::conditional_t; uint32_t const blockRank = blockIdx.x; uint32_t const tIdx = BLOCK_SIZE * blockRank + threadIdx.x; uint32_t const warpIdx = tIdx / WARP_SIZE; uint32_t const laneIdx = tIdx % WARP_SIZE; uint32_t const warpNum = gridDim.x * WARPS_PER_BLOCK; auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); BaseType minScore = BaseType{-INFINITY}; for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum) { auto scoreOffset = tokenId * numExperts; auto outputOffset = tokenId * topK; BaseType inputScore[MaxNumExperts / WARP_SIZE]; IdxT inputIndex[MaxNumExperts / WARP_SIZE]; BaseType warpTopKScore[MaxNumTopExperts]; IdxT warpTopKExpertIdx[MaxNumTopExperts]; // Load scores and indices for this warp for (uint32_t i = 0; i < MaxNumExperts / WARP_SIZE; ++i) { auto expertIdx = i * WARP_SIZE + laneIdx; inputScore[i] = expertIdx < numExperts ? static_cast(routerLogits[scoreOffset + expertIdx]) : minScore; inputIndex[i] = expertIdx; } if constexpr (DoSoftmaxBeforeTopK) { calcSoftmax(warp, inputScore); } // Reduce topK scores and indices for this warp reduce_topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore); // Normalize the scores if constexpr (DoSoftmaxBeforeTopK) { if (laneIdx < topK) { topkValues[outputOffset + laneIdx] = static_cast(warpTopKScore[laneIdx]); topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx]; } } else { auto softmaxScore = calcSoftmax(warp, laneIdx < topK ? static_cast(warpTopKScore[laneIdx]) : static_cast(minScore), laneIdx, topK); if (laneIdx < topK) { topkValues[outputOffset + laneIdx] = static_cast(softmaxScore); topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx]; } } } // end for tokenId } int nextPowerOfTwo(int num) { if (num <= 0) { return 1; // Handle invalid input } int power = 1; while (power < num) { // Check for overflow before shifting if (power > INT_MAX / 2) { return power; } power <<= 1; } return power; } #define CASE(MAX_NUM_EXPERTS) \ case MAX_NUM_EXPERTS: \ switch (maxNumTopExperts) \ { \ case 1: \ kernelInstance = &customMoeRoutingKernel; \ break; \ case 2: \ kernelInstance = &customMoeRoutingKernel; \ break; \ case 4: \ kernelInstance = &customMoeRoutingKernel; \ break; \ case 8: \ kernelInstance = &customMoeRoutingKernel; \ break; \ default: kernelInstance = nullptr; break; \ } \ break; template void invokeCustomMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens, int64_t const numExperts, int64_t const topK, cudaStream_t const stream) { const uint32_t maxNumBlocks = 1024; const uint32_t numBlocks = std::min(static_cast((numTokens - 1) / WARPS_PER_BLOCK + 1), maxNumBlocks); uint32_t maxNumExperts = nextPowerOfTwo(numExperts) < 32 ? 32 : nextPowerOfTwo(numExperts); uint32_t maxNumTopExperts = nextPowerOfTwo(topK); auto* kernelInstance = &customMoeRoutingKernel; switch (maxNumExperts) { CASE(32) CASE(64) CASE(96) CASE(128) default: kernelInstance = nullptr; break; } if (kernelInstance == nullptr) { TLLM_CHECK_WITH_INFO(kernelInstance != nullptr, "Can not find corresponding kernel instance."); } dim3 renormMoeRoutingGridDim(numBlocks); dim3 renormMoeRoutingBlockDim(BLOCK_SIZE); cudaLaunchConfig_t config; config.gridDim = renormMoeRoutingGridDim; config.blockDim = renormMoeRoutingBlockDim; config.dynamicSmemBytes = 0; config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, kernelInstance, routerLogits, topkValues, topkIndices, static_cast(numTokens), static_cast(numExperts), static_cast(topK)); sync_check_cuda_error(stream); } #define INSTANTIATE_RENORM_MOE_ROUTING(InputT, OutputT, IdxT, DoSoftmaxBeforeTopK) \ template void invokeCustomMoeRouting(InputT * routerLogits, \ OutputT * topkValues, IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, \ int64_t const topK, cudaStream_t const stream); INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, false); INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, false); INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, true); INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, true); #ifdef ENABLE_BF16 INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, false); INSTANTIATE_RENORM_MOE_ROUTING(float, __nv_bfloat16, int32_t, false); INSTANTIATE_RENORM_MOE_ROUTING(half, __nv_bfloat16, int32_t, false); INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, __nv_bfloat16, int32_t, false); INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, true); INSTANTIATE_RENORM_MOE_ROUTING(float, __nv_bfloat16, int32_t, true); INSTANTIATE_RENORM_MOE_ROUTING(half, __nv_bfloat16, int32_t, true); INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, __nv_bfloat16, int32_t, true); #endif } // namespace tensorrt_llm::kernels