/* * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #ifndef TRTLLM_MOETOPKFUNCS_CUH_H #define TRTLLM_MOETOPKFUNCS_CUH_H #include "tensorrt_llm/common/config.h" #include #include #include #include "tensorrt_llm/kernels/archCondition.h" TRTLLM_NAMESPACE_BEGIN namespace kernels { namespace reduce_topk { namespace cg = cooperative_groups; static constexpr int kWARP_SIZE = 32; static constexpr bool kTLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10>; template struct TopKRedType { using T = T_; static_assert(std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v, "Top K reduction only implemented for int, float, float16 and bfloat16"); using TypeCmp = std::conditional_t; using IdxT = std::conditional_t; static constexpr int kMoveBits = (sizeof(T) == 4) ? 32 : 16; static constexpr int kMaxIdx = 65535; TypeCmp compValIdx; static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0) { auto valueBits = cub::Traits::TwiddleIn(reinterpret_cast::UnsignedBits&>(val)); TypeCmp compactTmp = valueBits; compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx)); // Use 65535 minus idx to give higher priority to elements with smaller indices. return compactTmp; } static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp) { // Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the lower 16 bits index = kMaxIdx - static_cast((cmp & 0xFFFF)); auto compactTmp = cmp >> kMoveBits; auto valueBits = cub::Traits::TwiddleOut(reinterpret_cast::UnsignedBits&>(compactTmp)); value = reinterpret_cast(valueBits); } __host__ __device__ TopKRedType() = default; __host__ __device__ TopKRedType(T val, int32_t idx) : compValIdx(makeCmpVal(val, idx)) { } __host__ __device__ operator TypeCmp() const noexcept { return compValIdx; } __device__ inline TypeCmp reduce(cg::thread_block_tile const& warp) { if constexpr (!kTLLM_GEN_HAS_FAST_REDUX || sizeof(TypeCmp) == 8) { return cg::reduce(warp, compValIdx, cg::greater{}); } else { TypeCmp result; asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx)); return result; } } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct TopKIdx { // by default, empty }; template struct TopKIdx { static constexpr int K = K_; int32_t val[K]; }; //////////////////////////////////////////////////////////////////////////////////////////////////// #define TOPK_SWAP(I, J) \ { \ auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \ auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \ topK[I].compValIdx = pairMax; \ topK[J].compValIdx = pairMin; \ } template struct Sort; template struct Sort<1, RedType> { static __device__ void run(RedType* topK) {} }; template struct Sort<2, RedType> { static __device__ void run(RedType* topK) { TOPK_SWAP(0, 1); } }; template struct Sort<3, RedType> { static __device__ void run(RedType* topK) { TOPK_SWAP(0, 1); TOPK_SWAP(1, 2); TOPK_SWAP(0, 1); } }; template struct Sort<4, RedType> { static __device__ void run(RedType* topK) { TOPK_SWAP(0, 2); TOPK_SWAP(1, 3); TOPK_SWAP(0, 1); TOPK_SWAP(2, 3); TOPK_SWAP(1, 2); } }; template __forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue, int actualK = K) { static_assert(K > 0, "Top K must have K > 0"); static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); using RedType = TopKRedType; RedType topK{value, idx}; typename RedType::TypeCmp packedMax{}; #pragma unroll for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct { topK = kk > 0 && packedMax == topK.compValIdx ? RedType{minValue, idx} : topK; // get the next largest value packedMax = topK.reduce(warp); RedType::unpack(out[kk], outIdx[kk], packedMax); } }; template __device__ void reduceTopKFunc(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type minValue, int actualK = K) { static_assert(K > 0, "Top K must have K > 0"); static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); static_assert(N > 0, "Top K must have N > 0"); static_assert(N < 5, "Only support candidates number less than or equal to 128"); using RedType = TopKRedType; RedType topK[N]; #pragma unroll for (int nn = 0; nn < N; ++nn) { topK[nn] = RedType{value[nn], idx[nn]}; } if constexpr (!IsSorted) { Sort::run(topK); } typename RedType::TypeCmp packedMax{}; #pragma unroll for (int kk = 0; kk < actualK; ++kk) { bool update = kk > 0 && packedMax == topK[0].compValIdx; #pragma unroll for (int nn = 0; nn < N; ++nn) { topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn]; } // get the next largest value packedMax = topK[0].reduce(warp); RedType::unpack(out[kk], outIdx[kk], packedMax); } }; template __forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue, int actualK = K) { static_assert(K > 0, "Top K must have K > 0"); static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); static_assert(N > 0, "Top K must have N > 0"); static_assert(N <= 16, "Only support candidates number less than or equal to 16*32=512"); static_assert( N <= 4 || N % 4 == 0, "Only support candidates number is a multiple of 4*32=128 or less than or equal to 4"); using RedType = TopKRedType; if constexpr (N <= 4) { reduceTopKFunc(warp, out, outIdx, value, idx, minValue, actualK); } else { constexpr int numLoops = N / 4; constexpr int numResults = (numLoops * K - 1) / kWARP_SIZE + 1; Type topKBufferValue[numResults]; int32_t topKBufferIdx[numResults]; int32_t laneIdx = threadIdx.x % kWARP_SIZE; for (int ii = 0; ii < numResults; ++ii) { topKBufferValue[ii] = minValue; topKBufferIdx[ii] = ii * kWARP_SIZE - 1; //@todo: check if this is correct } for (int loop = 0; loop < numLoops; ++loop) { int start = loop * 4; Type topKValue[K]; int32_t topKIdx[K]; Type inValue[4]; int32_t inIdx[4]; for (int i = 0; i < 4; ++i) { inValue[i] = value[start + i]; inIdx[i] = idx[start + i]; } reduceTopKFunc(warp, topKValue, topKIdx, inValue, inIdx, minValue, actualK); int inOffset = laneIdx % K; if (laneIdx >= loop * K && laneIdx < (loop + 1) * K) { topKBufferValue[0] = topKValue[inOffset]; topKBufferIdx[0] = topKIdx[inOffset]; } if (loop == numLoops - 1 && (laneIdx < (numLoops * K - kWARP_SIZE))) { topKBufferValue[1] = topKValue[inOffset]; topKBufferIdx[1] = topKIdx[inOffset]; } } reduceTopKFunc(warp, out, outIdx, topKBufferValue, topKBufferIdx, minValue, actualK); } }; #undef TOPK_SWAP } // namespace reduce_topk } // namespace kernels TRTLLM_NAMESPACE_END #endif // TRTLLM_MOETOPKFUNCS_CUH_H