/* * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include "tensorrt_llm/kernels/archCondition.h" namespace moe::dev::routing { namespace topk { //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cg = cooperative_groups; //////////////////////////////////////////////////////////////////////////////////////////////////// static constexpr int WarpSize = 32; static constexpr int MaxNumExpertsUnit = 128; static constexpr int MaxNumTopK = 10; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct TopKRedType { using TypeExpW = TypeExpW_; static_assert( std::is_same_v || std::is_same_v || std::is_same_v, "Top K reduction only implemented for float, float16 and bfloat16"); using TypeCmp = std::conditional_t; using IdxT = std::conditional_t; static constexpr int moveBits = (sizeof(TypeExpW) == 4) ? 32 : 16; static constexpr int maxIdx = 65535; TypeCmp compVal; static __host__ __device__ inline TypeCmp makeCmpVal(TypeExpW val, int32_t idx = 0) { auto valueBits = cub::Traits::TwiddleIn(reinterpret_cast::UnsignedBits&>(val)); TypeCmp compactTmp = valueBits; compactTmp = (compactTmp << moveBits) | (0xFFFF & (maxIdx - idx)); // Use 65535 minus idx to give higher priority to elements with smaller indices. return compactTmp; } static __host__ __device__ inline void unpack(TypeExpW& value, int32_t& index, TypeCmp cmp) { // Since idx is always smaller than 65536 and positive, we can directly use it as the lower 16 // bits index = maxIdx - static_cast(cmp & 0xFFFF); auto compactTmp = cmp >> moveBits; auto valueBits = cub::Traits::TwiddleOut( reinterpret_cast::UnsignedBits&>(compactTmp)); value = reinterpret_cast(valueBits); } __host__ __device__ TopKRedType() = default; __host__ __device__ TopKRedType(TypeExpW val, int32_t idx) : compVal(makeCmpVal(val, idx)) { } __host__ __device__ operator TypeCmp() const noexcept { return compVal; } __device__ inline TypeCmp reduce(cg::thread_block_tile const& warp) { static constexpr bool hasFastRedux = tensorrt_llm::kernels::arch::is_major_v<10>; if constexpr (!hasFastRedux || sizeof(TypeCmp) == 8) { return cg::reduce(warp, compVal, cg::greater{}); } else { TypeCmp result; asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compVal)); return result; } } }; //////////////////////////////////////////////////////////////////////////////////////////////////// #define TOPK_SWAP(I, J) \ { \ auto pairMin = min(topK[I].compVal, topK[J].compVal); \ auto pairMax = max(topK[I].compVal, topK[J].compVal); \ topK[I].compVal = pairMax; \ topK[J].compVal = pairMin; \ } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Sort; template struct Sort<1, RedType> { static __device__ void run(RedType* topK) {} }; template struct Sort<2, RedType> { static __device__ void run(RedType* topK) { TOPK_SWAP(0, 1); } }; template struct Sort<3, RedType> { static __device__ void run(RedType* topK) { TOPK_SWAP(0, 1); TOPK_SWAP(1, 2); TOPK_SWAP(0, 1); } }; template struct Sort<4, RedType> { static __device__ void run(RedType* topK) { TOPK_SWAP(0, 2); TOPK_SWAP(1, 3); TOPK_SWAP(0, 1); TOPK_SWAP(2, 3); TOPK_SWAP(1, 2); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue, int actualK = K) { static_assert(K > 0, "Top K must have K > 0"); static_assert(K < WarpSize, "Top K must have K < WarpSize"); using RedType = TopKRedType; RedType topK{value, idx}; typename RedType::TypeCmp packedMax{}; #pragma unroll for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct { topK = kk > 0 && packedMax == topK.compVal ? RedType{minValue, idx} : topK; // get the next largest value packedMax = topK.reduce(warp); RedType::unpack(out[kk], outIdx[kk], packedMax); } }; template __forceinline__ __device__ void reduceTopKFunc(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K) { static_assert(K > 0, "Top K must have K > 0"); static_assert(K < WarpSize, "Top K must have K < WarpSize"); static_assert(N > 0, "Top K must have N > 0"); static_assert(N < 5, "Only support candidates number less than or equal to 128"); using RedType = TopKRedType; RedType topK[N]; #pragma unroll for (int nn = 0; nn < N; ++nn) { topK[nn] = RedType{value[nn], idx[nn]}; } Sort::run(topK); typename RedType::TypeCmp packedMax{}; #pragma unroll for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct { bool update = kk > 0 && packedMax == topK[0].compVal; #pragma unroll for (int nn = 0; nn < N; ++nn) { topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn]; } // get the next largest value packedMax = topK[0].reduce(warp); RedType::unpack(out[kk], outIdx[kk], packedMax); } }; template __forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K) { static_assert(K > 0, "Top K must have K > 0"); static_assert(K < WarpSize, "Top K must have K < WarpSize"); static_assert(N > 0, "Top K must have N > 0"); static_assert(N <= 16, "Only support candidates number less than or equal to 16*32=512"); static_assert( N <= 4 || N % 4 == 0, "Only support candidates number is a multiple of 4*32=128 or less than or equal to 4"); using RedType = TopKRedType; if constexpr (N <= 4) { reduceTopKFunc(warp, out, outIdx, value, idx, minValue, actualK); } else { constexpr int numLoops = (N - 1) / 4 + 1; constexpr int numResults = (numLoops * K - 1) / WarpSize + 1; Type topKBufferValue[numResults]; int32_t topKBufferIdx[numResults]; int32_t laneIdx = threadIdx.x % WarpSize; for (int ii = 0; ii < numResults; ++ii) { topKBufferValue[ii] = minValue; topKBufferIdx[ii] = ii * WarpSize - 1; //@todo: check if this is correct } for (int loop = 0; loop < numLoops; ++loop) { int start = loop * 4; Type topKValue[K]; int32_t topKIdx[K]; Type inValue[4]; int32_t inIdx[4]; for (int i = 0; i < 4; ++i) { inValue[i] = value[start + i]; inIdx[i] = idx[start + i]; } reduceTopKFunc(warp, topKValue, topKIdx, inValue, inIdx, minValue, actualK); int inOffset = laneIdx % K; if (laneIdx >= loop * K && laneIdx < (loop + 1) * K) { topKBufferValue[0] = topKValue[inOffset]; topKBufferIdx[0] = topKIdx[inOffset]; } if (loop == numLoops - 1 && (laneIdx < (numLoops * K - WarpSize))) { topKBufferValue[1] = topKValue[inOffset]; topKBufferIdx[1] = topKIdx[inOffset]; } } reduceTopKFunc(warp, out, outIdx, topKBufferValue, topKBufferIdx, minValue, actualK); } }; #undef TOPK_SWAP } // namespace topk } // namespace moe::dev::routing