TensorRT-LLMs/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp
Robin Kobus e2a8cbc80b
refactor: manage cache indirection in decoder state (#5315)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-06-24 09:15:59 +02:00

2881 lines
127 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "trtGptModelInflightBatching.h"
#include "tensorrt_llm/batch_manager/allocateKvCache.h"
#include "tensorrt_llm/batch_manager/assignReqSeqSlots.h"
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
#include "tensorrt_llm/batch_manager/capacityScheduler.h"
#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/batch_manager/contextProgress.h"
#include "tensorrt_llm/batch_manager/createNewDecoderRequests.h"
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
#include "tensorrt_llm/batch_manager/guidedDecoder.h"
#include "tensorrt_llm/batch_manager/handleContextLogits.h"
#include "tensorrt_llm/batch_manager/handleGenerationLogits.h"
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
#include "tensorrt_llm/batch_manager/kvCacheEventManager.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/logitsPostProcessor.h"
#include "tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.h"
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
#include "tensorrt_llm/batch_manager/pauseRequests.h"
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
#include "tensorrt_llm/batch_manager/promptTuningBuffers.h"
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
#include "tensorrt_llm/batch_manager/sequenceSlotManager.h"
#include "tensorrt_llm/batch_manager/transformerBuffers.h"
#include "tensorrt_llm/batch_manager/updateDecoderBuffers.h"
#include "tensorrt_llm/batch_manager/utils/debugUtils.h"
#include "tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h"
#include "tensorrt_llm/batch_manager/utils/logitsThread.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/common/nvtxUtils.h"
#include "tensorrt_llm/common/timestampUtils.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/layers/defaultDecodingParams.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/ipcUtils.h"
#include "tensorrt_llm/runtime/lookaheadModule.h"
#include "tensorrt_llm/runtime/memoryCounters.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include "tensorrt_llm/runtime/tllmLogger.h"
#include "tensorrt_llm/runtime/tllmRuntime.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/runtime/utils/runtimeUtils.h"
#include <algorithm>
#include <cstddef>
#include <memory>
#include <optional>
#include <stdexcept>
#include <thread>
#include <utility>
#include <vector>
using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
namespace tk = tensorrt_llm::kernels;
namespace tensorrt_llm::batch_manager
{
bool TrtGptModelInflightBatching::executorConfigIsValid(
ModelConfig const& modelConfig, executor::ExecutorConfig const& executorConfig)
{
// Make sure logic in this function matches fixExecutorConfig
if (executorConfig.getKvCacheConfig().getEnableBlockReuse())
{
if (!modelConfig.getPagedContextFMHA())
{
return false;
}
// Context logits cannot be returned for reused tokens, so disable reuse
if (modelConfig.computeContextLogits())
{
return false;
}
}
return true;
}
executor::ExecutorConfig TrtGptModelInflightBatching::fixExecutorConfig(
ModelConfig const& modelConfig, executor::ExecutorConfig const& executorConfig)
{
// Make sure logic in this function matches executorConfigIsValid
if (executorConfig.getKvCacheConfig().getEnableBlockReuse())
{
auto kvCacheConfig = executorConfig.getKvCacheConfig();
if (!modelConfig.getPagedContextFMHA())
{
TLLM_LOG_WARNING(
"Fixing executorConfig: KV cache reuse disabled because model was not built with paged context FMHA "
"support");
kvCacheConfig.setEnableBlockReuse(false);
}
if (modelConfig.computeContextLogits())
{
TLLM_LOG_WARNING(
"Fixing executorConfig: KV cache reuse disabled because model was built to return context logits");
kvCacheConfig.setEnableBlockReuse(false);
}
auto fixedExecutorConfig = executor::ExecutorConfig(executorConfig);
fixedExecutorConfig.setKvCacheConfig(kvCacheConfig);
return fixedExecutorConfig;
}
return executorConfig;
}
TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer1::ILogger> logger,
ModelConfig const& modelConfig, WorldConfig const& worldConfig, RawEngine const& rawEngine, bool ctxGenFusion,
executor::ExecutorConfig const& executorConfig, bool isLeaderInOrchMode)
: TrtGptModel(modelConfig, worldConfig, executorConfig)
, mModelConfig(modelConfig)
, mWorldConfig(worldConfig)
, mDevice{runtime::utils::initDevice(worldConfig)}
, mDecodingConfig{executorConfig.getDecodingConfig().value_or(executor::DecodingConfig{})}
, mExtendedRuntimePerfKnobConfig{executorConfig.getExtendedRuntimePerfKnobConfig()}
, mDebugConfig{executorConfig.getDebugConfig()}
, mAdditionalModelOutputs{worldConfig.isLastPipelineParallelRank() ? executorConfig.getAdditionalModelOutputs()
: std::nullopt}
, mLogger{logger ? std::move(logger) : std::make_shared<TllmLogger>()}
, mRuntime{std::make_unique<TllmRuntime>(rawEngine, mLogger.get(), executorConfig.getUseGpuDirectStorage(),
executorConfig.getGpuWeightsPercent(), modelConfig.useShapeInference())}
, mCopyBufferManager{std::make_shared<CudaStream>()}
, mCtxGenFusion(ctxGenFusion)
, mOperatingBeamWidth{getMaxBeamWidth()}
, mGatherGenerationLogits{executorConfig.getGatherGenerationLogits()}
, mPromptTableOffloading{executorConfig.getPromptTableOffloading()}
, mIsLeaderInOrchMode{isLeaderInOrchMode}
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_INFO("gatherContextLogits: %d", mModelConfig.computeContextLogits());
TLLM_LOG_INFO("gatherGenerationLogits: %d", getGatherGenerationLogits());
if (!(mModelConfig.supportsInflightBatching()))
{
throw std::runtime_error(
"TrtGptModelInflightBatching requires GPT attention/Mamba Conv 1d plugin with "
"packed input and paged KV cache.");
}
if (mWorldConfig.isTensorParallel())
{
mRuntime->initializeUserBuffer(mWorldConfig, mModelConfig.getMaxBatchSize(), mModelConfig.getMaxBeamWidth(),
mModelConfig.getMaxSequenceLen(), mModelConfig.getHiddenSize(), getMaxNumTokens());
}
if (mWorldConfig.isPipelineParallel())
{
mNumMicroBatches = mWorldConfig.getPipelineParallelism();
}
else
{
mNumMicroBatches = isTrtOverlap() ? 2 : 1;
}
mNumBuffers = (mCtxGenFusion ? 1 : 2) * mNumMicroBatches;
auto const kvCacheConfig = KvCacheConfig(executorConfig.getKvCacheConfig());
if (!kvCacheConfig.onboardBlocks)
{
TLLM_CHECK_WITH_INFO(
!mModelConfig.getPagedContextFMHA(), "KV cache blocks need to be onboarded if context FMHA.");
}
if (mModelConfig.getSpeculativeDecodingMode().isDraftTokensExternal())
{
TLLM_CHECK_WITH_INFO(kvCacheConfig.enableBlockReuse,
"KV cache block reuse must be enabled for speculative decoding target model");
}
if (mCtxGenFusion)
{
TLLM_CHECK_WITH_INFO(!mModelConfig.isRnnBased(), "RNN based model doesn't support context generation fusion.");
TLLM_CHECK_WITH_INFO(
mModelConfig.isTransformerBased(), "Only transformer based model support context generation fusion now.");
}
if (mModelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
{
mSeamlessLADMaxDraftLen = modelConfig.getMaxDecodingDraftTokens();
// TODO: enable it when speculativeDecodingMode is None and run with '--lookahead_config'
mUseSeamlessLookahead = false;
}
setupSpeculativeDecodingModule(mDecodingConfig);
if (mWorldConfig.isLastPipelineParallelRank() && executorConfig.getGuidedDecodingConfig())
{
mGuidedDecoder = std::make_unique<GuidedDecoder>(executorConfig.getGuidedDecodingConfig().value(),
getMaxNumSequences(), mModelConfig.getVocabSizePadded(mWorldConfig.getSize()),
mModelConfig.getLogitsDtype(), mRuntime->getBufferManager());
}
createRuntimeContexts();
if (mWorldConfig.isTensorParallel())
{
createCustomAllReduceWorkspace();
}
if (mModelConfig.isTransformerBased())
{
createRuntimePerfKnobsTensor(mExtendedRuntimePerfKnobConfig);
}
auto& memCounter = MemoryCounters::getInstance();
auto const gpuUsage1 = memCounter.getGpu();
createBuffers(mDecodingConfig, mAdditionalModelOutputs);
auto const gpuUsage2 = memCounter.getGpu();
TLLM_LOG_INFO("[MemUsageChange] Allocated %s GPU memory for runtime buffers.",
memCounter.bytesToString(gpuUsage2 - gpuUsage1).c_str());
createDecoder(mDecodingConfig.getDecodingMode());
auto const gpuUsage3 = memCounter.getGpu();
TLLM_LOG_INFO("[MemUsageChange] Allocated %s GPU memory for decoder.",
memCounter.bytesToString(gpuUsage3 - gpuUsage2).c_str());
if (modelConfig.getManageWeightsType() != ModelConfig::ManageWeightsType::kDisabled)
{
mRuntime->loadManagedWeights(rawEngine, worldConfig.getLocalRank());
}
if (mModelConfig.useLoraPlugin())
{
auto const peftCacheManagerConfig
= PeftCacheManagerConfig(executorConfig.getPeftCacheConfig().value_or(executor::PeftCacheConfig()));
mPeftCacheManager = std::make_shared<PeftCacheManager>(
peftCacheManagerConfig, mModelConfig, mWorldConfig, mRuntime->getBufferManager());
}
else
{
mPeftCacheManager = std::make_shared<NoOpPeftCacheManager>();
}
if (mModelConfig.isRnnBased())
{
createRnnStateManager();
}
if (mModelConfig.isTransformerBased() && modelConfig.isKVCacheEnabled())
{
auto cacheTransceiverConfig
= executorConfig.getCacheTransceiverConfig().value_or(executor::CacheTransceiverConfig());
auto cacheTransPreAllocaSize
= kv_cache_manager::CacheTransBufferManager::preAllocBufferSize(cacheTransceiverConfig.getMaxNumTokens());
auto const [freePrimaryMemBytes, freeSecondaryMemBytes]
= BaseKVCacheManager::calculateFreeMemBytes(mRuntime->getBufferManager(), kvCacheConfig);
if (mModelConfig.useCrossAttention())
{
TLLM_CHECK_WITH_INFO(kvCacheConfig.crossKvCacheFraction.has_value(),
"Must set crossKvCacheFraction for encoder-decoder model");
auto const crossKvCacheFraction = kvCacheConfig.crossKvCacheFraction.value();
mKvCacheManager = createKvCacheManager(kvCacheConfig, KvCacheType::kSELF,
freePrimaryMemBytes * (1.0f - crossKvCacheFraction),
freeSecondaryMemBytes * (1.0f - crossKvCacheFraction), cacheTransPreAllocaSize);
mCrossKvCacheManager
= createKvCacheManager(kvCacheConfig, KvCacheType::kCROSS, freePrimaryMemBytes * crossKvCacheFraction,
freeSecondaryMemBytes * crossKvCacheFraction, cacheTransPreAllocaSize);
TLLM_LOG_INFO("This is an Encoder-Decoder model, set %0.1f cross KV cache fraction based on the config.",
crossKvCacheFraction);
}
else
{
TLLM_CHECK_WITH_INFO(!kvCacheConfig.crossKvCacheFraction.has_value(),
"Do not set crossKvCacheFraction for decoder-only model");
mKvCacheManager = createKvCacheManager(
kvCacheConfig, KvCacheType::kSELF, freePrimaryMemBytes, freeSecondaryMemBytes, cacheTransPreAllocaSize);
}
mCacheTransceiver
= CacheTransceiverFactory::createCacheTransceiver(mKvCacheManager.get(), mModelConfig, mWorldConfig,
executor::kv_cache::CacheState::AttentionType::kDEFAULT, executorConfig.getCacheTransceiverConfig());
}
if (mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind())
{
TLLM_CHECK_WITH_INFO(
mModelConfig.isKVCacheEnabled(), "When needsKVCacheRewind() returns true, KV cache needs to be enabled.");
auto const& blockManager = mKvCacheManager->getBlockManager();
TLLM_CHECK_WITH_INFO(blockManager.getNumPools() == 1,
"Rewinding KV cache blocks for models with multiple pools is not supported");
// Two "redundant" checks given the pool size check above, but those below don't rely on an implementation
// detail I guess.
TLLM_CHECK_WITH_INFO(
!blockManager.isVariableWindow(), "Rewinding KV cache blocks for variable SWA models isn't supported");
auto const maxBlocksPerSeq = blockManager.getMaxBlockPerSeqWhenSingleWindowSize();
auto const isUseOneMoreBlock = kv_cache_manager::BlockManager::isUseOneMoreBlock(
getMaxAttentionWindow(), getMaxSequenceLen(), getMaxBeamWidth());
// TODO(oargov): VGQA is not supported, assume all layers have the same num_kv_heads
TLLM_CHECK_WITH_INFO(
!blockManager.isVariableGQA(), "Rewinding KV cache blocks for variable GQA models isn't supported");
auto const numKvHeads = mModelConfig.getNbKvHeads(0);
mRewindInputs = RewindInputs{maxBlocksPerSeq, isUseOneMoreBlock, numKvHeads};
}
if (mWorldConfig.isPipelineParallel())
{
mAsyncSendWaitThread = std::make_unique<tensorrt_llm::mpi::MpiWaitThread>(
"asyncSendWaitThread",
[this]()
{
mDecStepAsyncSndHdls.clear();
mDecSlotAsyncSndHdls.clear();
},
[this]() { TLLM_CUDA_CHECK(cudaSetDevice(mWorldConfig.getDevice())); });
auto const& commSession = COMM_SESSION;
mMpiCommPipelinePara = std::make_unique<tensorrt_llm::mpi::MpiComm>(
commSession.split(mWorldConfig.getTensorParallelRank(), mWorldConfig.getPipelineParallelRank()));
mDecSlotAsyncSndHdls.reserve(getMaxBatchSize());
}
if (mWorldConfig.isTensorParallel())
{
auto const& commSession = COMM_SESSION;
mMpiCommTensorPara = std::make_unique<tensorrt_llm::mpi::MpiComm>(
commSession.split(mWorldConfig.getPipelineParallelRank(), mWorldConfig.getTensorParallelRank()));
}
mSeqSlotManager
= std::make_shared<SequenceSlotManager>(getMaxNumSequences(), executorConfig.getMaxSeqIdleMicroseconds());
mMicroBatchScheduledRequests.resize(mNumMicroBatches);
mDecoderFinishedEvents.resize(mNumMicroBatches);
mPeftTables.resize(mNumMicroBatches);
if (modelConfig.isRnnBased())
{
TLLM_CHECK_WITH_INFO(modelConfig.getMaxBeamWidth() == 1, "RNN based model doesn't support beam search now.");
TLLM_CHECK_WITH_INFO(
!executorConfig.getEnableChunkedContext(), "RNN based model doesn't support Chunked Context now.");
TLLM_CHECK_WITH_INFO(
modelConfig.getSpeculativeDecodingMode().isNone(), "RNN based model doesn't support speculative decoding.");
}
std::optional<batch_scheduler::ContextChunkingConfig> ctxChunkConfig;
if (executorConfig.getEnableChunkedContext())
{
TLLM_CHECK_WITH_INFO(modelConfig.isKVCacheEnabled() && mModelConfig.getPagedContextFMHA(),
"Chunked context requires context FMHA, paged kv_cache and paged context FMHA all enabled at the same "
"time.");
SizeType32 chunkUnitSize = mKvCacheManager->getTokensPerBlock();
// If sliding window attention is used, then make sure the unit size aligns with the paged context fmha's kv
// step size.
if (getMaxInputLen() > getMaxAttentionWindow()) // TODO(nhaber): minAttentionWindow
{
chunkUnitSize = std::max(/* maxKvStepSizeInFmha */ 256, chunkUnitSize);
TLLM_LOG_INFO("ChunkUnitSize is set to %d as sliding window attention is used.", chunkUnitSize);
}
ctxChunkConfig = batch_scheduler::ContextChunkingConfig{
executorConfig.getSchedulerConfig().getContextChunkingPolicy().value_or(
executor::ContextChunkingPolicy::kFIRST_COME_FIRST_SERVED),
chunkUnitSize};
}
auto maxNumTokens = getMaxNumTokens();
TLLM_CHECK_WITH_INFO(maxNumTokens, "Max number of tokens is not set in model config.");
// Max context size is limited by `max_num_tokens` for chunked-context or context-FMHA,
// or by `max_input_len` of the model.
auto const maxContextLength = (executorConfig.getEnableChunkedContext() || mModelConfig.getContextFMHA())
? maxNumTokens
: std::make_optional<SizeType32>(mModelConfig.getMaxInputLen());
mMaxBatchSizeTunerRecommended = 0;
mMaxBatchSizeRuntime = getMaxBatchSize();
mMaxNumTokensStatic = maxNumTokens;
mMaxNumTokensTunerRecommended = 0;
mMaxNumTokensRuntime = maxNumTokens;
if (mKvCacheManager && ctxChunkConfig)
{
TLLM_CHECK_WITH_INFO(ctxChunkConfig.value().chunkUnitSize % mKvCacheManager->getTokensPerBlock() == 0,
"To prevent cache fragmentation, the context chunk unit size (%d) should be divisible by the number of "
"tokens per kv-cache block (%d).",
ctxChunkConfig.value().chunkUnitSize, mKvCacheManager->getTokensPerBlock());
}
mCapacityScheduler = std::make_unique<CapacityScheduler>(getMaxNumSequences(),
executorConfig.getSchedulerConfig().getCapacitySchedulerPolicy(), mKvCacheManager != nullptr,
mWorldConfig.isPipelineParallel());
mMicroBatchScheduler = std::make_unique<MicroBatchScheduler>(ctxChunkConfig, maxContextLength);
if (ctxChunkConfig)
{
if (maxContextLength)
{
ctxChunkConfig.value().chunkUnitSize
= std::min(ctxChunkConfig.value().chunkUnitSize, maxContextLength.value());
}
TLLM_CHECK_WITH_INFO(ctxChunkConfig.value().chunkUnitSize > 0,
"Context chunk size (%d) must be a positive integer.", maxContextLength.value());
}
else
{
if (maxContextLength && maxNumTokens)
{
TLLM_CHECK_WITH_INFO(maxContextLength.value() <= maxNumTokens.value(),
"Without enabling chunked context, the max context length (%d) needs to be less than or equal to the "
"max number of tokens (%d).",
maxContextLength.value(), maxNumTokens.value());
}
}
mPauseRequests = std::make_unique<PauseRequests>(getMaxInputLen());
mAssignReqSeqSlots = std::make_unique<AssignReqSeqSlots>();
mAllocateKvCache = std::make_unique<AllocateKvCache>();
if (isCudaGraphMode())
{
// Limit cuda graph cache size. Depending on the model one graph is 4-10MB of GPU memory.
SizeType32 cudaGraphCacheSize
= std::min(getMaxBatchSize(), std::max(mExtendedRuntimePerfKnobConfig.getCudaGraphCacheSize(), 1));
// We can't have common cache for all microbatches as cuda graph is tied to the memory pointers of the runtime
// buffers.
mCudaGraphExecutorCaches.resize(mNumBuffers, utils::CudaGraphExecutorCache(cudaGraphCacheSize));
}
mSpeculativeDecodingFastLogits
= executorConfig.getSpecDecConfig().has_value() && executorConfig.getSpecDecConfig()->fastLogits;
if (mSpeculativeDecodingFastLogits && modelConfig.getSpeculativeDecodingMode().isNone() && mIsLeaderInOrchMode)
{
mDraftModelSendLogitsThread = std::make_unique<std::thread>(&utils::draftModelSendLogitsThread, mDevice,
&mDraftModelThreadShouldExit, &mDraftRequestsWaitingToSendLogits, mSeqSlotManager, getMaxInputLen(),
mKvCacheManager, mCrossKvCacheManager, mPeftCacheManager);
}
mCreateNewDecoderRequests = std::make_unique<CreateNewDecoderRequests>(
mSpeculativeDecodingFastLogits, mIsLeaderInOrchMode, isNormalizeLogProbs());
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
TrtGptModelInflightBatching::~TrtGptModelInflightBatching()
{
if (mCacheTransceiver)
{
mCacheTransceiver->checkContextTransferStatus(true);
TLLM_CHECK_WITH_INFO(mCacheTransceiver->checkGenTransferComplete(), "Generation transfer not complete");
}
if (mAsyncSendWaitThread)
{
mAsyncSendWaitThread.reset(nullptr);
}
if (mDraftModelSendLogitsThread)
{
mDraftModelThreadShouldExit = true;
mDraftModelSendLogitsThread->join();
mDraftModelSendLogitsThread.reset(nullptr);
}
}
void TrtGptModelInflightBatching::setupSpeculativeDecodingModule(executor::DecodingConfig const& decodingConfig)
{
if (mModelConfig.getSpeculativeDecodingMode().isExplicitDraftTokens()
|| mModelConfig.getSpeculativeDecodingMode().isEagle())
{
TLLM_CHECK_WITH_INFO(mCtxGenFusion, "Current speculative decoding mode requires context-gen fusion IFB");
}
if (mModelConfig.getSpeculativeDecodingMode().isLookaheadDecoding() && decodingConfig.getLookaheadDecodingConfig())
{
// FIXME choose defaults
auto maxLookaheadConfig = decodingConfig.getLookaheadDecodingConfig().value();
SizeType32 maxDraftTokens{0};
SizeType32 maxDraftPathLen{0};
std::tie(std::ignore, std::ignore, maxDraftTokens, maxDraftPathLen)
= maxLookaheadConfig.calculateSpeculativeResource();
TLLM_CHECK(maxDraftTokens <= mModelConfig.getMaxDecodingDraftTokens());
mModelConfig.getSpeculativeDecodingModulePtr()->setMaxDraftTokens(maxDraftTokens);
mModelConfig.getSpeculativeDecodingModulePtr()->setMaxDraftPathLen(maxDraftPathLen);
auto lookaheadModulePtr
= std::dynamic_pointer_cast<runtime::LookaheadModule>(mModelConfig.getSpeculativeDecodingModulePtr());
lookaheadModulePtr->setExecutionConfig(maxLookaheadConfig);
}
}
void TrtGptModelInflightBatching::reshapeKvTensors(OffsetTableDimensions const& dims)
{
TLLM_CHECK(mBuffers.size() == static_cast<size_t>(mNumBuffers));
auto const& manager = mRuntime->getBufferManager();
for (auto& buffers : mBuffers)
{
TLLM_CHECK(buffers->transformerBuffers);
// any method that operates on transformerBuffers must distinguish between self and cross cache, because
// transformerBuffers is not managed by KVCacheManager same rule applies to kv pool pointers below
buffers->transformerBuffers->reshapeKvTensors(
getMaxBatchSize(), mOperatingBeamWidth, dims.maxBlocksPerSeq, dims.cacheType, dims.numPools, manager);
}
}
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
std::pair<BlocksPerWindow, std::vector<SizeType32>>
TrtGptModelInflightBatching::clampWindowSizesToFitAtLeastOneSequence(BlocksPerWindow const& blocksPerWindow)
{
// At this point, we can only validate that the cheapest sequence in terms of kv-cache resources still fits. More
// validation is needed on a per-request basis, once the prompt / output lengths and the actual beam width are
// known.
auto const promptLength = getMaxInputLen();
auto const outputLength
= getMaxSequenceLen() - promptLength; // This makes it the best case scenario, as context tokens are 'cheaper'
// in terms of kv-cache resources on average.
auto const sinkTokenLength = getSinkTokenLen();
auto const maxBeamWidth = getMaxBeamWidth();
auto const tokensPerBlock = mModelConfig.getTokensPerBlock();
auto const& oldMaxAttentionWindowVec = getMaxAttentionWindowVec();
std::vector<SizeType32> newMaxAttentionWindowVec;
BlocksPerWindow newBlocksPerWindow;
newMaxAttentionWindowVec.reserve(oldMaxAttentionWindowVec.size());
for (auto const windowSize : oldMaxAttentionWindowVec)
{
auto const bestCaseBlockRequirements = kv_cache_manager::KVCacheManager::calculateMaxBlockRequirements(
promptLength, outputLength, sinkTokenLength, windowSize, maxBeamWidth, tokensPerBlock);
auto const [numPrimaryBlocks, numSecondaryBlocks] = blocksPerWindow.at(windowSize);
if (bestCaseBlockRequirements > numPrimaryBlocks)
{
auto const newMaxAttentionWindow = KVCacheManager::calculateMaxAttentionWindow(
promptLength, outputLength, sinkTokenLength, numPrimaryBlocks, maxBeamWidth, tokensPerBlock);
newMaxAttentionWindowVec.push_back(newMaxAttentionWindow);
newBlocksPerWindow[newMaxAttentionWindow] = std::make_tuple(numPrimaryBlocks, numSecondaryBlocks);
}
else
{
newMaxAttentionWindowVec.push_back(windowSize);
newBlocksPerWindow[windowSize] = std::make_tuple(numPrimaryBlocks, numSecondaryBlocks);
}
}
if (newMaxAttentionWindowVec == getMaxAttentionWindowVec())
{
return {blocksPerWindow, newMaxAttentionWindowVec};
}
TLLM_LOG_WARNING("maxAttentionWindowVec too large to fit at least one sequence in kvCache. Old: %s, New: %s",
common::vec2str(getMaxAttentionWindowVec()).c_str(), common::vec2str(newMaxAttentionWindowVec).c_str());
setMaxAttentionWindowVec(newMaxAttentionWindowVec);
if (getMaxSequenceLen() < getMaxAttentionWindow())
{
TLLM_LOG_WARNING("maxSequenceLen is reduced to maxAttentionWindow: %d", getMaxAttentionWindow());
setMaxSequenceLen(getMaxAttentionWindow());
if (getMaxInputLen() > getMaxSequenceLen() - 1)
{
setMaxInputLen(getMaxSequenceLen() - 1);
TLLM_LOG_WARNING("maxInputLen is reduced to %d", getMaxInputLen());
}
}
// createBuffers depends on:
// maxAttentionWindow; maxAttentionWindowVec; maxSequenceLen;
// TODO: This is problematic, as createBuffers edits the state of trtGptModelInflightBatching, but
// what if there are different window values for cross+self etc. in encoder+decoder scenario...
createBuffers(mDecodingConfig, mAdditionalModelOutputs);
createDecoder(mDecodingConfig.getDecodingMode());
return {newBlocksPerWindow, newMaxAttentionWindowVec};
}
std::unique_ptr<kv_cache_manager::KVCacheManager> TrtGptModelInflightBatching::createKvCacheManager(
KvCacheConfig const& kvCacheConfig, KvCacheType kvCacheType, uint64_t freePrimaryMemBytes,
uint64_t freeSecondaryMemBytes, size_t extraCostMemory)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
bool isCrossAttention = kvCacheType == KvCacheType::kCROSS;
TLLM_CHECK_WITH_INFO(
mModelConfig.isTransformerBased(), "KvCacheManager is only needed by transformer based model.");
auto const tokensPerBlock = mModelConfig.getTokensPerBlock();
auto const kvDtype = mModelConfig.getKvDataType();
bool enableCyclicKvCache = false;
for (SizeType32 maxAttenWin : getMaxAttentionWindowVec())
{
if (maxAttenWin != getMaxSequenceLen())
{
enableCyclicKvCache = true;
break;
}
}
// Below assertion should be removed once SWA/VSWA is no longer cyclic.
TLLM_CHECK_WITH_INFO(
getMaxBeamWidth() == 1 || !enableCyclicKvCache, "Can't support cyclic kv cache with beam search.");
// init KV cache block manager
auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = mModelConfig.getNumKvHeadsPerLayerLocalRange(
mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank(), isCrossAttention);
auto numKvHeadsPerLayer = std::vector<SizeType32>(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd);
auto maxAttentionWindowVec = getMaxAttentionWindowVec();
if (kvCacheType != KvCacheType::kSELF) // TODO(nhaber): more foolproof way of initing cross-kvcache-manager
{
maxAttentionWindowVec = std::vector<SizeType32>{mModelConfig.getMaxEncoderLen()};
}
auto const numLayers = static_cast<SizeType32>(numKvHeadsPerLayer.size());
auto const windowSizeToLayers = KVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, numLayers);
auto blocksPerWindow = KVCacheManager::calculateMaxNumBlocks(kvCacheConfig, isCrossAttention, kvDtype, mModelConfig,
mWorldConfig, windowSizeToLayers, freePrimaryMemBytes, freeSecondaryMemBytes, extraCostMemory, 2);
// now we check if any of the window sizes is too large for at least one sequence to fit in kvCache
// this can happen if e.g. maxSeqLen is deduced from the model and is too large
// and user also didn't provide maxAttentionWindow, which leads it to be equal to maxSeqLen
if (kvCacheType == KvCacheType::kSELF)
{
std::tie(blocksPerWindow, maxAttentionWindowVec) = clampWindowSizesToFitAtLeastOneSequence(blocksPerWindow);
}
kv_cache_manager::TempAttentionWindowInputs tempAttentionWindowInputs;
tempAttentionWindowInputs.pagedContextFMHA = mModelConfig.getPagedContextFMHA();
tempAttentionWindowInputs.maxInputLen = getMaxInputLen();
tempAttentionWindowInputs.maxNumTokens = getMaxNumTokens().value();
if (kvCacheType == KvCacheType::kCROSS && kvCacheConfig.enableBlockReuse)
{
TLLM_LOG_INFO(
"Cross KV cache does not support reuse because cross attention depends on encoder and decoder input ids. "
"Thus, KV cache reuse is disabled for cross KV cache.");
}
auto const enableBlockReuse = kvCacheType == KvCacheType::kSELF ? kvCacheConfig.enableBlockReuse : false;
auto const sizePerHead = mModelConfig.getSizePerHead();
auto kvCacheManager = std::make_unique<KVCacheManager>(numKvHeadsPerLayer, sizePerHead, tokensPerBlock,
blocksPerWindow, getMaxNumSequences(), getMaxBeamWidth(), maxAttentionWindowVec, tempAttentionWindowInputs,
kvDtype, getSinkTokenLen(), mRuntime->getStreamPtr(), std::nullopt, enableBlockReuse,
kvCacheConfig.onboardBlocks, kvCacheType, kvCacheConfig.secondaryOffloadMinPriority,
kvCacheConfig.eventBufferMaxSize > 0
? std::make_unique<kv_cache_manager::KVCacheEventManager>(kvCacheConfig.eventBufferMaxSize)
: nullptr,
false, kvCacheConfig.enablePartialReuse, kvCacheConfig.copyOnPartialReuse);
reshapeKvTensors(kvCacheManager->getOffsetTableDimensions());
kvCacheManager->allocatePools(kvCacheConfig.useUvm);
TensorMap inputBuffers;
TensorPtr poolPointers = kvCacheManager->getBlockPoolPointers();
TensorPtr poolMapping = kvCacheManager->getLayerToPoolMapping();
if (kvCacheType == KvCacheType::kSELF)
{
inputBuffers.insert_or_assign("host_kv_cache_pool_pointers", std::move(poolPointers));
inputBuffers.insert_or_assign("host_kv_cache_pool_mapping", std::move(poolMapping));
}
else
{
inputBuffers.insert_or_assign("host_cross_kv_cache_pool_pointers", std::move(poolPointers));
inputBuffers.insert_or_assign("host_cross_kv_cache_pool_mapping", std::move(poolMapping));
}
mRuntime->setStaticInputTensors(inputBuffers);
// Emit the `created` event
kvCacheManager->flushIterationEvents();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return kvCacheManager;
}
void TrtGptModelInflightBatching::createRnnStateManager()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK_WITH_INFO(mModelConfig.isRnnBased(), "RnnStateManager is only needed by RNN based model.");
mRnnStateManager = std::make_unique<RnnStateManager>(
getMaxNumSequences(), mModelConfig, mWorldConfig, mRuntime->getBufferManager());
TensorMap inputBuffers;
mRnnStateManager->getPtrBuffers(inputBuffers, mModelConfig, mWorldConfig);
mRuntime->setStaticInputTensors(inputBuffers);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::createCustomAllReduceWorkspace()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(mWorldConfig.isTensorParallel());
auto const& manager = mRuntime->getBufferManager();
auto const hiddenSize = mModelConfig.getHiddenSize();
mAllReduceBuffers = std::make_unique<AllReduceBuffers>(getMaxBatchSize(), getMaxBeamWidth(), getMaxSequenceLen(),
hiddenSize, manager, mWorldConfig, mRuntime->isUserBufferEnabled());
TensorMap inputBuffers;
inputBuffers.insert_or_assign("all_reduce_workspace", mAllReduceBuffers->mAllReduceCommPtrs);
mRuntime->setStaticInputTensors(inputBuffers);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::createRuntimePerfKnobsTensor(
executor::ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
SizeType32 constexpr perfKnobSize{16};
mExtendedRuntimePerfKnobsHost = BufferManager::cpu(ITensor::makeShape({perfKnobSize}), nvinfer1::DataType::kINT64);
auto* runtimePerfKnobsHostPtr = bufferCast<int64_t>(*mExtendedRuntimePerfKnobsHost);
std::fill_n(runtimePerfKnobsHostPtr, perfKnobSize, -1);
SizeType32 multiBlockModeVal = extendedRuntimePerfKnobConfig.getMultiBlockMode() ? 1 : 0;
SizeType32 enableContextFMHAFP32AccVal = extendedRuntimePerfKnobConfig.getEnableContextFMHAFP32Acc() ? 1 : 0;
runtimePerfKnobsHostPtr[0] = multiBlockModeVal;
runtimePerfKnobsHostPtr[1] = enableContextFMHAFP32AccVal;
TensorMap inputBuffers;
inputBuffers.insert_or_assign("host_runtime_perf_knobs", mExtendedRuntimePerfKnobsHost);
mRuntime->setStaticInputTensors(inputBuffers);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::terminateRequest(LlmRequestPtr const& llmReq, bool pause)
{
utils::terminateRequest(
*mSeqSlotManager, *llmReq, getMaxInputLen(), mKvCacheManager, mCrossKvCacheManager, mPeftCacheManager, pause);
}
void TrtGptModelInflightBatching::terminateRequestSync(
LlmRequestPtr const& llmRequest, executor::FinishReason finishReason)
{
TLLM_LOG_DEBUG("Registering termination for request %lu with finish reason %d", llmRequest->mRequestId,
static_cast<int>(finishReason));
mReqIdsToTerminate.try_emplace(llmRequest->mRequestId, finishReason);
}
TrtGptModelInflightBatching::IterationStatsIFB TrtGptModelInflightBatching::fillIterationStats(
ScheduledRequests const& scheduledRequests, RequestVector const& requestsToPause)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(fillIterationStats);
IterationStatsIFB iterationStatsIfb{mMicroBatchId};
iterationStatsIfb.numCtxRequests = scheduledRequests.contextRequests.size();
iterationStatsIfb.numGenRequests = scheduledRequests.generationRequests.size();
iterationStatsIfb.avgNumDecodedTokensPerIter = 0;
auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId();
auto const& buffers = mBuffers.at(contextBufferId);
iterationStatsIfb.numCtxTokens = buffers->getNumContextTokens();
for (auto const& llmReq : scheduledRequests.contextRequests)
{
iterationStatsIfb.scheduledRequests.insert(llmReq->mRequestId);
}
for (auto const& llmReq : scheduledRequests.generationRequests)
{
iterationStatsIfb.scheduledRequests.insert(llmReq->mRequestId);
iterationStatsIfb.avgNumDecodedTokensPerIter += llmReq->getAvgDecodedTokensPerIter();
}
if (iterationStatsIfb.numGenRequests > 0)
{
iterationStatsIfb.avgNumDecodedTokensPerIter /= iterationStatsIfb.numGenRequests;
TLLM_LOG_DEBUG(
"iterationStatsIfb.avgNumDecodedTokensPerIter = %.2f", iterationStatsIfb.avgNumDecodedTokensPerIter);
}
for (auto const& llmReq : requestsToPause)
{
iterationStatsIfb.pausedRequests.insert(llmReq->mRequestId);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return iterationStatsIfb;
}
void TrtGptModelInflightBatching::forwardSync()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE_WITH_NAME(range, "TrtGptModelInflightBatching::forwardSync");
TLLM_CUDA_CHECK(cudaSetDevice(mWorldConfig.getDevice()));
if (!mWorldConfig.isLastPipelineParallelRank())
{
mAsyncSendWaitThread->waitStop();
}
auto& currRequests = mMicroBatchScheduledRequests.at(mMicroBatchId);
if (!currRequests.empty())
{
if (!mWorldConfig.isPipelineParallel() || !mWorldConfig.isLastPipelineParallelRank())
{
for (auto& hdl : mDecStepAsyncSndHdls)
{
TLLM_CHECK_WITH_INFO(hdl.get() == nullptr, "decoderSync handle must be nullptr.");
}
// Wait for decoding for requests in flight for the current micro batch
auto& decoderWaitEvent = mDecoderFinishedEvents.at(mMicroBatchId);
mDecStepAsyncSndHdls = decoderSync(currRequests, decoderWaitEvent);
decoderWaitEvent.reset();
if (!mWorldConfig.isLastPipelineParallelRank())
{
mAsyncSendWaitThread->notifyStart();
}
}
else
{
for (auto const& requests : {currRequests.contextRequests, currRequests.generationRequests})
{
for (auto const& llmReq : requests)
{
for (SizeType32 beam = 0; beam < llmReq->mSamplingConfig.beamWidth; ++beam)
{
llmReq->setNumPreDecodedTokens(0, beam);
}
if (llmReq->isGenerationToCompleteState())
{
llmReq->setState(LlmRequestState::kGENERATION_COMPLETE);
terminateRequest(llmReq);
}
}
}
}
(*mPauseRequests)(currRequests.contextRequests, mInflightReqIds, mReqIdsToPause, true, *mSeqSlotManager,
mKvCacheManager, mCrossKvCacheManager, mPeftCacheManager);
(*mPauseRequests)(currRequests.generationRequests, mInflightReqIds, mReqIdsToPause, true, *mSeqSlotManager,
mKvCacheManager, mCrossKvCacheManager, mPeftCacheManager);
if (!mReqIdsToTerminate.empty())
{
for (auto const& requests : {currRequests.contextRequests, currRequests.generationRequests})
{
for (auto const& llmReq : requests)
{
if (mReqIdsToTerminate.count(llmReq->mRequestId) != 0U)
{
if (!llmReq->isGenerationCompleteState())
{
TLLM_LOG_DEBUG("Terminating request %lu with finish reason %d", llmReq->mRequestId,
static_cast<int>(mReqIdsToTerminate[llmReq->mRequestId]));
terminateRequest(llmReq);
llmReq->finishByReason(mReqIdsToTerminate[llmReq->mRequestId]);
llmReq->clearGeneratedTokens();
}
mReqIdsToTerminate.erase(llmReq->mRequestId);
}
}
}
}
// Finished context requests have been moved to generationRequests by moveFinishedContextRequestsToGeneration
for (auto const& llmReq : currRequests.generationRequests)
{
// If a context-only request is finished, send its KV cache and mark it.
if (llmReq->isContextOnlyRequest() && llmReq->isContextFinished())
{
// TODO: skip if sending layer-wise
{
TLLM_CHECK_WITH_INFO(
mCacheTransceiver, "Disaggregated serving is not enabled, please check the configuration.");
mCacheTransceiver->respondAndSendAsync(llmReq.get());
}
mSeqSlotManager->freeSequenceSlot(llmReq->mRequestId);
}
}
}
// report profile data
auto const bufferId = getFusedBufferId();
auto const contextId = mBuffers[bufferId]->getContextIndex();
if (mRuntime->hasLayerProfiler(contextId))
{
mRuntime->reportToProfiler(contextId);
}
if (mCacheTransceiver)
{
mCacheTransceiver->checkContextTransferStatus(0);
}
++mIterCounter;
if (mKvCacheManager)
{
mKvCacheManager->flushIterationEvents();
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::storeContextBlocks(std::shared_ptr<LlmRequest> const& llmReq)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// TMJ - Note
// Make context blocks reusable immediately after context phase finishes.
// For chunked contexts, this occurs in step that processes last context chunk.
// isLastContextChunk() is always true for non-chunked contexts.
// This check is made in code that calls storeContextBlocks, so omitted here.
if (mKvCacheManager)
{
mKvCacheManager->storeContextBlocks(*llmReq);
}
if (mCrossKvCacheManager)
{
mCrossKvCacheManager->storeContextBlocks(*llmReq);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::resetIterationStats()
{
mLastIterationStatsIFB = IterationStatsIFB{mMicroBatchId};
}
void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE_WITH_NAME(range, "TrtGptModelInflightBatching::forwardAsync");
TLLM_CUDA_CHECK(cudaSetDevice(mWorldConfig.getDevice()));
try
{
verifyRequests(activeRequests);
if (mModelConfig.isTransformerBased() && getKVCacheManager() && mCacheTransceiver)
{
checkDisaggGenTransferStatus(activeRequests);
}
auto& currRequests = mMicroBatchScheduledRequests.at(mMicroBatchId);
// Get a new set of requests for that context
// The scheduler will not include any requests that are (i) still in encoder state if encoder-decoder models OR
// (ii) already in flight for decoder models
TLLM_LOG_DEBUG("Running DECODER request scheduler");
auto [fittingRequests, fittingDisaggGenInitRequests, requestsToPause]
= (*mCapacityScheduler)(activeRequests, mKvCacheManager, mPeftCacheManager, mCrossKvCacheManager);
// Remove from fitting requests the requests that cannot be scheduled due to disagg KV cache transfer
if (mModelConfig.isTransformerBased() && getKVCacheManager() && mCacheTransceiver)
{
prepareDisaggGenInitRequests(activeRequests, fittingDisaggGenInitRequests);
}
if (fittingRequests.empty() && fittingDisaggGenInitRequests.empty())
{
TLLM_LOG_WARNING(
"CapacityScheduler didn't schedule any requests, probably because of insufficient resources such as KV "
"cache, will try wait for KV cache transfer to complete");
if (mCacheTransceiver)
{
mCacheTransceiver->checkContextTransferStatus(1);
// will free kvCache in next iteration.
}
}
std::tie(currRequests.contextRequests, currRequests.generationRequests)
= (*mMicroBatchScheduler)(fittingRequests, mInflightReqIds, mMaxBatchSizeRuntime, mMaxNumTokensRuntime);
TLLM_CHECK(currRequests.size() <= static_cast<size_t>(getMaxBatchSize()));
(*mPauseRequests)(requestsToPause, mInflightReqIds, mReqIdsToPause, false, *mSeqSlotManager, mKvCacheManager,
mCrossKvCacheManager, mPeftCacheManager);
if (mUseSeamlessLookahead)
{
changeSpecDecMode(currRequests);
}
if (!currRequests.empty())
{
TLLM_LOG_DEBUG("Running DECODER model with batch size: %lu", currRequests.size());
// For overlap don't store inflight requests, so they are not skipped in scheduler
if (!isTrtOverlap())
{
NVTX3_SCOPED_RANGE(updateInflightReqIds);
// Add requests to in-flight set, so they can be skipped in other micro batches
for (auto const& requests : {currRequests.contextRequests, currRequests.generationRequests})
{
for (auto const& llmReq : requests)
{
TLLM_LOG_DEBUG("request with ID %lu added to DECODER model inflight set", llmReq->mRequestId);
mInflightReqIds.insert(llmReq->mRequestId);
}
}
}
(*mAssignReqSeqSlots)(*mSeqSlotManager, currRequests.contextRequests, currRequests.generationRequests);
if (mKvCacheManager)
{
(*mAllocateKvCache)(*mKvCacheManager, currRequests.contextRequests, currRequests.generationRequests,
mModelConfig, mCrossKvCacheManager);
}
mPeftTables.at(mMicroBatchId)
= mPeftCacheManager->ensureBatch(currRequests.contextRequests, currRequests.generationRequests, true);
// Do decoder setup before context phase if model needs to setup buffers for the context phase.
if (mModelConfig.getSpeculativeDecodingMode().needsDecoderPrologue())
{
auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId();
setupDecoderStep(currRequests.contextRequests, *mBuffers.at(contextBufferId),
mDecoderInputBuffers.at(getFusedBufferId()));
}
else
{
prepareDistGenBufferAndDecoder(currRequests.generationRequests);
}
executeBatch(currRequests);
if (mWorldConfig.isLastPipelineParallelRank() && mGuidedDecoder)
{
// XGrammar: build maskcache for context requests and perform maskgen for all requests
// These need to be overlapped with the kernel execution of forward step
mGuidedDecoder->build(currRequests);
}
sync_check_cuda_error(mRuntime->getStream().get());
// Postpone decoder setup if model does not need to setup buffers for the context phase.
if (!mModelConfig.getSpeculativeDecodingMode().needsDecoderPrologue())
{
auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId();
setupDecoderStep(currRequests.contextRequests, *mBuffers.at(contextBufferId),
mDecoderInputBuffers.at(getFusedBufferId()));
}
sync_check_cuda_error(mRuntime->getStream().get());
if (isTrtOverlap())
{
// WAR: Because the decoder is not stateless (yet) a sync is needed between
// decoder execution and next decoder step preparation.
auto const prevMicroBatchId = getPrevMicroBatchId(mMicroBatchId);
auto& prevDecoderFinishedEvent = mDecoderFinishedEvents.at(prevMicroBatchId);
if (prevDecoderFinishedEvent)
{
prevDecoderFinishedEvent->synchronize();
}
}
auto& decoderFinishedEvent = mDecoderFinishedEvents.at(mMicroBatchId);
TLLM_CHECK_WITH_INFO(!decoderFinishedEvent.has_value(), "decoderFinishedEvent must be nullopt.");
decoderFinishedEvent = mWorldConfig.isLastPipelineParallelRank()
? std::make_optional(decoderStepAsync(currRequests))
: std::nullopt;
mLastIterationStatsIFB = fillIterationStats(currRequests, requestsToPause);
for (auto const& requests : {currRequests.contextRequests, currRequests.generationRequests})
{
for (auto const& llmReq : requests)
{
if (llmReq->isContextInitState())
{
llmReq->moveToNextContextChunk();
if (llmReq->getContextRemainingLength() == 0)
{
TLLM_LOG_DEBUG("[RANK %d] request with ID %lu finishes decoder ctx phase",
COMM_SESSION.getRank(), llmReq->mRequestId);
llmReq->setState(LlmRequestState::kGENERATION_IN_PROGRESS);
// for encoder-decoder models, free encoder output buffers after decoder context phase is
// completed
if (llmReq->getEncoderTokens().has_value())
{
llmReq->freeEncoderOutputBuffers();
}
storeContextBlocks(llmReq);
if (isTrtOverlap() && llmReq->willCompleteNextIteration())
{
// This prohibits the request from being scheduled for another iteration if only one
// iteration is expected.
llmReq->setState(LlmRequestState::kGENERATION_TO_COMPLETE);
}
}
}
else if (llmReq->isGenerationInProgressState())
{
TLLM_LOG_DEBUG("request with ID %lu forwards a step in decoder gen phase", llmReq->mRequestId);
}
}
}
utils::moveFinishedContextRequestsToGeneration(currRequests);
}
else
{
mLastIterationStatsIFB = IterationStatsIFB{mMicroBatchId};
}
if (mWorldConfig.isPipelineParallel() && mWorldConfig.isLastPipelineParallelRank())
{
mAsyncSendWaitThread->waitStop();
if (!currRequests.empty())
{
for (auto& hdl : mDecStepAsyncSndHdls)
{
TLLM_CHECK_WITH_INFO(hdl.get() == nullptr, "decoderSync handle must be nullptr.");
}
// Wait for decoding for requests in flight for the current micro batch
auto& decoderFinishedEvent = mDecoderFinishedEvents.at(mMicroBatchId);
mDecStepAsyncSndHdls = decoderSync(currRequests, decoderFinishedEvent);
decoderFinishedEvent.reset();
mAsyncSendWaitThread->notifyStart();
}
}
// Update the micro batch ID
mMicroBatchId = getNextMicroBatchId(mMicroBatchId);
}
// In case of error, we need to free the batch slot associated with those requests
catch (std::exception const&)
{
try
{
for (auto const& llmReq : activeRequests)
{
terminateRequest(llmReq);
}
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR("forwardAsync catch-all catch block that runs `terminateRequest` has failed with:");
TLLM_LOG_EXCEPTION(e);
TLLM_LOG_ERROR("Rethrowing *outer* exception:");
}
throw;
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::setRuntimeBatchSize(SizeType32 runtimeMaxBatchSize)
{
mMaxBatchSizeTunerRecommended = runtimeMaxBatchSize;
mMaxBatchSizeRuntime = std::min(getMaxBatchSize(), runtimeMaxBatchSize);
}
SizeType32 TrtGptModelInflightBatching::getRuntimeBatchSize() const
{
return mMaxBatchSizeRuntime;
}
void TrtGptModelInflightBatching::setRuntimeMaxNumTokens(SizeType32 runtimeMaxNumTokens)
{
mMaxNumTokensTunerRecommended = runtimeMaxNumTokens;
mMaxNumTokensRuntime
= (mMaxNumTokensStatic) ? std::min(mMaxNumTokensStatic.value(), runtimeMaxNumTokens) : runtimeMaxNumTokens;
}
void TrtGptModelInflightBatching::updatePeftCache(std::shared_ptr<LlmRequest> const& llmRequest)
{
mPeftCacheManager->addRequestPeft(llmRequest, true);
}
runtime::BufferManager const& TrtGptModelInflightBatching::getBufferManager() const
{
return mRuntime->getBufferManager();
}
BufferManager::CudaStreamPtr TrtGptModelInflightBatching::getRuntimeStreamPtr() const
{
return mRuntime->getStreamPtr();
}
void TrtGptModelInflightBatching::executeContext(SizeType32 runtimeContextId, SizeType32 bufferId)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(executeContext);
auto const& currBatchState = mBuffers[bufferId]->getBatchState();
bool hasCudaGraph = false;
// If batch state is context only, do not capture/launch graph and execute the engine as is.
if (isCudaGraphMode() && !currBatchState.isAnyContext())
{
auto cudaGraphOpt = mCudaGraphExecutorCaches[bufferId].get(currBatchState);
// If graph exists for current batch state, launch it.
if (cudaGraphOpt.has_value())
{
hasCudaGraph = true;
}
}
// If there is no graph for current state, execute the engine.
if (!hasCudaGraph)
{
auto enqueueSuccessful = mRuntime->executeContext(runtimeContextId);
if (!enqueueSuccessful)
{
throw std::runtime_error("Executing TRT engine failed!");
}
}
else
{
// Launch graph.
auto cudaGraphOpt = mCudaGraphExecutorCaches[bufferId].get(currBatchState);
cudaGraphOpt.value()->launch(mRuntime->getStream());
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::setLayerProfiler()
{
mRuntime->setLayerProfiler();
}
std::string TrtGptModelInflightBatching::getLayerProfileInfo() const
{
return mRuntime->getLayerProfileInfo();
}
void TrtGptModelInflightBatching::verifyRequests(RequestList const& activeRequests)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(verifyRequests);
if (activeRequests.empty())
{
return;
}
auto const& firstRequest = activeRequests.front();
auto const firstRequestId = firstRequest->mRequestId;
auto const firstBeamWidth = firstRequest->mSamplingConfig.beamWidth;
for (auto const& llmReq : activeRequests)
{
auto const beamWidth = llmReq->mSamplingConfig.beamWidth;
auto const draftLength = llmReq->getNumDraftTokens();
auto const maxDraftLength = mModelConfig.getMaxDecodingDraftTokens();
TLLM_CHECK_WITH_INFO(beamWidth == 1 || draftLength == 0, "Can't use speculative decoding with beam search.");
TLLM_CHECK_WITH_INFO(draftLength <= maxDraftLength,
"Number of draft tokens (%d) is larger than maximum number of draft tokens (%d)", draftLength,
maxDraftLength);
// FIXME: Remove this check when varying beam width is supported
{
TLLM_CHECK_WITH_INFO(beamWidth == firstBeamWidth,
"All active requests must have same beam width, "
"but request %lu with beam width %d differs from first request %lu with beam width %d",
llmReq->mRequestId, beamWidth, firstRequestId, firstBeamWidth);
}
}
if (firstBeamWidth != mOperatingBeamWidth)
{
changeBeamWidth(firstBeamWidth);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::executeBatch(ScheduledRequests const& scheduledRequests)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(executeBatch);
if (!mCtxGenFusion)
{
if (!scheduledRequests.contextRequests.empty())
{
auto const bufferId = getContextBufferId();
executeStep(scheduledRequests.contextRequests, {}, bufferId);
}
if (!scheduledRequests.generationRequests.empty())
{
auto const bufferId = getGenerationBufferId();
executeStep({}, scheduledRequests.generationRequests, bufferId);
}
}
else
{
auto const bufferId = getFusedBufferId();
executeStep(scheduledRequests.contextRequests, scheduledRequests.generationRequests, bufferId);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::createRuntimeContexts()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
mRuntime->clearContexts();
auto const numProfiles = mRuntime->getNbProfiles();
for (auto i = 0; i < numProfiles; ++i)
{
mRuntime->addContext(i);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
namespace
{
// TODO: move this somewhere else?
/**
* This function logic is also implemented in tensorrt_llm/_torch/pyexecutor/_util.py get_decoding_mode().
*/
executor::DecodingMode getDecodingMode(SpeculativeDecodingMode specDecodingMode,
std::optional<executor::DecodingMode> const& decodingModeOpt, runtime::SizeType32 const beamWidth)
{
auto getDefaultDecodingMode = [beamWidth](std::optional<executor::DecodingMode> const& decodingModeOpt)
{
if (decodingModeOpt.has_value() && !decodingModeOpt->isAuto())
{
return decodingModeOpt.value();
}
return (beamWidth == 1) ? executor::DecodingMode::TopKTopP() : executor::DecodingMode::BeamSearch();
};
auto decodingMode = getDefaultDecodingMode(decodingModeOpt);
// Variable-Beam-Width-Search (special mode of Beam-Search) is enabled.
if (decodingMode.isBeamSearch() && decodingMode.isUseVariableBeamWidthSearch())
{
TLLM_LOG_INFO("Variable-Beam-Width-Search is enabled");
}
// Overwrite decoding mode when beam width is one.
if (beamWidth == 1 && decodingMode.isBeamSearch())
{
TLLM_LOG_WARNING(
"Beam width is set to 1, but decoding mode is BeamSearch. Overwriting decoding mode to TopKTopP.");
decodingMode = executor::DecodingMode::TopKTopP();
}
// Overwrite decoding mode when Medusa is used.
if (specDecodingMode.isMedusa() && !decodingMode.isMedusa())
{
TLLM_LOG_WARNING("Model is Medusa, but decoding mode is not Medusa. Overwriting decoding mode to Medusa.");
decodingMode = executor::DecodingMode::Medusa();
}
// Overwrite decoding mode when Medusa is not used.
if (!specDecodingMode.isMedusa() && decodingMode.isMedusa())
{
TLLM_LOG_WARNING("Model is not Medusa, but decoding mode is Medusa. Overwriting decoding mode.");
decodingMode = getDefaultDecodingMode(decodingModeOpt);
}
// Overwrite decoding mode when lookahead decoding is used.
if (specDecodingMode.isLookaheadDecoding() && !decodingMode.isLookahead())
{
TLLM_LOG_WARNING(
"Model is Lookahead, but decoding mode is not Lookahead. Overwriting decoding mode to Lookahead.");
decodingMode = executor::DecodingMode::Lookahead();
}
// Overwrite decoding mode when lookahead decoding is not used.
if (!specDecodingMode.isLookaheadDecoding() && decodingMode.isLookahead())
{
TLLM_LOG_WARNING(
"Model is not built with Lookahead decoding, but decoding mode is Lookahead. Overwriting decoding "
"mode.");
decodingMode = getDefaultDecodingMode(decodingModeOpt);
}
// Overwrite decoding mode when 'explicit draft tokens' is used.
if (specDecodingMode.isExplicitDraftTokens() && !decodingMode.isExplicitDraftTokens())
{
TLLM_LOG_WARNING(
"Model is built with 'explicit draft tokens' decoding, but decoding mode is something else. Overwriting "
"decoding mode.");
decodingMode = executor::DecodingMode::ExplicitDraftTokens();
}
// Overwrite decoding mode when 'explicit draft tokens' is not used.
if (!specDecodingMode.isExplicitDraftTokens() && decodingMode.isExplicitDraftTokens())
{
TLLM_LOG_WARNING(
"Model is not built with 'explicit draft tokens' decoding, but decoding mode is set to it. Overwriting "
"decoding "
"mode to default.");
decodingMode = getDefaultDecodingMode(decodingModeOpt);
}
// Overwrite decoding mode when EAGLE is used.
if (specDecodingMode.isEagle() && !decodingMode.isEagle())
{
TLLM_LOG_WARNING("Model is Eagle, but decoding mode is not Eagle. Overwriting decoding mode to Eagle.");
decodingMode = executor::DecodingMode::Eagle();
}
// Overwrite decoding mode when Eagle is not used.
if (!specDecodingMode.isEagle() && decodingMode.isEagle())
{
TLLM_LOG_WARNING("Model is not Eagle, but decoding mode is Eagle. Overwriting decoding mode.");
decodingMode = getDefaultDecodingMode(decodingModeOpt);
}
if (specDecodingMode.isDraftTokensExternal())
{
TLLM_LOG_WARNING("Overwriting decoding mode to external draft token");
decodingMode = executor::DecodingMode::ExternalDraftTokens();
}
TLLM_LOG_DEBUG("DecodingMode: %s", decodingMode.getName());
return decodingMode;
}
} // namespace
void TrtGptModelInflightBatching::createDecoder(std::optional<executor::DecodingMode> const& decodingModeOpt)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (mWorldConfig.isLastPipelineParallelRank())
{
auto decoderType = mRuntime->getEngine().getTensorDataType("logits");
auto const decodingMode
= getDecodingMode(mModelConfig.getSpeculativeDecodingMode(), decodingModeOpt, mOperatingBeamWidth);
if (decodingMode.isExplicitDraftTokens())
{
// There are no logits in Explicit draft tokens model.
decoderType = mModelConfig.getDataType();
// Decoder is not instantiated for bf16. We use half to get the same data size
// and explicitly pass dtype to redrafter that has bf16 kernels.
if (decoderType == nvinfer1::DataType::kBF16)
{
decoderType = nvinfer1::DataType::kHALF;
}
}
mDecoder = std::make_unique<runtime::GptDecoderBatched>(mRuntime->getStreamPtr());
mDecoder->setup(
decodingMode, getMaxNumSequences(), mOperatingBeamWidth, decoderType, mModelConfig, mWorldConfig);
mDecoderState = std::make_unique<runtime::decoder::DecoderState>(decoderType, mRuntime->getBufferManager());
if (!mModelConfig.getSpeculativeDecodingMode().isNone())
{
mDecoderState->allocateSpeculativeDecodingBuffers(
mModelConfig.getSpeculativeDecodingMode(), decoderType, mRuntime->getBufferManager());
}
mDecoderState->setup(getMaxNumSequences(), mOperatingBeamWidth, getMaxAttentionWindow(), getSinkTokenLen(),
getMaxSequenceLen(), mModelConfig, mWorldConfig, mRuntime->getBufferManager());
mDecoderState->setupSpeculativeDecoding(mDecoderState->getSpeculativeDecodingMode(),
mModelConfig.getMaxDecodingTokens(), mModelConfig, mWorldConfig, mRuntime->getBufferManager());
}
else
{
auto constexpr decoderDummyType = TRTDataType<float>::value;
mDecoderState
= std::make_unique<runtime::decoder::DecoderState>(decoderDummyType, mRuntime->getBufferManager());
mDecoderState->setupCacheIndirection(getMaxNumSequences(), mOperatingBeamWidth, getMaxAttentionWindow());
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::createBuffers(executor::DecodingConfig const& decodingConfig,
std::optional<std::vector<executor::AdditionalModelOutput>> const& additionalModelOutputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
mBuffers.clear();
for (SizeType32 i = 0; i < mNumBuffers; ++i)
{
mBuffers.emplace_back(
std::make_unique<RuntimeBuffers>(getMaxBatchSize(), mOperatingBeamWidth, getMaxAttentionWindowVec(),
getMaxAttentionWindow(), getSinkTokenLen(), *mRuntime, mModelConfig, mWorldConfig, decodingConfig,
getGatherGenerationLogits(), getMaxNumTokens(), additionalModelOutputs, mPromptTableOffloading));
}
mDecoderInputBuffers.clear();
mDecoderOutputBuffers.clear();
for (SizeType32 i = 0; i < mNumMicroBatches; ++i)
{
mDecoderInputBuffers.emplace_back(
getMaxNumSequences(), getMaxBatchSize(), mModelConfig.getMaxDecodingTokens(), mRuntime->getBufferManager());
mDecoderOutputBuffers.emplace_back(getMaxNumSequences(), mOperatingBeamWidth, getMaxSequenceLen(),
mModelConfig.getMaxDecodingTokens(), mRuntime->getBufferManager());
}
mDecoderBuffers = std::make_unique<DecoderBuffers>(getMaxNumSequences(), mModelConfig.getMaxDecodingTokens(),
mRuntime->getBufferManager(), mModelConfig, mWorldConfig);
mSlotDecoderBuffers.clear();
for (SizeType32 i = 0; i < getMaxNumSequences(); ++i)
{
mSlotDecoderBuffers.emplace_back(std::make_unique<SlotDecoderBuffers>(
mOperatingBeamWidth, getMaxSequenceLen(), mRuntime->getBufferManager()));
}
mDecodingInputs.resize(mNumMicroBatches);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::prepareDisaggGenInitRequests(
RequestList const& activeRequests, RequestVector& newGenReqs)
{
NVTX3_SCOPED_RANGE(prepareDisaggGenInitRequests);
// Allocate KV cache by treating them as context requests
(*mAllocateKvCache)(*mKvCacheManager, newGenReqs, {}, mModelConfig, mCrossKvCacheManager);
// Initiate KV cache transfer
auto timeStart = std::chrono::steady_clock::now();
if (tc::getEnvDisaggBenchmarkGenOnly())
{
TLLM_LOG_DEBUG("Disaggregated generation only benchmark mode is enabled");
for (auto& req : newGenReqs)
{
req->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);
}
return;
}
auto const genInitReqNum = std::count_if(activeRequests.begin(), activeRequests.end(),
[](auto const& req) { return req->isDisaggGenerationInitState(); });
// Loop over the new disagg gen requests and trigger receive of KV cache
for (auto& newGenReq : newGenReqs)
{
TLLM_CHECK_WITH_INFO(
mCacheTransceiver, "Disaggregated serving is not enabled, please check the configuration.");
if (common::getEnvDisableKVCacheTransferOverlap())
{
mCacheTransceiver->requestAndReceiveSync(newGenReq.get());
}
else
{
mCacheTransceiver->requestAndReceiveAsync(newGenReq.get());
}
}
if (!common::getEnvDisableKVCacheTransferOverlap())
{
auto const blockTransfer = std::all_of(activeRequests.begin(), activeRequests.end(),
[](auto const& req) { return req->isDisaggGenerationTransmissionInProgress(); });
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"newGenReqs.size():%ld requests, activeRequests.size():%ld checkGenTransferStatus :%d original "
"gen_only_requests_num:%ld",
newGenReqs.size(), activeRequests.size(), blockTransfer, genInitReqNum);
mCacheTransceiver->checkGenTransferStatus(blockTransfer ? 1 : 0);
auto timeEnd = std::chrono::steady_clock::now();
auto duration = std::chrono::duration<float, std::milli>(timeEnd - timeStart).count();
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"receiveDisaggGenCache time:%f ms, "
"blockTransfer:%d,genInitReqNum:%ld,newGenReqs.size():%ld,activeRequests.size():%ld",
duration, blockTransfer, genInitReqNum, newGenReqs.size(), activeRequests.size());
}
return;
}
void TrtGptModelInflightBatching::checkDisaggGenTransferStatus(RequestList const& activeRequests)
{
NVTX3_SCOPED_RANGE(checkDisaggGenTransferStatus);
if (common::getEnvDisableKVCacheTransferOverlap())
{
return;
}
auto timeStart = std::chrono::steady_clock::now();
// TODO:
auto const needCheck = std::any_of(activeRequests.begin(), activeRequests.end(),
[](auto const& req) { return req->isDisaggGenerationTransmissionInProgress(); });
if (needCheck)
{
auto const needCheckOne = std::all_of(activeRequests.begin(), activeRequests.end(),
[](auto const& req) { return req->isDisaggGenerationTransmissionInProgress(); });
int atLeastNum = needCheckOne ? 1 : 0;
TLLM_LOG_DEBUG(
mpi::MpiComm::world().getRank(), "noPreppared requests, checkGenTransferStatus atLeastNum:%d", atLeastNum);
mCacheTransceiver->checkGenTransferStatus(atLeastNum);
auto timeEnd = std::chrono::steady_clock::now();
auto duration = std::chrono::duration<float, std::milli>(timeEnd - timeStart).count();
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"no Prepare checkDisaggGenTransferStatus time:%f ms, "
"needCheckOne:%d,needCheck:%ld,activeRequests.size():%ld",
duration, needCheckOne, needCheck, activeRequests.size());
}
}
void TrtGptModelInflightBatching::prepareDistGenBufferAndDecoder(RequestVector const& generationRequests)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// set decoderStep for disagg_generation
RequestVector cacheTransCompleteRequests;
for (auto const& request : generationRequests)
{
if (request->isDisaggGenerationTransmissionComplete())
{
cacheTransCompleteRequests.push_back((request));
}
}
if (!cacheTransCompleteRequests.empty())
{
auto timeStart = std::chrono::steady_clock::now();
auto const bufferId = getFusedBufferId();
auto& runtimeBuffers = *mBuffers[bufferId];
runtimeBuffers.prepareStep(cacheTransCompleteRequests, {}, getMaxBeamWidth(), getMaxAttentionWindow(),
*mDecoderState, mKvCacheManager.get(), mCrossKvCacheManager.get(), mRnnStateManager.get(),
mPeftTables[mMicroBatchId], *mRuntime, mModelConfig, mWorldConfig, getGatherGenerationLogits(),
isTrtOverlap());
auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId();
setupDecoderStep(
cacheTransCompleteRequests, *mBuffers.at(contextBufferId), mDecoderInputBuffers.at(getFusedBufferId()));
sync_check_cuda_error(mRuntime->getStream().get());
auto timeEnd = std::chrono::steady_clock::now();
auto duration = std::chrono::duration<float, std::milli>(timeEnd - timeStart).count();
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"prepareDistGenBufferAndDecoder time:%f ms , cacheTransCompleteRequests.size():%ld", duration,
cacheTransCompleteRequests.size());
}
for (auto& request : cacheTransCompleteRequests)
{
request->setState(LlmRequestState::kGENERATION_IN_PROGRESS);
request->setContextCurrentPosition(request->mPromptLen);
request->setDecodingIter(1);
auto const reqBeamWidth = request->mSamplingConfig.beamWidth;
auto firstGenTokens = request->getContextPhaseParams().value().getFirstGenTokens();
for (SizeType32 beam = 0; beam < reqBeamWidth; ++beam)
{
request->addNewToken(firstGenTokens.at(beam), beam);
}
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::debugIOTensors(RequestVector const& contextRequests,
RequestVector const& generationRequests, TensorMap const& inputMap, TensorMap const& outputMap)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(mDebugConfig);
auto const& manager = mRuntime->getBufferManager();
auto requestIds = utils::collectRequestIds(contextRequests, generationRequests);
if (mDebugConfig->getDebugTensorsMaxIterations() > 0)
{
mLastIterationDebugTensors.clear();
mLastIterationDebugTensors = utils::storeIOTensors(*mDebugConfig, requestIds, inputMap, outputMap, manager);
}
else
{
utils::dumpIOTensors(*mDebugConfig, mIterCounter, requestIds, inputMap, outputMap, mWorldConfig, manager);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
std::tuple<SizeType32, runtime::StringPtrMap<runtime::ITensor> const&, runtime::StringPtrMap<runtime::ITensor>&>
TrtGptModelInflightBatching::prepareBuffers(
RequestVector const& contextRequests, RequestVector const& generationRequests, SizeType32 bufferId)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(prepareBuffers);
auto& runtimeBuffers = *mBuffers.at(bufferId);
auto allNewTokens = mWorldConfig.isLastPipelineParallelRank()
? RuntimeBuffers::OptionalRef<runtime::ITensor const>(mDecoderState->getAllNewTokens())
: std::nullopt;
auto [optProfileId, inputMap, outputMap] = runtimeBuffers.prepareStep(contextRequests, generationRequests,
mOperatingBeamWidth, getMaxAttentionWindow(), *mDecoderState, mKvCacheManager.get(), mCrossKvCacheManager.get(),
mRnnStateManager.get(), mPeftTables[bufferId], *mRuntime, mModelConfig, mWorldConfig,
getGatherGenerationLogits(), isTrtOverlap(), allNewTokens);
// For Variable-Beam-Width-Search
mRuntime->setCurrentBeamWidths(
tensorrt_llm::batch_manager::utils::getRequestBeamWidths(contextRequests, generationRequests));
mRuntime->setInputTensors(optProfileId, inputMap);
mRuntime->setOutputTensors(optProfileId, outputMap);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return {optProfileId, inputMap, outputMap};
}
void TrtGptModelInflightBatching::prepareGraph(SizeType32 bufferId, SizeType32 optProfileId)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(prepareGraph);
auto const nextBatchState = mBuffers[bufferId]->getBatchState();
auto cudaGraphOpt = mCudaGraphExecutorCaches[bufferId].get(nextBatchState);
// If graph is not found in the cache, capture it.
if (!cudaGraphOpt.has_value())
{
// We need to prepare some tensors once again to properly set values for graph capture.
// Graph capture requires setting some tensors (e.g. past_kv_len)
// to the round_up(max_kv_cache_len, kKV_CACHE_LEN_CUDA_GRAPH_ROUND_SIZE)
// in order to capture the kernels with the large enough grid.
mBuffers[bufferId]->prepareBuffersForCudaGraph(getMaxSequenceLen());
auto cudaGraph = std::make_shared<utils::CudaGraphExecutor>();
cudaGraph->prepareNextGraph(mRuntime, optProfileId);
mCudaGraphExecutorCaches[bufferId].put(nextBatchState, cudaGraph);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::executeStep(
RequestVector const& contextRequests, RequestVector const& generationRequests, SizeType32 bufferId)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE_WITH_NAME(range,
"executeStep: " + std::to_string(contextRequests.size()) + " ctx reqs, "
+ std::to_string(generationRequests.size()) + " gen reqs");
if (mPromptTableOffloading)
{
prefetchNextPromptTableChunk(contextRequests, /* isFirstChunk */ true, bufferId);
}
auto [optProfileId, inputMap, outputMap] = prepareBuffers(contextRequests, generationRequests, bufferId);
if (mBuffers[bufferId]->transformerBuffers)
{
// Creation of context progress, or remains nullptr if not needed
std::shared_ptr<ContextProgress> progress = nullptr;
RequestVector layerWiseRequests;
if (common::getEnvDisaggLayerwise())
{
for (auto const& request : contextRequests)
{
bool const enableLayerWise = request->isContextOnlyRequest() && request->isLastContextChunk();
if (enableLayerWise)
{
layerWiseRequests.push_back(request);
}
}
}
// TODO: support layer-wise cross kv cache in encoder-decoder models
if (!layerWiseRequests.empty() && !mModelConfig.useCrossAttention())
{
int const numLayers = mModelConfig.getNbAttentionLayers(
mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
progress = std::make_shared<ContextProgress>(numLayers);
}
bufferCast<void*>(*mBuffers[bufferId]->transformerBuffers->contextProgressHost)[0] = progress.get();
if (progress)
{
TLLM_CHECK_WITH_INFO(
mCacheTransceiver, "Disaggregated serving is not enabled, please check the configuration.");
mCacheTransceiver->respondAndSendLayerWise(layerWiseRequests, progress);
}
}
if (mPromptTableOffloading)
{
prefetchNextPromptTableChunk(contextRequests, /* isFirstChunk */ false, bufferId);
}
executeContext(optProfileId, bufferId);
// If batch state has any context request, do not capture this graph.
if (isCudaGraphMode() && contextRequests.empty())
{
// Capture graph of current batch state during engine execution.
// This is based on the assumptions that
// a) We can hide CPU graph capture behind the GPU engine execution.
// b) Batch size in the next iterations won't change and we can reuse the graph multiple times.
prepareGraph(bufferId, optProfileId);
}
if (mDebugConfig)
{
debugIOTensors(contextRequests, generationRequests, inputMap, outputMap);
}
if (mAdditionalModelOutputs.has_value() && !mAdditionalModelOutputs.value().empty())
{
utils::copyAdditionalOutputs(
mAdditionalModelOutputs.value(), contextRequests, generationRequests, outputMap, getBufferManager());
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::setupDecoderStep(
RequestVector const& contextRequests, RuntimeBuffers const& buffers, DecoderInputBuffers& inputBuffers)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(setupDecoderStep);
if (mWorldConfig.isLastPipelineParallelRank() && !contextRequests.empty())
{
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs]
= (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests,
mRuntime->getBufferManager(), logitsType, inputBuffers, *mDecoderState, mRuntime->getStream(),
*mDecoder->getDecoderStream(), getMaxSequenceLen(), mOperatingBeamWidth, buffers.mMedusaBuffers);
auto const localBatchSize = batchSlots->getSize();
if (localBatchSize > 0)
{
auto samplingConfig = SamplingConfig(samplingConfigs);
mDecoder->getUnderlyingDecoder().setup(samplingConfig, localBatchSize, batchSlots,
{mDecoderState->getJointDecodingOutput()}, mModelConfig.getDataType(), lookaheadPrompt,
lookaheadAlgoConfigs);
auto const& stream = mDecoder->getDecoderStream();
CudaEvent event{};
stream->record(event);
mRuntime->getStreamPtr()->wait(event);
}
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::postProcessRequest(
LlmRequest& llmReq, std::vector<SizeType32> const& numDroppedTokens)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const seqSlot = llmReq.mSeqSlot.value();
auto const reqBeamWidth = llmReq.getBeamWidthByIter(true);
auto const& bufferManager = getBufferManager();
if (llmReq.getReturnGenerationLogits() && !llmReq.getGenerationLogitsFragments().empty())
{
TLLM_CHECK(!llmReq.isStreaming());
auto const genBufferId = mCtxGenFusion ? getFusedBufferId() : getGenerationBufferId();
auto& genRuntimeBuffers = *mBuffers.at(genBufferId);
auto constexpr beforeDecoder = false;
utils::copyGenerationLogits(
genRuntimeBuffers.generationLogitsCache, bufferManager, llmReq, beforeDecoder, numDroppedTokens);
bufferManager.getStream().synchronize();
}
if (mWorldConfig.isPipelineParallel())
{
// Send context logits from last to first PP rank
if (llmReq.getReturnContextLogits())
{
if (mWorldConfig.isLastPipelineParallelRank())
{
mMpiCommPipelinePara->send(
*(llmReq.getContextLogitsHost()), 0, mpi::MpiTag::kTrtGptModelInflightBatchingContextLogits);
}
else if (mWorldConfig.isFirstPipelineParallelRank())
{
mMpiCommPipelinePara->recv(*(llmReq.getContextLogitsHost()), mWorldConfig.getPipelineParallelism() - 1,
mpi::MpiTag::kTrtGptModelInflightBatchingContextLogits);
}
}
// Send generation logits from last to first PP rank
if (llmReq.getReturnGenerationLogits())
{
if (mWorldConfig.isLastPipelineParallelRank())
{
mMpiCommPipelinePara->send(
*(llmReq.getGenerationLogitsHost()), 0, mpi::MpiTag::kTrtGptModelInflightBatchingGenerationLogits);
}
else if (mWorldConfig.isFirstPipelineParallelRank())
{
mMpiCommPipelinePara->recv(*(llmReq.getGenerationLogitsHost()),
mWorldConfig.getPipelineParallelism() - 1,
mpi::MpiTag::kTrtGptModelInflightBatchingGenerationLogits);
}
}
}
if (reqBeamWidth == 1)
{
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return;
}
// Update mDecoderBuffers->slotOutputIdsHost and synchronize
getDecoderSlotHostOutputs(seqSlot, llmReq.returnLogProbs(), llmReq.mSamplingConfig, llmReq.isStreaming());
auto const* outputIdsHostData = bufferCast<TokenIdType>(*mSlotDecoderBuffers[seqSlot]->outputIdsHost);
auto const* sequenceLengthsHostData = bufferCast<SizeType32>(*mSlotDecoderBuffers[seqSlot]->sequenceLengthsHost);
auto const* cumLogProbsHostData = bufferCast<float>(*mSlotDecoderBuffers[seqSlot]->cumLogProbsHost);
auto logProbsHost = mSlotDecoderBuffers[seqSlot]->logProbsHost;
auto const* logProbsHostData = bufferCast<float>(*logProbsHost);
auto const& outputIdsShape = mSlotDecoderBuffers[seqSlot]->outputIdsHost->getShape();
auto const maxSeqLength = outputIdsShape.d[1];
std::vector<std::vector<TokenIdType>> generatedTokens(reqBeamWidth);
for (SizeType32 beam = 0; beam < reqBeamWidth; ++beam)
{
auto const* const begin = outputIdsHostData + tc::flat_index2(beam, llmReq.mPromptLen, maxSeqLength);
auto const generatedLength = sequenceLengthsHostData[beam] - llmReq.mPromptLen;
auto const* const end = begin + generatedLength;
generatedTokens[beam].assign(begin, end);
if (llmReq.returnLogProbs())
{
llmReq.setCumLogProb(cumLogProbsHostData[beam], beam);
auto const beginLogProbsOffset = reqBeamWidth == 1 ? llmReq.mPromptLen : 0;
auto const* const begin = logProbsHostData + beam * logProbsHost->getShape().d[1] + beginLogProbsOffset;
auto const* const end = begin + generatedLength;
LlmRequest::VecLogProbs logProbs(begin, end);
llmReq.setLogProbs(logProbs, beam);
}
}
// store the generated tokens into the mTokensGathered buffer
llmReq.setGeneratedTokens(generatedTokens);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::getDecoderSlotHostOutputs(
SizeType32 seqSlot, bool returnLogProbs, SamplingConfig const& samplingConfig, bool streaming)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (mWorldConfig.isLastPipelineParallelRank())
{
auto event = mDecoder->finalize(*mDecoderState, seqSlot, samplingConfig, streaming);
// Make sure that postprocessing is done before copying outputIds
mCopyBufferManager.getStream().wait(event.get());
auto sequenceLengths = mDecoderState->getSequenceLengths(seqSlot);
auto outputIds = mDecoderState->getGatheredIds(seqSlot);
auto cumLogProbs = mDecoderState->getCumLogProbs(seqSlot);
auto logProbs = mDecoderState->getLogProbs(seqSlot);
mCopyBufferManager.copy(*sequenceLengths, *mSlotDecoderBuffers[seqSlot]->sequenceLengths);
mCopyBufferManager.copy(*outputIds, *mSlotDecoderBuffers[seqSlot]->outputIds);
if (returnLogProbs)
{
mCopyBufferManager.copy(*cumLogProbs, *mSlotDecoderBuffers[seqSlot]->cumLogProbs);
mCopyBufferManager.copy(*logProbs, *mSlotDecoderBuffers[seqSlot]->logProbs);
}
if (mWorldConfig.isPipelineParallel())
{
// Make sure that postprocessing is done before sending outputIds
event.synchronize();
auto const peerSend = 0;
mDecSlotAsyncSndHdls.emplace_back(std::make_unique<DecoderSlotAsyncSend>(
outputIds, sequenceLengths, cumLogProbs, logProbs, returnLogProbs, *mMpiCommPipelinePara, peerSend));
}
}
else
{
auto const peerRecv = mWorldConfig.getPipelineParallelRank() == 0 ? mWorldConfig.getPipelineParallelism() - 1
: mWorldConfig.getPipelineParallelRank() - 1;
DecoderSlotAsyncSend::recv(*mSlotDecoderBuffers[seqSlot], returnLogProbs, *mMpiCommPipelinePara, peerRecv);
auto const peerSend = mWorldConfig.getPipelineParallelRank() + 1;
if (peerSend != mWorldConfig.getPipelineParallelism() - 1)
{
mDecSlotAsyncSndHdls.emplace_back(std::make_unique<DecoderSlotAsyncSend>(
*mSlotDecoderBuffers[seqSlot], returnLogProbs, *mMpiCommPipelinePara, peerSend));
}
}
sync_check_cuda_error(mRuntime->getStream().get());
// Here copy stream is synchronized after receiving decoderSlotOutputIdsView either by copy or by receive
// before copying to host on copy stream
runtime::CudaEvent beforeEvent{};
mRuntime->getStreamPtr()->record(beforeEvent);
mCopyBufferManager.getStream().wait(beforeEvent);
mCopyBufferManager.copy(*mSlotDecoderBuffers[seqSlot]->outputIds, *mSlotDecoderBuffers[seqSlot]->outputIdsHost);
mCopyBufferManager.copy(
*mSlotDecoderBuffers[seqSlot]->sequenceLengths, *mSlotDecoderBuffers[seqSlot]->sequenceLengthsHost);
if (returnLogProbs)
{
mCopyBufferManager.copy(
*mSlotDecoderBuffers[seqSlot]->cumLogProbs, *mSlotDecoderBuffers[seqSlot]->cumLogProbsHost);
mCopyBufferManager.copy(*mSlotDecoderBuffers[seqSlot]->logProbs, *mSlotDecoderBuffers[seqSlot]->logProbsHost);
}
// Make sure copy is done before continuing on host
mCopyBufferManager.getStream().synchronize();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
namespace
{
// Check if one of the request needs log probs, need to get from decoder and communicate
bool batchReturnLogProbs(ScheduledRequests const& scheduledRequests)
{
auto pred = [](auto const& llmReq) { return llmReq->returnLogProbs(); };
return std::any_of(scheduledRequests.contextRequests.begin(), scheduledRequests.contextRequests.end(), pred)
|| std::any_of(scheduledRequests.generationRequests.begin(), scheduledRequests.generationRequests.end(), pred);
}
} // namespace
runtime::CudaEvent TrtGptModelInflightBatching::decoderStepAsync(ScheduledRequests const& scheduledRequests)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(decoderStepAsync);
auto& decoderInputBuffers = mDecoderInputBuffers.at(getFusedBufferId());
auto& seqSlotLogits = decoderInputBuffers.logits;
auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId();
auto& contextRuntimeBuffers = mBuffers.at(contextBufferId);
auto const logitsIndex = (*mHandleContextLogits)(decoderInputBuffers, scheduledRequests.contextRequests,
contextRuntimeBuffers->logits, contextRuntimeBuffers->numContextLogits, mModelConfig,
mRuntime->getBufferManager(), mDecoderBuffers->draftBuffers, contextRuntimeBuffers->mMedusaBuffers);
auto const genLogitsIndex = mCtxGenFusion ? logitsIndex : 0;
auto const genBufferId = mCtxGenFusion ? getFusedBufferId() : getGenerationBufferId();
auto& genRuntimeBuffers = mBuffers.at(genBufferId);
(*mHandleGenerationLogits)(decoderInputBuffers, scheduledRequests.generationRequests, genRuntimeBuffers->logits,
genLogitsIndex, mModelConfig, mRuntime->getBufferManager(), *genRuntimeBuffers, mDecoderBuffers->draftBuffers);
if (mOperatingBeamWidth > 1)
{
copyCacheIndirectionFromOutputsToInputs(scheduledRequests, genBufferId);
}
mLogitsPostProcessorIsApplied
= (*mLogitsPostProcessor)(scheduledRequests.contextRequests, scheduledRequests.generationRequests,
mReplicateLogitsPostProcessor, seqSlotLogits, mWorldConfig, *mRuntime, mLogitsPostProcessorBatched);
if (mGuidedDecoder)
{
mGuidedDecoder->execute(scheduledRequests, mRuntime->getBufferManager(), seqSlotLogits);
}
auto const fusedBufferId = getFusedBufferId();
auto& fusedRuntimeBuffers = mBuffers.at(fusedBufferId);
auto& decodingInput = mDecodingInputs.at(mMicroBatchId);
decodingInput = (*mMakeDecodingBatchInputOutput)(scheduledRequests.contextRequests,
scheduledRequests.generationRequests, *mDecoderBuffers, mDecoderInputBuffers.at(fusedBufferId), *mDecoderState,
mModelConfig, getMaxNumSequences(), *fusedRuntimeBuffers);
auto decoderFinishEvent = mDecoder->forwardAsync(*mDecoderState, *decodingInput);
auto const returnLogProbs = batchReturnLogProbs(scheduledRequests);
auto updateDecoderBuffersEvent
= (*mUpdateDecoderBuffers)(mModelConfig, *mDecoderBuffers, mDecoderOutputBuffers.at(fusedBufferId),
mRuntime->getBufferManager(), *mDecoderState, returnLogProbs, decoderFinishEvent);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return updateDecoderBuffersEvent;
}
void TrtGptModelInflightBatching::copyCacheIndirectionFromOutputsToInputs(
ScheduledRequests const& scheduledRequests, SizeType32 genBufferId)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(copyCacheIndirectionFromOutputsToInputs);
auto& genRuntimeBuffers = *mBuffers.at(genBufferId);
auto* srcOffsetsPtr = bufferCast<SizeType64>(*genRuntimeBuffers.cacheIndirDecoderIOBatchedCopySrcOffsets);
auto* dstOffsetsPtr = bufferCast<SizeType64>(*genRuntimeBuffers.cacheIndirDecoderIOBatchedCopyDstOffsets);
auto* copySizesPtr = bufferCast<SizeType64>(*genRuntimeBuffers.cacheIndirDecoderIOBatchedCopySizes);
// Only `cacheIndirShape.d[2]` is used
auto const& cacheIndirShape = mDecoderState->getCacheIndirectionOutput()->getShape();
auto const maxBeamWidth = cacheIndirShape.d[1];
auto const maxAttentionWindow = cacheIndirShape.d[2];
auto const slotOffset = maxBeamWidth * maxAttentionWindow;
SizeType32 batchIdx{0};
SizeType64 maxCopySize{0};
auto& manager = mRuntime->getBufferManager();
for (auto const& requests : {scheduledRequests.contextRequests, scheduledRequests.generationRequests})
{
for (auto const& llmReq : requests)
{
auto const reqBeamWidth = llmReq->getBeamWidthByIter();
auto const seqSlot = llmReq->mSeqSlot.value();
auto const copySize = reqBeamWidth * maxAttentionWindow;
srcOffsetsPtr[batchIdx] = seqSlot * slotOffset;
dstOffsetsPtr[batchIdx] = seqSlot * slotOffset;
copySizesPtr[batchIdx] = copySize;
maxCopySize = std::max(maxCopySize, copySize);
batchIdx++;
}
}
if (batchIdx != 0)
{
auto const srcOffsetsSlice
= ITensor::slice(genRuntimeBuffers.cacheIndirDecoderIOBatchedCopySrcOffsets, 0, batchIdx);
auto const srcOffsetsSliceDeviceSlice
= ITensor::slice(genRuntimeBuffers.mCacheIndirDecoderIOBatchedCopySrcOffsetsSliceDevice, 0, batchIdx);
manager.copy(srcOffsetsSlice->data(), *srcOffsetsSliceDeviceSlice,
runtime::MemoryType::kGPU); // Explicitly move to device for faster access.
auto const dstOffsetsSlice
= ITensor::slice(genRuntimeBuffers.cacheIndirDecoderIOBatchedCopyDstOffsets, 0, batchIdx);
auto const dstOffsetsSliceDeviceSlice
= ITensor::slice(genRuntimeBuffers.mCacheIndirDecoderIOBatchedCopyDstOffsetsSliceDevice, 0, batchIdx);
manager.copy(dstOffsetsSlice->data(), *dstOffsetsSliceDeviceSlice,
runtime::MemoryType::kGPU); // Explicitly move to device for faster access.
auto const sizesSlice = ITensor::slice(genRuntimeBuffers.cacheIndirDecoderIOBatchedCopySizes, 0, batchIdx);
auto const copySizesDeviceSlice
= ITensor::slice(genRuntimeBuffers.mCacheIndirDecoderIOBatchedCopyCopySizesDevice, 0, batchIdx);
manager.copy(sizesSlice->data(), *copySizesDeviceSlice); // Explicitly move to device for faster access.
runtime::kernels::invokeCopyBatch(*mDecoderState->getCacheIndirectionOutput(),
*mDecoderState->getCacheIndirectionInput(), *srcOffsetsSliceDeviceSlice, *dstOffsetsSliceDeviceSlice,
*copySizesDeviceSlice, maxCopySize, manager.getStream());
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
std::vector<std::unique_ptr<DecoderStepAsyncSend>> TrtGptModelInflightBatching::communicateDecoderBuffers(
bool returnLogProbs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(communicateDecoderBuffers);
auto& decoderOutputBuffers = mDecoderOutputBuffers.at(getFusedBufferId());
std::vector<std::unique_ptr<DecoderStepAsyncSend>> asyncHandles;
if (mWorldConfig.isLastPipelineParallelRank())
{
if (broadcastPostDecoder())
{
DecoderStepAsyncSend::bcast(decoderOutputBuffers, mDecoderBuffers->draftBuffers,
mDecoderState->getCacheIndirectionOutput(), returnLogProbs, mOperatingBeamWidth,
mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), *mMpiCommTensorPara, 0);
}
if (mWorldConfig.isPipelineParallel())
{
auto const peerSend = 0;
asyncHandles.emplace_back(
std::make_unique<DecoderStepAsyncSend>(decoderOutputBuffers, mDecoderBuffers->draftBuffers,
mDecoderState->getCacheIndirectionOutput(), returnLogProbs, mOperatingBeamWidth,
mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), *mMpiCommPipelinePara, peerSend));
}
}
else
{
auto const peerRecv = mWorldConfig.isFirstPipelineParallelRank() ? mWorldConfig.getPipelineParallelism() - 1
: mWorldConfig.getPipelineParallelRank() - 1;
DecoderStepAsyncSend::recv(decoderOutputBuffers, mDecoderBuffers->draftBuffers,
mDecoderState->getCacheIndirectionOutput(), returnLogProbs, mOperatingBeamWidth,
mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), *mMpiCommPipelinePara, peerRecv);
auto const peerSend = mWorldConfig.getPipelineParallelRank() + 1;
if (peerSend != mWorldConfig.getPipelineParallelism() - 1)
{
asyncHandles.emplace_back(
std::make_unique<DecoderStepAsyncSend>(decoderOutputBuffers, mDecoderBuffers->draftBuffers,
mDecoderState->getCacheIndirectionOutput(), returnLogProbs, mOperatingBeamWidth,
mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind(), *mMpiCommPipelinePara, peerSend));
}
}
TLLM_CHECK_WITH_INFO(asyncHandles.size() <= 2, "Up to two decoder step async handles expected");
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return asyncHandles;
}
void TrtGptModelInflightBatching::updateRequests(ScheduledRequests const& scheduledRequests)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(updateRequests);
auto const& decoderOutputBuffers = mDecoderOutputBuffers.at(getFusedBufferId());
auto const hostNewOutputTokensShape = decoderOutputBuffers.newOutputTokensHost->getShape();
auto const* const hostNewOutputTokensData
= bufferCast<TokenIdType const>(*decoderOutputBuffers.newOutputTokensHost);
auto const* const sequenceLengthsHostData = bufferCast<SizeType32 const>(*decoderOutputBuffers.sequenceLengthsHost);
auto const* const decoderFinishedSumPtr = bufferCast<SizeType32 const>(*decoderOutputBuffers.finishedSumHost);
auto const* const cumLogProbsPtr = bufferCast<float const>(*decoderOutputBuffers.cumLogProbsHost);
auto const* const logProbsPtr = bufferCast<float const>(*decoderOutputBuffers.logProbsHost);
auto const* const nextDraftTokensHostData = mModelConfig.getSpeculativeDecodingMode().predictsDraftTokens()
? bufferCast<TokenIdType const>(*mDecoderBuffers->draftBuffers.nextDraftTokensHost)
: nullptr;
auto const* const nextDraftTokensLengthsHostData = mModelConfig.getSpeculativeDecodingMode().predictsDraftTokens()
&& mModelConfig.getSpeculativeDecodingMode().variableDraftLength()
? bufferCast<SizeType32 const>(*mDecoderBuffers->draftBuffers.nextDraftTokensLengthsHost)
: nullptr;
auto const* const finishReasonsHostData
= bufferCast<kernels::FinishedState>(*decoderOutputBuffers.finishReasonsHost);
// Update only requests that ran through the decoder
for (auto const& llmReq : scheduledRequests.generationRequests)
{
if (llmReq->isGenerationCompleteState())
{
continue;
}
auto const reqBeamWidth = llmReq->getBeamWidthByIter(true);
auto const seqSlot = llmReq->mSeqSlot.value();
auto const currentNumOfTokens = llmReq->getMaxBeamNumTokens();
// Save the accepted token logits from target model
if (mModelConfig.getSpeculativeDecodingMode().isDraftTokensExternal() && llmReq->getReturnGenerationLogits()
&& llmReq->hasDraftTokens())
{
TLLM_CHECK_WITH_INFO(reqBeamWidth == 1, "Speculative decoding only works for beam width == 1");
SizeType32 numAcceptedTokens
= sequenceLengthsHostData[seqSlot * mOperatingBeamWidth + 0] - llmReq->getMaxBeamNumTokens();
auto const& generationLogitsHost = llmReq->getGenerationLogitsHost();
auto shape = generationLogitsHost->getShape();
shape.d[1] = numAcceptedTokens;
generationLogitsHost->reshape(shape);
}
std::vector<SizeType32> numNewTokens(reqBeamWidth);
std::vector<SizeType32> numDroppedTokens(reqBeamWidth);
// numGeneratedTokens is the number of tokens generated by the decoder.
// Some tokens might be dropped due to end token or rejected draft tokens.
auto const numGeneratedTokens = llmReq->getNumDraftTokens() + 1;
for (SizeType32 beam = 0; beam < reqBeamWidth; ++beam)
{
// Sequence length is only advanced for accepted tokens.
auto const seqLen = sequenceLengthsHostData[seqSlot * mOperatingBeamWidth + beam];
// Actual number of tokens that should be added to the request.
auto const numNewOutputTokens = seqLen - llmReq->getNumTokens(beam);
if (reqBeamWidth == 1)
{
TLLM_CHECK_WITH_INFO(numGeneratedTokens >= numNewOutputTokens,
"numNewOutputTokens must not be greater than numGeneratedTokens: "
"numGeneratedTokens %d < numNewOutputTokens %d",
numGeneratedTokens, numNewOutputTokens);
}
numNewTokens[beam] = std::min(numGeneratedTokens, numNewOutputTokens);
numDroppedTokens[beam] = numGeneratedTokens - numNewTokens[beam];
for (SizeType32 step = 0; step < numNewTokens[beam]; ++step)
{
auto const newTokenIdx = tc::flat_index(hostNewOutputTokensShape.d, step, seqSlot, beam);
auto const newToken = hostNewOutputTokensData[newTokenIdx];
llmReq->addNewToken(newToken, beam);
TLLM_LOG_DEBUG("request ID %ld beam %d newToken %d", llmReq->mRequestId, beam, newToken);
if (llmReq->returnLogProbs())
{
auto const cumLogProb = cumLogProbsPtr[seqSlot * mOperatingBeamWidth + beam];
llmReq->setCumLogProb(cumLogProb, beam);
auto const beginLogProbsOffset = reqBeamWidth == 1 ? llmReq->mPromptLen : 0;
SizeType32 offset
= (seqSlot * mOperatingBeamWidth + beam) * getMaxSequenceLen() + beginLogProbsOffset;
auto const generatedLength = seqLen - llmReq->mPromptLen;
std::vector<float> logProbs(logProbsPtr + offset, logProbsPtr + offset + generatedLength);
llmReq->setLogProbs(logProbs, beam);
}
}
auto const finishReason = finishReasonsHostData[seqSlot * mOperatingBeamWidth + beam];
llmReq->setFinishedReason(finishReason.toFinishReason(), beam);
TLLM_LOG_DEBUG("[RANK %d] decoderSync: request ID %lu beam %d tokens %s finished %d",
COMM_SESSION.getRank(), llmReq->mRequestId, beam, common::vec2str(llmReq->getTokens(beam)).c_str(),
static_cast<int>(finishReason.toFinishReason()));
}
// Set number of tokens predicted per runtime iteration. Will be > 1 for speculative decoding.
llmReq->updateNumTokensPerIteration(llmReq->getMaxBeamNumTokens() - currentNumOfTokens, mModelConfig);
// Fill new draft tokens for the next step
if (decoderFinishedSumPtr[seqSlot] != reqBeamWidth
&& (mModelConfig.getSpeculativeDecodingMode().predictsDraftTokens()
|| mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind()))
{
auto const maxDraftTokensLen = mModelConfig.getMaxDecodingDraftTokens();
auto prevDraftTokensLen = llmReq->getNumDraftTokens();
// We overallocate KV cache for EAGLE to the maxDecodingTokens + maxPathLen in order to fit both
// Base model verification (needs up to maxDecodingTokens) and
// Drafter (needs up to maxPathLen of accepted tokens and maxDecodingDraftTokens for new draft tokens).
if (mModelConfig.getSpeculativeDecodingMode().isEagle())
{
prevDraftTokensLen = mModelConfig.getSpeculativeDecodingModule().getMaxDecodingTokens()
+ mModelConfig.getSpeculativeDecodingModule().getMaxPathLen() - 1;
}
auto nextDraftTokensLen = mModelConfig.getSpeculativeDecodingModule().getMaxDecodingDraftTokens();
if (mModelConfig.getSpeculativeDecodingMode().variableDraftLength())
{
nextDraftTokensLen = nextDraftTokensLengthsHostData[seqSlot];
}
TLLM_CHECK(nextDraftTokensLen <= maxDraftTokensLen);
auto draftTokensShared
= std::make_shared<std::vector<TokenIdType>>(nextDraftTokensHostData + seqSlot * maxDraftTokensLen,
nextDraftTokensHostData + seqSlot * maxDraftTokensLen + nextDraftTokensLen);
llmReq->setDraftTokens(draftTokensShared);
// For all phases except context that does not have draft tokens
if (!llmReq->isGenerationCompleteState() && prevDraftTokensLen != 0
&& mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind())
{
// -1 here is for current 'main' token
auto const acceptedTokensLen = llmReq->getMaxBeamNumTokens() - currentNumOfTokens - 1;
auto const rewindLength = prevDraftTokensLen - acceptedTokensLen;
TLLM_LOG_DEBUG("request ID %lu (seqSlot %d): accepted %d of %d draft tokens, rewind %d tokens",
llmReq->mRequestId, seqSlot, acceptedTokensLen, prevDraftTokensLen, rewindLength);
TLLM_CHECK(0 <= acceptedTokensLen && acceptedTokensLen <= prevDraftTokensLen);
// At this point, KV cache rows are already gathered and moved to the right location.
// We can safely rewind (draft - accepted) tokens
mKvCacheManager->rewindKVCache(llmReq->mRequestId, rewindLength);
}
}
// Terminate if request has finished or if it is speculative decoding target model
if (decoderFinishedSumPtr[seqSlot] == reqBeamWidth
|| (mModelConfig.getSpeculativeDecodingMode().isDraftTokensExternal() && llmReq->hasDraftTokens()))
{
postProcessRequest(*llmReq, numDroppedTokens);
if (!mWorldConfig.isPipelineParallel() || !mWorldConfig.isLastPipelineParallelRank())
{
if (llmReq->getReturnGenerationLogits() && mSpeculativeDecodingFastLogits && mIsLeaderInOrchMode)
{
mDraftRequestsWaitingToSendLogits.push_back(llmReq);
}
else
{
terminateRequest(llmReq);
}
llmReq->setState(LlmRequestState::kGENERATION_COMPLETE);
}
else
{
llmReq->setState(LlmRequestState::kGENERATION_TO_COMPLETE);
}
}
else
{
// gather tokens in the case of streaming and beam search
if (llmReq->isStreaming() && llmReq->mSamplingConfig.beamWidth > 1)
{
postProcessRequest(*llmReq, numDroppedTokens);
}
if (llmReq->isContextInitState())
{
llmReq->setState(LlmRequestState::kGENERATION_IN_PROGRESS);
}
if (isTrtOverlap() && llmReq->willCompleteNextIteration())
{
// This state prohibits the request from being scheduled for another iteration. It assumes that the next
// iteration has already been scheduled and the request can finish in the next call to updateRequests().
llmReq->setState(LlmRequestState::kGENERATION_TO_COMPLETE);
}
}
if (llmReq->getReturnPerfMetrics())
{
llmReq->updatePerfMetrics(mIterCounter);
}
llmReq->advanceDecodingIter();
if (mWorldConfig.isPipelineParallel() && mWorldConfig.isLastPipelineParallelRank())
{
for (SizeType32 beam = 0; beam < reqBeamWidth; ++beam)
{
llmReq->setNumPreDecodedTokens(numNewTokens[beam], beam);
}
}
}
if (mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind())
{
SizeType32 numSequences{0};
for (auto const& requests : {scheduledRequests.contextRequests, scheduledRequests.generationRequests})
{
for (auto const& llmReq : requests)
{
auto const reqBeamWidth = llmReq->mSamplingConfig.beamWidth;
numSequences += reqBeamWidth;
}
}
TLLM_CHECK_WITH_INFO(mCtxGenFusion, "Current speculative decoding mode requires context-gen fusion IFB");
rewindKVCacheBlocks(numSequences);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
std::vector<std::unique_ptr<DecoderStepAsyncSend>> TrtGptModelInflightBatching::decoderSync(
ScheduledRequests const& scheduledRequests, std::optional<runtime::CudaEvent> const& decoderFinishEvent)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(decoderSync);
if (mWorldConfig.isLastPipelineParallelRank())
{
decoderFinishEvent->synchronize();
}
auto const returnLogProbs = batchReturnLogProbs(scheduledRequests);
auto asyncHandles = communicateDecoderBuffers(returnLogProbs);
updateRequests(scheduledRequests);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return asyncHandles;
}
void TrtGptModelInflightBatching::rewindKVCacheBlocks(SizeType32 numSequences)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const bufferId = getFusedBufferId();
auto& runtimeBuffers = *mBuffers.at(bufferId);
auto localNbLayers = mModelConfig.getNbAttentionLayers(
mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
if (mWorldConfig.isLastPipelineParallelRank() && mModelConfig.getSpeculativeDecodingMode().isEagle())
{
// Do not correct the last kv caches, which are for EagleNet drafter. Those KV caches are managed separately.
auto eagleModulePtr
= std::dynamic_pointer_cast<runtime::EagleModule>(mModelConfig.getSpeculativeDecodingModulePtr());
localNbLayers -= eagleModulePtr->getNumTransformerLayers();
}
auto const tokensPerBlock = mModelConfig.getTokensPerBlock();
auto const elemSize = BufferDataType(mModelConfig.getKvDataType()).getSize();
auto const sizeInBytesPerKVHead = mModelConfig.getSizePerHead() * elemSize;
auto const poolPointers = mKvCacheManager->getBlockPoolPointers();
auto* const* pointerArrayPtr = bufferCast<void*>(*poolPointers);
auto const* offsetArrayPtr
= bufferCast<tk::KVCacheIndex>(*runtimeBuffers.transformerBuffers->kvCacheBlockOffsetsDevice);
auto commonRewindLen = mModelConfig.getSpeculativeDecodingModule().getMaxDecodingDraftTokens();
SizeType32 const* rewindLens = nullptr;
if (mModelConfig.getSpeculativeDecodingMode().variableDraftLength())
{
commonRewindLen = 0;
rewindLens = bufferCast<SizeType32 const>(*mDecoderBuffers->draftBuffers.prevDraftTokensLengthsHost);
}
tensorrt_llm::runtime::kernels::invokeUpdateKVBlockArrayDraftTokenLocation(
*mDecoderBuffers->draftBuffers.acceptedLengthsCumSumDevice,
*mDecoderBuffers->draftBuffers.acceptedPackedPathsDevice, *runtimeBuffers.sequenceLengthsDevice,
pointerArrayPtr, offsetArrayPtr, localNbLayers, numSequences, mRewindInputs.numKvHeads, sizeInBytesPerKVHead,
commonRewindLen, rewindLens, *runtimeBuffers.seqSlotRemappingDevice, *runtimeBuffers.sortedSeqSlots,
getMaxAttentionWindow(), mRewindInputs.maxBlocksPerSeq, tokensPerBlock, mRewindInputs.isUseOneMoreBlock,
mRuntime->getStreamPtr()->get());
sync_check_cuda_error(mRuntime->getStream().get());
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
nvinfer1::DataType TrtGptModelInflightBatching::getLogitDataType() const
{
return mModelConfig.getLogitsDtype();
}
void TrtGptModelInflightBatching::changeBeamWidth(SizeType32 beamWidth)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(mInflightReqIds.empty());
TLLM_CHECK_WITH_INFO(beamWidth <= getMaxBeamWidth(),
"Requested beam width %d is larger than configured max beam width %d", beamWidth, getMaxBeamWidth());
TLLM_LOG_DEBUG("Changing operating beam width from %d to %d", mOperatingBeamWidth, beamWidth);
mOperatingBeamWidth = beamWidth;
createBuffers(mDecodingConfig, mAdditionalModelOutputs);
createDecoder(mDecodingConfig.getDecodingMode());
if (static_cast<bool>(mKvCacheManager))
{
auto const dims = mKvCacheManager->getOffsetTableDimensions();
reshapeKvTensors(dims);
}
if (static_cast<bool>(mCrossKvCacheManager))
{
auto const dims = mCrossKvCacheManager->getOffsetTableDimensions();
reshapeKvTensors(dims);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::changeSpecDecMode(ScheduledRequests const& scheduledRequests)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if ((!mModelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()
&& !mModelConfig.getSpeculativeDecodingMode().isNone())
|| scheduledRequests.empty() || mSeamlessLADMaxDraftLen == 0 || getGatherGenerationLogits()
|| mModelConfig.isRnnBased())
{
return;
}
bool canUseLookahead = false;
auto maxNumRequestForLad = mDecodingConfig.getLookaheadDecodingMaxNumRequest();
SizeType32 numRequests = scheduledRequests.contextRequests.size() + scheduledRequests.generationRequests.size();
if (numRequests > maxNumRequestForLad)
{
if (mModelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
{
canUseLookahead = false;
}
else
{
return;
}
}
{
bool useTopKTopP = false;
bool useBanWords = false;
bool useTempAccVocabPenalties = false; // use temperature and penalties that need to accumulate #vocab.
SizeType32 beamWidth = 1;
for (auto const& requests : {scheduledRequests.contextRequests, scheduledRequests.generationRequests})
{
for (auto const& llmReq : requests)
{
useTopKTopP |= !(llmReq->mSamplingConfig.useDefaultValues(
llmReq->mSamplingConfig.topK, layers::DefaultDecodingParams::getTopK())
|| llmReq->mSamplingConfig.useDefaultValues(llmReq->mSamplingConfig.topK, 1));
useTopKTopP |= !llmReq->mSamplingConfig.useDefaultValues(
llmReq->mSamplingConfig.topP, layers::DefaultDecodingParams::getTopP());
useBanWords |= llmReq->getBadWordsList().has_value();
useBanWords |= !llmReq->mSamplingConfig.useDefaultValues(
llmReq->mSamplingConfig.noRepeatNgramSize, layers::DefaultDecodingParams::getNoRepeatNgramSize());
useTempAccVocabPenalties |= !llmReq->mSamplingConfig.useDefaultValues(
llmReq->mSamplingConfig.temperature, layers::DefaultDecodingParams::getTemperature());
useTempAccVocabPenalties |= !llmReq->mSamplingConfig.useDefaultValues(
llmReq->mSamplingConfig.repetitionPenalty, layers::DefaultDecodingParams::getRepetitionPenalty());
useTempAccVocabPenalties |= !llmReq->mSamplingConfig.useDefaultValues(
llmReq->mSamplingConfig.presencePenalty, layers::DefaultDecodingParams::getPresencePenalty());
useTempAccVocabPenalties |= !llmReq->mSamplingConfig.useDefaultValues(
llmReq->mSamplingConfig.frequencyPenalty, layers::DefaultDecodingParams::getFrequencyPenalty());
beamWidth = llmReq->mSamplingConfig.beamWidth;
if (useTopKTopP || useBanWords || useTempAccVocabPenalties || beamWidth > 1)
{
break;
}
}
canUseLookahead = !(useTopKTopP || useBanWords || useTempAccVocabPenalties || beamWidth > 1);
}
}
// Change speculative decoding mode
auto const bufferId = mCtxGenFusion
? getFusedBufferId()
: (!scheduledRequests.contextRequests.empty() ? getContextBufferId() : getGenerationBufferId());
// TODO: enable lookahead for generation requests.
bool canChangeToLookahead = scheduledRequests.generationRequests.empty();
if (mModelConfig.getSpeculativeDecodingMode().isNone() && canUseLookahead && canChangeToLookahead)
{
// None -> Lookahead
mModelConfig.enableSeamlessLookaheadDecoding(mSeamlessLADMaxDraftLen);
mDecodingConfig.enableSeamlessLookaheadDecoding();
setupSpeculativeDecodingModule(mDecodingConfig);
mBuffers.at(bufferId)->mLookaheadBuffers->enableLookaheadDecoding(
getMaxBatchSize(), mModelConfig.getMaxDecodingTokens());
mDecoderOutputBuffers.at(getFusedBufferId())
.enableLookaheadDecoding(getMaxNumSequences(), mModelConfig.getMaxDecodingTokens());
createDecoder(mDecodingConfig.getDecodingMode());
}
else if (mModelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()
&& (!canUseLookahead || numRequests > maxNumRequestForLad))
{
// Lookahead -> None
mModelConfig.disableSeamlessLookaheadDecoding();
mDecodingConfig.setDecodingMode(executor::DecodingMode::Auto());
mBuffers.at(bufferId)->mLookaheadBuffers->disableLookaheadDecoding();
mDecoderOutputBuffers.at(getFusedBufferId()).disableLookaheadDecoding(getMaxNumSequences());
mDecoder->disableLookahead(
scheduledRequests.generationRequests, mDecoderInputBuffers.at(getFusedBufferId()).setupBatchSlots);
mDecoderState->disableLookahead(scheduledRequests.generationRequests);
for (auto const& llmReq : scheduledRequests.generationRequests)
{
if (llmReq->getNumDraftTokens() > 0)
{
llmReq->discardDraftTokens(llmReq->getNumDraftTokens());
}
}
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void TrtGptModelInflightBatching::getCurrentIterationStats(executor::IterationStats& stats) const
{
stats.iter = mIterCounter;
// Max batch size and max num tokens can be tuned at runtime
stats.maxBatchSizeStatic = getMaxBatchSize();
stats.maxBatchSizeTunerRecommended = mMaxBatchSizeTunerRecommended;
stats.maxBatchSizeRuntime = mMaxBatchSizeRuntime;
stats.maxNumTokensStatic = mMaxNumTokensStatic.value_or(0);
stats.maxNumTokensTunerRecommended = mMaxNumTokensTunerRecommended;
stats.maxNumTokensRuntime = mMaxNumTokensRuntime.value_or(0);
// KVCacheManager statistics
auto const& kvCacheManager = getKVCacheManager();
if (kvCacheManager)
{
executor::KvCacheStats kvStats{};
auto kvCacheStats = kvCacheManager->getKvCacheStats();
kvStats.maxNumBlocks = kvCacheStats.maxNumBlocks;
kvStats.freeNumBlocks = kvCacheStats.freeNumBlocks;
kvStats.usedNumBlocks = kvCacheStats.usedNumBlocks;
kvStats.tokensPerBlock = kvCacheStats.toksPerBlock;
kvStats.allocTotalBlocks = kvCacheStats.allocTotalBlocks;
kvStats.allocNewBlocks = kvCacheStats.allocNewBlocks;
kvStats.reusedBlocks = kvCacheStats.reusedBlocks;
kvStats.missedBlocks = kvCacheStats.missedBlocks;
kvStats.cacheHitRate = kvCacheStats.cacheHitRate;
stats.kvCacheStats = kvStats;
}
auto const& crossKvCacheManager = getCrossKVCacheManager();
if (crossKvCacheManager)
{
executor::KvCacheStats kvStats{};
auto kvCacheStats = crossKvCacheManager->getKvCacheStats();
kvStats.maxNumBlocks = kvCacheStats.maxNumBlocks;
kvStats.freeNumBlocks = kvCacheStats.freeNumBlocks;
kvStats.usedNumBlocks = kvCacheStats.usedNumBlocks;
kvStats.tokensPerBlock = kvCacheStats.toksPerBlock;
kvStats.allocTotalBlocks = kvCacheStats.allocTotalBlocks;
kvStats.allocNewBlocks = kvCacheStats.allocNewBlocks;
kvStats.reusedBlocks = kvCacheStats.reusedBlocks;
kvStats.missedBlocks = kvCacheStats.missedBlocks;
kvStats.cacheHitRate = kvCacheStats.cacheHitRate;
stats.crossKvCacheStats = kvStats;
}
executor::InflightBatchingStats modelStats{};
modelStats.numScheduledRequests = mLastIterationStatsIFB.scheduledRequests.size();
modelStats.numContextRequests = mLastIterationStatsIFB.numCtxRequests;
modelStats.numGenRequests = mLastIterationStatsIFB.numGenRequests;
modelStats.numPausedRequests = mLastIterationStatsIFB.pausedRequests.size();
modelStats.avgNumDecodedTokensPerIter = mLastIterationStatsIFB.avgNumDecodedTokensPerIter;
modelStats.numCtxTokens = mLastIterationStatsIFB.numCtxTokens;
modelStats.microBatchId = mLastIterationStatsIFB.microBatchId;
stats.inflightBatchingStats = modelStats;
}
void TrtGptModelInflightBatching::getCurrentRequestStats(executor::RequestStatsPerIteration& stats) const
{
stats.iter = mIterCounter;
for (auto& requestStat : stats.requestStats)
{
requestStat.scheduled
= mLastIterationStatsIFB.scheduledRequests.count(static_cast<RequestIdType>(requestStat.id));
requestStat.paused = mLastIterationStatsIFB.pausedRequests.count(static_cast<RequestIdType>(requestStat.id));
}
}
executor::DebugTensorsPerIteration TrtGptModelInflightBatching::getCurrentDebugTensors() const
{
executor::DebugTensorsPerIteration debugTensors;
debugTensors.iter = mIterCounter;
for (auto const& [name, tensor] : mLastIterationDebugTensors)
{
debugTensors.debugTensors.emplace(name, executor::detail::ofITensor(tensor));
}
return debugTensors;
}
nvinfer1::DataType TrtGptModelInflightBatching::getTensorDataType(std::string const& name) const
{
auto const& engine = mRuntime->getEngine();
return engine.getTensorDataType(name.c_str());
}
nvinfer1::Dims TrtGptModelInflightBatching::getTensorShape(std::string const& name) const
{
auto const& engine = mRuntime->getEngine();
return engine.getTensorShape(name.c_str());
}
SizeType32 TrtGptModelInflightBatching::getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const
{
return mKvCacheManager->getMaxCapacityBatchSize(inputLength, outputLength);
}
/*
* Manages prefetching of prompt table chunks using a double-buffer strategy
*
* Function Flow:
* 1. First Chunk Processing (isFirstChunk == true):
* - Uses blocking prefetch on main runtime stream
* - Ensures initial data is ready before computation starts
*
* 2. Subsequent Chunks (isFirstChunk == false):
* - Uses non-blocking prefetch on separate copy stream
* - Overlaps data transfer with computation
*
* Synchronization:
* - First prefetch: No wait needed (fresh start)
* - Later prefetches: Wait for previous copy to complete
* - Uses mPtableCopyDoneEvent to track completion
*
* Key Functions:
* 1. prefetchNextPromptTableChunk:
* - Calls the correct function based on position in code (before or after prepareBuffers())
* - Waits for previous copy to complete if not the first chunk
*
* 2. remapInputTokensForPromptTable:
* - Identifies tokens that need prompt table embeddings (tokens that are greater than vocabSize)
* - Remaps IDs to match chunked prompt table layout
*
* 3. copyPromptTableToGpuInChunk:
* - Handles actual transfer from CPU pinned memory to GPU
* - Uses appropriate buffer manager based on isFirstChunk
*/
void TrtGptModelInflightBatching::prefetchNextPromptTableChunk(
RequestVector const& contextRequests, bool isFirstChunk, SizeType32 bufferId)
{
auto& promptTuningBuffers = mBuffers[bufferId]->promptTuningBuffers;
if (!isFirstChunk)
{
// Only switch buffer after prepareBuffer()
promptTuningBuffers->switchChunkPtableBuffer();
}
SizeType32 contextId = 0;
for (auto const& llmReq : contextRequests)
{
if (llmReq->isFirstContextChunk() && isFirstChunk)
{
// For first chunk: Blocking prefetch on runtime stream to ensure data is ready
remapInputTokensForPromptTable(llmReq, true, bufferId, contextId);
}
else if (!isFirstChunk) // prefetching for subsequent chunks
{
// For the first prefetch chunk, don't need to wait for previous prefetch to complete
// For subsequent chunks: Need to wait for previous prefetch to complete
if (!llmReq->isFirstContextChunk())
{
mRuntime->getBufferManager().getStream().wait(mPtableCopyDoneEvent);
}
// Non-blocking prefetch on copy stream to prepare next chunk in pong buffer
if (llmReq->getContextRemainingLength() > 0)
{
remapInputTokensForPromptTable(llmReq, false, bufferId, contextId);
}
}
++contextId;
}
}
void TrtGptModelInflightBatching::remapInputTokensForPromptTable(
std::shared_ptr<LlmRequest> const& llmReq, bool isFirstChunk, SizeType32 bufferId, SizeType32 contextId)
{
NVTX3_SCOPED_RANGE_WITH_NAME(range, "remapInputTokensForPromptTable");
auto& promptTuningBuffers = mBuffers[bufferId]->promptTuningBuffers;
auto const chunkSize = llmReq->getContextChunkSize();
auto& inputTokensMutable = llmReq->getTokensMutable(0);
auto vocabSize = mModelConfig.getVocabSize();
if (isFirstChunk)
{
promptTuningBuffers->initializeChunkPtableBuffers(
mRuntime->getBufferManager(), mModelConfig, chunkSize, llmReq);
}
size_t processChunkSize;
size_t beginPos;
if (!isFirstChunk)
{
processChunkSize = std::min(chunkSize, llmReq->getContextRemainingLength() - chunkSize);
}
else
{
processChunkSize = std::min(chunkSize, llmReq->getContextRemainingLength());
}
if (!isFirstChunk)
{
// For prefetching next chunk
if (llmReq->getContextRemainingLength() - chunkSize <= 0)
{
promptTuningBuffers->updateBufferStartPosition(promptTuningBuffers->getChunkPtableCurrentIndex(), 0);
return; // No more chunks to prefetch
}
beginPos = llmReq->getContextCurrentPosition() + chunkSize;
}
else
{
// For current chunk
beginPos = llmReq->getContextCurrentPosition();
}
TLLM_CHECK_WITH_INFO(beginPos + processChunkSize <= inputTokensMutable.size(),
"Invalid chunk access: beginPos(%zu) + processChunkSize(%zu) > totalSize(%zu)", beginPos, processChunkSize,
inputTokensMutable.size());
auto inputTokensChunk = inputTokensMutable.begin() + beginPos;
std::vector<SizeType32> outOfVocabTokens;
SizeType32 ptableTokenId = vocabSize;
for (size_t i = 0; i < processChunkSize; i++)
{
if (inputTokensChunk[i] >= vocabSize)
{
outOfVocabTokens.push_back(inputTokensChunk[i]);
inputTokensChunk[i] = ptableTokenId++;
}
}
copyPromptTableToGpuInChunk(llmReq, outOfVocabTokens, isFirstChunk, bufferId, contextId);
}
void TrtGptModelInflightBatching::copyPromptTableToGpuInChunk(std::shared_ptr<LlmRequest> const& llmReq,
std::vector<int32_t> const& outOfVocabTokens, bool isFirstChunk, SizeType32 bufferId, SizeType32 contextId)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE_WITH_NAME(range, "copyPromptTableToGpuInChunk");
auto& promptTuningBuffers = mBuffers[bufferId]->promptTuningBuffers;
if (outOfVocabTokens.empty())
{
return;
}
auto const& promptTable = llmReq->getPromptEmbeddingTable();
TLLM_CHECK_WITH_INFO(promptTable.has_value(), "promptTable is empty but there's fake_prompt");
TLLM_CHECK_WITH_INFO(promptTable.value() != nullptr, "promptTable value is null but there's fake_prompt");
auto currentBufferManager = isFirstChunk ? mRuntime->getBufferManager() : mCopyBufferManager;
auto const hiddenSize = mModelConfig.getHiddenSize();
auto numRows = outOfVocabTokens.size();
std::size_t sliceSize = static_cast<size_t>(numRows * hiddenSize);
auto currentIndex = promptTuningBuffers->getChunkPtableCurrentIndex();
// Calculate the offset based on current position
size_t srcOffset = llmReq->mPtableCurrentPosition * hiddenSize;
size_t dstOffset = promptTuningBuffers->getChunkPtableBufferStartPosition(currentIndex, contextId);
auto gpuBuffer = promptTuningBuffers->getChunkPtableBuffer(currentIndex);
// First view as 1D tensor of elements
auto totalElements = promptTable.value()->getSize();
auto table1D = runtime::ITensor::view(
promptTable.value(), runtime::ITensor::makeShape({static_cast<int64_t>(totalElements)}));
TLLM_CHECK_WITH_INFO(srcOffset + sliceSize <= totalElements,
"Buffer bounds violation: Trying to access up to %zu elements but buffer only has %zu elements (offset: %zu, "
"slice size: %zu)",
srcOffset + sliceSize, totalElements, srcOffset, sliceSize);
auto table1DShared = runtime::ITensor::SharedPtr(table1D.release());
auto pTableView = runtime::ITensor::slice(table1DShared, srcOffset, sliceSize);
auto gpuBufferSlice = runtime::ITensor::slice(gpuBuffer, dstOffset, numRows);
currentBufferManager.copy(*pTableView, *gpuBufferSlice);
promptTuningBuffers->updateBufferStartPosition(currentIndex, outOfVocabTokens.size());
llmReq->mPtableCurrentPosition += outOfVocabTokens.size();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
} // namespace tensorrt_llm::batch_manager