/* * Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "moeTopKFuncs.cuh" #include "tensorrt_llm/common/config.h" #include "tensorrt_llm/common/cudaTypeUtils.cuh" #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/kernels/noAuxTcKernels.h" #include #include namespace cg = cooperative_groups; using namespace tensorrt_llm::common; TRTLLM_NAMESPACE_BEGIN namespace kernels { namespace { template static inline __device__ uint32_t extractBinIdx(float x) { if constexpr (step == 0) { __half hx = __float2half(x); uint16_t bits = __half_as_ushort(hx); bits = (bits & 0x8000) ? bits : ~bits & 0x7fff; return bits >> 5; } else { uint32_t bits = __float_as_uint(x); bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff; if constexpr (step == 1) { return bits >> 21; } else if constexpr (step == 2) { return (bits >> 10) & 0x7ff; } else if constexpr (step == 3) { return bits & 0x3ff; } } } template static inline __device__ bool isPartialMatch(float x, uint32_t pattern) { if constexpr (shift == 0) { return true; } uint32_t bits = __float_as_uint(x); bits = (bits & 0x80000000) ? bits : ~bits & 0x7fffffff; return (bits ^ pattern) >> shift == 0; } /** * Map a Func over the input data, using vectorized load instructions if * possible. * * @tparam T element type * @tparam IdxT indexing type * @tparam Func void (T x, IdxT idx) * * @param thread_rank rank of the calling thread among all participating threads * @param num_threads number of the threads that participate in processing * @param in the input data * @param len the number of elements to read * @param f the lambda taking two arguments (T x, IdxT idx) */ template __device__ void vectorized_process(size_t thread_rank, size_t num_threads, T const* in, idxT len, Func f) { constexpr int WARP_SIZE = 32; using WideT = float4; if constexpr (sizeof(T) >= sizeof(WideT)) { for (idxT i = thread_rank; i < len; i += num_threads) { f(in[i], i); } } else { static_assert(sizeof(WideT) % sizeof(T) == 0); constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); // TODO: it's UB union { WideT scalar; T array[items_per_scalar]; } wide; int skip_cnt = (reinterpret_cast(in) % sizeof(WideT)) ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / sizeof(T)) : 0; if (skip_cnt > len) { skip_cnt = len; } WideT const* in_cast = reinterpret_cast(in + skip_cnt); const idxT len_cast = (len - skip_cnt) / items_per_scalar; for (idxT i = thread_rank; i < len_cast; i += num_threads) { wide.scalar = in_cast[i]; const idxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll for (int j = 0; j < items_per_scalar; ++j) { f(wide.array[j], real_i + j); } } static_assert(WARP_SIZE >= items_per_scalar); // and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt // no need to use loop if (thread_rank < skip_cnt) { f(in[thread_rank], thread_rank); } // because len_cast = (len - skip_cnt) / items_per_scalar, // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; // and so // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= // WARP_SIZE no need to use loop const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank; if (remain_i < len) { f(in[remain_i], remain_i); } } } template __device__ bool processHistogramStep(int const* indices, float const* logits, int rowEnd, uint32_t& logitPattern, int& thresholdBinIdx, SmemOutputType& smemOutput, int* smemThresholdBinIdx, int* smemFinalDstIdx, int* smemFinalBinSize, int* smemFoundTopKValues, SmemFinalType& smemFinal, int stride1, int rowStart, int topK) { // Clear the histogram. #pragma unroll for (int idx = threadIdx.x; idx < kNumBins; idx += kNumThreadsPerBlock) { smemFinal.histo.data[idx] = 0; } // Make sure the histogram is ready. __syncthreads(); // Update pattern constexpr auto patternShift = step < 2 ? 0 : step == 2 ? 21 : 10; if constexpr (step == 2) { logitPattern = static_cast(thresholdBinIdx & 0x7ff) << patternShift; } else if constexpr (step == 3) { logitPattern |= static_cast(thresholdBinIdx & 0x7ff) << patternShift; } auto distributeToBins = [&](float logit, int /* idx */ = 0) { if (isPartialMatch(logit, logitPattern)) { uint32_t binIdx = extractBinIdx(logit); atomicAdd(&smemFinal.histo.data[binIdx], 1); } }; // Distribute the elements to the histogram bins. if (stride1 == 1) { vectorized_process(threadIdx.x, kNumThreadsPerBlock, logits + rowStart, rowEnd - rowStart, distributeToBins); } else { for (int idx = rowStart + threadIdx.x; idx < rowEnd; idx += kNumThreadsPerBlock) { float logit = logits[idx * stride1]; distributeToBins(logit, idx); } } // Make sure the histogram is ready. __syncthreads(); // Reads the value of the starting position in the smemOutput array int lastValue = smemFoundTopKValues[0]; for (int round = 0; round < kNumBins / kNumThreadsPerBlock; round++) { // Read the values from SMEM. int idx = threadIdx.x + kNumThreadsPerBlock * round; int binCount{0}; binCount = smemFinal.histo.data[idx]; // Make sure each thread has read its value. __syncthreads(); // Compute the prefix sum. int prefixSum{0}, totalSum{0}; using Scan = cub::BlockScan; Scan(smemFinal.histo.scan).ExclusiveSum(binCount, prefixSum, totalSum); // Update the histogram with the prefix sums. prefixSum += lastValue; totalSum += lastValue; smemFinal.histo.data[idx] = prefixSum; // Make sure the data is in shared memory. __syncthreads(); // Find the last valid bin. bool foundThreshold = false; if (prefixSum < topK) { int nextPrefixSum = threadIdx.x == kNumThreadsPerBlock - 1 ? totalSum : smemFinal.histo.data[idx + 1]; if (nextPrefixSum >= topK) { smemThresholdBinIdx[0] = idx; smemFinalBinSize[0] = nextPrefixSum - prefixSum; foundThreshold = true; } } // Early exit: if any thread found the threshold, we can skip remaining // rounds if (__syncthreads_or(foundThreshold)) { break; } lastValue = totalSum; } // Make sure the data is in shared memory. __syncthreads(); // The threshold bin. thresholdBinIdx = smemThresholdBinIdx[0]; auto processBins = [&](float logit, int idx) { if (isPartialMatch(logit, logitPattern)) { uint32_t binIdx = extractBinIdx(logit); if (binIdx < thresholdBinIdx) { // The element is part of the top-k selection int dstIdx = atomicAdd(&smemFoundTopKValues[0], 1); if constexpr (mergeBlocks) { smemOutput[dstIdx] = indices[idx]; } else if constexpr (multipleBlocksPerRow) { smemOutput[dstIdx] = idx + rowStart; reinterpret_cast(smemOutput + topK)[dstIdx] = logit; } else { smemOutput[dstIdx] = idx; } } if constexpr (step < 3) { // Only fill the final items for sorting if the threshold bin fits if (binIdx == thresholdBinIdx && smemFinalBinSize[0] <= kNumFinalItems) { int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); smemFinal.items.logits[dstIdx] = logit; if constexpr (mergeBlocks) { smemFinal.items.indices[dstIdx] = indices[idx]; } else if constexpr (multipleBlocksPerRow) { smemFinal.items.indices[dstIdx] = idx + rowStart; } else { smemFinal.items.indices[dstIdx] = idx; } } } else { if (binIdx == thresholdBinIdx) { // The elements in the threshold bin share the same 32 bits at step 3 int dstIdx = atomicAdd(&smemFinal.histo.data[binIdx], 1); if (dstIdx < topK) { if constexpr (mergeBlocks) { smemOutput[dstIdx] = indices[idx]; } else if constexpr (multipleBlocksPerRow) { smemOutput[dstIdx] = idx + rowStart; reinterpret_cast(smemOutput + topK)[dstIdx] = logit; } else { smemOutput[dstIdx] = idx; } } } } } }; if (stride1 == 1) { vectorized_process(threadIdx.x, kNumThreadsPerBlock, logits + rowStart, rowEnd - rowStart, processBins); } else { for (int idx = rowStart + threadIdx.x; idx < rowEnd; idx += kNumThreadsPerBlock) { float logit = logits[idx * stride1]; processBins(logit, idx); } } // Make sure the elements are in shared memory. __syncthreads(); // Check if we should continue to next step return smemFinalBinSize[0] > kNumFinalItems; } // Follows half - 11 - 11 - 10 bit iterations template static __device__ void topKPerRowJob(int const* indices, float const* logits, int rowStart, int rowEnd, int* outIndices, float* outLogits, int stride1, int topK) { // The number of slots for the final pass. static constexpr int kNumFinalItems = 2048; // The number of elements per thread for the final sort. static constexpr int kNumFinalItemsPerThread = kNumFinalItems / kNumThreadsPerBlock; // The class to sort the elements during the final pass. using FinalSort = cub::BlockRadixSort; using FinalSortTempStorage = std::conditional_t; // The class to compute the inclusive prefix-sum over the histogram. using Scan = cub::BlockScan; // The structure to store the final items (for the final pass). struct FinalItems { // Shared memory to store the indices for the final pass. int indices[kNumFinalItems]; // Shared memory to store the logits for the final pass. float logits[kNumFinalItems]; }; struct Histogram { typename Scan::TempStorage scan; int data[kNumBins]; }; // Shared memory to compute the block sort. __shared__ union { FinalItems items; FinalSortTempStorage finalSort; Histogram histo; } smemFinal; // Shared memory to store the selected indices. // If we are processing using multiple blocks, we need to store the logits and // indices. extern __shared__ int32_t smemOutput[]; // Shared memory to store the threshold bin. __shared__ int smemThresholdBinIdx[1]; // Shared memory counter to register the candidates for the final phase. __shared__ int smemFinalDstIdx[1]; // Shared memory to determine if the threshold bin fits in the final items. __shared__ int smemFinalBinSize[1]; // Shared memory to keep track of the top-k values found so far by the // previous iterations __shared__ int smemFoundTopKValues[1]; // The length of the row. int rowLen = rowEnd - rowStart; // Shortcut if the length of the row is smaller than Top-K. Indices are not // sorted by their corresponding logit. if (rowLen <= topK) { for (int rowIt = threadIdx.x; rowIt < rowLen; rowIt += kNumThreadsPerBlock) { if constexpr (multipleBlocksPerRow) { outIndices[rowIt] = rowIt + rowStart; outLogits[rowIt] = logits[rowIt + rowStart]; } else { outIndices[rowIt] = rowIt; } } for (int rowIt = rowLen + threadIdx.x; rowIt < topK; rowIt += kNumThreadsPerBlock) { outIndices[rowIt] = -1; if constexpr (multipleBlocksPerRow) { outLogits[rowIt] = -FLT_MAX; } } return; } // Initialize values if (threadIdx.x == 0) { smemFinalDstIdx[0] = 0; smemFoundTopKValues[0] = 0; } __syncthreads(); int thresholdBinIdx = -1; uint32_t logitPattern = 0; // Step 0: Process first 11 bits of half representation bool continueToNextStep = processHistogramStep<0, kNumThreadsPerBlock, kNumBins, kNumFinalItems, multipleBlocksPerRow, mergeBlocks>( indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, smemFoundTopKValues, smemFinal, stride1, rowStart, topK); if (continueToNextStep) { // Step 1: Process next 11 bits continueToNextStep = processHistogramStep<1, kNumThreadsPerBlock, kNumBins, kNumFinalItems, multipleBlocksPerRow, mergeBlocks>( indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, smemFoundTopKValues, smemFinal, stride1, rowStart, topK); } if (continueToNextStep) { // Step 2: Process next 11 bits continueToNextStep = processHistogramStep<2, kNumThreadsPerBlock, kNumBins, kNumFinalItems, multipleBlocksPerRow, mergeBlocks>( indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, smemFoundTopKValues, smemFinal, stride1, rowStart, topK); } if (continueToNextStep) { // Step 3: Process last 10 bits processHistogramStep<3, kNumThreadsPerBlock, kNumBins, kNumFinalItems, multipleBlocksPerRow, mergeBlocks>( indices, logits, rowEnd, logitPattern, thresholdBinIdx, smemOutput, smemThresholdBinIdx, smemFinalDstIdx, smemFinalBinSize, smemFoundTopKValues, smemFinal, stride1, rowStart, topK); } if (!continueToNextStep) { // The histogram did not proceed to the final 10 bits, therefore we need to // sort the final items The logits of the elements to be sorted in the final // pass. if constexpr (useRadixSort) { // Sorting with radix sort float finalLogits[kNumFinalItemsPerThread]; // The indices of the elements to be sorted in the final pass. int finalIndices[kNumFinalItemsPerThread]; #pragma unroll for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { finalLogits[ii] = -FLT_MAX; } // Read the elements from SMEM. #pragma unroll for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; if (srcIdx < smemFinalDstIdx[0]) { finalLogits[ii] = smemFinal.items.logits[srcIdx]; finalIndices[ii] = smemFinal.items.indices[srcIdx]; } } // Make sure the shared memory has been read. __syncthreads(); // Sort the elements. FinalSort(smemFinal.finalSort).SortDescendingBlockedToStriped(finalLogits, finalIndices); // Copy the data back to the shared memory storage. int baseIdx = smemFoundTopKValues[0]; #pragma unroll for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) { int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; int dstIdx = baseIdx + srcIdx; if (dstIdx < topK) { smemOutput[dstIdx] = finalIndices[ii]; if constexpr (multipleBlocksPerRow) { reinterpret_cast(smemOutput + topK)[dstIdx] = finalLogits[ii]; } } } } else { // Sorting with insertion sort auto baseIdx = smemFoundTopKValues[0]; for (int i = threadIdx.x; i < smemFinalDstIdx[0]; i += kNumThreadsPerBlock) { int outIndex = 0; auto logit = smemFinal.items.logits[i]; for (int j = 0; j < smemFinalDstIdx[0]; j++) { auto otherLogit = smemFinal.items.logits[j]; if (logit < otherLogit || (logit == otherLogit && i < j)) { outIndex++; } } // Store if outIndex is in bounds if (outIndex + baseIdx < topK) { smemOutput[outIndex + baseIdx] = smemFinal.items.indices[i]; if constexpr (multipleBlocksPerRow) { reinterpret_cast(smemOutput + topK)[outIndex + baseIdx] = smemFinal.items.logits[i]; } } } } __syncthreads(); } // Store to global memory. for (int i = threadIdx.x; i < topK; i += kNumThreadsPerBlock) { if constexpr (multipleBlocksPerRow) { outIndices[i] = smemOutput[i]; outLogits[i] = reinterpret_cast(smemOutput + topK)[i]; } else { if (stride1 == 1) { // stride1 == 1 will use vectorized_process, which indexes already skip the rowStart. outIndices[i] = smemOutput[i]; } else { outIndices[i] = smemOutput[i] - rowStart; } } } } } // namespace template static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(float const* logits, int const* rowStarts, int const* rowEnds, int* outIndices, int stride0, int stride1, int const topK, int const offsetIndex) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif // The number of bins in the histogram. static constexpr int kNumBins = 2048; // The row computed by this block. int rowIdx = blockIdx.x + offsetIndex; // The range of logits within the row. int rowStart = rowStarts[rowIdx]; int rowEnd = rowEnds[rowIdx]; // Local pointers to this block outIndices += rowIdx * topK; logits += rowIdx * stride0; topKPerRowJob( nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif } template static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(float const* logits, int const* seqLens, int* outIndices, int stride0, int stride1, int const topK, int next_n, float* outLogits = nullptr, int const numBlocksToMerge = 0, int const* indices = nullptr) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif // The number of bins in the histogram. static constexpr int kNumBins = 2048; // The row computed by this block. int rowIdx = blockIdx.x; // The range of logits within the row. int rowStart = 0; int seq_len = seqLens[rowIdx / next_n]; int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1; // Local pointers to this block if constexpr (!multipleBlocksPerRow && !mergeBlocks) { outIndices += rowIdx * topK; } else if constexpr (multipleBlocksPerRow) { auto const blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192 rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192 rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize; outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK; outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK; } else if constexpr (mergeBlocks) { rowEnd = numBlocksToMerge * topK; indices += rowIdx * numBlocksToMerge * topK; outIndices += rowIdx * topK; } logits += rowIdx * stride0; topKPerRowJob( indices, logits, rowStart, rowEnd, outIndices, outLogits, stride1, topK); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif } void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* indices, float* outLogitsAux, int* outIndicesAux, int const splitWorkThreshold, int const numRows, int const numColumns, int const stride0, int const stride1, int const next_n, int const topK, cudaStream_t const stream) { constexpr int kSortingAlgorithmThreshold = 12288; constexpr int kSplitWorkThreshold = 200 * 1000; constexpr int kNumThreadsPerBlock = 512; if (numColumns < kSortingAlgorithmThreshold) { // Use insertion sort auto* kernel_instance = &topKPerRowDecode; cudaLaunchConfig_t config; config.gridDim = numRows; config.blockDim = kNumThreadsPerBlock; config.dynamicSmemBytes = topK * sizeof(int32_t); config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx( &config, kernel_instance, logits, seqLens, indices, stride0, stride1, topK, next_n, nullptr, 0, nullptr); } else if (numColumns < kSplitWorkThreshold) { // From this threshold, use radix sort instead auto* kernel_instance = &topKPerRowDecode; cudaLaunchConfig_t config; config.gridDim = numRows; config.blockDim = kNumThreadsPerBlock; config.dynamicSmemBytes = topK * sizeof(int32_t); config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx( &config, kernel_instance, logits, seqLens, indices, stride0, stride1, topK, next_n, nullptr, 0, nullptr); } else { // Long sequences are run in two steps constexpr auto multipleBlocksPerRowConfig = 10; auto* kernel_instance_part1 = &topKPerRowDecode; cudaLaunchConfig_t config_part1; config_part1.gridDim = dim3(numRows, multipleBlocksPerRowConfig); config_part1.blockDim = kNumThreadsPerBlock; config_part1.dynamicSmemBytes = 2 * topK * sizeof(int32_t); config_part1.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); config_part1.numAttrs = 1; config_part1.attrs = attrs; cudaLaunchKernelEx(&config_part1, kernel_instance_part1, logits, seqLens, outIndicesAux, stride0, stride1, topK, next_n, outLogitsAux, 0, nullptr); constexpr int kNumThreadsPerBlockMerge = 1024; auto* kernel_instance_part2 = &topKPerRowDecode; cudaLaunchConfig_t config_part2; config_part2.gridDim = numRows; config_part2.blockDim = kNumThreadsPerBlockMerge; config_part2.dynamicSmemBytes = topK * sizeof(int32_t); config_part2.stream = stream; // Reuse attrs array since part1 kernel has already been launched config_part2.numAttrs = 1; config_part2.attrs = attrs; cudaLaunchKernelEx(&config_part2, kernel_instance_part2, outLogitsAux, seqLens, indices, multipleBlocksPerRowConfig * topK, 1, topK, next_n, nullptr, multipleBlocksPerRowConfig, outIndicesAux); } sync_check_cuda_error(stream); } void invokeIndexerTopKPrefill(float const* logits, int const* rowStarts, int const* rowEnds, int* indices, int const numRows, int const numColumns, int const stride0, int const stride1, int const topK, cudaStream_t const stream) { constexpr int kSortingAlgorithmThreshold = 12288; constexpr int kNumThreadsPerBlock = 512; int numInsertionBlocks = std::min(numRows, kSortingAlgorithmThreshold); topKPerRowPrefill <<>>( logits, rowStarts, rowEnds, indices, stride0, stride1, topK, 0); if (numRows > kSortingAlgorithmThreshold) { int numRadixBlocks = numRows - kSortingAlgorithmThreshold; topKPerRowPrefill <<>>( logits, rowStarts, rowEnds, indices, stride0, stride1, topK, kSortingAlgorithmThreshold); } sync_check_cuda_error(stream); } } // namespace kernels TRTLLM_NAMESPACE_END