TensorRT-LLMs/cpp/tensorrt_llm/runtime/runtimeKernels.h
Robin Kobus 72057a0a64
[TRTLLM-3429] feat: Overlap scheduling in C++ runtime (#3625)
* disable overlap in encoder

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* feat: invokeGatherBatch

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* feat: overlap same batch

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* chore: add enableTrtOverlap to ExecutorConfig

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* disable overlap for beam search and spec decode

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* skip overlap tests with beam search or speculative decoding

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* moveFinishedContextRequestsToGeneration and skip unfinished requests in updateRequests

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* enable overlap in GptChunkedLongContextTests

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* feat: Enable overlap in gptManagerBenchmark

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* feat: Improve early exit

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Use OptionalRef for newOutputTokens tensor

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* feat: Add overlap scheduling support to TRTLLMDecoder

- Updated TRTLLMDecoder to accept an `enable_overlap_scheduler` parameter.
- Modified the decoder's internal logic to utilize the overlap scheduling feature.
- Adjusted the sequence lengths handling to ensure compatibility with the new scheduling approach.
- Enhanced unit tests to include cases for the overlap scheduler with the TRTLLMDecoder.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fix: allNewTokens in PP

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-05-06 15:06:46 +02:00

83 lines
3.8 KiB
C++

/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/kernels/kvCacheIndex.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/iTensor.h"
namespace tensorrt_llm::runtime::kernels
{
using TensorPtr = runtime::ITensor::SharedPtr;
template <typename T>
void invokeFill(IBuffer& buffer, T value, CudaStream const& stream);
void invokeFillBatch(
IBuffer& buffer, IBuffer const& indices, std::size_t stride, IBuffer const& values, CudaStream const& stream);
void invokeGatherBatch(IBuffer& buffer, IBuffer const& values, IBuffer const& slotIndices, std::size_t slotStride,
CudaStream const& stream);
void invokeCopyBatch(IBuffer const& srcBuffer, IBuffer& dstBuffer, IBuffer const& srcOffsets, IBuffer const& dstOffsets,
IBuffer const& sizes, std::size_t maxStride, CudaStream const& stream);
template <typename T>
void invokeAdd(IBuffer& buffer, T value, CudaStream const& stream);
void reduce(IBuffer& output, IBuffer const& input, CudaStream const& stream);
void invokeTransposeWithOutputOffset(
ITensor& output, ITensor const& input, SizeType32 outputOffset, CudaStream const& stream);
void invokeInclusiveSum(IBuffer& output, IBuffer const& input, BufferManager const& manager, CudaStream const& stream);
void invokeBuildAttentionMask(ITensor& attentionMask, TokenIdType padId, CudaStream const& stream);
void invokeExtendAttentionMask(ITensor& newMask, ITensor const& oldMask, CudaStream const& stream);
void invokeCopyInputToOutput(ITensor& outputIds, ITensor const& inputIds, ITensor const& inputLengths, SizeType32 padId,
CudaStream const& stream);
void initOutputIds(ITensor& outputIds, ITensor const& inputIds, ITensor const& inputLengths,
ITensor const& inputOffsets, TokenIdType padId, TokenIdType endId, SizeType32 maxInputLength, bool inputPacked,
CudaStream const& stream);
void scatterTensor(ITensor& output, ITensor const& input, SizeType32 beamWidth, CudaStream const& stream);
void tileTensor(ITensor& output, ITensor const& input, SizeType32 beamWidth, CudaStream const& stream);
void gatherLastTokenLogits(
ITensor& output, ITensor const& input, ITensor const& lastTokenIds, CudaStream const& stream);
void mergeLogitsFragments(BufferManager const& bufferManager, ITensor& output,
std::vector<TensorPtr> const& inputVector, ITensor& cachePointerDevice, ITensor& cachePointerHost,
SizeType32 firstBatchSlotIdx, SizeType32 microBatchSize, SizeType32 beamWidth, CudaStream const& stream,
int stepOffset);
void invokeUpdateKVBlockArrayDraftTokenLocation(ITensor const& seqAcceptedDraftTokenOffsets,
ITensor const& packedAcceptedDraftTokensIndices, ITensor const& pastKeyValueLengths, void* const* pointerArray,
::tensorrt_llm::kernels::KVCacheIndex const* offsetArray, SizeType32 layerCount, SizeType32 seqCount,
SizeType32 numKVHeads, SizeType32 sizeInBytesPerKVHead, SizeType32 rewindDraftTokenCommonCount,
SizeType32 const* rewindDraftTokenSeparateAdjustments, ITensor const& seqSlotRemapping, ITensor const& batchSlots,
SizeType32 maxKVCacheLen, SizeType32 maxBlocksPerSeq, SizeType32 tokensPerBlock, bool canUseOneMoreBlock,
cudaStream_t stream);
} // namespace tensorrt_llm::runtime::kernels