TensorRT-LLMs/cpp/tensorrt_llm/batch_manager/generateRequestOptions.cpp
Robin Kobus 4e370a509a
refactor: Copy sequence lengths once in decoder setup (#4102)
* refactor: Copy sequence lengths once in decoder setup

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Update DecoderInputBuffers to remove duplicated buffers

- Renamed and reorganized buffer variables in decoderBuffers.h and decoderBuffers.cpp for better readability.
- Adjusted references in generateRequestOptions.cpp to align with the new buffer structure.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Move getEmbeddingBias to anonymous namespace

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Filter context requests

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: GenerateRequestOptions using more fine-grained functions

- Added a new method `createDecoderRequests` to encapsulate the logic for creating decoder requests from finished context requests.
- Updated the `operator()` method to utilize the new method, improving code clarity and maintainability.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Update TRTLLMDecoder

- Updated the `generate_request_options` call.
- Updated the `make_decoding_batch_input_output` call.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Remove const where we modify input buffers

- Changed `DecoderInputBuffers` parameters from const references to non-const references in multiple functions to allow modifications.
- Updated related function calls to ensure compatibility with the new parameter types.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fixup! refactor: Copy sequence lengths once in decoder setup

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-05-16 22:03:55 +08:00

297 lines
13 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/generateRequestOptions.h"
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
#include "tensorrt_llm/batch_manager/utils/logitsThread.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/nvtxUtils.h"
#include "tensorrt_llm/runtime/decoderState.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include <NvInferRuntimeBase.h>
using namespace tensorrt_llm::runtime;
namespace te = tensorrt_llm::executor;
namespace tr = tensorrt_llm::runtime;
namespace tensorrt_llm::batch_manager
{
using SizeType32 = GenerateRequestOptions::SizeType32;
using TensorPtr = GenerateRequestOptions::TensorPtr;
namespace
{
void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffers& inputBuffers,
ITensor& sequenceLengths, SizeType32 beamWidth, runtime::BufferManager const& manager,
runtime::CudaStream const& stream)
{
auto const batchSize = contextRequests.size();
auto batchSlotsView = tr::ITensor::slice(inputBuffers.setupBatchSlots, 0, batchSize);
auto fillValuesView = tr::ITensor::slice(inputBuffers.fillValues, 0, batchSize);
auto batchSlotsRange = tr::BufferRange<SizeType32>(*batchSlotsView);
auto fillValuesRange = tr::BufferRange<SizeType32>(*fillValuesView);
// fill buffers on host
SizeType32 batchIdx{0};
for (auto const& llmReq : contextRequests)
{
auto const currentSequenceLen = llmReq->mPromptLen + llmReq->getMaxNumGeneratedTokens();
// Get position of the current sequence in the decoder
auto const seqSlot = llmReq->mSeqSlot.value();
batchSlotsRange[batchIdx] = seqSlot;
fillValuesRange[batchIdx] = currentSequenceLen;
++batchIdx;
}
// copy sequence lengths
{
auto batchSlotsDeviceView = tr::ITensor::slice(inputBuffers.setupBatchSlotsDevice, 0, batchSize);
auto fillValuesViewDevice = tr::ITensor::slice(inputBuffers.fillValuesDevice, 0, batchSize);
manager.copy(*batchSlotsView, *batchSlotsDeviceView);
manager.copy(*fillValuesView, *fillValuesViewDevice);
tr::kernels::invokeFillBatch(sequenceLengths, *batchSlotsDeviceView, beamWidth, *fillValuesViewDevice, stream);
}
}
/// @brief Retrieve the embedding bias from the request. This potentially makes a copy of the tensor
/// to the appropriate type if the input tensor does not match it.
[[nodiscard]] TensorPtr getEmbeddingBias(nvinfer1::DataType logitsType, TensorPtr const& tensor)
{
// Check that embedding bias type is same as logits type. If so, we can return the tensor right away
if (tensor->getDataType() == logitsType)
{
return tensor;
}
// Support FP32 input for FP16 embedding bias (in the case of FP8 models)
if (tensor->getDataType() == nvinfer1::DataType::kFLOAT && logitsType == nvinfer1::DataType::kHALF)
{
// Do a deep copy of the tensor to the expected type
TLLM_LOG_WARNING(
"Embedding bias data type must be same as model logits type, will copy the tensor from float to half");
TLLM_CHECK_WITH_INFO(
tensor->getMemoryType() != MemoryType::kGPU, "Embedding bias tensor needs to be in CPU memory for casting");
auto const shape = tensor->getShape();
TLLM_CHECK(shape.nbDims == 2); // [1, vocabSizePadded]
TLLM_CHECK(shape.d[0] == 1);
auto newTensor = tensorrt_llm::runtime::BufferManager::pinnedPool(shape, logitsType);
auto const tensorRange = BufferRange<float>(*tensor);
auto newTensorRange = BufferRange<half>(*newTensor);
std::transform(tensorRange.begin(), tensorRange.end(), newTensorRange.begin(),
[](float value) -> half { return static_cast<half>(value); });
return newTensor;
}
TLLM_THROW("Embedding bias data type must be same as model logits type.");
}
} // namespace
std::tuple<TensorPtr, std::vector<decoder_batch::Request>, std::vector<SamplingConfig>>
GenerateRequestOptions::operator()(tr::ModelConfig const& modelConfig, tr::WorldConfig const& worldConfig,
te::DecodingConfig const& decodingConfig, RequestVector const& contextRequests, BufferManager const& bufferManager,
nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
SizeType32 beamWidth, runtime::CudaStream const& stream, OptionalRef<RuntimeBuffers const> buffers) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(GenerateRequestOptions);
RequestVector finishedContextRequests;
std::copy_if(contextRequests.begin(), contextRequests.end(), std::back_inserter(finishedContextRequests),
[](auto const& llmReq) { return llmReq->isLastContextChunk(); });
copySequenceLengths(
finishedContextRequests, inputBuffers, *decoderState.getSequenceLengths(), beamWidth, bufferManager, stream);
auto decoderRequests = createDecoderRequests(finishedContextRequests, inputBuffers.inputsIds, decodingConfig,
bufferManager, logitsType, modelConfig, worldConfig, buffers);
auto const batchSize = finishedContextRequests.size();
std::vector<SamplingConfig> samplingConfigs;
samplingConfigs.reserve(batchSize);
for (auto const& llmReq : finishedContextRequests)
{
samplingConfigs.push_back(llmReq->mSamplingConfig);
}
TensorPtr batchSlotsView = runtime::ITensor::slice(inputBuffers.setupBatchSlots, 0, batchSize);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return {std::move(batchSlotsView), std::move(decoderRequests), std::move(samplingConfigs)};
}
[[nodiscard]] std::vector<runtime::decoder_batch::Request> GenerateRequestOptions::createDecoderRequests(
RequestVector const& finishedContextRequests, TensorPtr const& inputIds,
executor::DecodingConfig const& decodingConfig, BufferManager const& bufferManager, nvinfer1::DataType logitsType,
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
OptionalRef<RuntimeBuffers const> buffers) const
{
unsigned decoderInputSize{0};
for (auto const& llmReq : finishedContextRequests)
{
auto const& reqTokens = llmReq->getTokens(0);
decoderInputSize += reqTokens.size();
}
inputIds->resize(decoderInputSize);
std::vector<decoder_batch::Request> decoderRequests;
decoderRequests.reserve(finishedContextRequests.size());
SizeType32 inputOffset{0};
for (auto const& llmReq : finishedContextRequests)
{
auto const promptLen = llmReq->getPromptLen();
auto const& reqTokens = llmReq->getTokens(0);
TLLM_CHECK(reqTokens.size() == static_cast<decltype(reqTokens.size())>(promptLen));
TensorPtr inputView = ITensor::slice(inputIds, inputOffset, promptLen);
bufferManager.copy(reqTokens.data(), *inputView);
auto decoderRequest = decoder_batch::Request{inputView, promptLen, llmReq->mMaxNewTokens, llmReq->mEndId};
llmReq->mSamplingConfig.normalizeLogProbs = mIsNormalizeLogProbs;
if (modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal())
{
if (llmReq->hasDraftTokens())
{
auto const& draftTokens = llmReq->getDraftTokens();
decoderRequest.draftTokens = bufferManager.copyFrom(*draftTokens, MemoryType::kPINNEDPOOL);
auto const& draftLogits = llmReq->getDraftLogits();
if (draftLogits.has_value())
{
decoderRequest.draftLogits
= retrieveDraftLogits(modelConfig, worldConfig, draftLogits.value(), bufferManager);
}
decoderRequest.generatedTokensPerEngineStep = draftTokens->size() + 1;
}
else
{
decoderRequest.generatedTokensPerEngineStep = 1;
}
}
else if (!modelConfig.getSpeculativeDecodingMode().isNone())
{
decoderRequest.generatedTokensPerEngineStep = modelConfig.getMaxDecodingTokens();
}
if (modelConfig.getSpeculativeDecodingMode().isMedusa())
{
TLLM_CHECK(buffers);
llmReq->mSamplingConfig.topKMedusaHeads = {buffers->medusaBuffers->mTopKs};
// FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest?
// When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot.
decoderRequest.medusaPaths = ITensor::slice(buffers->medusaBuffers->medusaPathsDevice, 0, 1);
decoderRequest.medusaTreeIds = ITensor::slice(buffers->medusaBuffers->medusaTreeIdsDevice, 0, 1);
}
else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
{
decoderRequest.lookaheadRuntimeConfig = llmReq->getLookaheadConfig()
? llmReq->getLookaheadConfig()
: decodingConfig.getLookaheadDecodingConfig();
}
else if (modelConfig.getSpeculativeDecodingMode().isExplicitDraftTokens())
{
// Only Explicit draft tokens model needs dtype to WAR the lack of bf16 decoder.
decoderRequest.dtype = modelConfig.getDataType();
}
else if (modelConfig.getSpeculativeDecodingMode().isEagle())
{
decoderRequest.eagleConfig
= llmReq->getEagleConfig() ? llmReq->getEagleConfig() : decodingConfig.getEagleConfig();
}
if (llmReq->getEmbeddingBias().has_value())
{
decoderRequest.embeddingBias = getEmbeddingBias(logitsType, llmReq->getEmbeddingBias().value());
}
if (llmReq->getBadWordsList().has_value())
{
// Move to GPU and remove leading bs1 dimension since this is what decoderRequest expects
decoderRequest.badWordsList = bufferManager.copyFrom(*llmReq->getBadWordsList().value(), MemoryType::kGPU);
decoderRequest.badWordsList->squeeze(0);
}
if (llmReq->getStopWordsList().has_value())
{
decoderRequest.stopWordsList
= bufferManager.copyFrom(*llmReq->getStopWordsList().value(), MemoryType::kGPU);
decoderRequest.stopWordsList->squeeze(0);
}
decoderRequests.push_back(decoderRequest);
inputOffset += promptLen;
}
return decoderRequests;
}
std::shared_ptr<runtime::ITensor> GenerateRequestOptions::retrieveDraftLogits(tr::ModelConfig const& modelConfig,
tr::WorldConfig const& worldConfig, std::shared_ptr<runtime::ITensor> const& tensor,
BufferManager const& bufferManager) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (!mSpeculativeDecodingFastLogits)
{
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return bufferManager.copyFrom(*tensor, MemoryType::kPINNEDPOOL);
}
if (mIsLeaderInOrchMode)
{
te::SpeculativeDecodingFastLogitsInfo fastLogitsInfo;
std::memcpy(&fastLogitsInfo, tensor->data(), sizeof(fastLogitsInfo));
auto logits = utils::targetModelReceiveLogits(fastLogitsInfo, modelConfig).value();
// Broadcast to other ranks if needed
if (worldConfig.isTensorParallel())
{
auto const& commSession = COMM_SESSION;
auto shape = logits->getShape();
commSession.bcastValue(shape.d[0], 0);
commSession.bcastValue(shape.d[1], 0);
commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return logits;
}
// Get logits from leader rank
auto const& commSession = COMM_SESSION;
int64_t dims[2];
commSession.bcastValue(dims[0], 0);
commSession.bcastValue(dims[1], 0);
auto const logitsDtype = modelConfig.getLogitsDtype();
auto logits = tensorrt_llm::runtime::BufferManager::pinnedPool(ITensor::makeShape({dims[0], dims[1]}), logitsDtype);
commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return logits;
};
} // namespace tensorrt_llm::batch_manager