[nvbugs/5354884][fix] Update beam search workspace estimation to new upper bound (#5926)

Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
This commit is contained in:
Stefan Niebler 2025-07-18 19:54:51 +02:00 committed by GitHub
parent 6d7874a467
commit d475c97c82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1459,13 +1459,23 @@ template <typename T>
size_t invokeComputeTopkLastDimWorkspaceSize(
SizeType32 batchSize, SizeType32 inputLength, SizeType32 k, bool is_largest)
{
using idxT = SizeType32;
size_t buf_size = 0;
void* workspace = nullptr;
T const* in = nullptr;
T* out_val = nullptr;
SizeType32* out_idx = nullptr;
standalone_stable_radix_11bits<T, SizeType32, true>(
workspace, buf_size, in, batchSize, inputLength, k, out_val, out_idx, is_largest, 0);
idxT* out_idx = nullptr;
constexpr int block_dim = 512;
constexpr bool fused_last_filter = false;
constexpr bool sorted = true;
int sm_cnt = tensorrt_llm::common::getMultiProcessorCount();
unsigned grid_dim = air_topk_stable::calc_grid_dim<T, idxT, 11, block_dim>(batchSize, inputLength, sm_cnt);
standalone_stable_radix_topk_<T, idxT, 11, block_dim>(workspace, buf_size, in, static_cast<idxT*>(nullptr),
batchSize, inputLength, k, out_val, out_idx, !is_largest, fused_last_filter, grid_dim, 0, sorted);
return buf_size;
}