mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[nvbugs/5354884][fix] Update beam search workspace estimation to new upper bound (#5926)
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
This commit is contained in:
parent
6d7874a467
commit
d475c97c82
@ -1459,13 +1459,23 @@ template <typename T>
|
||||
size_t invokeComputeTopkLastDimWorkspaceSize(
|
||||
SizeType32 batchSize, SizeType32 inputLength, SizeType32 k, bool is_largest)
|
||||
{
|
||||
using idxT = SizeType32;
|
||||
|
||||
size_t buf_size = 0;
|
||||
void* workspace = nullptr;
|
||||
T const* in = nullptr;
|
||||
T* out_val = nullptr;
|
||||
SizeType32* out_idx = nullptr;
|
||||
standalone_stable_radix_11bits<T, SizeType32, true>(
|
||||
workspace, buf_size, in, batchSize, inputLength, k, out_val, out_idx, is_largest, 0);
|
||||
idxT* out_idx = nullptr;
|
||||
|
||||
constexpr int block_dim = 512;
|
||||
constexpr bool fused_last_filter = false;
|
||||
constexpr bool sorted = true;
|
||||
|
||||
int sm_cnt = tensorrt_llm::common::getMultiProcessorCount();
|
||||
unsigned grid_dim = air_topk_stable::calc_grid_dim<T, idxT, 11, block_dim>(batchSize, inputLength, sm_cnt);
|
||||
|
||||
standalone_stable_radix_topk_<T, idxT, 11, block_dim>(workspace, buf_size, in, static_cast<idxT*>(nullptr),
|
||||
batchSize, inputLength, k, out_val, out_idx, !is_largest, fused_last_filter, grid_dim, 0, sorted);
|
||||
return buf_size;
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user