mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* disable overlap in encoder Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * feat: invokeGatherBatch Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * feat: overlap same batch Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: add enableTrtOverlap to ExecutorConfig Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * disable overlap for beam search and spec decode Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * skip overlap tests with beam search or speculative decoding Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * moveFinishedContextRequestsToGeneration and skip unfinished requests in updateRequests Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * enable overlap in GptChunkedLongContextTests Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * feat: Enable overlap in gptManagerBenchmark Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * feat: Improve early exit Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Use OptionalRef for newOutputTokens tensor Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * feat: Add overlap scheduling support to TRTLLMDecoder - Updated TRTLLMDecoder to accept an `enable_overlap_scheduler` parameter. - Modified the decoder's internal logic to utilize the overlap scheduling feature. - Adjusted the sequence lengths handling to ensure compatibility with the new scheduling approach. - Enhanced unit tests to include cases for the overlap scheduler with the TRTLLMDecoder. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * fix: allNewTokens in PP Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --------- Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
208 lines
8.9 KiB
C++
208 lines
8.9 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h"
|
|
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
|
|
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
|
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
#include "tensorrt_llm/common/logger.h"
|
|
#include "tensorrt_llm/runtime/bufferManager.h"
|
|
#include "tensorrt_llm/runtime/cudaStream.h"
|
|
#include "tensorrt_llm/runtime/decoderState.h"
|
|
#include "tensorrt_llm/runtime/iGptDecoderBatched.h"
|
|
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
|
|
|
namespace tr = tensorrt_llm::runtime;
|
|
|
|
namespace tensorrt_llm::batch_manager
|
|
{
|
|
using SizeType32 = MakeDecodingBatchInputOutput::SizeType32;
|
|
using TensorPtr = MakeDecodingBatchInputOutput::TensorPtr;
|
|
|
|
std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::createDecoderBatchInputs(
|
|
std::vector<SizeType32> const& activeSlots, runtime::decoder::DecoderState const& decoderState,
|
|
std::vector<TensorPtr> const& logits, SizeType32 maxNumSequences, std::vector<TensorPtr> const& batchSlots,
|
|
TensorPtr const& cacheIndirectionInput)
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
|
|
auto const& numDecodingEngineTokens = decoderState.getNumDecodingEngineTokens();
|
|
auto const& maxDecodingEngineTokens = decoderState.getMaxDecodingEngineTokens();
|
|
auto const& maxDecodingDecoderTokens = decoderState.getMaxDecodingDecoderTokens();
|
|
auto const maxDecoderSteps = common::ceilDiv(maxDecodingEngineTokens, maxDecodingDecoderTokens);
|
|
|
|
for (SizeType32 step = 0; step < maxDecoderSteps; ++step)
|
|
{
|
|
batchSlots.at(step)->resize(maxNumSequences);
|
|
}
|
|
|
|
std::vector<SizeType32> batchIdx(maxDecoderSteps);
|
|
auto maxActiveDecoderSteps = 1;
|
|
for (auto const slot : activeSlots)
|
|
{
|
|
auto const numDecoderSteps = common::ceilDiv(numDecodingEngineTokens.at(slot), maxDecodingDecoderTokens);
|
|
maxActiveDecoderSteps = std::max(maxActiveDecoderSteps, numDecoderSteps);
|
|
for (SizeType32 step = 0; step < numDecoderSteps; ++step)
|
|
{
|
|
auto batchSlotsRange = tr::BufferRange<SizeType32>(*batchSlots.at(step));
|
|
batchSlotsRange[batchIdx[step]] = slot;
|
|
batchIdx[step]++;
|
|
}
|
|
}
|
|
|
|
for (SizeType32 step = 0; step < maxDecoderSteps; ++step)
|
|
{
|
|
batchSlots.at(step)->resize(batchIdx[step]);
|
|
}
|
|
|
|
auto constexpr singleRequest = 1;
|
|
std::vector<std::vector<tr::ITensor::SharedConstPtr>> logitsVec(maxActiveDecoderSteps);
|
|
for (SizeType32 step = 0; step < maxActiveDecoderSteps; ++step)
|
|
{
|
|
auto batchSlotsRange = tr::BufferRange<SizeType32>(*batchSlots.at(step));
|
|
|
|
for (auto slot : batchSlotsRange)
|
|
{
|
|
auto const& targetLogits = logits.at(slot);
|
|
TensorPtr logitsSlice = tr::ITensor::slice(targetLogits, step, singleRequest);
|
|
logitsVec.at(step).push_back(logitsSlice);
|
|
}
|
|
}
|
|
|
|
auto decodingInput = std::make_unique<tr::decoder_batch::Input>(logitsVec, maxActiveDecoderSteps);
|
|
decodingInput->batchSlots = batchSlots;
|
|
decodingInput->cacheIndirection = cacheIndirectionInput;
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
return decodingInput;
|
|
}
|
|
|
|
namespace
|
|
{
|
|
|
|
std::vector<SizeType32> getActiveSlots(RequestVector const& contextRequests, RequestVector const& generationRequests)
|
|
{
|
|
std::vector<SizeType32> activeSlots;
|
|
for (auto const& requests : {contextRequests, generationRequests})
|
|
{
|
|
for (auto const& llmReq : requests)
|
|
{
|
|
if (llmReq->isGenerationInProgressState() || llmReq->isLastContextChunk())
|
|
{
|
|
activeSlots.push_back(llmReq->mSeqSlot.value());
|
|
}
|
|
}
|
|
}
|
|
|
|
std::sort(activeSlots.begin(), activeSlots.end());
|
|
|
|
return activeSlots;
|
|
}
|
|
|
|
void copySequenceLengths(RequestVector const& contextRequests, RequestVector const& generationRequests,
|
|
DecoderInputBuffers const& inputBuffers, TensorPtr const& sequenceLengths, SizeType32 beamWidth, bool isTrtOverlap,
|
|
runtime::BufferManager const& manager, runtime::CudaStream const& stream)
|
|
{
|
|
auto const batchSize = contextRequests.size() + generationRequests.size();
|
|
auto batchSlotsView = tr::ITensor::slice(inputBuffers.forwardBatchSlotsRequestOrder, 0, batchSize);
|
|
auto fillValuesView = tr::ITensor::slice(inputBuffers.fillValues, 0, batchSize);
|
|
|
|
auto batchSlotsRange = tr::BufferRange<SizeType32>(*batchSlotsView);
|
|
auto fillValuesRange = tr::BufferRange<SizeType32>(*fillValuesView);
|
|
|
|
// fill buffers on host
|
|
SizeType32 batchIdx{0};
|
|
for (auto const& llmReq : contextRequests)
|
|
{
|
|
auto const currentSequenceLen = llmReq->mPromptLen + llmReq->getMaxNumGeneratedTokens();
|
|
// Get position of the current sequence in the decoder
|
|
auto const seqSlot = llmReq->mSeqSlot.value();
|
|
batchSlotsRange[batchIdx] = seqSlot;
|
|
fillValuesRange[batchIdx] = currentSequenceLen;
|
|
++batchIdx;
|
|
}
|
|
for (auto const& llmReq : generationRequests)
|
|
{
|
|
auto const currentSequenceLen
|
|
= llmReq->mPromptLen + llmReq->getMaxNumGeneratedTokens() + static_cast<SizeType32>(isTrtOverlap);
|
|
// Get position of the current sequence in the decoder
|
|
auto const seqSlot = llmReq->mSeqSlot.value();
|
|
batchSlotsRange[batchIdx] = seqSlot;
|
|
fillValuesRange[batchIdx] = currentSequenceLen;
|
|
++batchIdx;
|
|
}
|
|
|
|
// copy sequence lengths
|
|
{
|
|
auto batchSlotsDeviceView = tr::ITensor::slice(inputBuffers.forwardBatchSlotsRequestOrderDevice, 0, batchSize);
|
|
auto fillValuesViewDevice = tr::ITensor::slice(inputBuffers.fillValuesDevice, 0, batchSize);
|
|
|
|
manager.copy(*batchSlotsView, *batchSlotsDeviceView);
|
|
manager.copy(*fillValuesView, *fillValuesViewDevice);
|
|
tr::kernels::invokeFillBatch(*sequenceLengths, *batchSlotsDeviceView, beamWidth, *fillValuesViewDevice, stream);
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
std::tuple<std::unique_ptr<tr::decoder_batch::Input>, std::unique_ptr<tr::decoder_batch::Output>>
|
|
MakeDecodingBatchInputOutput::operator()(RequestVector const& contextRequests, RequestVector const& generationRequests,
|
|
DecoderBuffers& decoderBuffers, DecoderInputBuffers const& inputBuffers,
|
|
runtime::decoder::DecoderState& decoderState, runtime::ModelConfig const& modelConfig, SizeType32 maxNumSequences,
|
|
SizeType32 beamWidth, bool isTrtOverlap, runtime::BufferManager const& manager, runtime::CudaStream const& stream,
|
|
OptionalRef<RuntimeBuffers> fusedRuntimeBuffers) const
|
|
{
|
|
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
|
|
|
auto activeSlots = getActiveSlots(contextRequests, generationRequests);
|
|
|
|
auto decodingInput = createDecoderBatchInputs(activeSlots, decoderState, decoderBuffers.logits, maxNumSequences,
|
|
inputBuffers.forwardBatchSlots, decoderBuffers.cacheIndirectionInput);
|
|
|
|
if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits())
|
|
{
|
|
decodingInput->predictedDraftLogits = decoderBuffers.draftBuffers.predictedDraftLogits;
|
|
}
|
|
|
|
if (modelConfig.getSpeculativeDecodingMode().isExplicitDraftTokens())
|
|
{
|
|
TLLM_CHECK(fusedRuntimeBuffers);
|
|
// requires mCtxGenFusion == true
|
|
decodingInput->batchSlotsRequestOrder = fusedRuntimeBuffers->seqSlots;
|
|
decodingInput->explicitDraftTokensInputs = fusedRuntimeBuffers->explicitDraftTokensBuffers->engineOutputs;
|
|
decodingInput->explicitDraftTokensLastInputs = fusedRuntimeBuffers->explicitDraftTokensBuffers->engineInputs;
|
|
}
|
|
else if (modelConfig.getSpeculativeDecodingMode().isEagle())
|
|
{
|
|
TLLM_CHECK(fusedRuntimeBuffers);
|
|
// requires mCtxGenFusion == true
|
|
decodingInput->batchSlotsRequestOrder = fusedRuntimeBuffers->seqSlots;
|
|
decodingInput->eagleInputs = fusedRuntimeBuffers->eagleBuffers->engineOutputs;
|
|
decodingInput->eagleLastInputs = fusedRuntimeBuffers->eagleBuffers->engineInputs;
|
|
}
|
|
|
|
copySequenceLengths(contextRequests, generationRequests, inputBuffers,
|
|
decoderState.getJointDecodingOutput().lengths, beamWidth, isTrtOverlap, manager, stream);
|
|
|
|
auto decodingOutput = std::make_unique<tr::decoder_batch::Output>();
|
|
decodingOutput->cacheIndirection = decoderBuffers.cacheIndirectionOutput;
|
|
|
|
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
|
return {std::move(decodingInput), std::move(decodingOutput)};
|
|
}
|
|
|
|
} // namespace tensorrt_llm::batch_manager
|