/* * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /** * This file contains a specialized implementation of AIR TopK * introduced in https://dl.acm.org/doi/pdf/10.1145/3581784.3607062 . * Another variant can be found in TopP sampling: cpp/tensorrt_llm/kernels/samplingAirTopPKernels.cu . */ #include #include "topkLastDim.h" #include #include namespace tensorrt_llm { namespace kernels { using SizeType32 = tensorrt_llm::runtime::SizeType32; /////////////// // AIR TopK Kernel #if 1 namespace air_topk_stable { using WideT = float4; constexpr int VECTORIZED_READ_SIZE = 16; constexpr int WARP_SIZE = 32; // constexpr unsigned FULL_WARP_MASK = 0xffffffff; template struct ComputeOffset { __host__ __device__ explicit ComputeOffset(IdxT const& cols) : cols_(cols) { } __host__ __device__ IdxT operator()(IdxT const& x) const { return cols_ * x; } IdxT cols_; }; template __host__ __device__ constexpr int calc_num_buckets() { return 1 << BitsPerPass; } /** * @brief Provide a ceiling division operation ie. ceil(a / b) * @tparam IntType supposed to be only integers for now! */ template constexpr __host__ __device__ IntType ceildiv(IntType a, IntType b) { return (a + b - 1) / b; } /** * @brief Provide an alignment function ie. ceil(a / b) * b * @tparam IntType supposed to be only integers for now! */ template constexpr __host__ __device__ IntType alignTo(IntType a, IntType b) { return ceildiv(a, b) * b; } template __host__ __device__ constexpr int calc_num_passes() { return ceildiv(sizeof(T) * 8, BitsPerPass); } __host__ __device__ int round(int num, int round_value) { return ((num - 1) / round_value + 1) * round_value; } /** * Bit 0 is the least significant (rightmost); * this implementation processes input from the most to the least significant bit. * This way, we can skip some passes in the end at the cost of having an unsorted output. * * NB: Use pass=-1 for calc_mask(). */ template __device__ constexpr int calc_start_bit(int pass) { int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; if (start_bit < 0) { start_bit = 0; } return start_bit; } template __device__ constexpr unsigned calc_mask(int pass) { static_assert(BitsPerPass <= 31); int num_bits = calc_start_bit(pass - 1) - calc_start_bit(pass); return (1 << num_bits) - 1; } /** * Use CUB to twiddle bits - so that we can correctly compare bits of floating-point * values as well as of integers. */ template __device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool select_min) { auto bits = reinterpret_cast::UnsignedBits&>(key); bits = cub::Traits::TwiddleIn(bits); if (!select_min) { bits = ~bits; } return bits; } template __device__ T twiddle_out(typename cub::Traits::UnsignedBits bits, bool select_min) { if (!select_min) { bits = ~bits; } bits = cub::Traits::TwiddleOut(bits); return reinterpret_cast(bits); } template __device__ int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) { static_assert(BitsPerPass <= sizeof(int) * 8 - 1, "BitsPerPass is too large that the result type could not be int"); return (twiddle_in(x, select_min) >> start_bit) & mask; } template constexpr inline std::enable_if_t::value, bool> is_a_power_of_two(I val) noexcept { return ((val - 1) & val) == 0; } template __host__ __device__ IdxT calc_buf_len(IdxT len) { // When writing is skipped, only read `in`(type T). // When writing is not skipped, read `in_buf`(T) and `in_idx_buf`(IdxT), and write // `out_buf`(T) and `out_idx_buf`(IdxT). The ratio between these cases determines // whether to skip writing and hence the buffer size. constexpr RATIO_T ratio = 2 + sizeof(IdxT) * 2 / sizeof(T); // Even such estimation is too conservative, so further decrease buf_len by 1/8 IdxT buf_len = len / (ratio * 8); // one-block kernel splits one large buffer into smaller ones, so round buf size to // 256 bytes to avoid alignment issues static_assert(is_a_power_of_two(sizeof(T))); static_assert(is_a_power_of_two(sizeof(IdxT))); constexpr IdxT aligned = 256 / std::min(sizeof(T), sizeof(IdxT)); buf_len = buf_len & (~(aligned - 1)); return buf_len; } /** * Map a Func over the input data, using vectorized load instructions if possible. * * NB: in future, we should move this to cpp/include/raft/linalg/detail/unary_op.cuh, * which currently does not support the second lambda argument (index of an element) * * @tparam T element type * @tparam IdxT indexing type * @tparam Func void (T x, IdxT idx) * * @param thread_rank rank of the calling thread among all participating threads * @param num_threads number of the threads that participate in processing * @param in the input data * @param len the number of elements to read * @param f the lambda taking two arguments (T x, IdxT idx) */ template __device__ void vectorized_process(size_t thread_rank, size_t num_threads, T const* in, idxT len, Func f) { if constexpr (sizeof(T) >= sizeof(WideT)) { for (idxT i = thread_rank; i < len; i += num_threads) { f(in[i], i); } } else { static_assert(sizeof(WideT) % sizeof(T) == 0); constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); // TODO: it's UB union { WideT scalar; T array[items_per_scalar]; } wide; int skip_cnt = (reinterpret_cast(in) % sizeof(WideT)) ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / sizeof(T)) : 0; if (skip_cnt > len) { skip_cnt = len; } WideT const* in_cast = reinterpret_cast(in + skip_cnt); const idxT len_cast = (len - skip_cnt) / items_per_scalar; for (idxT i = thread_rank; i < len_cast; i += num_threads) { wide.scalar = in_cast[i]; const idxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll for (int j = 0; j < items_per_scalar; ++j) { f(wide.array[j], real_i + j); } } static_assert(WARP_SIZE >= items_per_scalar); // and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt // no need to use loop if (thread_rank < skip_cnt) { f(in[thread_rank], thread_rank); } // because len_cast = (len - skip_cnt) / items_per_scalar, // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; // and so // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= WARP_SIZE // no need to use loop const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank; if (remain_i < len) { f(in[remain_i], remain_i); } } } // sync_width should >= WARP_SIZE template __device__ void vectorized_process(T const* in, idxT len, Func f, int sync_width) { const idxT stride = blockDim.x * gridDim.x; const idxT tid = blockIdx.x * blockDim.x + threadIdx.x; if constexpr (sizeof(T) >= sizeof(WideT)) { for (idxT i = tid; i < len; i += stride) { f(in[i], i, true); } } else { static_assert(sizeof(WideT) % sizeof(T) == 0); constexpr int items_per_scalar = sizeof(WideT) / sizeof(T); union { WideT scalar; T array[items_per_scalar]; } wide; int skip_cnt = (reinterpret_cast(in) % sizeof(WideT)) ? ((sizeof(WideT) - reinterpret_cast(in) % sizeof(WideT)) / sizeof(T)) : 0; if (skip_cnt > len) { skip_cnt = len; } WideT const* in_cast = reinterpret_cast(in + skip_cnt); const idxT len_cast = (len - skip_cnt) / items_per_scalar; const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width; for (idxT i = tid; i < len_cast_for_sync; i += stride) { bool valid = i < len_cast; if (valid) { wide.scalar = in_cast[i]; } const idxT real_i = skip_cnt + i * items_per_scalar; #pragma unroll for (int j = 0; j < items_per_scalar; ++j) { f(wide.array[j], real_i + j, valid); } } static_assert(WARP_SIZE >= items_per_scalar); // need at most one warp for skipped and remained elements, // and sync_width >= WARP_SIZE if (tid < sync_width) { bool valid = tid < skip_cnt; T value = valid ? in[tid] : T(); f(value, tid, valid); const idxT remain_i = skip_cnt + len_cast * items_per_scalar + tid; valid = remain_i < len; value = valid ? in[remain_i] : T(); f(value, remain_i, valid); } } } template struct alignas(128) Counter { // We are processing the values in multiple passes, from most significant to least // significant. In each pass, we keep the length of input (`len`) and the `k` of // current pass, and update them at the end of the pass. IdxT k; IdxT len; // `previous_len` is the length of input in previous pass. Note that `previous_len` // rather than `len` is used for the filtering step because filtering is indeed for // previous pass (see comments before `radix_kernel`). IdxT previous_len; // We determine the bits of the k_th value inside the mask processed by the pass. The // already known bits are stored in `kth_value_bits`. It's used to discriminate a // element is a result (written to `out`), a candidate for next pass (written to // `out_buf`), or not useful (discarded). The bits that are not yet processed do not // matter for this purpose. typename cub::Traits::UnsignedBits kth_value_bits; // Record how many elements have passed filtering. It's used to determine the position // in the `out_buf` where an element should be written. alignas(128) IdxT filter_cnt; // For a row inside a batch, we may launch multiple thread blocks. This counter is // used to determine if the current block is the last running block. If so, this block // will execute scan() and choose_bucket(). alignas(128) unsigned int finished_block_cnt; // Record how many elements have been written to the front of `out`. Elements less (if // select_min==true) than the k-th value are written from front to back. alignas(128) IdxT out_cnt; // Record how many elements have been written to the back of `out`. Elements equal to // the k-th value are written from back to front. We need to keep count of them // separately because the number of elements that <= the k-th value might exceed k. alignas(128) IdxT out_back_cnt; }; /** * Fused filtering of the current pass and building histogram for the next pass * (see steps 4 & 1 in `radix_kernel` description). */ template __device__ void filter_and_histogram(T const* in_buf, IdxT const* in_idx_buf, T* out_buf, IdxT* out_idx_buf, T* out, IdxT* out_idx, IdxT previous_len, Counter* counter, IdxT* histogram, bool select_min, int pass, bool early_stop) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; for (IdxT i = threadIdx.x; i < num_buckets; i += blockDim.x) { histogram_smem[i] = 0; } __syncthreads(); int const start_bit = calc_start_bit(pass); unsigned const mask = calc_mask(pass); if (pass == 0) { // Passed to vectorized_process, this function executes in all blocks in parallel, // i.e. the work is split along the input (both, in batches and chunks of a single // row). Later, the histograms are merged using atomicAdd. auto f = [select_min, start_bit, mask](T value, IdxT) { int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram_smem + bucket, static_cast(1)); }; vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, static_cast(blockDim.x) * gridDim.x, in_buf, previous_len, f); } else { IdxT* p_filter_cnt = &counter->filter_cnt; IdxT* p_out_cnt = &counter->out_cnt; auto const kth_value_bits = counter->kth_value_bits; int const previous_start_bit = calc_start_bit(pass - 1); // See the remark above on the distributed execution of `f` using // vectorized_process. auto f = [in_idx_buf, out_buf, out_idx_buf, out, out_idx, select_min, start_bit, mask, previous_start_bit, kth_value_bits, p_filter_cnt, p_out_cnt, early_stop](T value, IdxT i) { const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) << previous_start_bit; if (previous_bits == kth_value_bits) { if (early_stop) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } else { if (out_buf) { IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); out_buf[pos] = value; out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram_smem + bucket, static_cast(1)); } } // the condition `(out_buf || early_stop)` is a little tricky: // If we skip writing to `out_buf` (when `out_buf` is nullptr), we should skip // writing to `out` too. So we won't write the same value to `out` multiple // times in different passes. And if we keep skipping the writing, values will // be written in `last_filter_kernel()` at last. But when `early_stop` is // true, we need to write to `out` since it's the last chance. else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } }; vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, static_cast(blockDim.x) * gridDim.x, in_buf, previous_len, f); } if (early_stop) { return; } __syncthreads(); // merge histograms produced by individual blocks for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { if (histogram_smem[i] != 0) { atomicAdd(histogram + i, histogram_smem[i]); } } } /** * Replace histogram with its own prefix sum * (step 2 in `radix_kernel` description) */ template __device__ void scan(IdxT volatile* histogram) { constexpr int num_buckets = calc_num_buckets(); if constexpr (num_buckets >= BlockSize) { static_assert(num_buckets % BlockSize == 0); constexpr int items_per_thread = num_buckets / BlockSize; typedef cub::BlockLoad BlockLoad; typedef cub::BlockStore BlockStore; typedef cub::BlockScan BlockScan; __shared__ union { typename BlockLoad::TempStorage load; typename BlockScan::TempStorage scan; typename BlockStore::TempStorage store; } temp_storage; IdxT thread_data[items_per_thread]; BlockLoad(temp_storage.load).Load(histogram, thread_data); __syncthreads(); BlockScan(temp_storage.scan).InclusiveSum(thread_data, thread_data); __syncthreads(); BlockStore(temp_storage.store).Store(histogram, thread_data); } else { typedef cub::BlockScan BlockScan; __shared__ typename BlockScan::TempStorage temp_storage; IdxT thread_data = 0; if (threadIdx.x < num_buckets) { thread_data = histogram[threadIdx.x]; } BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); __syncthreads(); if (threadIdx.x < num_buckets) { histogram[threadIdx.x] = thread_data; } } } /** * Calculate in which bucket the k-th value will fall * (steps 3 in `radix_kernel` description) */ template __device__ void choose_bucket(Counter* counter, IdxT const* histogram, const IdxT k, int const pass) { constexpr int num_buckets = calc_num_buckets(); for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { IdxT prev = (i == 0) ? 0 : histogram[i - 1]; IdxT cur = histogram[i]; // one and only one thread will satisfy this condition, so counter is written by // only one thread if (prev < k && cur >= k) { counter->k = k - prev; // how many values still are there to find counter->len = cur - prev; // number of values in next pass typename cub::Traits::UnsignedBits bucket = i; int start_bit = calc_start_bit(pass); counter->kth_value_bits |= bucket << start_bit; } } } // For one-block version, last_filter() could be called when pass < num_passes - 1. // So `pass` could not be constexpr template __device__ void last_filter(T const* in_buf, IdxT const* in_idx_buf, T* out, IdxT* out_idx, IdxT current_len, IdxT k, Counter* counter, bool const select_min, int const pass) { auto const kth_value_bits = counter->kth_value_bits; int const start_bit = calc_start_bit(pass); // changed in choose_bucket(); need to reload const IdxT num_of_kth_needed = counter->k; IdxT* p_out_cnt = &counter->out_cnt; IdxT* p_out_back_cnt = &counter->out_back_cnt; IdxT* p_equal = out_idx + k - num_of_kth_needed; cuda::atomic_ref ref_last(p_equal[num_of_kth_needed - 1]); for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { const T value = in_buf[i]; auto const bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; if (bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; // For one-block version, `in_idx_buf` could be nullptr at pass 0. // For non one-block version, if writing has been skipped, `in_idx_buf` could // be nullptr if `in_buf` is `in` out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } else if (bits == kth_value_bits) { IdxT new_idx = in_idx_buf ? in_idx_buf[i] : i; IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); if (back_pos < num_of_kth_needed) { IdxT pos = k - 1 - back_pos; out[pos] = value; if constexpr (!prioritize_smaller_indice) { out_idx[pos] = new_idx; } } if constexpr (prioritize_smaller_indice) { if (new_idx < ref_last.load(cuda::memory_order_relaxed)) { for (int j = 0; j < num_of_kth_needed; j++) { IdxT pre_idx = atomicMin(&p_equal[j], new_idx); if (pre_idx > new_idx) { new_idx = pre_idx; } } } } } } } template __global__ void last_filter_kernel(T const* in, IdxT const* in_idx, T const* in_buf, IdxT const* in_idx_buf, T* out, IdxT* out_idx, IdxT len, IdxT k, Counter* counters, bool const select_min) { const size_t batch_id = blockIdx.y; // size_t to avoid multiplication overflow Counter* counter = counters + batch_id; IdxT previous_len = counter->previous_len; if (previous_len == 0) { return; } const IdxT buf_len = calc_buf_len(len); if (previous_len > buf_len || in_buf == in) { in_buf = in + batch_id * len; in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; previous_len = len; } else { in_buf += batch_id * buf_len; in_idx_buf += batch_id * buf_len; } out += batch_id * k; out_idx += batch_id * k; constexpr int pass = calc_num_passes() - 1; constexpr int start_bit = calc_start_bit(pass); auto const kth_value_bits = counter->kth_value_bits; const IdxT num_of_kth_needed = counter->k; IdxT* p_out_cnt = &counter->out_cnt; IdxT* p_out_back_cnt = &counter->out_back_cnt; IdxT* p_equal = out_idx + k - num_of_kth_needed; cuda::atomic_ref ref_last(p_equal[num_of_kth_needed - 1]); auto f = [k, select_min, kth_value_bits, num_of_kth_needed, p_out_cnt, p_out_back_cnt, in_idx_buf, out, out_idx, p_equal, ref_last](T value, IdxT i) { const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; if (bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } else if (bits == kth_value_bits) { IdxT new_idx = in_idx_buf ? in_idx_buf[i] : i; IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); if (back_pos < num_of_kth_needed) { IdxT pos = k - 1 - back_pos; out[pos] = value; if constexpr (!prioritize_smaller_indice) { out_idx[pos] = new_idx; } } if constexpr (prioritize_smaller_indice) { if (new_idx < ref_last.load(cuda::memory_order_relaxed)) { for (int j = 0; j < num_of_kth_needed; j++) { IdxT pre_idx = atomicMin(&p_equal[j], new_idx); if (pre_idx > new_idx) { new_idx = pre_idx; } } } } } }; vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, static_cast(blockDim.x) * gridDim.x, in_buf, previous_len, f); } /** * * It is expected to call this kernel multiple times (passes), in each pass we process a * radix, going from the most significant towards the least significant bits (MSD). * * Conceptually, each pass consists of 4 steps: * * 1. Calculate histogram * First, transform bits into a digit, the value of which is in the range * [0, 2^{BITS_PER_PASS}-1]. Then count the frequency of each digit value and the * result is a histogram. That is, histogram[i] contains the count of inputs having value * i. * * 2. Scan the histogram * Inclusive prefix sum is computed for the histogram. After this step, histogram[i] * contains the count of inputs having value <= i. * * 3. Find the bucket j of the histogram that the k-th value falls into * * 4. Filtering * Input elements whose digit value __global__ void radix_kernel(T const* in, IdxT const* in_idx, T const* in_buf, IdxT const* in_idx_buf, T* out_buf, IdxT* out_idx_buf, T* out, IdxT* out_idx, Counter* counters, IdxT* histograms, const IdxT len, const IdxT k, bool const select_min, int const pass) { const size_t batch_id = blockIdx.y; auto counter = counters + batch_id; IdxT current_k; IdxT previous_len; IdxT current_len; if (pass == 0) { current_k = k; previous_len = len; // Need to do this so setting counter->previous_len for the next pass is correct. // This value is meaningless for pass 0, but it's fine because pass 0 won't be the // last pass in this implementation so pass 0 won't hit the "if (pass == // num_passes - 1)" branch. // Maybe it's better to reload counter->previous_len and use it rather than // current_len in last_filter() current_len = len; } else { current_k = counter->k; current_len = counter->len; previous_len = counter->previous_len; } if (current_len == 0) { return; } // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() // should handle correctly the case that pass=0 and early_stop=true. However, this // special case of k=len is handled in other way in select_k() so such case is not // possible here. bool const early_stop = (current_len == current_k); const IdxT buf_len = calc_buf_len(len); // "previous_len > buf_len" means previous pass skips writing buffer if (pass == 0 || pass == 1 || previous_len > buf_len) { in_buf = in + batch_id * len; in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; previous_len = len; } else { in_buf += batch_id * buf_len; in_idx_buf += batch_id * buf_len; } // "current_len > buf_len" means current pass will skip writing buffer if (pass == 0 || current_len > buf_len) { out_buf = nullptr; out_idx_buf = nullptr; } else { out_buf += batch_id * buf_len; out_idx_buf += batch_id * buf_len; } out += batch_id * k; out_idx += batch_id * k; constexpr int num_buckets = calc_num_buckets(); auto histogram = histograms + batch_id * num_buckets; filter_and_histogram(in_buf, in_idx_buf, out_buf, out_idx_buf, out, out_idx, previous_len, counter, histogram, select_min, pass, early_stop); __threadfence(); bool isLastBlock = false; if (threadIdx.x == 0) { unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); isLastBlock = (finished == (gridDim.x - 1)); } if (__syncthreads_or(isLastBlock)) { if (early_stop) { if (threadIdx.x == 0) { // `last_filter_kernel()` requires setting previous_len counter->previous_len = 0; counter->len = 0; } return; } scan(histogram); __syncthreads(); choose_bucket(counter, histogram, current_k, pass); __syncthreads(); constexpr int num_passes = calc_num_passes(); // reset for next pass if (pass != num_passes - 1) { for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { histogram[i] = 0; } } if (threadIdx.x == 0) { // `last_filter_kernel()` requires setting previous_len even in the last pass counter->previous_len = current_len; // not necessary for the last pass, but put it here anyway counter->filter_cnt = 0; } // if constexpr (fused_last_filter) { // if (pass == num_passes - 1) { // last_filter(out_buf ? out_buf : in_buf, // out_idx_buf ? out_idx_buf : // in_idx_buf, out, out_idx, out_buf ? // current_len : len, k, counter, // select_min, // pass); // } // } if (pass == num_passes - 1) { const volatile IdxT num_of_kth_needed = counter->k; for (IdxT i = threadIdx.x; i < num_of_kth_needed; i += blockDim.x) { out_idx[k - num_of_kth_needed + i] = cuda::std::numeric_limits::max(); } __syncthreads(); if constexpr (fused_last_filter) { last_filter(out_buf ? out_buf : in_buf, out_idx_buf ? out_idx_buf : in_idx_buf, out, out_idx, out_buf ? current_len : len, k, counter, select_min, pass); } } } } template unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) { static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); int active_blocks; cudaOccupancyMaxActiveBlocksPerMultiprocessor( &active_blocks, radix_kernel, BlockSize, 0); active_blocks *= sm_cnt; IdxT best_num_blocks = 0; float best_tail_wave_penalty = 1.0f; const IdxT max_num_blocks = ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize); for (int num_waves = 1;; ++num_waves) { IdxT num_blocks = std::min(max_num_blocks, static_cast(std::max(num_waves * active_blocks / batch_size, 1))); IdxT items_per_thread = ceildiv(len, num_blocks * BlockSize); items_per_thread = alignTo(items_per_thread, VECTORIZED_READ_SIZE / sizeof(T)); num_blocks = ceildiv(len, items_per_thread * BlockSize); float actual_num_waves = static_cast(num_blocks) * batch_size / active_blocks; float tail_wave_penalty = (ceilf(actual_num_waves) - actual_num_waves) / ceilf(actual_num_waves); // 0.15 is determined experimentally. It also ensures breaking the loop early, // e.g. when num_waves > 7, tail_wave_penalty will always <0.15 if (tail_wave_penalty < 0.15) { best_num_blocks = num_blocks; break; } else if (tail_wave_penalty < best_tail_wave_penalty) { best_num_blocks = num_blocks; best_tail_wave_penalty = tail_wave_penalty; } if (num_blocks == max_num_blocks) { break; } } return best_num_blocks; } template __host__ __device__ void set_buf_pointers(T const* in, IdxT const* in_idx, T* buf1, IdxT* idx_buf1, T* buf2, IdxT* idx_buf2, int pass, T const*& in_buf, IdxT const*& in_idx_buf, T*& out_buf, IdxT*& out_idx_buf) { if (pass == 0) { in_buf = in; in_idx_buf = nullptr; out_buf = nullptr; out_idx_buf = nullptr; } else if (pass == 1) { in_buf = in; in_idx_buf = in_idx; out_buf = buf1; out_idx_buf = idx_buf1; } else if (pass % 2 == 0) { in_buf = buf1; in_idx_buf = idx_buf1; out_buf = buf2; out_idx_buf = idx_buf2; } else { in_buf = buf2; in_idx_buf = idx_buf2; out_buf = buf1; out_idx_buf = idx_buf1; } } template __device__ void set_buf_pointers(T const* in, IdxT const* in_idx, char* bufs, IdxT buf_len, int pass, T const*& in_buf, IdxT const*& in_idx_buf, T*& out_buf, IdxT*& out_idx_buf) { // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 if (pass == 0) { in_buf = in; in_idx_buf = nullptr; out_buf = nullptr; out_idx_buf = nullptr; } else if (pass == 1) { in_buf = in; in_idx_buf = in_idx; out_buf = reinterpret_cast(bufs); out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); } else if (pass % 2 == 0) { in_buf = reinterpret_cast(bufs); in_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); out_buf = const_cast(in_buf + buf_len); out_idx_buf = const_cast(in_idx_buf + buf_len); } else { out_buf = reinterpret_cast(bufs); out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); in_buf = out_buf + buf_len; in_idx_buf = out_idx_buf + buf_len; } } // The following a few functions are for the one-block version, which uses single thread // block for each row of a batch. template __device__ void filter_and_histogram_for_one_block(T const* in_buf, IdxT const* in_idx_buf, T* out_buf, IdxT* out_idx_buf, T* out, IdxT* out_idx, const IdxT previous_len, Counter* counter, IdxT* histogram, bool select_min, int pass) { constexpr int num_buckets = calc_num_buckets(); for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { histogram[i] = 0; } IdxT* p_filter_cnt = &counter->filter_cnt; if (threadIdx.x == 0) { *p_filter_cnt = 0; } __syncthreads(); int const start_bit = calc_start_bit(pass); unsigned const mask = calc_mask(pass); if (pass == 0) { auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram + bucket, static_cast(1)); }; vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f); } else if (!out_buf) { // not use vectorized_process here because it increases #registers a lot auto const kth_value_bits = counter->kth_value_bits; int const previous_start_bit = calc_start_bit(pass - 1); for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { const T value = in_buf[i]; auto const previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) << previous_start_bit; if (previous_bits == kth_value_bits) { int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram + bucket, static_cast(1)); } } } else { // not use vectorized_process here because it increases #registers a lot IdxT* p_out_cnt = &counter->out_cnt; auto const kth_value_bits = counter->kth_value_bits; int const previous_start_bit = calc_start_bit(pass - 1); for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { const T value = in_buf[i]; auto const previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) << previous_start_bit; if (previous_bits == kth_value_bits) { #if CUDART_VERSION < 12000 // Avoiding potential compiler bug in CUDA 11 volatile #endif IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); out_buf[pos] = value; out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; int bucket = calc_bucket(value, start_bit, mask, select_min); atomicAdd(histogram + bucket, static_cast(1)); } else if (previous_bits < kth_value_bits) { IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } } } } template __global__ void radix_topk_one_block_kernel(T const* in, IdxT const* in_idx, const IdxT len, const IdxT k, T* out, IdxT* out_idx, bool const select_min, char* bufs) { constexpr int num_buckets = calc_num_buckets(); __shared__ Counter counter; __shared__ IdxT histogram[num_buckets]; if (threadIdx.x == 0) { counter.k = k; counter.len = len; counter.previous_len = len; counter.kth_value_bits = 0; counter.out_cnt = 0; counter.out_back_cnt = 0; } __syncthreads(); const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow in += batch_id * len; if (in_idx) { in_idx += batch_id * len; } out += batch_id * k; out_idx += batch_id * k; const IdxT buf_len = calc_buf_len(len); bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { T const* in_buf = nullptr; IdxT const* in_idx_buf = nullptr; T* out_buf = nullptr; IdxT* out_idx_buf = nullptr; set_buf_pointers(in, in_idx, bufs, buf_len, pass, in_buf, in_idx_buf, out_buf, out_idx_buf); const IdxT current_len = counter.len; const IdxT current_k = counter.k; IdxT previous_len = counter.previous_len; if (previous_len > buf_len) { in_buf = in; in_idx_buf = in_idx; previous_len = len; } if (current_len > buf_len) { // so "out_buf==nullptr" denotes skipping writing buffer in current pass out_buf = nullptr; out_idx_buf = nullptr; } filter_and_histogram_for_one_block(in_buf, in_idx_buf, out_buf, out_idx_buf, out, out_idx, previous_len, &counter, histogram, select_min, pass); //@TODO CHECK UPDATE CODE __syncthreads(); scan(histogram); __syncthreads(); choose_bucket(&counter, histogram, current_k, pass); if (threadIdx.x == 0) { counter.previous_len = current_len; } __syncthreads(); if ((pass == num_passes - 1)) { if constexpr (prioritize_smaller_indice) { const IdxT num_of_kth_needed = counter.k; for (IdxT i = threadIdx.x; i < num_of_kth_needed; i += blockDim.x) { out_idx[k - num_of_kth_needed + i] = cuda::std::numeric_limits::max(); } __syncthreads(); } last_filter(out_buf ? out_buf : in, out_buf ? out_idx_buf : in_idx, out, out_idx, out_buf ? current_len : len, k, &counter, select_min, pass); break; } else if (counter.len == counter.k) { last_filter(out_buf ? out_buf : in, out_buf ? out_idx_buf : in_idx, out, out_idx, out_buf ? current_len : len, k, &counter, select_min, pass); break; } } } } // namespace air_topk_stable //} /***************Runtime API****************/ inline size_t calc_aligned_size(std::vector const& sizes) { const size_t ALIGN_BYTES = 256; const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); size_t total = 0; for (auto sz : sizes) { total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; } return total + ALIGN_BYTES - 1; } inline std::vector calc_aligned_pointers(void const* p, std::vector const& sizes) { const size_t ALIGN_BYTES = 256; const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); char* ptr = reinterpret_cast((reinterpret_cast(p) + ALIGN_BYTES - 1) & ALIGN_MASK); std::vector aligned_pointers; aligned_pointers.reserve(sizes.size()); for (auto sz : sizes) { aligned_pointers.push_back(ptr); ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; } return aligned_pointers; } template void standalone_stable_radix_topk_(void* buf, size_t& buf_size, T const* in, IdxT const* in_idx, int batch_size, IdxT len, IdxT k, T* out, IdxT* out_idx, bool select_min, bool fused_last_filter, unsigned grid_dim, cudaStream_t stream, bool sorted = false) { static_assert(air_topk_stable::calc_num_passes() > 1); constexpr int num_buckets = air_topk_stable::calc_num_buckets(); air_topk_stable::Counter* counters = nullptr; IdxT* histograms = nullptr; T* buf1 = nullptr; IdxT* idx_buf1 = nullptr; T* buf2 = nullptr; IdxT* idx_buf2 = nullptr; void* sort_temp_storage = nullptr; size_t temp_storage_bytes = 0; size_t temp_storage_bytes_sort = 0; T* topk_out = nullptr; IdxT* topk_out_idx = nullptr; T* sort_in = nullptr; IdxT* sort_in_idx = nullptr; air_topk_stable::ComputeOffset computeoffset(k); cub::CountingInputIterator counting_iter(0); cub::TransformInputIterator, cub::CountingInputIterator> transform_iter(counting_iter, computeoffset); cub::DeviceSegmentedSort::SortPairs(NULL, temp_storage_bytes, out_idx, out_idx, out, out, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); if (sorted) { if (select_min) { cub::DeviceSegmentedSort::StableSortPairs(NULL, temp_storage_bytes_sort, out, out, out_idx, out_idx, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); } else { cub::DeviceSegmentedSort::StableSortPairsDescending(NULL, temp_storage_bytes_sort, out, out, out_idx, out_idx, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); } } temp_storage_bytes = max(temp_storage_bytes, temp_storage_bytes_sort); { IdxT len_candidates = air_topk_stable::calc_buf_len(len); size_t sort_buffer_size = 0; if (sorted) { sort_buffer_size = k * batch_size; } std::vector sizes = {sizeof(*counters) * batch_size, sizeof(*histograms) * num_buckets * batch_size, sizeof(*buf1) * len_candidates * batch_size, sizeof(*idx_buf1) * len_candidates * batch_size, sizeof(*buf2) * len_candidates * batch_size, sizeof(*idx_buf2) * len_candidates * batch_size, temp_storage_bytes, sizeof(*topk_out) * k * batch_size, sizeof(*topk_out_idx) * k * batch_size, sizeof(*sort_in) * sort_buffer_size, sizeof(*sort_in_idx) * sort_buffer_size}; size_t total_size = calc_aligned_size(sizes); if (!buf) { buf_size = total_size; return; } std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); counters = static_cast(aligned_pointers[0]); histograms = static_cast(aligned_pointers[1]); buf1 = static_cast(aligned_pointers[2]); idx_buf1 = static_cast(aligned_pointers[3]); buf2 = static_cast(aligned_pointers[4]); idx_buf2 = static_cast(aligned_pointers[5]); sort_temp_storage = aligned_pointers[6]; topk_out = static_cast(aligned_pointers[7]); topk_out_idx = static_cast(aligned_pointers[8]); if (sorted) { sort_in = static_cast(aligned_pointers[9]); sort_in_idx = static_cast(aligned_pointers[10]); } cudaMemsetAsync( buf, 0, static_cast(aligned_pointers[2]) - static_cast(aligned_pointers[0]), stream); } T const* in_buf = nullptr; IdxT const* in_idx_buf = nullptr; T* out_buf = nullptr; IdxT* out_idx_buf = nullptr; dim3 blocks(grid_dim, batch_size); constexpr int num_passes = air_topk_stable::calc_num_passes(); auto kernel = air_topk_stable::radix_kernel; for (int pass = 0; pass < num_passes; ++pass) { air_topk_stable::set_buf_pointers( in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, out_buf, out_idx_buf); if (fused_last_filter && pass == num_passes - 1) { kernel = air_topk_stable::radix_kernel; } kernel<<>>(in, in_idx, in_buf, in_idx_buf, out_buf, out_idx_buf, topk_out, topk_out_idx, counters, histograms, len, k, select_min, pass); } if (!fused_last_filter) { air_topk_stable::last_filter_kernel<<>>( in, in_idx, out_buf, out_idx_buf, topk_out, topk_out_idx, len, k, counters, select_min); } T* idx_sort_out = sorted ? sort_in : out; IdxT* idx_sort_out_idx = sorted ? sort_in_idx : out_idx; cub::DeviceSegmentedSort::SortPairs(sort_temp_storage, temp_storage_bytes, topk_out_idx, idx_sort_out_idx, topk_out, idx_sort_out, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); if (sorted) { if (select_min) { cub::DeviceSegmentedSort::StableSortPairs(sort_temp_storage, temp_storage_bytes, sort_in, out, sort_in_idx, out_idx, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); } else { cub::DeviceSegmentedSort::StableSortPairsDescending(sort_temp_storage, temp_storage_bytes, sort_in, out, sort_in_idx, out_idx, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); } } } template void standalone_stable_radix_topk_one_block_(void* buf, size_t& buf_size, T const* in, IdxT const* in_idx, int batch_size, IdxT len, IdxT k, T* out, IdxT* out_idx, bool select_min, cudaStream_t stream, bool sorted = false) { static_assert(air_topk_stable::calc_num_passes() > 1); char* bufs = nullptr; void* sort_temp_storage = nullptr; T* topk_out = nullptr; IdxT* topk_out_idx = nullptr; T* sort_in = nullptr; IdxT* sort_in_idx = nullptr; size_t temp_storage_bytes = 0; size_t temp_storage_bytes_sort = 0; const IdxT buf_len = air_topk_stable::calc_buf_len(len); air_topk_stable::ComputeOffset computeoffset(k); cub::CountingInputIterator counting_iter(0); cub::TransformInputIterator, cub::CountingInputIterator> transform_iter(counting_iter, computeoffset); cub::DeviceSegmentedSort::SortPairs(NULL, temp_storage_bytes, out_idx, out_idx, out, out, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); if (sorted) { if (select_min) { cub::DeviceSegmentedSort::StableSortPairs(NULL, temp_storage_bytes_sort, out, out, out_idx, out_idx, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); } else { cub::DeviceSegmentedSort::StableSortPairsDescending(NULL, temp_storage_bytes_sort, out, out, out_idx, out_idx, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); } } temp_storage_bytes = max(temp_storage_bytes, temp_storage_bytes_sort); { size_t total_size = 0; size_t sort_buffer_size = 0; if (sorted) { sort_buffer_size = k * batch_size; } std::vector sizes = {buf_len * 2 * (sizeof(T) + sizeof(IdxT)) * batch_size, temp_storage_bytes, sizeof(*topk_out) * k * batch_size, sizeof(*topk_out_idx) * k * batch_size, sizeof(*sort_in) * sort_buffer_size, sizeof(*sort_in_idx) * sort_buffer_size}; total_size = calc_aligned_size(sizes); if (!buf) { buf_size = total_size; return; } std::vector aligned_pointers = calc_aligned_pointers(buf, sizes); bufs = static_cast(aligned_pointers[0]); sort_temp_storage = aligned_pointers[1]; topk_out = static_cast(aligned_pointers[2]); topk_out_idx = static_cast(aligned_pointers[3]); if (sorted) { sort_in = static_cast(aligned_pointers[4]); sort_in_idx = static_cast(aligned_pointers[5]); } } air_topk_stable::radix_topk_one_block_kernel <<>>(in, in_idx, len, k, topk_out, topk_out_idx, select_min, bufs); T* idx_sort_out = sorted ? sort_in : out; IdxT* idx_sort_out_idx = sorted ? sort_in_idx : out_idx; cub::DeviceSegmentedSort::SortPairs(sort_temp_storage, temp_storage_bytes, topk_out_idx, idx_sort_out_idx, topk_out, idx_sort_out, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); if (sorted) { if (select_min) { cub::DeviceSegmentedSort::StableSortPairs(sort_temp_storage, temp_storage_bytes, sort_in, out, sort_in_idx, out_idx, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); } else { cub::DeviceSegmentedSort::StableSortPairsDescending(sort_temp_storage, temp_storage_bytes, sort_in, out, sort_in_idx, out_idx, k * batch_size, batch_size, transform_iter, transform_iter + 1, stream); } } } template void standalone_stable_radix_11bits(void* buf, size_t& buf_size, T const* in, int batch_size, idxT len, idxT k, T* out, idxT* out_idx, bool greater, cudaStream_t stream = 0) { constexpr int items_per_thread = 32; constexpr int block_dim = 512; constexpr bool fused_last_filter = false; if (len <= block_dim * items_per_thread) { standalone_stable_radix_topk_one_block_( buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted); } else { int sm_cnt = tensorrt_llm::common::getMultiProcessorCount(); unsigned grid_dim = air_topk_stable::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { standalone_stable_radix_topk_one_block_(buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted); } else { standalone_stable_radix_topk_(buf, buf_size, in, static_cast(nullptr), batch_size, len, k, out, out_idx, !greater, fused_last_filter, grid_dim, stream, sorted); } } } #endif /////////////// template size_t invokeComputeTopkLastDimWorkspaceSize( SizeType32 batchSize, SizeType32 inputLength, SizeType32 k, bool is_largest) { size_t buf_size = 0; void* workspace = nullptr; T const* in = nullptr; T* out_val = nullptr; SizeType32* out_idx = nullptr; standalone_stable_radix_11bits( workspace, buf_size, in, batchSize, inputLength, k, out_val, out_idx, is_largest, 0); return buf_size; } #define INSTANTIATE_COMPUTE_TOPK_LastDim_WORKSPACE_SIZE_DATA_TYPE(T) \ template size_t invokeComputeTopkLastDimWorkspaceSize( \ SizeType32 batchSize, SizeType32 inputLength, SizeType32 k, bool is_largest) INSTANTIATE_COMPUTE_TOPK_LastDim_WORKSPACE_SIZE_DATA_TYPE(int); INSTANTIATE_COMPUTE_TOPK_LastDim_WORKSPACE_SIZE_DATA_TYPE(float); INSTANTIATE_COMPUTE_TOPK_LastDim_WORKSPACE_SIZE_DATA_TYPE(half); #ifdef ENABLE_BF16 INSTANTIATE_COMPUTE_TOPK_LastDim_WORKSPACE_SIZE_DATA_TYPE(__nv_bfloat16); #endif #undef INSTANTIATE_COMPUTE_TOPK_LastDim_WORKSPACE_SIZE_DATA_TYPE // Might need FP8 in the future. /////////////// template void invokeTopkLastDim(SizeType32 batchSize, SizeType32 inputLength, SizeType32 k, bool is_largest, void const* __restrict__ input, void* __restrict__ out_val, void* __restrict__ out_idx, void* workspace, cudaStream_t stream) { size_t buf_size = 0; // will be overwritten by the kernel T const* in = reinterpret_cast(input); T* out_val_ = reinterpret_cast(out_val); SizeType32* out_idx_ = reinterpret_cast(out_idx); standalone_stable_radix_11bits( workspace, buf_size, in, batchSize, inputLength, k, out_val_, out_idx_, is_largest, stream); } #define INSTANTIATE_TOPK_LastDim_DATA_TYPE(T) \ template void invokeTopkLastDim(SizeType32 batchSize, SizeType32 inputLength, SizeType32 k, bool is_largest, \ void const* __restrict__ input, void* __restrict__ out_val, void* __restrict__ out_idx, void* workspace, \ cudaStream_t stream) INSTANTIATE_TOPK_LastDim_DATA_TYPE(int); INSTANTIATE_TOPK_LastDim_DATA_TYPE(float); INSTANTIATE_TOPK_LastDim_DATA_TYPE(half); #ifdef ENABLE_BF16 INSTANTIATE_TOPK_LastDim_DATA_TYPE(__nv_bfloat16); #endif #undef INSTANTIATE_TOPK_LastDim_DATA_TYPE } // namespace kernels } // namespace tensorrt_llm