TensorRT-LLMs/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp
Robin Kobus ceec4924d9
refactor: batch slot management in decoder classes (#3300)
* refactor: batch slot management in decoder classes

- Changed `forwardBatchSlots` from a single `TensorPtr` to a `std::vector<TensorPtr>` in `decoderBuffers.h` and updated its initialization in `decoderBuffers.cpp`.
- Updated `batchSlots` in `iGptDecoderBatched.h` to a `std::vector<TensorPtr>` for better handling of batch sizes.
- Modified `mBatchSlotsDecoder` in `statefulGptDecoderBatched.h` to use a `std::vector<TensorPtr>` and adjusted its initialization in `statefulGptDecoderBatched.cpp`.
- Ensured proper reshaping of tensors in the setup methods to accommodate the new vector structure.

These changes enhance flexibility in managing tensor buffers across different batch sizes.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Setup batch slots outside of the decoder

- Refactored batch slot management to utilize `makeBatchSlots`, enhancing clarity and functionality in batch processing.
- Introduced `DecoderState` to `MakeDecodingBatchInputOutput` for improved state handling during decoding.
- Updated the `operator()` method to include `decoderState` as a parameter, facilitating better integration with the decoding process.
- Modified related tests to accommodate changes in batch slot handling and ensure proper functionality.

These updates improve the overall structure and efficiency of the decoding process in the batch manager.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Enhance decoder input structure with maxDecodingEngineTokens

- Updated the `Input` class in `iGptDecoderBatched.h` to include a new parameter `maxDecodingEngineTokens` for better control over decoding limits.
- Modified the `MakeDecodingBatchInputOutput` algorithm to compute the maximum number of decoding tokens based on active slots.
- Adjusted the `GptDecoderBatched` class to utilize the new `maxDecodingEngineTokens` parameter, improving clarity in token management during decoding.
- Updated Python bindings to reflect changes in the `Input` class constructor.
- Enhanced tests to ensure proper handling of the new parameter.

These changes improve the flexibility and efficiency of the decoding process in the batch manager.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Streamline decoder input creation and batch slot management

- Introduced a new function `createDecoderInputs` to encapsulate the logic for creating decoder inputs, improving code organization.
- Updated the `operator()` method to utilize the new `createDecoderInputs` function, simplifying the decoding input setup process.
- Removed the `maxOfActiveSlots` template function to streamline the logic for determining the maximum number of active decoding engine tokens.
- Introduced a direct calculation of `maxActiveDecodingEngineTokens` within the `createDecoderInputs` function, enhancing clarity and reducing complexity.

These changes enhance the maintainability and readability of the decoding process in the batch manager.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Update logits handling in decoder batch

- Modified the `decoder_batch::Input` to accept a vector of vectors for logits, enhancing flexibility in tensor management.
- Adjusted the `createDecoderInputs` function to accommodate the new logits structure, ensuring proper batch processing.
- Updated Python bindings to reflect changes in the `Input` class constructor, maintaining compatibility with existing interfaces.
- Refactored the `GptDecoderBatched` and `StatefulGptDecoderBatched` classes to utilize the updated logits structure, improving clarity in tensor slicing and batch size management.
- Enhanced tests to validate the new input structure and ensure correct functionality across various decoding scenarios.

These changes streamline the decoding process and improve the overall maintainability of the codebase.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Rename maxDecodingEngineTokens to maxDecoderSteps

- Updated the `Input` class in `iGptDecoderBatched.h` to rename `maxDecodingEngineTokens` to `maxDecoderSteps` for improved clarity.
- Adjusted the `createDecoderInputs` function to reflect the new naming, ensuring consistency in the decoding process.
- Modified the `GptDecoderBatched` class to utilize `maxDecoderSteps` in its logic, enhancing readability and maintainability.
- Updated Python bindings to expose the renamed parameter, maintaining compatibility with existing interfaces.

These changes enhance the clarity of the decoding parameters and improve the overall structure of the codebase.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: remove usage of `active` vector from prepareForward

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Removed the `active` vector from `decoder_batch::Input`

- Removed the `active` vector from the `Input` class constructor in `iGptDecoderBatched.h`, streamlining the input handling for decoding.
- Updated the `createDecoderInputs` function and related tests to reflect the changes in the `Input` class, ensuring compatibility and maintaining functionality.
- Adjusted Python bindings to accommodate the new constructor signature, enhancing clarity in the interface.

These changes improve the maintainability and readability of the decoding process in the batch manager.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: remove usage of `active` vector from gptDecoderBatchedTest

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Unify the creation of decoder batch inputs in algorithm and tests

- Added a new static method `createDecoderBatchInputs` to streamline the creation of decoder batch inputs, enhancing clarity and maintainability.
- Updated the implementation to utilize active slots directly, simplifying the logic for managing batch slots and logits.
- Refactored the `operator()` method to leverage the new input creation function, ensuring compatibility with existing decoding processes.
- Enhanced tests to validate the new input handling approach, ensuring correct functionality across various scenarios.

These changes improve the overall structure and readability of the decoding process in the batch manager.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: remove usage of active vector from createDecoderBatchInputs

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Update maxDecoderSteps calculation

- Replaced integer division with `common::ceilDiv` for calculating `maxDecoderSteps` and `numDecoderSteps`, ensuring correct handling of token counts.

These changes enhance the robustness of the decoding batch input creation process.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-04-13 05:05:13 +08:00

356 lines
14 KiB
C++

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
#include "common.h"
#include "decoderState.h"
#include "iBuffer.h"
#include "tensorrt_llm/batch_manager/createNewDecoderRequests.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <limits>
#include <memory>
#include <numeric>
#include <vector>
using namespace tensorrt_llm::runtime;
GptDecoderBatched::GptDecoderBatched(GptDecoderBatched::CudaStreamPtr stream,
SpeculativeDecodingMode const& speculativeDecodingMode, nvinfer1::DataType dtype)
: mRuntimeStream{std::move(stream)}
, mBufferManager{mRuntimeStream}
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
mDecoderState = std::make_shared<decoder::DecoderState>(dtype, mBufferManager);
if (!speculativeDecodingMode.isNone())
{
mDecoderState->allocateSpeculativeDecodingBuffers(speculativeDecodingMode, dtype, mBufferManager);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatched::disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
mDecoderState->disableLookahead(genRequests);
std::vector<SamplingConfig> samplingConfigs;
samplingConfigs.reserve(genRequests.size());
auto batchSlotsRange = BufferRange<SizeType32>(*batchSlots);
SizeType32 batchIdx = 0;
for (auto const& llmReq : genRequests)
{
samplingConfigs.push_back(llmReq->mSamplingConfig);
batchSlotsRange[batchIdx] = llmReq->mSeqSlot.value();
batchIdx += 1;
}
auto const batchSize = batchIdx;
std::optional<SamplingConfig> samplingConfig;
if (batchSize > 0)
{
samplingConfig = SamplingConfig(samplingConfigs);
}
TensorPtr batchSlotsView = ITensor::slice(batchSlots, 0, batchSize);
mDecoder->disableLookahead(samplingConfig, batchSize, batchSlots);
CudaEvent event{};
mDecoderStream->record(event);
mRuntimeStream->wait(event);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength,
SizeType32 maxTokensPerEngineStep, nvinfer1::DataType dtype, ModelConfig const& modelConfig,
WorldConfig const& worldConfig)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(maxBatchSize > 0);
TLLM_CHECK(maxBeamWidth > 0);
TLLM_CHECK(maxTokensPerEngineStep > 0);
TLLM_CHECK(maxSequenceLength > 0);
mDecoderState->setup(maxBatchSize, maxBeamWidth, maxAttentionWindow, sinkTokenLength, maxSequenceLength,
modelConfig, worldConfig, mBufferManager);
mDecoderState->setupSpeculativeDecoding(
mDecoderState->getSpeculativeDecodingMode(), maxTokensPerEngineStep, modelConfig, worldConfig, mBufferManager);
std::shared_ptr<SpeculativeDecodingModule const> speculativeDecodingModulePtr = nullptr;
if (mDecoderState->getSpeculativeDecodingMode().predictsDraftTokens())
{
speculativeDecodingModulePtr = modelConfig.getSpeculativeDecodingModulePtr();
}
auto const device = mRuntimeStream->getDevice();
mDecoderStream = std::make_shared<CudaStream>();
TLLM_CHECK(mDecoderStream->getDevice() == device);
auto const vocabSize = modelConfig.getVocabSize();
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
mDecoder = IGptDecoder::create(mode, dtype, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
maxSequenceLength, mDecoderStream, speculativeDecodingModulePtr);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatched::setExplicitDraftTokensInputs(decoder_batch::Input const& input)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto explicitDraftTokensInputs = DecodingInput::ExplicitDraftTokensInputs();
TLLM_CHECK(input.explicitDraftTokensInputs.has_value());
TLLM_CHECK(input.explicitDraftTokensLastInputs.has_value());
explicitDraftTokensInputs.nextDraftTokens = input.explicitDraftTokensInputs->nextDraftTokens;
explicitDraftTokensInputs.nextFlatTokens = input.explicitDraftTokensInputs->nextFlatTokens;
explicitDraftTokensInputs.nextDraftIndices = input.explicitDraftTokensInputs->nextDraftIndices;
explicitDraftTokensInputs.nextDraftProbs = input.explicitDraftTokensInputs->nextDraftProbs;
explicitDraftTokensInputs.lastDraftTokens = input.explicitDraftTokensLastInputs->draftTokens;
explicitDraftTokensInputs.lastDraftIndices = input.explicitDraftTokensLastInputs->draftIndices;
explicitDraftTokensInputs.lastPositionIdsBase = input.explicitDraftTokensLastInputs->positionIdsBase;
explicitDraftTokensInputs.masks = input.explicitDraftTokensInputs->masks;
explicitDraftTokensInputs.packedPositionIds = input.explicitDraftTokensInputs->packedPositionIds;
explicitDraftTokensInputs.bestPathLengths = input.explicitDraftTokensInputs->bestPathLengths;
explicitDraftTokensInputs.bestPathIndices = input.explicitDraftTokensInputs->bestPathIndices;
explicitDraftTokensInputs.nextGenerationLengths = input.explicitDraftTokensInputs->nextGenerationLengths;
explicitDraftTokensInputs.lastGenerationLengths = input.explicitDraftTokensLastInputs->generationLengths;
explicitDraftTokensInputs.maxGenLengthDevice = input.explicitDraftTokensInputs->maxGenToken;
explicitDraftTokensInputs.seqSlots = input.batchSlotsRequestOrder;
mDecoderState->getJointDecodingInput().explicitDraftTokensInputs = explicitDraftTokensInputs;
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatched::setEagleInputs(decoder_batch::Input const& input)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(input.eagleInputs.has_value());
TLLM_CHECK(input.eagleLastInputs.has_value());
auto eagleInputs = DecodingInput::EagleInputs(input.eagleInputs->nextDraftTokens, input.eagleInputs->nextDraftLens,
input.eagleInputs->nextDraftPaths, input.eagleLastInputs->draftTokens, input.eagleLastInputs->draftLens,
input.eagleLastInputs->draftPaths, input.eagleInputs->acceptedTokens, input.eagleInputs->acceptedLens,
input.eagleInputs->acceptedPaths, input.eagleInputs->chunkedContextNextTokens, input.batchSlotsRequestOrder);
mDecoderState->getJointDecodingInput().eagleInputs = eagleInputs;
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatched::forwardDispatch(decoder_batch::Output& output, decoder_batch::Input const& input)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
for (SizeType32 step = 0; step < input.maxDecoderSteps; ++step)
{
prepareForward(step, output, input);
if (mDecoderState->getJointDecodingInput().batchSize > 0)
{
mDecoder->forwardAsync(mDecoderState->getJointDecodingOutput(), mDecoderState->getJointDecodingInput());
}
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
CudaEvent GptDecoderBatched::forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto eventStart = CudaEvent{};
mRuntimeStream->record(eventStart);
mDecoderStream->wait(eventStart.get());
forwardDispatch(output, input);
CudaEvent event{};
mDecoderStream->record(event);
mRuntimeStream->wait(event);
CudaEvent eventStop{};
mRuntimeStream->record(eventStop);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return eventStop;
}
// TODO: produce new input and output
void GptDecoderBatched::prepareForward(
SizeType32 step, decoder_batch::Output& output, decoder_batch::Input const& input)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const& jointOutputIdsShape = mDecoderState->getJointDecodingOutput().ids->getShape();
auto const maxBeamWidth = jointOutputIdsShape.d[1];
auto const speculativeDecodingMode = mDecoderState->getSpeculativeDecodingMode();
auto& dInput = mDecoderState->getJointDecodingInput();
auto& dOutput = mDecoderState->getJointDecodingOutput();
if (maxBeamWidth > 1)
{
dInput.cacheIndirection = input.cacheIndirection;
dOutput.cacheIndirection = output.cacheIndirection;
}
if (speculativeDecodingMode.isExplicitDraftTokens())
{
setExplicitDraftTokensInputs(input);
}
else if (speculativeDecodingMode.isEagle())
{
setEagleInputs(input);
}
dInput.batchSlots = input.batchSlots.at(step);
dInput.batchSize = static_cast<SizeType32>(dInput.batchSlots->getSize());
dInput.logitsVec = input.logits.at(step);
TensorPtr finishedStepsInput = ITensor::slice(mDecoderState->getFinishedSteps(), step, 1);
TensorPtr finishedStepsOutput
= ITensor::slice(mDecoderState->getFinishedSteps(), std::min(input.maxDecoderSteps - 1, step + 1), 1);
finishedStepsInput->squeeze(0);
finishedStepsOutput->squeeze(0);
TensorPtr newTokensStepView
= ITensor::slice(dOutput.newTokensSteps, step, mDecoderState->getMaxDecodingDecoderTokens());
dInput.finishReasons = finishedStepsInput;
if (speculativeDecodingMode.isMedusa())
{
dInput.medusaInputs->medusaLogits = input.predictedDraftLogits;
}
if (speculativeDecodingMode.isDraftTokensExternal())
{
dInput.externalDraftTokensInputs->step = step;
// WAR: reset finished state for generation requests
if (step == 0)
{
BufferManager manager{mDecoderStream};
auto batchSlotsRange = BufferRange<SizeType32 const>(*dInput.batchSlots);
for (auto batchSlot : batchSlotsRange)
{
TensorPtr finishedStepsView = ITensor::slice(mDecoderState->getFinishedSteps(), 0, 1);
finishedStepsView->squeeze(0);
TensorPtr finishedSteps = ITensor::slice(finishedStepsView, batchSlot, 1);
manager.setZero(*finishedStepsView);
}
}
}
dOutput.newTokens = newTokensStepView;
dOutput.finishReasons = finishedStepsOutput;
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatched::forward(decoder_batch::Output& output, decoder_batch::Input const& input)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto decoderFinishEvent = forwardAsync(output, input);
decoderFinishEvent.synchronize();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
namespace
{
std::pair<DecodingInput, DecodingOutput> prepareGatherTree(
decoder::DecoderState const& decoderState, SizeType32 batchSlot, bool streaming, CudaStream const& stream)
{
auto& dJointInput = decoderState.getJointDecodingInput();
auto& dJointOutput = decoderState.getJointDecodingOutput();
auto slice = [batchSlot](auto& a, auto const& b)
{
if (b && b->getShape().d[0] > 0)
{
a = ITensor::slice(b, batchSlot, 1);
}
};
// Prepare a slice of dJointInput and dJointOutput for gatherTree
DecodingInput dInput{dJointInput};
slice(dInput.endIds, dJointInput.endIds);
slice(dInput.lengths, dJointInput.lengths);
DecodingOutput dOutput{
ITensor::slice(dJointOutput.ids, batchSlot, 1), ITensor::slice(dJointOutput.gatheredIds, batchSlot, 1)};
dOutput.beamHypotheses = dJointOutput.beamHypotheses.slice(batchSlot, 1);
slice(dOutput.parentIds, dJointOutput.parentIds);
slice(dOutput.cumLogProbs, dJointOutput.cumLogProbs);
slice(dOutput.cacheIndirection, dJointOutput.cacheIndirection);
slice(dOutput.lengths, dJointOutput.lengths);
slice(dOutput.finishReasons, dJointOutput.finishReasons);
slice(dOutput.logProbs, dJointOutput.logProbs);
dOutput.newTokens = ITensor::view(dJointOutput.newTokens);
TLLM_CHECK(dOutput.newTokens->getShape().d[0] == 1);
dOutput.newTokens->squeeze(0);
dOutput.newTokens = ITensor::slice(dOutput.newTokens, batchSlot, 1);
dOutput.logProbsTiled = dJointOutput.logProbsTiled;
if (streaming)
{
// in case of streaming we shouldn't overwrite the data in beamHypotheses, since the beam search kernels expect
// ungathered data but the kernels in gatherTree write in-place.
// Thus, we need to make a copy of the beamHypotheses
auto const& beamSearchBuffers = decoderState.getBeamSearchBuffers();
tensorrt_llm::kernels::invokeCopyBeamHypotheses(dOutput.beamHypotheses, beamSearchBuffers.mOutputBeamHypotheses,
*dOutput.cumLogProbs, *beamSearchBuffers.mCumLogProbsTmp, stream, beamSearchBuffers.mNumSMs);
dOutput.beamHypotheses = beamSearchBuffers.mOutputBeamHypotheses;
dOutput.cumLogProbs = beamSearchBuffers.mCumLogProbsTmp;
}
return {(std::move(dInput)), (std::move(dOutput))};
}
} // namespace
// TODO call this at the end of forward if mFinished[i] changes from false to true?
CudaEvent GptDecoderBatched::finalize(decoder::DecoderState const& decoderState, SizeType32 batchSlot,
SamplingConfig const& samplingConfig, bool streaming) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto [dInput, dOutput] = prepareGatherTree(decoderState, batchSlot, streaming, *mRuntimeStream);
kernels::gatherTree(dOutput, dInput, samplingConfig, *mRuntimeStream);
CudaEvent event{};
mRuntimeStream->record(event);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return event;
}