TensorRT-LLMs/cpp/tensorrt_llm/batch_manager/decoderBuffers.cpp
Robin Kobus f58d4698c8
chore: Clean up cpp runtime (#3505)
* chore: Remove unused tensors from DecoderBuffers

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fix: Remove unused argument from readme

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* chore: remove unused tensor

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Remove unnecessary newOutputTokens

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Remove unnecessary event in getDecoderSlotHostOutputs

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-04-14 18:00:03 +08:00

365 lines
16 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm::batch_manager
{
DecoderInputBuffers::DecoderInputBuffers(
SizeType32 maxBatchSize, SizeType32 maxDecoderSteps, BufferManager const& manager)
{
auto const maxBatchSizeShape = ITensor::makeShape({maxBatchSize});
auto const nvSizeType = TRTDataType<SizeType32>::value;
setupBatchSlots = BufferManager::pinnedPool(maxBatchSizeShape, nvSizeType);
inputsIds = BufferManager::pinnedPool(ITensor::makeShape({0}), TRTDataType<TokenIdType>::value);
forwardBatchSlotsRequestOrder = tensorrt_llm::runtime::BufferManager::pinnedPool(maxBatchSizeShape, nvSizeType);
forwardBatchSlotsRequestOrderDevice = manager.gpu(maxBatchSizeShape, nvSizeType);
fillValues = tensorrt_llm::runtime::BufferManager::pinnedPool(maxBatchSizeShape, nvSizeType);
fillValuesDevice = manager.gpu(maxBatchSizeShape, nvSizeType);
forwardBatchSlots.reserve(maxDecoderSteps);
for (SizeType32 i = 0; i < maxDecoderSteps; ++i)
{
forwardBatchSlots.emplace_back(BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize}), nvSizeType));
}
}
DecoderBuffers::DecoderBuffers(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
SizeType32 maxSeqLen, SizeType32 maxTokensPerStep, BufferManager const& manager, ModelConfig const& modelConfig,
WorldConfig const& worldConfig)
{
if (worldConfig.isLastPipelineParallelRank())
{
logits.resize(maxNumSequences);
}
auto constexpr TRTTokenIdType = runtime::TRTDataType<runtime::TokenIdType>::value;
cacheIndirectionInput = manager.gpu(
ITensor::makeShape({maxNumSequences, maxBeamWidth, maxAttentionWindow}), nvinfer1::DataType::kINT32);
cacheIndirectionOutput = manager.gpu(
ITensor::makeShape({maxNumSequences, maxBeamWidth, maxAttentionWindow}), nvinfer1::DataType::kINT32);
sequenceLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxNumSequences, maxBeamWidth}), nvinfer1::DataType::kINT32);
finishedSumHost = BufferManager::pinned(ITensor::makeShape({maxNumSequences}), nvinfer1::DataType::kINT32);
newOutputTokensHost
= BufferManager::pinned(ITensor::makeShape({maxTokensPerStep, maxNumSequences, maxBeamWidth}), TRTTokenIdType);
cumLogProbsHost
= BufferManager::pinned(ITensor::makeShape({maxNumSequences, maxBeamWidth}), nvinfer1::DataType::kFLOAT);
logProbsHost = BufferManager::pinned(
ITensor::makeShape({maxNumSequences, maxBeamWidth, maxSeqLen}), nvinfer1::DataType::kFLOAT);
finishReasonsHost
= BufferManager::pinned(ITensor::makeShape({maxNumSequences, maxBeamWidth}), nvinfer1::DataType::kUINT8);
if (modelConfig.getSpeculativeDecodingMode().needsKVCacheRewind()
|| modelConfig.getSpeculativeDecodingMode().hasDraftLogits()
|| modelConfig.getSpeculativeDecodingMode().predictsDraftTokens())
{
draftBuffers.create(maxNumSequences, maxTokensPerStep, manager, modelConfig);
}
if (modelConfig.getSpeculativeDecodingMode().isExplicitDraftTokens())
{
explicitDraftTokensBuffers.create(maxNumSequences, manager, modelConfig, worldConfig);
}
else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
{
lookaheadBuffers.emplace(maxNumSequences, maxTokensPerStep, manager);
}
else if (modelConfig.getSpeculativeDecodingMode().isEagle())
{
eagleBuffers.create(maxNumSequences, manager, modelConfig, worldConfig);
}
}
void DecoderBuffers::DraftBuffers::create(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep,
BufferManager const& manager, ModelConfig const& modelConfig)
{
auto const speculativeDecodingMode = modelConfig.getSpeculativeDecodingMode();
auto constexpr TRTTokenIdType = runtime::TRTDataType<runtime::TokenIdType>::value;
if (speculativeDecodingMode.predictsDraftTokens())
{
nextDraftTokensHost
= BufferManager::pinned(ITensor::makeShape({maxNumSequences, maxTokensPerStep - 1}), TRTTokenIdType);
if (speculativeDecodingMode.variableDraftLength())
{
nextDraftTokensLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxNumSequences}), nvinfer1::DataType::kINT32);
prevDraftTokensLengthsHost
= BufferManager::pinned(ITensor::makeShape({maxNumSequences}), nvinfer1::DataType::kINT32);
}
}
if (speculativeDecodingMode.isMedusa())
{
auto const maxDraftPathLen = modelConfig.getSpeculativeDecodingModule().getMaxDraftPathLen();
predictedDraftLogits.resize(maxNumSequences);
for (auto& medusaLogitsHead : predictedDraftLogits)
{
medusaLogitsHead.resize(maxDraftPathLen);
}
}
if (speculativeDecodingMode.needsKVCacheRewind())
{
auto const maxDraftPathLen = modelConfig.getSpeculativeDecodingModule().getMaxDraftPathLen();
acceptedLengthsCumSumDevice
= manager.gpu(ITensor::makeShape({maxNumSequences + 1}), nvinfer1::DataType::kINT32);
acceptedPackedPathsDevice
= manager.gpu(ITensor::makeShape({maxNumSequences, maxDraftPathLen}), nvinfer1::DataType::kINT32);
}
}
void DecoderBuffers::enableLookaheadDecoding(SizeType32 maxNumSequences, SizeType32 maxTokensPerStep)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
newOutputTokensHost->reshape(ITensor::makeShape({maxTokensPerStep, maxNumSequences, 1}));
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void DecoderBuffers::disableLookaheadDecoding(SizeType32 maxNumSequences)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
newOutputTokensHost->reshape(ITensor::makeShape({1, maxNumSequences, 1}));
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
DecoderStepAsyncSend::DecoderStepAsyncSend(std::shared_ptr<mpi::MpiComm> const& commSession,
BufferPtr const& newOutputTokensHost, BufferPtr const& finished, BufferPtr const& sequenceLengthsHost,
BufferPtr const& cumLogProbsHost, BufferPtr const& logProbsHost, BufferPtr const& cacheIndirectionOutput,
BufferPtr const& acceptedCumSum, BufferPtr const& packedPaths, BufferPtr const& finishReasonsHost, int const peer)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_DEBUG("start send outputs of DecoderBuffers to rank %d", peer);
mRequest1 = commSession->sendAsync(*newOutputTokensHost, peer, kMpiTagOffset);
mRequest2 = commSession->sendAsync(*finished, peer, kMpiTagOffset + 1);
mRequest3 = commSession->sendAsync(*sequenceLengthsHost, peer, kMpiTagOffset + 2);
mRequest4 = cumLogProbsHost ? commSession->sendAsync(*cumLogProbsHost, peer, kMpiTagOffset + 3) : nullptr;
mRequest5 = logProbsHost ? commSession->sendAsync(*logProbsHost, peer, kMpiTagOffset + 4) : nullptr;
mRequest6
= cacheIndirectionOutput ? commSession->sendAsync(*cacheIndirectionOutput, peer, kMpiTagOffset + 5) : nullptr;
mRequest7 = acceptedCumSum ? commSession->sendAsync(*acceptedCumSum, peer, kMpiTagOffset + 6) : nullptr;
mRequest8 = packedPaths ? commSession->sendAsync(*packedPaths, peer, kMpiTagOffset + 7) : nullptr;
mRequest9 = commSession->sendAsync(*finishReasonsHost, peer, kMpiTagOffset + 8);
static_assert(kMpiTagUpperBound >= kMpiTagOffset + 9);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
DecoderStepAsyncSend::~DecoderStepAsyncSend()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
mRequest1->wait();
mRequest2->wait();
mRequest3->wait();
if (mRequest4)
mRequest4->wait();
if (mRequest5)
mRequest5->wait();
if (mRequest6)
mRequest6->wait();
if (mRequest7)
mRequest7->wait();
if (mRequest8)
mRequest8->wait();
mRequest9->wait();
TLLM_LOG_DEBUG("end send outputs of DecoderBuffers");
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
DecoderSlotAsyncSend::DecoderSlotAsyncSend(std::shared_ptr<mpi::MpiComm> const& commSession,
TensorPtr const& outputIdsView, TensorPtr const& sequenceLengthView, TensorPtr const& cumLogProbsView,
TensorPtr const& logProbsView, bool const returnLogProbs, int const peer)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_DEBUG("start send outputs of SlotDecoderBuffers to rank %d", peer);
mRequest1 = commSession->sendAsync(*outputIdsView, peer, kMpiTagOffset);
mRequest2 = commSession->sendAsync(*sequenceLengthView, peer, kMpiTagOffset + 1);
mRequest3 = returnLogProbs ? commSession->sendAsync(*cumLogProbsView, peer, kMpiTagOffset + 2) : nullptr;
mRequest4 = returnLogProbs ? commSession->sendAsync(*logProbsView, peer, kMpiTagOffset + 3) : nullptr;
static_assert(kMpiTagUpperBound >= kMpiTagOffset + 4);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
DecoderSlotAsyncSend::~DecoderSlotAsyncSend()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
mRequest1->wait();
mRequest2->wait();
if (mRequest3)
mRequest3->wait();
if (mRequest4)
mRequest4->wait();
TLLM_LOG_DEBUG("end send outputs of SlotDecoderBuffers");
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
std::unique_ptr<DecoderStepAsyncSend> DecoderBuffers::asyncSend(std::shared_ptr<mpi::MpiComm> const& commSession,
bool const returnLogProbs, SizeType32 const maxBeamWidth, bool const useMedusa, int const peer)
{
auto decStepAsyncSndHdl = std::make_unique<DecoderStepAsyncSend>(commSession, newOutputTokensHost, finishedSumHost,
sequenceLengthsHost, returnLogProbs ? cumLogProbsHost : nullptr, returnLogProbs ? logProbsHost : nullptr,
maxBeamWidth > 1 ? cacheIndirectionOutput : nullptr,
useMedusa ? draftBuffers.acceptedLengthsCumSumDevice : nullptr,
useMedusa ? draftBuffers.acceptedPackedPathsDevice : nullptr, finishReasonsHost, peer);
return decStepAsyncSndHdl;
}
void DecoderBuffers::recv(std::shared_ptr<mpi::MpiComm> const& commSession, bool const returnLogProbs,
SizeType32 const maxBeamWidth, bool const useMedusa, int const peer)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_DEBUG("start recv outputs of DecoderBuffers from rank %d", peer);
commSession->recv(*newOutputTokensHost, peer, DecoderStepAsyncSend::kMpiTagOffset);
commSession->recv(*finishedSumHost, peer, DecoderStepAsyncSend::kMpiTagOffset + 1);
commSession->recv(*sequenceLengthsHost, peer, DecoderStepAsyncSend::kMpiTagOffset + 2);
if (returnLogProbs)
{
commSession->recv(*cumLogProbsHost, peer, DecoderStepAsyncSend::kMpiTagOffset + 3);
commSession->recv(*logProbsHost, peer, DecoderStepAsyncSend::kMpiTagOffset + 4);
}
if (maxBeamWidth > 1)
{
commSession->recv(*cacheIndirectionOutput, peer, DecoderStepAsyncSend::kMpiTagOffset + 5);
}
if (useMedusa)
{
commSession->recv(*draftBuffers.acceptedLengthsCumSumDevice, peer, DecoderStepAsyncSend::kMpiTagOffset + 6);
commSession->recv(*draftBuffers.acceptedPackedPathsDevice, peer, DecoderStepAsyncSend::kMpiTagOffset + 7);
}
commSession->recv(*finishReasonsHost, peer, DecoderStepAsyncSend::kMpiTagOffset + 8);
TLLM_LOG_DEBUG("end recv outputs of DecoderBuffers from rank %d", peer);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void DecoderBuffers::bcast(std::shared_ptr<mpi::MpiComm> const& commSession, bool const returnLogProbs,
SizeType32 const maxBeamWidth, bool const useMedusa, int const root)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_DEBUG("start bcast outputs of DecoderBuffers from rank %d", root);
auto request1 = commSession->bcastAsync(*newOutputTokensHost, root);
auto request2 = commSession->bcastAsync(*finishedSumHost, root);
auto request3 = commSession->bcastAsync(*sequenceLengthsHost, root);
auto request4 = returnLogProbs ? commSession->bcastAsync(*cumLogProbsHost, root) : nullptr;
auto request5 = returnLogProbs ? commSession->bcastAsync(*logProbsHost, root) : nullptr;
auto request6 = maxBeamWidth > 1 ? commSession->bcastAsync(*cacheIndirectionOutput, root) : nullptr;
auto request7 = useMedusa ? commSession->bcastAsync(*draftBuffers.acceptedLengthsCumSumDevice, root) : nullptr;
auto request8 = useMedusa ? commSession->bcastAsync(*draftBuffers.acceptedPackedPathsDevice, root) : nullptr;
auto request9 = commSession->bcastAsync(*finishReasonsHost, root);
request1->wait();
request2->wait();
request3->wait();
if (request4)
request4->wait();
if (request5)
request5->wait();
if (request6)
request6->wait();
if (request7)
request7->wait();
if (request8)
request8->wait();
request9->wait();
TLLM_LOG_DEBUG("end bcast outputs of DecoderBuffers from rank %d", root);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
std::unique_ptr<DecoderSlotAsyncSend> SlotDecoderBuffers::asyncSend(std::shared_ptr<mpi::MpiComm> const& commSession,
TensorPtr const& outputIdsView, TensorPtr const& sequenceLengthView, TensorPtr const& cumLogProbsView,
TensorPtr const& logProbsView, bool const returnLogProbs, int const peer)
{
auto decSlotAsyncSndHdl = std::make_unique<DecoderSlotAsyncSend>(
commSession, outputIdsView, sequenceLengthView, cumLogProbsView, logProbsView, returnLogProbs, peer);
return decSlotAsyncSndHdl;
}
std::unique_ptr<DecoderSlotAsyncSend> SlotDecoderBuffers::asyncSend(
std::shared_ptr<mpi::MpiComm> const& commSession, bool const returnLogProbs, int const peer) const
{
return asyncSend(commSession, outputIds, sequenceLengths, cumLogProbs, logProbs, returnLogProbs, peer);
}
void SlotDecoderBuffers::recv(
std::shared_ptr<mpi::MpiComm> const& commSession, bool const returnLogProbs, int const peer)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_DEBUG("start recv outputs of SlotDecoderBuffers from rank %d", peer);
commSession->recv(*outputIds, peer, DecoderSlotAsyncSend::kMpiTagOffset);
commSession->recv(*sequenceLengths, peer, DecoderSlotAsyncSend::kMpiTagOffset + 1);
if (returnLogProbs)
{
commSession->recv(*cumLogProbs, peer, DecoderSlotAsyncSend::kMpiTagOffset + 2);
commSession->recv(*logProbs, peer, DecoderSlotAsyncSend::kMpiTagOffset + 3);
}
TLLM_LOG_DEBUG("end recv outputs of SlotDecoderBuffers from rank %d", peer);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
SlotDecoderBuffers::SlotDecoderBuffers(SizeType32 maxBeamWidth, SizeType32 maxSeqLen, BufferManager const& manager)
{
outputIds = manager.gpu(ITensor::makeShape({maxBeamWidth, maxSeqLen}), nvinfer1::DataType::kINT32);
outputIdsHost = BufferManager::pinned(ITensor::makeShape({maxBeamWidth, maxSeqLen}), nvinfer1::DataType::kINT32);
sequenceLengths = manager.gpu(ITensor::makeShape({maxBeamWidth}), nvinfer1::DataType::kINT32);
sequenceLengthsHost = BufferManager::pinned(ITensor::makeShape({maxBeamWidth}), nvinfer1::DataType::kINT32);
cumLogProbs = manager.gpu(ITensor::makeShape({maxBeamWidth}), nvinfer1::DataType::kFLOAT);
cumLogProbsHost = BufferManager::pinned(ITensor::makeShape({maxBeamWidth}), nvinfer1::DataType::kFLOAT);
logProbs = manager.gpu(ITensor::makeShape({maxBeamWidth, maxSeqLen}), nvinfer1::DataType::kFLOAT);
logProbsHost = BufferManager::pinned(ITensor::makeShape({maxBeamWidth, maxSeqLen}), nvinfer1::DataType::kFLOAT);
}
} // namespace tensorrt_llm::batch_manager