TensorRT-LLMs/cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp
Robin Kobus 1bd84c6d8c
feat: Allow individual gatherContext for each additional output (#3374)
* refactor: Update ExecutorConfig to use AdditionalModelOutput type

- Changed function signatures and member variables across multiple files to replace std::optional<std::vector<std::string>> with std::optional<std::vector<executor::AdditionalModelOutput>> to include gatherContext flag for each additional output.
- Updated related serialization and deserialization methods to accommodate the new type.
- Adjusted tests to reflect the changes in the output handling structure.

This refactor enhances the flexibility and maintainability of the output configuration in the executor and batch manager components.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Remove equality operator from TrtGptModelOptionalParams

- Deleted the operator== implementation from TrtGptModelOptionalParams to simplify the class.
- Updated the pybind11 bindings to remove the exposure of the equality operator to Python.

This change streamlines the class definition and reduces unnecessary complexity in the bindings.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Enhance copyAdditionalOutputs to utilize AdditionalModelOutput

- Updated the copyAdditionalOutputs function to accept a vector of AdditionalModelOutput, allowing for the inclusion of the gatherContext flag.
- Adjusted the logic to handle context and non-context outputs separately, improving the output handling mechanism.
- Modified related unit tests to incorporate the new gatherContext parameter, ensuring comprehensive testing of the updated functionality.

This refactor improves the flexibility and clarity of output management in the batch processing workflow.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Introduce findOutputTensor utility function for output tensor retrieval

- Added a new utility function, findOutputTensor, to encapsulate the logic for finding output tensors and checking their validity.
- Refactored copyAdditionalOutputs to utilize findOutputTensor, reducing code duplication and improving clarity.
- Enhanced error checking for additional context and generation output tensors.

This change streamlines the output tensor retrieval process, enhancing maintainability and readability in the batch processing workflow.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Check final indices of additional output tensors and update tests

- Added checks to verify the final indices of additional output tensors for context and generation outputs.
- Updated unit tests to verify the changes.
  - Add lastTokenIds input tensor to test engines.
  - Logits output depends on gatherContextLogits parameter.
- Removed gatherContextOutputs parameter from the validate method in LlmRequest.
  - Context outputs do not depend on computeContextLogits parameter.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fixup! refactor: Check final indices of additional output tensors and update tests

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fixup! refactor: Update ExecutorConfig to use AdditionalModelOutput type

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fixup! refactor: Remove equality operator from TrtGptModelOptionalParams

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* docs: Update executor.md

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* chore: Clean up includes

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-04-12 17:00:36 +08:00

1026 lines
44 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
#include "tensorrt_llm/batch_manager/encoderBuffers.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/batch_manager/loraBuffers.h"
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
#include "tensorrt_llm/batch_manager/promptTuningBuffers.h"
#include "tensorrt_llm/batch_manager/rnnStateBuffers.h"
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
#include "tensorrt_llm/batch_manager/transformerBuffers.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/nvtxUtils.h"
#include "tensorrt_llm/common/stlUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/tllmRuntime.h"
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
#include <algorithm>
#include <iterator>
#include <memory>
#include <numeric>
#include <vector>
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm::batch_manager
{
RuntimeBuffers::RuntimeBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLen,
TllmRuntime const& runtime, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
executor::DecodingConfig const& decodingConfig, bool gatherGenerationLogits, std::optional<SizeType32> maxNumTokens,
std::optional<std::vector<executor::AdditionalModelOutput>> const& additionalModelOutputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
create(maxBatchSize, maxBeamWidth, maxAttentionWindowVec, maxAttentionWindow, sinkTokenLen, runtime, modelConfig,
worldConfig, decodingConfig, gatherGenerationLogits, additionalModelOutputs);
// pre-allocate
setMaxBufferSizes(maxBatchSize, maxBeamWidth, modelConfig, maxNumTokens);
reshape(runtime, modelConfig, worldConfig, gatherGenerationLogits);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
RuntimeBuffers::~RuntimeBuffers() = default;
void RuntimeBuffers::reshape(TllmRuntime const& runtime, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
bool gatherGenerationLogits)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(runtimeBuffersReshape);
if (worldConfig.isLastPipelineParallelRank())
{
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
if (modelConfig.computeContextLogits() && (numContextRequests > 0))
{
// Only when need to return context logits, and there are new requests will execute context phase,
// logits buffer need to be re-allocated with size of [numContextTokens + numGenSequences, vocabSizePadded]
auto const& engine = runtime.getEngine();
auto const& manager = runtime.getBufferManager();
auto const logitsType = engine.getTensorDataType(kLogitsTensorName);
logits = manager.gpu(ITensor::makeShape({numContextTokens + numGenSequences, vocabSizePadded}), logitsType);
}
else if (gatherGenerationLogits && modelConfig.getSpeculativeDecodingMode().isNone())
{
// If need to return generation logits, re-point the logit buffer to avoid overwrite,
// so we could write back GenerationLogitsCache::kCACHE_LENGTH steps' logits together
// logits shape: [1, maxBatchSize * maxBeamWidth, vocabSizePadded]
// which is large enough to cover both numContextRequests and numGenSequences
logits = ITensor::slice(generationLogitsCache.logits, generationLogitsCache.offset, 1);
generationLogitsCache.offset = (generationLogitsCache.offset + 1) % GenerationLogitsCache::kCACHE_LENGTH;
logits->squeeze(0);
}
else
{
logits->reshape(ITensor::makeShape({numLogits, vocabSizePadded}));
}
}
auto const numSequences = getNumSequences();
auto const numSequencesShape = ITensor::makeShape({numSequences});
requestTypes->reshape(numSequencesShape);
contextLengthsHost->reshape(numSequencesShape);
contextLengthsDevice->reshape(numSequencesShape);
sequenceLengthsHost->reshape(numSequencesShape);
sequenceLengthsDevice->reshape(numSequencesShape);
auto const numLogitsShape = ITensor::makeShape({numLogits});
lastTokenIdsHost->reshape(numLogitsShape);
lastTokenIdsDevice->reshape(numLogitsShape);
logitsIdsHost->reshape(numLogitsShape);
logitsIdsDevice->reshape(numLogitsShape);
if (transformerBuffers)
{
transformerBuffers->reshape(numSequences, numContextTokens + numGenTokens);
}
if (rnnStateBuffers)
{
rnnStateBuffers->reshape(numSequences);
}
if (modelConfig.useCrossAttention())
{
encoderBuffers->reshape();
}
if (modelConfig.useLoraPlugin())
{
loraBuffers->reshape(numSequences);
}
if (medusaBuffers)
{
medusaBuffers->reshape(
numContextRequests, numGenRequests, modelConfig.getSpeculativeDecodingModulePtr()->getMaxDecodingTokens());
}
if (lookaheadBuffers && modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
{
lookaheadBuffers->reshape(
numContextRequests, numGenRequests, modelConfig.getSpeculativeDecodingModulePtr()->getMaxDecodingTokens());
}
if (explicitDraftTokensBuffers)
{
explicitDraftTokensBuffers->reshape(numContextRequests, numGenRequests, modelConfig);
}
if (eagleBuffers)
{
eagleBuffers->reshape(numContextRequests, numGenRequests, modelConfig);
}
auto const numRequests = getNumRequests();
auto const numRequestsShape = ITensor::makeShape({numRequests});
seqSlots->reshape(numRequestsShape);
seqSlotsDevice->reshape(numRequestsShape);
sortedSeqSlots->reshape(numRequestsShape);
seqSlotRemappingHost->reshape(numRequestsShape);
seqSlotRemappingDevice->reshape(numRequestsShape);
auto const numTokens = getNumTokens();
inputsIds->reshape(ITensor::makeShape({numTokens}));
if (modelConfig.useMrope())
{
auto const mropeRotaryCosSinSize = modelConfig.getMaxPositionEmbeddings() * modelConfig.getRotaryEmbeddingDim();
mropeRotaryCosSin->reshape(ITensor::makeShape({numSequences, mropeRotaryCosSinSize}));
mropePositionDeltas->reshape(ITensor::makeShape({numSequences, 1}));
}
if (worldConfig.isPipelineParallel())
{
auto const hiddenSize = (!modelConfig.getPpReduceScatter() || worldConfig.isFirstPipelineParallelRank())
? modelConfig.getHiddenSize() * worldConfig.getTensorParallelism()
: modelConfig.getHiddenSize();
auto const hiddenStatesShape = ITensor::makeShape({numTokens, hiddenSize});
hiddenStates->reshape(hiddenStatesShape);
}
if (modelConfig.useLanguageAdapter())
{
languageAdapterRoutings->reshape(ITensor::makeShape({numTokens, 1}));
}
for (auto const& outputTensor : mAdditionalOutputTensors)
{
auto const& [name, tensor] = outputTensor;
auto const& engine = runtime.getEngine();
auto shape = engine.getTensorShape(name.c_str());
TLLM_CHECK_WITH_INFO(
shape.d[0] == -1, "First dimension of additional output tensor '%s' must be dynamic", name.c_str());
shape.d[0] = numTokens;
tensor->reshape(shape);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLen,
TllmRuntime const& runtime, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
executor::DecodingConfig const& decodingConfig, bool gatherGenerationLogits,
std::optional<std::vector<executor::AdditionalModelOutput>> const& additionalModelOutputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const& manager = runtime.getBufferManager();
auto const& engine = runtime.getEngine();
if (modelConfig.isTransformerBased())
{
transformerBuffers = std::make_unique<TransformerBuffers>(maxBatchSize, maxBeamWidth, maxAttentionWindowVec,
maxAttentionWindow, sinkTokenLen, runtime, modelConfig, worldConfig);
}
if (modelConfig.isRnnBased())
{
rnnStateBuffers = std::make_unique<RnnStateBuffers>(maxBatchSize, runtime);
}
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
inputsIds = manager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
mropeRotaryCosSin = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kFLOAT);
mropePositionDeltas = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
if (worldConfig.isLastPipelineParallelRank())
{
auto const logitsType = engine.getTensorDataType(batch_manager::RuntimeBuffers::kLogitsTensorName);
logits = manager.emptyTensor(MemoryType::kGPU, logitsType);
}
seqSlotRemappingHost = manager.emptyTensor(MemoryType::kPINNEDPOOL, nvinfer1::DataType::kINT32);
seqSlotRemappingDevice = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
// TODO: check which tensors can be allocated as pinned for max size
requestTypes = manager.emptyTensor(MemoryType::kCPU, TRTDataType<runtime::RequestType>::value);
contextLengthsHost = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
contextLengthsDevice = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
sequenceLengthsHost = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
sequenceLengthsDevice = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
lastTokenIdsHost = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
lastTokenIdsDevice = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
logitsIdsHost = manager.emptyTensor(MemoryType::kCPU, nvinfer1::DataType::kINT32);
logitsIdsDevice = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
inputsIds = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
if (worldConfig.isPipelineParallel())
{
hiddenStates = manager.emptyTensor(MemoryType::kGPU, modelConfig.getDataType());
}
auto const maxBatchSizeShape = ITensor::makeShape({maxBatchSize});
seqSlots = tensorrt_llm::runtime::BufferManager::pinnedPool(maxBatchSizeShape, nvinfer1::DataType::kINT32);
seqSlotsDevice = manager.gpu(maxBatchSizeShape, nvinfer1::DataType::kINT32);
sortedSeqSlots = tensorrt_llm::runtime::BufferManager::pinnedPool(maxBatchSizeShape, nvinfer1::DataType::kINT32);
cacheIndirDecoderIOBatchedCopySrcOffsets
= tensorrt_llm::runtime::BufferManager::pinnedPool(maxBatchSizeShape, nvinfer1::DataType::kINT64);
cacheIndirDecoderIOBatchedCopyDstOffsets
= tensorrt_llm::runtime::BufferManager::pinnedPool(maxBatchSizeShape, nvinfer1::DataType::kINT64);
cacheIndirDecoderIOBatchedCopySizes
= tensorrt_llm::runtime::BufferManager::pinnedPool(maxBatchSizeShape, nvinfer1::DataType::kINT64);
mCacheIndirDecoderIOBatchedCopySrcOffsetsSliceDevice = manager.gpu(maxBatchSizeShape, nvinfer1::DataType::kINT64);
mCacheIndirDecoderIOBatchedCopyDstOffsetsSliceDevice = manager.gpu(maxBatchSizeShape, nvinfer1::DataType::kINT64);
mCacheIndirDecoderIOBatchedCopyCopySizesDevice = manager.gpu(maxBatchSizeShape, nvinfer1::DataType::kINT64);
// Pre-allocate buffer for saving generation logits for model w/o draft tokens
if (gatherGenerationLogits
&& (modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal()
|| modelConfig.getSpeculativeDecodingMode().isNone())
&& worldConfig.isLastPipelineParallelRank())
{
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
auto const logitsType = engine.getTensorDataType(batch_manager::RuntimeBuffers::kLogitsTensorName);
generationLogitsCache.transposedLogits = manager.gpu(
ITensor::makeShape({maxBeamWidth, GenerationLogitsCache::kCACHE_LENGTH, vocabSizePadded}), logitsType);
generationLogitsCache.logits = manager.gpu(
ITensor::makeShape({GenerationLogitsCache::kCACHE_LENGTH, maxBatchSize * maxBeamWidth, vocabSizePadded}),
logitsType);
generationLogitsCache.fragmentPointerDevice
= manager.gpu(ITensor::makeShape({GenerationLogitsCache::kCACHE_LENGTH}), nvinfer1::DataType::kINT64);
generationLogitsCache.fragmentPointerHost = tensorrt_llm::runtime::BufferManager::pinnedPool(
ITensor::makeShape({maxBatchSize, GenerationLogitsCache::kCACHE_LENGTH}), nvinfer1::DataType::kINT64);
}
if (modelConfig.useCrossAttention())
{
encoderBuffers = std::make_unique<EncoderBuffers>();
encoderBuffers->create(maxBatchSize, modelConfig, runtime);
}
if (modelConfig.usePromptTuning())
{
promptTuningBuffers = std::make_unique<PromptTuningBuffers>(maxBatchSize, manager, modelConfig, worldConfig);
}
if (modelConfig.useLoraPlugin())
{
loraBuffers = std::make_unique<LoraBuffers>(maxBatchSize, maxBeamWidth, runtime, modelConfig, worldConfig);
}
if (modelConfig.getSpeculativeDecodingMode().isMedusa())
{
medusaBuffers = std::make_unique<MedusaBuffers>(
maxBatchSize, maxBeamWidth, manager, modelConfig, worldConfig, decodingConfig, runtime);
}
else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
{
lookaheadBuffers.emplace(
maxBatchSize, maxBeamWidth, manager, modelConfig, worldConfig, decodingConfig, runtime);
}
else if (modelConfig.getSpeculativeDecodingMode().isExplicitDraftTokens())
{
explicitDraftTokensBuffers.emplace(maxBatchSize, maxBeamWidth, manager, modelConfig, worldConfig);
}
else if (modelConfig.getSpeculativeDecodingMode().isEagle())
{
eagleBuffers.emplace(maxBatchSize, maxBeamWidth, manager, modelConfig, worldConfig, decodingConfig);
}
if (modelConfig.useLanguageAdapter())
{
languageAdapterRoutings = manager.emptyTensor(MemoryType::kGPU, TRTDataType<SizeType32>::value);
}
for (auto const& output : additionalModelOutputs.value_or(std::vector<executor::AdditionalModelOutput>{}))
{
auto const& engine = runtime.getEngine();
auto const dataType = engine.getTensorDataType(output.name.c_str());
mAdditionalOutputTensors.emplace(output.name, manager.emptyTensor(runtime::MemoryType::kGPU, dataType));
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::setMaxBufferSizes(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
runtime::ModelConfig const& modelConfig, std::optional<SizeType32> maxNumRuntimeTokens)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// `maxNumSequences` is reached when all requests are in generation
numContextRequests = 0;
numGenRequests = maxBatchSize;
numGenSequences = maxBatchSize * maxBeamWidth;
auto const maxDraftTokens = modelConfig.getMaxDecodingDraftTokens();
// Draft-Tokens and Beam-Search are mutually exclusive
numLogits = maxBatchSize * std::max(1 + maxDraftTokens, maxBeamWidth);
auto const maxNumModelTokens = modelConfig.getMaxNumTokens();
auto const maxNumContextTokens = maxBatchSize * modelConfig.getMaxInputLen();
auto const maxNumGenTokens = numLogits;
// For pre-allocation
numContextTokens = 0; // Set in `setBufferSizes` rather than here for `computeContextLogits`
numGenTokens
= maxNumRuntimeTokens.value_or(maxNumModelTokens.value_or(std::max(maxNumContextTokens, maxNumGenTokens)));
if (modelConfig.useCrossAttention())
{
encoderBuffers->setMaxBufferSizes(maxBatchSize, modelConfig);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::setBufferSizes(RequestVector const& contextRequests, RequestVector const& genRequests)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(runtimeBuffersSetBufferSizes);
// set context sizes
numContextRequests = static_cast<SizeType32>(contextRequests.size());
auto numContextLogits = numContextRequests;
numContextTokens = 0;
maxContextLength = 0;
for (auto const& llmReq : contextRequests)
{
auto const draftLength = llmReq->isLastContextChunk() ? llmReq->getNumDraftTokens() : 0;
numContextLogits += draftLength;
auto const contextChunkSize = llmReq->getContextChunkSize();
numContextTokens += contextChunkSize + draftLength;
if (maxContextLength < llmReq->mPromptLen)
{
maxContextLength = llmReq->mPromptLen;
}
}
// set generation sizes
numGenRequests = static_cast<SizeType32>(genRequests.size());
numGenSequences = 0;
numGenTokens = 0;
for (auto const& llmReq : genRequests)
{
auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth;
numGenSequences += reqBeamWidth;
auto const draftLen = llmReq->getNumDraftTokens();
numGenTokens += draftLen + reqBeamWidth;
}
numLogits = numContextLogits + numGenTokens;
if (encoderBuffers)
{
encoderBuffers->setBufferSizes(contextRequests, genRequests);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::prepareBuffersForCudaGraph(SizeType32 maxSequenceLength)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(prepareBuffersForCudaGraph);
TLLM_CHECK(numContextRequests == 0);
if (transformerBuffers)
{
// Set pastKeyValueLength for graph capturing. This way we will capture graph with
// maxKvCacheLengthRounded rounded to the next kKV_CACHE_LEN_CUDA_GRAPH_ROUND_SIZE.
// MMHA will launch excessive amount of blocks and some of them will exit early during the actual launch.
// We can reuse the same graph for the next kKV_CACHE_LEN_CUDA_GRAPH_ROUND_SIZE iterations.
// make sure the size does not overflow the max allowed pastKvCacheLength
auto const pastKvCacheLength = std::min(maxSequenceLength - 1, maxKvCacheLengthRounded);
auto* pastKeyValueLengthsPtr = bufferCast<SizeType32>(*transformerBuffers->pastKeyValueLengths);
std::fill_n(pastKeyValueLengthsPtr, getNumSequences(), pastKvCacheLength);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, RequestVector const& genRequests,
SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, DecoderBuffers& decoderBuffers,
kv_cache_manager::BaseKVCacheManager* kvCacheManagerPtr,
kv_cache_manager::BaseKVCacheManager* crossKvCacheManagerPtr,
rnn_state_manager::RnnStateManager* rnnStateManagerPtr, PeftTable const& peftTable,
runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(runtimeBuffersSetFromInputs);
auto const& manager = runtime.getBufferManager();
auto const& stream = runtime.getStream();
// Fill requestTypes
{
auto* hostRequestTypes = bufferCast<runtime::RequestType>(*requestTypes);
std::fill_n(hostRequestTypes, numContextRequests, runtime::RequestType::kCONTEXT);
std::fill_n(hostRequestTypes + numContextRequests, numGenSequences, runtime::RequestType::kGENERATION);
}
SizeType32 totalInputSize = 0;
std::vector<TokenIdType> inputHost;
std::vector<SizeType32> positionIdsHost;
std::vector<SizeType32> positionIdsHostRow2;
std::vector<SizeType32> mropePositionDeltasHost;
std::vector<SizeType32> languageAdapterRoutingsHost;
auto* contextLengthsHostPtr = bufferCast<SizeType32>(*contextLengthsHost);
auto* sequenceLengthsHostPtr = bufferCast<SizeType32>(*sequenceLengthsHost);
auto* pastKeyValueLengthsPtr
= transformerBuffers ? bufferCast<SizeType32>(*transformerBuffers->pastKeyValueLengths) : nullptr;
SizeType32 totalNumLogits{0};
auto* logitsIdsHostPtr = bufferCast<SizeType32>(*logitsIdsHost);
bool const isChatGlm = modelConfig.getModelVariant() == ModelConfig::ModelVariant::kChatGlm;
bool const isGlm = modelConfig.getModelVariant() == ModelConfig::ModelVariant::kGlm;
auto const mropeRotaryCosSinSize = modelConfig.getMaxPositionEmbeddings() * modelConfig.getRotaryEmbeddingDim();
{
NVTX3_SCOPED_RANGE(seqSlotsLoop);
auto* seqSlotIndices = bufferCast<SizeType32>(*seqSlots);
SizeType32 batchIdx{0};
for (auto const& requests : {contextRequests, genRequests})
{
for (auto const& llmReq : requests)
{
// Get position of the current sequence in the decoder
auto const seqSlot = llmReq->mSeqSlot.value();
seqSlotIndices[batchIdx] = seqSlot;
++batchIdx;
}
}
TLLM_CHECK(seqSlots->getSize() == static_cast<std::size_t>(batchIdx));
manager.copy(*seqSlots, *seqSlotsDevice);
}
// context preparation loop
if (!contextRequests.empty())
{
NVTX3_SCOPED_RANGE(contextPrepareLoop);
numContextLogits.resize(contextRequests.size());
SizeType32 batchIdx{0};
for (auto const& llmReq : contextRequests)
{
TLLM_CHECK_WITH_INFO(llmReq->isContextInitState() || llmReq->isDisaggGenerationTransmissionComplete(),
"The request should be in context phase or disaggregated generation tranmissionComplete phase.");
TLLM_CHECK_WITH_INFO(
llmReq->getMaxNumGeneratedTokens() == 0, "Context request should not have generated tokens.");
auto const& reqTokens = llmReq->getTokens(0);
auto const& draftTokens = llmReq->getDraftTokens();
auto const draftLength = llmReq->getNumDraftTokens();
auto const& positionIds = llmReq->getPositionIds();
auto const contextChunkSize = llmReq->getContextChunkSize();
auto const beginCompute = llmReq->getContextCurrentPosition();
auto const endCompute = beginCompute + contextChunkSize;
inputHost.insert(inputHost.end(), reqTokens.begin() + beginCompute, reqTokens.begin() + endCompute);
logitsIdsHostPtr[totalNumLogits++] = contextChunkSize;
numContextLogits.at(batchIdx) = modelConfig.computeContextLogits() ? contextChunkSize : 1;
if (llmReq->isLastContextChunk())
{
inputHost.insert(inputHost.end(), draftTokens->begin(), draftTokens->end());
std::fill_n(logitsIdsHostPtr + totalNumLogits, draftLength, 1);
totalNumLogits += draftLength;
}
auto const inputLength = contextChunkSize + (llmReq->isLastContextChunk() ? draftLength : 0);
contextLengthsHostPtr[batchIdx] = inputLength;
auto const sequenceLen = inputLength + llmReq->getContextCurrentPosition();
sequenceLengthsHostPtr[batchIdx] = sequenceLen;
if (static_cast<bool>(pastKeyValueLengthsPtr))
{
pastKeyValueLengthsPtr[batchIdx] = beginCompute + inputLength;
}
if (positionIds.has_value())
{
TLLM_CHECK_WITH_INFO(!(isChatGlm || isGlm), "ChatGLM-6B and Glm only use the default initialization");
positionIdsHost.insert(positionIdsHost.end(), positionIds.value()->begin() + beginCompute,
positionIds.value()->begin() + endCompute);
}
else
{
if (isChatGlm)
{
// Specialize for ChatGLM-6B with 2D-Position-Embedding
positionIdsHost.resize(totalInputSize + inputLength);
std::iota(std::begin(positionIdsHost) + totalInputSize, std::end(positionIdsHost), 0);
positionIdsHost.back() = positionIdsHost.back() - 1;
positionIdsHostRow2.resize(totalInputSize + inputLength);
positionIdsHostRow2.back() = 1;
}
else if (isGlm)
{
// Specialize for GLM-10B with 2D-Position-Embedding and special value of the mask id position
auto start = inputHost.begin() + totalInputSize;
auto end = start + inputLength;
auto it = std::find_if(
start, end, [](SizeType32 id) { return id == 50260 || id == 50263 || id == 50264; });
llmReq->mMaskPosition = (it != end) ? std::distance(start, it) : maxContextLength;
positionIdsHost.resize(totalInputSize + inputLength);
std::iota(std::begin(positionIdsHost) + totalInputSize, std::end(positionIdsHost), 0);
positionIdsHost.back() = llmReq->mMaskPosition;
positionIdsHostRow2.resize(totalInputSize + inputLength);
positionIdsHostRow2.back() = 1;
}
else
{
// Other models
positionIdsHost.resize(totalInputSize + inputLength);
std::iota(std::begin(positionIdsHost) + totalInputSize,
std::begin(positionIdsHost) + totalInputSize + inputLength, beginCompute);
}
}
if (modelConfig.useMrope())
{
auto optMropeRotaryCosSin = llmReq->getMropeRotaryCosSin().value();
TLLM_CHECK_WITH_INFO(optMropeRotaryCosSin->getShape().d[0] == mropeRotaryCosSinSize,
"Provided MropeRotarySinCos is %ld and expected is %d.\n", optMropeRotaryCosSin->getShape().d[0],
int(mropeRotaryCosSinSize));
auto const mropeRotaryCosSinCtx = ITensor::slice(mropeRotaryCosSin, batchIdx, 1);
manager.copy(*optMropeRotaryCosSin, *mropeRotaryCosSinCtx);
}
if (modelConfig.useLanguageAdapter())
{
auto const languageAdapterRouting = llmReq->getLanguageAdapterRouting(
modelConfig.getNumLanguages().value(), endCompute - beginCompute);
languageAdapterRoutingsHost.insert(languageAdapterRoutingsHost.end(),
std::begin(languageAdapterRouting), std::end(languageAdapterRouting));
}
totalInputSize += inputLength;
++batchIdx;
}
if (rnnStateBuffers)
{
rnnStateBuffers->fillSlotMappings(contextRequests, rnnStateManagerPtr);
}
if (transformerBuffers && maxBeamWidth > 1)
{
transformerBuffers->resetCacheIndirection(contextRequests, maxBeamWidth, maxAttentionWindow,
decoderBuffers.cacheIndirectionInput, decoderBuffers.cacheIndirectionOutput, manager);
}
}
// generation preparation loop
if (!genRequests.empty())
{
NVTX3_SCOPED_RANGE(genPrepareLoop);
auto const numContextRequests = static_cast<SizeType32>(contextRequests.size());
auto numSequences = numContextRequests;
for (auto const& llmReq : genRequests)
{
auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth;
auto const draftLength = llmReq->getNumDraftTokens();
auto const& draftTokens = llmReq->getDraftTokens();
auto const numLogits = draftLength + reqBeamWidth;
TLLM_CHECK(draftLength == 0 || reqBeamWidth == 1);
auto const promptLen = llmReq->mPromptLen;
auto const sequenceLen = promptLen + llmReq->getMaxNumGeneratedTokens();
auto const& positionIds = llmReq->getPositionIds();
for (int beam = 0; beam < reqBeamWidth; ++beam)
{
auto const lastToken = llmReq->getLastTokens(beam);
auto const numTokens = llmReq->getNumTokens(beam);
inputHost.push_back(lastToken);
// If model updates generation position ids do not append them here.
if (!modelConfig.getSpeculativeDecodingMode().updatesPositionIds())
{
if (positionIds.has_value())
{
TLLM_CHECK_WITH_INFO(
!(isChatGlm || isGlm), "ChatGLM-6B and Glm only use the default initialization");
auto last_context_position_id = positionIds.value()->back();
positionIdsHost.push_back(
static_cast<SizeType32>(last_context_position_id + sequenceLen - promptLen));
}
else
{
if (isChatGlm) // ChatGLM-6B
{
positionIdsHost.push_back(static_cast<SizeType32>(promptLen - 2));
positionIdsHostRow2.push_back(static_cast<SizeType32>(sequenceLen - promptLen + 1));
}
else if (isGlm)
{
positionIdsHost.push_back(llmReq->mMaskPosition);
positionIdsHostRow2.push_back(static_cast<SizeType32>(sequenceLen - promptLen + 1));
}
else // GPT / ChatGLM2-6B / ChatGLM3-6B / BART
{
// positionIds is just the size of tokens -1
positionIdsHost.push_back(numTokens - 1);
}
}
}
if (draftLength > 0)
{
inputHost.insert(inputHost.end(), draftTokens->begin(), draftTokens->end());
}
if (modelConfig.useMrope())
{
auto optMropePositionDeltas = llmReq->getMropePositionDeltas().value();
mropePositionDeltasHost.push_back(optMropePositionDeltas);
}
if (modelConfig.useLanguageAdapter())
{
// Generation requests only have one token per sequence
auto const languageAdapterRouting
= llmReq->getLanguageAdapterRouting(modelConfig.getNumLanguages().value(), 1);
languageAdapterRoutingsHost.insert(languageAdapterRoutingsHost.end(),
std::begin(languageAdapterRouting), std::end(languageAdapterRouting));
}
}
if (static_cast<bool>(pastKeyValueLengthsPtr))
{
SizeType32 pastKeyValueLength = sequenceLen - 1;
std::fill_n(pastKeyValueLengthsPtr + numSequences, reqBeamWidth, pastKeyValueLength);
}
totalInputSize += numLogits;
std::fill_n(logitsIdsHostPtr + totalNumLogits, numLogits, 1);
totalNumLogits += numLogits;
if (rnnStateBuffers)
{
auto const seqSlot = llmReq->mSeqSlot.value();
auto& rnnStateManager = *rnnStateManagerPtr;
rnnStateManager.fillSlotMapping(*rnnStateBuffers->slotMappingHost, numSequences, seqSlot, reqBeamWidth);
}
numSequences += reqBeamWidth;
}
if (transformerBuffers && maxBeamWidth > 1)
{
transformerBuffers->copyCacheIndirection(genRequests, decoderBuffers.cacheIndirectionOutput, stream);
}
numSequences = numContextRequests;
for (auto const& llmReq : genRequests)
{
auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth;
auto const draftLength = llmReq->getNumDraftTokens();
auto const contextQLength = llmReq->mPromptLen + draftLength;
auto const sequenceLen = contextQLength + llmReq->getMaxNumGeneratedTokens();
std::fill_n(contextLengthsHostPtr + numSequences, reqBeamWidth, contextQLength);
std::fill_n(sequenceLengthsHostPtr + numSequences, reqBeamWidth, sequenceLen);
numSequences += reqBeamWidth;
}
if (modelConfig.getSpeculativeDecodingMode().needsKVCacheRewind())
{
auto remappingSeqSlotIndices = BufferRange<SizeType32>(*seqSlotRemappingHost);
auto const* seqSlotIndices = bufferCast<SizeType32>(*seqSlots);
std::iota(remappingSeqSlotIndices.begin(), remappingSeqSlotIndices.end(), 0);
std::sort(remappingSeqSlotIndices.begin(), remappingSeqSlotIndices.end(),
[&seqSlotIndices](SizeType32 a, SizeType32 b) { return seqSlotIndices[a] < seqSlotIndices[b]; });
manager.copy(*seqSlotRemappingHost, *seqSlotRemappingDevice);
manager.copy(*seqSlots, *sortedSeqSlots);
auto sortedSeqSlotIndices = BufferRange<SizeType32>(*sortedSeqSlots);
std::sort(sortedSeqSlotIndices.begin(), sortedSeqSlotIndices.end());
}
if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
{
// copy from lookahead decoding buffer
lookaheadBuffers->setFromInputs(numContextRequests, numGenRequests, *requestTypes, *seqSlots,
decoderBuffers.lookaheadBuffers.value(), runtime, modelConfig, worldConfig);
}
}
// check skipCrossAttnBlocks
if (transformerBuffers && modelConfig.skipCrossAttnBlocks())
{
bool isSkipCrossAttn = true;
for (auto const& requests : {contextRequests, genRequests})
{
for (auto const& llmReq : requests)
{
bool tmpValue = false;
if (llmReq->getSkipCrossAttnBlocks() != nullptr)
{
manager.copy(*llmReq->getSkipCrossAttnBlocks(), &tmpValue);
}
isSkipCrossAttn &= tmpValue;
}
}
transformerBuffers->copySkipCrossAttnBlocks(isSkipCrossAttn, runtime);
}
if (isChatGlm || isGlm)
{
positionIdsHost.reserve(totalInputSize * 2);
positionIdsHost.insert(positionIdsHost.end(), positionIdsHostRow2.begin(), positionIdsHostRow2.end());
}
if (modelConfig.useCrossAttention())
{
encoderBuffers->fill(contextRequests, genRequests, manager);
}
if (modelConfig.usePromptTuning())
{
promptTuningBuffers->fill(contextRequests, genRequests, manager, modelConfig.usePackedInput());
}
if (modelConfig.useLoraPlugin())
{
loraBuffers->fill(contextRequests, genRequests, peftTable, manager, modelConfig, worldConfig);
}
if (modelConfig.useMrope())
{
if (!mropePositionDeltasHost.empty())
{
auto mropePositionDeltasGen = ITensor::slice(mropePositionDeltas, 0, numGenSequences);
manager.copy(mropePositionDeltasHost.data(), *mropePositionDeltasGen);
}
}
{
NVTX3_SCOPED_RANGE(bufferCopies);
manager.copy(inputHost.data(), *inputsIds);
// In generation phase, device ptr of context lengths need to be tiled.
manager.copy(*contextLengthsHost, *contextLengthsDevice);
manager.copy(*sequenceLengthsHost, *sequenceLengthsDevice);
manager.copy(*logitsIdsHost, *logitsIdsDevice);
auto const logitsIdsHostRange = BufferRange<SizeType32>(*logitsIdsHost);
auto lastTokenIdsHostRange = BufferRange<SizeType32>(*lastTokenIdsHost);
common::stl_utils::inclusiveScan(
logitsIdsHostRange.begin(), logitsIdsHostRange.end(), lastTokenIdsHostRange.begin());
manager.copy(*lastTokenIdsHost, *lastTokenIdsDevice);
if (transformerBuffers)
{
TensorPtr decoderPositionIds = modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()
? lookaheadBuffers->positionIdsDevice
: nullptr;
transformerBuffers->copyPositionIds(runtime, positionIdsHost, isChatGlm || isGlm, decoderPositionIds);
}
if (rnnStateBuffers)
{
rnnStateBuffers->copySlotMappingH2D(runtime);
}
if (modelConfig.useLanguageAdapter())
{
manager.copy(languageAdapterRoutingsHost.data(), *languageAdapterRoutings);
}
}
if (transformerBuffers && static_cast<bool>(kvCacheManagerPtr))
{
transformerBuffers->copyKvBlockOffsets(
contextRequests, genRequests, kvCacheManagerPtr, crossKvCacheManagerPtr, manager);
}
if (modelConfig.useCrossAttention())
{
transformerBuffers->copyCrossAttentionMasks(contextRequests, genRequests, contextLengthsDevice,
encoderBuffers->inputLengths, maxContextLength, encoderBuffers->getMaxInputLengthInBatch(), runtime);
}
maxKvCacheLengthRounded = 0;
if (static_cast<bool>(pastKeyValueLengthsPtr))
{
auto const maxKvCacheLength
= *std::max_element(pastKeyValueLengthsPtr, pastKeyValueLengthsPtr + getNumSequences());
// Round up kv cache length
maxKvCacheLengthRounded = common::ceilDiv(maxKvCacheLength, kKV_CACHE_LEN_CUDA_GRAPH_ROUND_SIZE)
* kKV_CACHE_LEN_CUDA_GRAPH_ROUND_SIZE;
}
if (modelConfig.getSpeculativeDecodingMode().needsDecoderPrologue())
{
if (modelConfig.getSpeculativeDecodingMode().isExplicitDraftTokens())
{
prepareExplicitDraftTokenBuffers(decoderBuffers, runtime, modelConfig, worldConfig);
}
if (modelConfig.getSpeculativeDecodingMode().isEagle())
{
prepareEagleBuffers(contextRequests, genRequests, decoderBuffers, runtime, modelConfig, worldConfig);
}
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::prepareExplicitDraftTokenBuffers(DecoderBuffers& decoderBuffers, TllmRuntime const& runtime,
ModelConfig const& modelConfig, WorldConfig const& worldConfig)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(explicitDraftTokensBuffers);
explicitDraftTokensBuffers->setFromInputs(numContextRequests, numGenRequests, *requestTypes, *seqSlots,
decoderBuffers.explicitDraftTokensBuffers, *transformerBuffers->positionIds, modelConfig, worldConfig,
runtime.getBufferManager(), runtime.getStream());
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void RuntimeBuffers::prepareEagleBuffers(RequestVector const& contextRequests, RequestVector const& genRequests,
DecoderBuffers& decoderBuffers, TllmRuntime const& runtime, ModelConfig const& modelConfig,
WorldConfig const& worldConfig)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(eagleBuffers);
eagleBuffers->setFromInputs(contextRequests, genRequests, *requestTypes, *seqSlots, decoderBuffers.eagleBuffers,
runtime.getBufferManager(), modelConfig, worldConfig);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
std::tuple<SizeType32, RuntimeBuffers::TensorMap const&, RuntimeBuffers::TensorMap&> RuntimeBuffers::prepareStep(
RequestVector const& contextRequests, RequestVector const& genRequests, SizeType32 maxBeamWidth,
SizeType32 maxAttentionWindow, DecoderBuffers& decoderBuffers, kv_cache_manager::BaseKVCacheManager* kvCacheManager,
kv_cache_manager::BaseKVCacheManager* crossKvCacheManager, rnn_state_manager::RnnStateManager* rnnStateManager,
PeftTable const& peftTable, TllmRuntime const& runtime, ModelConfig const& modelConfig,
WorldConfig const& worldConfig, bool gatherGenerationLogits)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(runtimeBuffersPrepareStep);
setBufferSizes(contextRequests, genRequests);
reshape(runtime, modelConfig, worldConfig, gatherGenerationLogits);
setFromInputs(contextRequests, genRequests, maxBeamWidth, maxAttentionWindow, decoderBuffers, kvCacheManager,
crossKvCacheManager, rnnStateManager, peftTable, runtime, modelConfig, worldConfig);
fillIOMaps(modelConfig, worldConfig);
auto const numTokens = getNumTokens();
auto const optProfileId = runtime.getOptProfileId(numTokens, ModelConfig::getOptProfilesSplitPoints());
setContextIndex(optProfileId);
TLLM_LOG_DEBUG("numTokens: %d, optProfileId: %d", numTokens, optProfileId);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return {optProfileId, inputMap, outputMap};
}
void RuntimeBuffers::fillIOMaps(ModelConfig const& modelConfig, WorldConfig const& worldConfig)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(runtimeBuffersFillIOMaps);
inputMap.clear();
outputMap.clear();
if (transformerBuffers)
{
transformerBuffers->getBuffers(inputMap, outputMap, modelConfig);
}
if (rnnStateBuffers)
{
rnnStateBuffers->getBuffers(inputMap);
}
if (worldConfig.isLastPipelineParallelRank())
{
// feed a view to TensorRT runtime so reshaping does not change logits buffer
outputMap.insert_or_assign(kLogitsTensorName, ITensor::view(logits));
}
else
{
outputMap.insert_or_assign(kHiddenStatesOutputTensorName, hiddenStates);
}
if (worldConfig.isFirstPipelineParallelRank())
{
inputMap.insert_or_assign(kInputIdsTensorName, inputsIds);
}
else
{
inputMap.insert_or_assign(kHiddenStatesInputTensorName, hiddenStates);
}
inputMap.insert_or_assign(kLastTokenIdsTensorName, lastTokenIdsDevice);
inputMap.insert_or_assign(kHostRequestTypesTensorName, requestTypes);
// In the generation phase, we still pass context lengths.
inputMap.insert_or_assign(kContextLengthsTensorName, contextLengthsDevice);
inputMap.insert_or_assign(kHostContextLengthsTensorName, contextLengthsHost);
inputMap.insert_or_assign(kSequenceLengthsTensorName, sequenceLengthsDevice);
if (modelConfig.useCrossAttention())
{
encoderBuffers->insertInputTensors(inputMap);
}
if (modelConfig.usePromptTuning())
{
auto const& promptTuningParams = promptTuningBuffers->mPromptTuningParams;
inputMap.insert_or_assign(kPromptEmbeddingTableTensorName, promptTuningParams.embeddingTable);
inputMap.insert_or_assign(kTasksTensorName, promptTuningParams.tasks);
inputMap.insert_or_assign(kPromptVocabSizeTensorName, promptTuningParams.vocabSize);
}
if (modelConfig.useMrope())
{
inputMap.insert_or_assign(kMRopeRotaryCosSinTensorName, mropeRotaryCosSin);
inputMap.insert_or_assign(kMRopePositionDeltasTensorName, mropePositionDeltas);
}
if (modelConfig.useLoraPlugin())
{
loraBuffers->insertInputTensors(inputMap, loraBuffers->mLoraWeightsPointersHost,
loraBuffers->mLoraAdapterSizesHost, modelConfig, worldConfig);
}
if (modelConfig.useLanguageAdapter())
{
inputMap.insert_or_assign("language_adapter_routings", languageAdapterRoutings);
}
if (medusaBuffers)
{
medusaBuffers->insertInputTensors(inputMap, outputMap, worldConfig);
}
if (lookaheadBuffers)
{
lookaheadBuffers->insertInputTensors(inputMap, outputMap, worldConfig);
}
if (explicitDraftTokensBuffers)
{
explicitDraftTokensBuffers->insertInputTensors(inputMap, outputMap, worldConfig);
}
if (eagleBuffers)
{
eagleBuffers->insertInputTensors(inputMap, outputMap, worldConfig);
}
for (auto const& outputTensor : mAdditionalOutputTensors)
{
outputMap.insert_or_assign(outputTensor.first, outputTensor.second);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
} // namespace tensorrt_llm::batch_manager