TensorRT-LLMs/cpp/tensorrt_llm/runtime/runtimeKernels.cu
Robin Kobus 6d4b045d1f
refactor: Remove enforced sorted order of batch slots (#3502)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-07-14 17:23:02 +02:00

486 lines
22 KiB
Plaintext

/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/kvCacheIndex.h"
#include "tensorrt_llm/kernels/speculativeDecoding/kvCacheUpdateKernels.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include <NvInferRuntimeBase.h>
#include <cuda_runtime.h>
using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
namespace tensorrt_llm::runtime::kernels
{
namespace
{
template <typename T>
__global__ void fill(T* data, std::size_t size, T const value)
{
auto const tidx = (static_cast<std::size_t>(blockIdx.x) * blockDim.x) + threadIdx.x;
auto const stride = static_cast<std::size_t>(blockDim.x) * gridDim.x;
for (auto idx = tidx; idx < size; idx += stride)
{
data[idx] = value;
}
}
//! @param data expected shape [indicesRange, size]
//! @param indices expected shape [gridDim.y]
//! @param size
//! @param values expected shape [gridDim.y]
template <typename T>
__global__ void fillBatch(T* data, std::int32_t const* indices, std::size_t size, T const* values)
{
auto const batchIdx = indices[blockIdx.y];
T const value = values[blockIdx.y];
auto const tidx = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
auto const stride = static_cast<std::size_t>(blockDim.x) * gridDim.x;
auto const startIdx = batchIdx * size;
auto const endIdx = startIdx + size;
for (auto idx = startIdx + tidx; idx < endIdx; idx += stride)
{
data[idx] = value;
}
}
template <typename T>
void invokeFillBatch(IBuffer& buffer, IBuffer const& slotIndices, std::size_t slotStride, IBuffer const& values,
CudaStream const& stream)
{
auto data = bufferCast<T>(buffer);
auto const* const indices = bufferCast<std::int32_t>(slotIndices);
auto fillValues = bufferCast<T>(values);
auto numSlots = slotIndices.getSize();
auto const size = slotStride;
dim3 const blockSize{256};
std::size_t const gridx{tc::ceilDiv(size, blockSize.x)};
std::size_t const gridMax{std::numeric_limits<std::uint32_t>::max()};
dim3 const gridSize{static_cast<std::uint32_t>(std::min(gridx, gridMax)), static_cast<std::uint32_t>(numSlots)};
fillBatch<<<gridSize, blockSize, 0, stream.get()>>>(data, indices, size, fillValues);
}
//! @param data expected shape [gridDim.y, size]
//! @param indices expected shape [gridDim.y]
//! @param size
//! @param values expected shape [indicesRange, size]
template <typename T>
__global__ void gatherBatch(T* data, T const* values, std::int32_t const* indices, std::size_t size)
{
auto const tidx = static_cast<std::size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
auto const stride = static_cast<std::size_t>(blockDim.x) * gridDim.x;
for (auto idx = tidx; idx < size; idx += stride)
{
auto const batchIdx = blockIdx.y;
auto const slotIdx = indices[blockIdx.y];
data[batchIdx + idx] = values[slotIdx + idx];
}
}
template <typename T>
void invokeGatherBatch(IBuffer& buffer, IBuffer const& values, IBuffer const& slotIndices, std::size_t slotStride,
CudaStream const& stream)
{
auto data = bufferCast<T>(buffer);
auto const* const indices = bufferCast<std::int32_t>(slotIndices);
auto sparseValues = bufferCast<T>(values);
auto numSlots = slotIndices.getSize();
auto const size = slotStride;
dim3 const blockSize{256};
std::size_t const gridx{tc::ceilDiv(size, blockSize.x)};
std::size_t const gridMax{std::numeric_limits<std::uint32_t>::max()};
dim3 const gridSize{static_cast<std::uint32_t>(std::min(gridx, gridMax)), static_cast<std::uint32_t>(numSlots)};
gatherBatch<<<gridSize, blockSize, 0, stream.get()>>>(data, sparseValues, indices, size);
}
template <typename VecT>
__global__ void copyBatch(uint8_t const* srcData, uint8_t* dstData, SizeType64 const* srcOffsets,
SizeType64 const* dstOffsets, SizeType64 const* sizes, SizeType64 const dataTypeSize)
{
constexpr auto VEC_ELTS = static_cast<int32_t>(sizeof(VecT));
SizeType64 const srcStartIdx = srcOffsets[blockIdx.y] * dataTypeSize;
SizeType64 const dstStartIdx = dstOffsets[blockIdx.y] * dataTypeSize;
SizeType64 const size = sizes[blockIdx.y] * dataTypeSize;
SizeType64 const tidx = (static_cast<SizeType64>(blockIdx.x) * blockDim.x + threadIdx.x) * VEC_ELTS;
SizeType64 const stride = static_cast<SizeType64>(blockDim.x) * gridDim.x * VEC_ELTS;
SizeType64 const srcEndIdx = srcStartIdx + size;
SizeType64 srcIdx = srcStartIdx + tidx;
SizeType64 dstIdx = dstStartIdx + tidx;
for (; srcIdx < srcEndIdx; srcIdx += stride, dstIdx += stride)
{
*reinterpret_cast<VecT*>(&dstData[dstIdx]) = *reinterpret_cast<VecT const*>(&srcData[srcIdx]);
}
}
template <typename T>
__global__ void scatterTensor(T* output, T const* input, std::uint32_t const batchSize,
std::uint32_t const inputRowSize, std::size_t const outputRowSize, std::uint32_t const beamWidth)
{
auto const tidx = (static_cast<std::size_t>(blockIdx.x) * blockDim.x) + threadIdx.x;
auto const tidy = (static_cast<std::size_t>(blockIdx.y) * blockDim.y) + threadIdx.y;
auto const stridex = static_cast<std::size_t>(blockDim.x) * gridDim.x;
auto const stridey = static_cast<std::size_t>(blockDim.y) * gridDim.y;
for (auto batchIdx = tidy; batchIdx < batchSize; batchIdx += stridey)
{
for (auto columnIdx = tidx; columnIdx < inputRowSize; columnIdx += stridex)
{
auto const inputIdx = (batchIdx * inputRowSize) + columnIdx;
auto const value = input[inputIdx];
std::size_t constexpr beamIdx{0};
auto const outputIdx = ((batchIdx * beamWidth + beamIdx) * outputRowSize) + columnIdx;
output[outputIdx] = value;
}
}
}
template <typename T>
__global__ void tileTensor(T* output, T const* input, std::uint32_t const batchSize, std::size_t const inputRowSize,
std::size_t const outputRowSize, std::uint32_t const beamWidth)
{
auto const tidx = (static_cast<std::size_t>(blockIdx.x) * blockDim.x) + threadIdx.x;
auto const tidy = (static_cast<std::size_t>(blockIdx.y) * blockDim.y) + threadIdx.y;
auto const stridex = static_cast<std::size_t>(blockDim.x) * gridDim.x;
auto const stridey = static_cast<std::size_t>(blockDim.y) * gridDim.y;
for (auto batchIdx = tidy; batchIdx < batchSize; batchIdx += stridey)
{
for (auto columnIdx = tidx; columnIdx < inputRowSize; columnIdx += stridex)
{
auto const inputIdx = (batchIdx * inputRowSize) + columnIdx;
auto const value = input[inputIdx];
for (std::size_t beamIdx = 0; beamIdx < beamWidth; ++beamIdx)
{
auto const outputIdx = ((batchIdx * beamWidth + beamIdx) * outputRowSize) + columnIdx;
output[outputIdx] = value;
}
}
}
}
template <typename T>
void invokeScatterTensor(ITensor& output, ITensor const& input, SizeType32 beamWidth, CudaStream const& stream)
{
auto const& inputShape = input.getShape();
auto const nbInputRows = static_cast<std::uint32_t>(inputShape.d[0]);
auto const inputRowSize = input.getSize() / static_cast<std::size_t>(nbInputRows);
auto const& outputShape = output.getShape();
auto const nbOutputRows = static_cast<std::uint32_t>(outputShape.d[0]);
auto const outputRowSize = output.getSize() / static_cast<std::size_t>(nbOutputRows);
TLLM_CHECK_WITH_INFO(nbOutputRows == beamWidth * nbInputRows,
common::fmtstr(
"nbOutputRows (%d) must be beamWidth (%d) times nbInputRows (%d)", nbOutputRows, beamWidth, nbInputRows));
TLLM_CHECK_WITH_INFO(outputRowSize >= inputRowSize,
common::fmtstr("output row size (%ld) must be at least input row size (%ld)", outputRowSize, inputRowSize));
dim3 const blockSize{256, 1};
std::size_t const gridx{tc::ceilDiv(inputRowSize, blockSize.x)};
std::size_t const gridMax{std::numeric_limits<std::uint32_t>::max()};
dim3 const gridSize{static_cast<std::uint32_t>(std::min(gridx, gridMax)), nbInputRows};
scatterTensor<<<gridSize, blockSize, 0, stream.get()>>>(bufferCast<T>(output), bufferCast<T const>(input),
nbInputRows, inputRowSize, outputRowSize, static_cast<uint32_t>(beamWidth));
}
template <typename T>
void invokeTileTensor(ITensor& output, ITensor const& input, SizeType32 const beamWidth, CudaStream const& stream)
{
auto const& inputShape = input.getShape();
auto const nbInputRows = static_cast<std::uint32_t>(inputShape.d[0]);
auto const inputRowSize = input.getSize() / static_cast<std::size_t>(nbInputRows);
auto const& outputShape = output.getShape();
auto const nbOutputRows = static_cast<std::uint32_t>(outputShape.d[0]);
auto const outputRowSize = output.getSize() / static_cast<std::size_t>(nbOutputRows);
TLLM_CHECK_WITH_INFO(nbOutputRows == beamWidth * nbInputRows,
common::fmtstr(
"nbOutputRows (%d) must be beamWidth (%d) times nbInputRows (%d)", nbOutputRows, beamWidth, nbInputRows));
TLLM_CHECK_WITH_INFO(outputRowSize >= inputRowSize,
common::fmtstr("output row size (%ld) must be at least input row size (%ld)", outputRowSize, inputRowSize));
dim3 const blockSize{256, 1};
std::size_t const gridx{tc::ceilDiv(inputRowSize, blockSize.x)};
std::size_t const gridMax{std::numeric_limits<std::uint32_t>::max()};
dim3 const gridSize{static_cast<std::uint32_t>(std::min(gridx, gridMax)), nbInputRows};
tileTensor<<<gridSize, blockSize, 0, stream.get()>>>(bufferCast<T>(output), bufferCast<T const>(input), nbInputRows,
inputRowSize, outputRowSize, static_cast<uint32_t>(beamWidth));
}
// In the following kernel, we launch a grid with (microBatchSize * beamWidth, outputLen) blocks of threads. Each thread
// block copies a `vocabSizePadded` length logits tensor from the "inputLogits (microBatchSize, beamWidth,
// vocabSizePadded)" to the "outputGenerationLogits (batchSize, beamWidth, outputLen, vocabSizePadded)"
template <typename T>
__global__ void mergeLogitsFragmentsKernel(T* output, T** fragmentsVector, int const outputLen, int firstBatchSlotIdx,
int beamWidth, int vocabSizePadded, int stepOffset)
{
// output: shape: [batchSize, beamWidth, outputLen, vocabSize]
// inputVecor.at(i): shape: [microBatchSize, beamWidth, vocabSize]
// Current step
int const curStep = blockIdx.y;
// The relatively batch slot index that this thread block in microBatchSize.
int const relativeBatchSlotIdx = blockIdx.x / beamWidth;
// The Absolute batch slot index in batchSize.
int const absoluteBatchSlotIdx = firstBatchSlotIdx + relativeBatchSlotIdx;
// The beam index that this thread block process
int const mbeamIdx = blockIdx.x % beamWidth;
// The output pointer
unsigned int const outputOffset
= (absoluteBatchSlotIdx * beamWidth * outputLen + mbeamIdx * outputLen + curStep + stepOffset)
* vocabSizePadded;
T* outputPtr = &output[outputOffset];
unsigned int const inputOffset = (relativeBatchSlotIdx * beamWidth + mbeamIdx) * vocabSizePadded;
// The input pointer.
T const* inputPtr = &fragmentsVector[curStep][inputOffset];
// The threads in the block collaborate to copy the logits.
for (int idx = threadIdx.x; idx < vocabSizePadded; idx += blockDim.x)
{
outputPtr[idx] = inputPtr[idx];
}
}
template <typename T>
void invokeMergeLogitsFragments(BufferManager const& bufferManager, ITensor& output,
std::vector<TensorPtr> const& fragmentsVector, ITensor& cachePointerDevice, ITensor& cachePointerHost,
SizeType32 firstBatchSlotIdx, SizeType32 microBatchSize, SizeType32 beamWidth, CudaStream const& stream,
int stepOffset)
{
size_t const fragmentsVectorSize = fragmentsVector.size();
auto cachePointerHostPtr = bufferCast<T*>(cachePointerHost);
for (int i = 0; i < fragmentsVectorSize; i++)
{
cachePointerHostPtr[i] = bufferCast<T>(*fragmentsVector.at(i));
}
bufferManager.copy(cachePointerHost, cachePointerDevice);
dim3 const blockSize(256);
dim3 const gridSize{(unsigned int) (microBatchSize * beamWidth), (unsigned int) (fragmentsVectorSize)};
auto const& outputShape = output.getShape();
auto const vocabSizePadded = static_cast<SizeType32>(outputShape.d[outputShape.nbDims - 1]);
auto const outputLen = static_cast<SizeType32>(outputShape.d[outputShape.nbDims - 2]);
TLLM_CHECK_WITH_INFO(outputLen >= fragmentsVectorSize, "Fragments size does not match outputLen size");
mergeLogitsFragmentsKernel<T><<<gridSize, blockSize, 0, stream.get()>>>(bufferCast<T>(output),
bufferCast<T*>(cachePointerDevice), outputLen, firstBatchSlotIdx, beamWidth, vocabSizePadded, stepOffset);
}
} // namespace
template <typename T>
void invokeFill(IBuffer& buffer, T const value, CudaStream const& stream)
{
auto data = bufferCast<T>(buffer);
auto const size = buffer.getSize();
dim3 const blockSize{256};
std::size_t const gridx{tc::ceilDiv(size, blockSize.x)};
std::size_t const gridMax{std::numeric_limits<std::uint32_t>::max()};
dim3 const gridSize{static_cast<std::uint32_t>(std::min(gridx, gridMax))};
fill<<<gridSize, blockSize, 0, stream.get()>>>(data, size, value);
}
// template instantiation
template void invokeFill(IBuffer&, SizeType64, CudaStream const&);
template void invokeFill(IBuffer&, std::int32_t, CudaStream const&);
template void invokeFill(IBuffer&, std::int8_t, CudaStream const&);
template void invokeFill(IBuffer&, std::uint8_t, CudaStream const&);
template void invokeFill(IBuffer&, bool, CudaStream const&);
template void invokeFill(IBuffer&, half, CudaStream const&);
template void invokeFill(IBuffer&, float, CudaStream const&);
#ifdef ENABLE_BF16
template void invokeFill(IBuffer&, __nv_bfloat16, CudaStream const&);
#endif // ENABLE_BF16
void invokeFillBatch(IBuffer& buffer, IBuffer const& slotIndices, std::size_t slotStride, IBuffer const& values,
CudaStream const& stream)
{
switch (buffer.getDataType())
{
case nvinfer1::DataType::kINT32:
invokeFillBatch<std::int32_t>(buffer, slotIndices, slotStride, values, stream);
break;
case nvinfer1::DataType::kINT8:
invokeFillBatch<std::int8_t>(buffer, slotIndices, slotStride, values, stream);
break;
case nvinfer1::DataType::kFLOAT: invokeFillBatch<float>(buffer, slotIndices, slotStride, values, stream); break;
default: TLLM_THROW("data type not supported");
}
}
void invokeGatherBatch(IBuffer& buffer, IBuffer const& values, IBuffer const& slotIndices, std::size_t slotStride,
CudaStream const& stream)
{
switch (buffer.getDataType())
{
case nvinfer1::DataType::kINT32:
invokeGatherBatch<std::int32_t>(buffer, values, slotIndices, slotStride, stream);
break;
case nvinfer1::DataType::kINT8:
invokeGatherBatch<std::int8_t>(buffer, values, slotIndices, slotStride, stream);
break;
case nvinfer1::DataType::kFLOAT: invokeGatherBatch<float>(buffer, values, slotIndices, slotStride, stream); break;
default: TLLM_THROW("data type not supported");
}
}
void invokeCopyBatch(IBuffer const& srcBuffer, IBuffer& dstBuffer, IBuffer const& srcOffsets, IBuffer const& dstOffsets,
IBuffer const& sizes, std::size_t maxStride, CudaStream const& stream)
{
auto const* srcDataPtr = reinterpret_cast<uint8_t const*>(srcBuffer.data());
auto* dstDataPtr = reinterpret_cast<uint8_t*>(dstBuffer.data());
auto const* srcOffsetsPtr = bufferCast<SizeType64>(srcOffsets);
auto const* dstOffsetsPtr = bufferCast<SizeType64>(dstOffsets);
auto const* sizesPtr = bufferCast<SizeType64>(sizes);
auto numSlots = srcOffsets.getSize();
auto const size = maxStride;
auto const dataTypeSize = BufferDataType(srcBuffer.getDataType()).getSize();
auto const copyRowSizeInBytes = size * dataTypeSize;
auto copyBatchInvocation = copyBatch<uint8_t>;
auto vectorSize = 1;
if (dataTypeSize % 16 == 0)
{
vectorSize = 16;
copyBatchInvocation = copyBatch<uint4>;
}
else if (dataTypeSize % 8 == 0)
{
vectorSize = 8;
copyBatchInvocation = copyBatch<uint2>;
}
else if (dataTypeSize % 4 == 0)
{
vectorSize = 4;
copyBatchInvocation = copyBatch<uint32_t>;
}
else if (dataTypeSize % 2 == 0)
{
vectorSize = 2;
copyBatchInvocation = copyBatch<uint16_t>;
}
dim3 const blockSize{256};
std::size_t const gridx{tc::ceilDiv(copyRowSizeInBytes / vectorSize, blockSize.x)};
std::size_t const gridMax{std::numeric_limits<std::uint32_t>::max()};
dim3 const gridSize{static_cast<std::uint32_t>(std::min(gridx, gridMax)), static_cast<std::uint32_t>(numSlots)};
copyBatchInvocation<<<gridSize, blockSize, 0, stream.get()>>>(
srcDataPtr, dstDataPtr, srcOffsetsPtr, dstOffsetsPtr, sizesPtr, static_cast<SizeType64>(dataTypeSize));
}
void scatterTensor(ITensor& output, ITensor const& input, SizeType32 beamWidth, CudaStream const& stream)
{
switch (input.getDataType())
{
case nvinfer1::DataType::kINT32: invokeScatterTensor<SizeType32>(output, input, beamWidth, stream); break;
case nvinfer1::DataType::kFLOAT: invokeScatterTensor<float>(output, input, beamWidth, stream); break;
case nvinfer1::DataType::kHALF: invokeScatterTensor<half>(output, input, beamWidth, stream); break;
case nvinfer1::DataType::kINT8: invokeScatterTensor<int8_t>(output, input, beamWidth, stream); break;
#ifdef ENABLE_FP8
case nvinfer1::DataType::kFP8: invokeScatterTensor<__nv_fp8_e4m3>(output, input, beamWidth, stream); break;
#endif // ENABLE_FP8
default: TLLM_THROW("data type not supported");
}
}
void tileTensor(ITensor& output, ITensor const& input, SizeType32 beamWidth, CudaStream const& stream)
{
switch (input.getDataType())
{
case nvinfer1::DataType::kINT32: invokeTileTensor<SizeType32>(output, input, beamWidth, stream); break;
case nvinfer1::DataType::kFLOAT: invokeTileTensor<float>(output, input, beamWidth, stream); break;
case nvinfer1::DataType::kHALF: invokeTileTensor<half>(output, input, beamWidth, stream); break;
#ifdef ENABLE_BF16
case nvinfer1::DataType::kBF16: invokeTileTensor<__nv_bfloat16>(output, input, beamWidth, stream); break;
#endif // ENABLE_BF16
case nvinfer1::DataType::kINT8: invokeTileTensor<int8_t>(output, input, beamWidth, stream); break;
#ifdef ENABLE_FP8
case nvinfer1::DataType::kFP8: invokeTileTensor<__nv_fp8_e4m3>(output, input, beamWidth, stream); break;
#endif // ENABLE_FP8
default: TLLM_THROW("data type not supported");
}
}
void mergeLogitsFragments(BufferManager const& bufferManager, ITensor& output,
std::vector<TensorPtr> const& fragmentsVector, ITensor& cachePointerDevice, ITensor& cachePointerHost,
SizeType32 firstBatchSlotIdx, SizeType32 const microBatchSize, SizeType32 const beamWidth, CudaStream const& stream,
int stepOffset)
{
switch (output.getDataType())
{
case nvinfer1::DataType::kFLOAT:
invokeMergeLogitsFragments<float>(bufferManager, output, fragmentsVector, cachePointerDevice, cachePointerHost,
firstBatchSlotIdx, microBatchSize, beamWidth, stream, stepOffset);
break;
case nvinfer1::DataType::kHALF:
invokeMergeLogitsFragments<half>(bufferManager, output, fragmentsVector, cachePointerDevice, cachePointerHost,
firstBatchSlotIdx, microBatchSize, beamWidth, stream, stepOffset);
break;
#ifdef ENABLE_BF16
case nvinfer1::DataType::kBF16:
invokeMergeLogitsFragments<__nv_bfloat16>(bufferManager, output, fragmentsVector, cachePointerDevice,
cachePointerHost, firstBatchSlotIdx, microBatchSize, beamWidth, stream, stepOffset);
break;
#endif // ENABLE_BF16
#ifdef ENABLE_FP8
case nvinfer1::DataType::kFP8:
invokeMergeLogitsFragments<__nv_fp8_e4m3>(bufferManager, output, fragmentsVector, cachePointerDevice,
cachePointerHost, firstBatchSlotIdx, microBatchSize, beamWidth, stream, stepOffset);
break;
#endif // ENABLE_FP8
default: TLLM_THROW("data type not supported");
}
}
void invokeUpdateKVBlockArrayDraftTokenLocation(ITensor const& seqAcceptedDraftTokenOffsets,
ITensor const& packedAcceptedDraftTokensIndices, ITensor const& pastKeyValueLengths, void* const* pointerArray,
::tensorrt_llm::kernels::KVCacheIndex const* offsetArray, SizeType32 layerCount, SizeType32 seqCount,
SizeType32 numKVHeads, SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCommonCount,
SizeType32 const* rewindDraftTokenSeparateAdjustments, ITensor const& batchSlots, SizeType32 maxKVCacheLen,
SizeType32 maxBlocksPerSeq, SizeType32 tokensPerBlock, bool canUseOneMoreBlock, cudaStream_t stream)
{
tensorrt_llm::kernels::speculative_decoding::updateKVBlockArrayDraftTokenLocation(
bufferCast<SizeType32>(seqAcceptedDraftTokenOffsets), bufferCast<SizeType32>(packedAcceptedDraftTokensIndices),
bufferCast<SizeType32>(pastKeyValueLengths), pointerArray, offsetArray, layerCount, seqCount, numKVHeads,
sizeInBytesPerKVHead, rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, nullptr,
bufferCast<SizeType32>(batchSlots), maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, canUseOneMoreBlock, stream);
}
} // namespace tensorrt_llm::runtime::kernels