TensorRT-LLMs/cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Iman Tabrizian 48b09e5a25
[https://nvbugs/5689235][fix] Fix cancellation+chunked prefill+disagg (#10111)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
2026-01-12 18:23:26 -05:00

2332 lines
93 KiB
C++

/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cstdint>
#include <cstring>
#include <list>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
namespace tensorrt_llm::batch_manager
{
/**
* @brief The state of the request.
*
* Enum order must follow chronological order for state dependency check, @see hasReachedState().
*/
enum class LlmRequestState : int32_t
{
kUNKNOWN = 0, ///< Unknown state
kENCODER_INIT = 1, ///< Encoder phase starts (for encoder-decoder models)
kDISAGG_GENERATION_INIT = 8, ///< New Generation request arrived at generation model
kDISAGG_GENERATION_TRANS_IN_PROGRESS = 9, ///< Transmitting the kv cache
// schedulable states starts
kCONTEXT_INIT = 10, ///< Context phase starts
kDISAGG_CONTEXT_INIT_AND_TRANS = 11, ///< Context phase starts and cache transmission is in progress,
/// used in layer-wise transmission
kDISAGG_GENERATION_TRANS_COMPLETE = 12, ///< Kv cache transmission are finished
kGENERATION_IN_PROGRESS = 13, ///< Generation phase is in progress
// schedulable states ends
kGENERATION_TO_COMPLETE = 14, ///< Generation phase is to be completed
kGENERATION_COMPLETE = 20, ///< Generation phase completed
kDISAGG_CONTEXT_TRANS_IN_PROGRESS = 21, ///< Waiting context-only request transmitting the kv cache,
/// after computation finished
kDISAGG_CONTEXT_COMPLETE = 22, ///< Context-only request finished kv cache transmission.
// error states
kDISAGG_TRANS_ERROR = -1, ///< Error occurred during kv cache transmission
};
enum LlmRequestType
{
LLMREQUEST_TYPE_CONTEXT_AND_GENERATION = 0, // Normal request will inference both context phase and generation phase
LLMREQUEST_TYPE_CONTEXT_ONLY = 1, // Only inference context phase
LLMREQUEST_TYPE_GENERATION_ONLY = 2 // only inference generation phase
};
class ContextProgress;
template <typename TTensor, typename TStream = runtime::BufferManager::CudaStreamPtr>
class GenericLlmRequest
{
using TensorMap = runtime::StringPtrMap<runtime::ITensor>;
public:
using SizeType32 = runtime::SizeType32;
using TokenIdType = runtime::TokenIdType;
using RequestIdType = std::uint64_t;
using LoraTaskIdType = runtime::LoraTaskIdType;
using VecTokens = std::vector<TokenIdType>;
using TokenExtraIdType = runtime::TokenExtraIdType;
using VecTokenExtraIds = runtime::VecTokenExtraIds;
using VecLogProbs = std::vector<float>;
using BeamTokens = std::vector<VecTokens>;
using UniqueToken = runtime::UniqueToken;
using VecUniqueTokens = runtime::VecUniqueTokens;
using BeamUniqueTokens = std::vector<VecUniqueTokens>;
using TensorPtr = TTensor;
using LogitsPostProcessor = std::function<void(
RequestIdType, TensorPtr&, BeamTokens const&, TStream const&, std::optional<RequestIdType>)>;
using RequestPtr = std::shared_ptr<GenericLlmRequest>;
using MillisecondsType = std::chrono::milliseconds;
using TimePoint = std::chrono::time_point<std::chrono::steady_clock>;
using Duration = std::chrono::time_point<std::chrono::steady_clock>::duration;
using CacheSaltIDType = runtime::CacheSaltIDType;
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> const& inputTokens,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<std::shared_ptr<std::vector<SizeType32>>> positionIds = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType32> promptVocabSize = std::nullopt,
std::optional<std::shared_ptr<std::vector<std::vector<SizeType32>>>> multimodalHashes = std::nullopt,
std::optional<std::shared_ptr<std::vector<SizeType32>>> multimodalPositions = std::nullopt,
std::optional<std::shared_ptr<std::vector<SizeType32>>> multimodalLengths = std::nullopt,
std::optional<TensorPtr> multimodalEmbedding = std::nullopt,
std::optional<TensorPtr> mropeRotaryCosSin = std::nullopt,
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
std::optional<TensorPtr> loraConfig = std::nullopt,
std::optional<executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
std::optional<executor::KvCacheRetentionConfig> kvCacheRetentionConfig = std::nullopt,
bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false,
std::optional<std::shared_ptr<VecTokens>> const& draftTokens = std::nullopt,
std::optional<TensorPtr> draftLogits = std::nullopt, bool excludeInputFromOutput = false,
std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt,
bool applyLogitsPostProcessorBatched = false,
std::optional<std::shared_ptr<VecTokens>> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false,
std::optional<RequestIdType> clientId = std::nullopt,
executor::PriorityType priority = executor::Request::kDefaultPriority,
std::optional<TensorPtr> encoderInputFeatures = std::nullopt,
std::optional<SizeType32> encoderOutputLength = std::nullopt,
std::optional<TensorPtr> crossAttentionMask = std::nullopt,
LlmRequestType llmRequestType = LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION,
std::optional<std::shared_ptr<VecTokenExtraIds>> inputTokenExtraIds = std::nullopt,
SizeType32 numReturnSequences = 1, std::optional<executor::EagleConfig> eagleConfig = std::nullopt,
std::optional<TensorPtr> skipCrossAttnBlocks = std::nullopt, bool returnPerfMetrics = false,
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt, std::optional<TimePoint> arrivalTime = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens->size())
, mMaxNewTokens(maxNewTokens)
, mSamplingConfig(samplingConfig)
, mEndId(endId)
, mPadId(padId)
, mLogitsPostProcessor(std::move(logitsPostProcessor))
, mApplyLogitsPostProcessorBatched(applyLogitsPostProcessorBatched)
, mClientId(clientId)
, mIsStreaming(isStreaming)
, mOrigPromptLen(mPromptLen)
, mNumPreDecodedTokens(samplingConfig.beamWidth, 0)
, mMaxSentTokenLen(mPromptLen)
, mEmbeddingBias(std::move(embeddingBias))
, mBadWordsList(std::move(badWordsList))
, mStopWordsList(std::move(stopWordsList))
, mPositionIds(std::move(positionIds))
, mPromptEmbeddingTable(std::move(promptEmbeddingTable))
, mPromptVocabSize(promptVocabSize)
, mMultimodalHashes(std::move(multimodalHashes))
, mMultimodalPositions(std::move(multimodalPositions))
, mMultimodalLengths(std::move(multimodalLengths))
, mMultimodalEmbedding(std::move(multimodalEmbedding))
, mMropeRotaryCosSin(std::move(mropeRotaryCosSin))
, mMropePositionDeltas(mropePositionDeltas)
, mLoraTaskId(loraTaskId)
, mLoraWeights(std::move(loraWeights))
, mLoraConfig(std::move(loraConfig))
, mLookaheadConfig(std::move(lookaheadConfig))
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
, mContextChunkSize{mPromptLen}
, mLogProbs(samplingConfig.beamWidth)
, mCumLogProbs(samplingConfig.beamWidth)
, mDraftTokens(draftTokens.value_or(std::make_shared<VecTokens>()))
, mDraftLogits(std::move(draftLogits))
, mReturnAllGeneratedTokens(isStreaming && (samplingConfig.beamWidth > 1))
, mReturnContextLogits(returnContextLogits)
, mReturnGenerationLogits(returnGenerationLogits)
, mExcludeInputFromOutput(excludeInputFromOutput)
, mEncoderTokens(std::move(encoderInputTokens))
, mReturnEncoderOutput(returnEncoderOutput)
, mPriority(priority)
, mFinishReasons(samplingConfig.beamWidth)
, mEncoderInputFeatures(std::move(encoderInputFeatures))
, mEncoderOutputLength(encoderOutputLength)
, mCrossAttentionMask(std::move(crossAttentionMask))
, mLlmRequestType(llmRequestType)
, mContextPhaseParams(contextPhaseParams)
, mInputTokenExtraIds(std::move(inputTokenExtraIds))
, mNumReturnSequences(numReturnSequences)
, mEagleConfig(std::move(eagleConfig))
, mSkipCrossAttnBlocks(std::move(skipCrossAttnBlocks))
, mReturnPerfMetrics(returnPerfMetrics)
, mGuidedDecodingParams(std::move(guidedDecodingParams))
, mLanguageAdapterUid(languageAdapterUid)
, mAllottedTimeMs(allottedTimeMs)
, mCacheSaltID(cacheSaltID)
{
if (mEncoderTokens.has_value() || encoderInputFeatures.has_value())
{
mState = LlmRequestState::kENCODER_INIT;
}
initialize(*inputTokens, returnLogProbs, arrivalTime);
}
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, VecTokens const& inputTokens,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<std::shared_ptr<std::vector<SizeType32>>> positionIds = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType32> promptVocabSize = std::nullopt,
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
std::optional<TensorPtr> loraConfig = std::nullopt,
std::optional<executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt, bool returnLogProbs = false,
bool returnContextLogits = false, bool returnGenerationLogits = false,
std::optional<VecTokens> draftTokens = std::nullopt, std::optional<TensorPtr> draftLogits = std::nullopt,
bool excludeInputFromOutput = false, std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt,
bool applyLogitsPostProcessorBatched = false, std::optional<VecTokens> encoderInputTokens = std::nullopt,
bool returnEncoderOutput = false, std::optional<RequestIdType> clientId = std::nullopt,
executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens.size())
, mMaxNewTokens(maxNewTokens)
, mSamplingConfig(samplingConfig)
, mEndId(endId)
, mPadId(padId)
, mLogitsPostProcessor(logitsPostProcessor)
, mApplyLogitsPostProcessorBatched(applyLogitsPostProcessorBatched)
, mClientId(clientId)
, mIsStreaming(isStreaming)
, mOrigPromptLen(mPromptLen)
, mNumPreDecodedTokens(samplingConfig.beamWidth, 0)
, mMaxSentTokenLen(mPromptLen)
, mEmbeddingBias(std::move(embeddingBias))
, mBadWordsList(std::move(badWordsList))
, mStopWordsList(std::move(stopWordsList))
, mPositionIds(std::move(positionIds))
, mPromptEmbeddingTable(std::move(promptEmbeddingTable))
, mPromptVocabSize(promptVocabSize)
, mLoraTaskId(loraTaskId)
, mLoraWeights(std::move(loraWeights))
, mLoraConfig(std::move(loraConfig))
, mLookaheadConfig(lookaheadConfig)
, mContextChunkSize(mPromptLen)
, mLogProbs(samplingConfig.beamWidth)
, mCumLogProbs(samplingConfig.beamWidth)
, mDraftTokens(std::make_shared<VecTokens>(draftTokens.value_or(VecTokens())))
, mDraftLogits(draftLogits)
, mReturnAllGeneratedTokens(isStreaming && (samplingConfig.beamWidth > 1))
, mReturnContextLogits(returnContextLogits)
, mReturnGenerationLogits(returnGenerationLogits)
, mExcludeInputFromOutput(excludeInputFromOutput)
, mEncoderTokens(std::make_shared<VecTokens>(encoderInputTokens.value_or(VecTokens())))
, mReturnEncoderOutput(returnEncoderOutput)
, mPriority(priority)
, mFinishReasons(samplingConfig.beamWidth)
, mContextPhaseParams(contextPhaseParams)
, mNumReturnSequences(numReturnSequences)
, mLanguageAdapterUid(languageAdapterUid)
, mCacheSaltID(cacheSaltID)
{
if (mEncoderTokens.has_value())
{
mState = LlmRequestState::kENCODER_INIT;
}
initialize(inputTokens, returnLogProbs);
}
GenericLlmRequest(RequestIdType requestId, executor::Request const& req)
: mRequestId(requestId)
, mPromptLen(req.getInputTokenIds().size())
, mMaxNewTokens(req.getMaxTokens())
, mSamplingConfig(req.getSamplingConfig(), req.getExternalDraftTokensConfig())
, mEndId(req.getEndId())
, mPadId(req.getPadId())
, mClientId(req.getClientId())
, mIsStreaming(req.getStreaming())
, mOrigPromptLen(mPromptLen)
, mNumPreDecodedTokens(mSamplingConfig.beamWidth, 0)
, mMaxSentTokenLen(mPromptLen)
, mContextChunkSize{mPromptLen}
, mLogProbs(mSamplingConfig.beamWidth)
, mCumLogProbs(mSamplingConfig.beamWidth)
, mDraftTokens(std::make_shared<VecTokens>())
, mReturnAllGeneratedTokens(req.getReturnAllGeneratedTokens())
, mReturnContextLogits(req.getOutputConfig().returnContextLogits)
, mReturnGenerationLogits(req.getOutputConfig().returnGenerationLogits)
, mExcludeInputFromOutput(req.getOutputConfig().excludeInputFromOutput)
, mReturnEncoderOutput(req.getOutputConfig().returnEncoderOutput)
, mPriority(req.getPriority())
, mFinishReasons(mSamplingConfig.beamWidth)
, mEncoderOutputLength(req.getEncoderOutputLength())
, mContextPhaseParams(req.getContextPhaseParams())
, mEagleConfig(req.getEagleConfig())
, mReturnPerfMetrics(req.getOutputConfig().returnPerfMetrics)
, mGuidedDecodingParams(req.getGuidedDecodingParams())
, mLanguageAdapterUid(req.getLanguageAdapterUid())
, mAllottedTimeMs(req.getAllottedTimeMs())
, mCacheSaltID(req.getCacheSaltID())
{
if (req.getRequestType() == executor::RequestType::REQUEST_TYPE_GENERATION_ONLY)
{
mState = LlmRequestState::kDISAGG_GENERATION_INIT;
}
if (mIsStreaming && mSamplingConfig.beamWidth > 1 && !mReturnAllGeneratedTokens)
{
TLLM_LOG_WARNING(
"Setting mReturnAllGeneratedTokens to True since streaming AND beam search are done simultaneously. "
"Returning the full beams at each streaming step is needed because beam search + streaming can change "
"previous outputs. Initialize request with mReturnAllGeneratedTokens = True to dismiss this error. "
"WARNING: using this option may increase network usage significantly (quadratically w.r.t output "
"length).");
mReturnAllGeneratedTokens = true;
}
if (mIsStreaming && mSamplingConfig.beamWidth > 1 && mReturnGenerationLogits)
{
TLLM_LOG_WARNING(
"Returning generation logits when streaming is enabled and beamWidth > 1 is not allowed. "
"This is because the logits may appear in irrelevant order when the beams are gathered, "
"since logits are not. Disabling returnGenerationLogits.");
mReturnGenerationLogits = false;
}
if (req.getEncoderInputTokenIds().has_value() || req.getEncoderInputFeatures().has_value())
{
mState = LlmRequestState::kENCODER_INIT;
if (req.getEncoderInputTokenIds().has_value())
{
mEncoderTokens = std::make_shared<VecTokens>(req.getEncoderInputTokenIds().value());
}
}
if (req.getEmbeddingBias())
{
mEmbeddingBias
= tensorrt_llm::runtime::ITensor::view(executor::detail::toITensor(req.getEmbeddingBias().value()));
// Add leading 1 dimension since that's what IFB code expects
mEmbeddingBias.value()->unsqueeze(0);
}
if (req.getBadWords())
{
mBadWordsList = createListTensor(req.getBadWords().value());
}
if (req.getStopWords())
{
mStopWordsList = createListTensor(req.getStopWords().value());
}
if (req.getPositionIds())
{
mPositionIds = std::make_shared<std::vector<SizeType32>>(req.getPositionIds().value());
}
auto pTuningConfig = req.getPromptTuningConfig();
if (pTuningConfig)
{
mPromptEmbeddingTable = tensorrt_llm::runtime::ITensor::view(
executor::detail::toITensor(pTuningConfig.value().getEmbeddingTable()));
TLLM_CHECK(mPromptEmbeddingTable.value()->getShape().nbDims == 2);
mPromptVocabSize = mPromptEmbeddingTable.value()->getShape().d[0];
mPromptEmbeddingTable.value()->unsqueeze(0);
if (pTuningConfig->getInputTokenExtraIds())
{
mInputTokenExtraIds
= std::make_shared<VecTokenExtraIds>(pTuningConfig->getInputTokenExtraIds().value());
}
}
auto mRopeConfig = req.getMropeConfig();
if (mRopeConfig)
{
mMropeRotaryCosSin = executor::detail::toITensor(mRopeConfig.value().getMRopeRotaryCosSin());
mMropePositionDeltas = mRopeConfig.value().getMRopePositionDeltas();
}
auto loraConfig = req.getLoraConfig();
if (loraConfig)
{
mLoraTaskId = loraConfig->getTaskId();
if (loraConfig.value().getWeights())
{
mLoraWeights = tensorrt_llm::runtime::ITensor::view(
executor::detail::toITensor(loraConfig.value().getWeights().value()));
mLoraWeights.value()->unsqueeze(0);
}
if (loraConfig.value().getConfig())
{
mLoraConfig = tensorrt_llm::runtime::ITensor::view(
executor::detail::toITensor(loraConfig.value().getConfig().value()));
mLoraConfig.value()->unsqueeze(0);
}
}
auto externalDraftTokensConfig = req.getExternalDraftTokensConfig();
if (externalDraftTokensConfig)
{
mDraftTokens = std::make_shared<VecTokens>(externalDraftTokensConfig.value().getTokens());
if (externalDraftTokensConfig.value().getLogits())
{
mDraftLogits = executor::detail::toITensor(externalDraftTokensConfig.value().getLogits().value());
}
// NOTE: Draft acceptance threshold is stored in mSamplingConfig
}
if (req.getOutputConfig().additionalModelOutputs.has_value())
{
auto const& outputConfig = req.getOutputConfig();
auto const& additionalModelOutputs = outputConfig.additionalModelOutputs.value();
for (auto const& modelOutput : additionalModelOutputs)
{
if (modelOutput.gatherContext)
{
mAdditionalContextOutputTensors.emplace(modelOutput.name, TensorPtr{});
}
mAdditionalGenerationOutputTensors.emplace(modelOutput.name, TensorPtr{});
}
}
auto const& encoderInputFeatures = req.getEncoderInputFeatures();
if (encoderInputFeatures.has_value())
{
mEncoderInputFeatures = executor::detail::toITensor(encoderInputFeatures.value());
}
else
{
mEncoderInputFeatures = std::nullopt;
}
auto const& crossAttentionMask = req.getCrossAttentionMask();
if (crossAttentionMask.has_value())
{
mCrossAttentionMask = executor::detail::toITensor(crossAttentionMask.value());
}
else
{
mCrossAttentionMask = std::nullopt;
}
auto const& skipCrossAttnBlocks = req.getSkipCrossAttnBlocks();
if (skipCrossAttnBlocks.has_value())
{
mSkipCrossAttnBlocks = executor::detail::toITensor(skipCrossAttnBlocks.value());
}
else
{
mSkipCrossAttnBlocks = std::nullopt;
}
switch (req.getRequestType())
{
case executor::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION:
mLlmRequestType = LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION;
break;
case executor::RequestType::REQUEST_TYPE_CONTEXT_ONLY:
mLlmRequestType = LlmRequestType::LLMREQUEST_TYPE_CONTEXT_ONLY;
break;
case executor::RequestType::REQUEST_TYPE_GENERATION_ONLY:
mLlmRequestType = LlmRequestType::LLMREQUEST_TYPE_GENERATION_ONLY;
break;
default: throw std::runtime_error("Unsupported request type found.");
}
initialize(req.getInputTokenIds(), req.getOutputConfig().returnLogProbs);
}
GenericLlmRequest(GenericLlmRequest&& request) = default;
GenericLlmRequest(GenericLlmRequest const& request) = default;
void setExcludeInputFromOutput(bool exclude)
{
mExcludeInputFromOutput = exclude;
}
/// @brief Get the params of the context
/// @return The params of the context
[[nodiscard]] std::optional<executor::ContextPhaseParams> const& getContextPhaseParams() const noexcept
{
return mContextPhaseParams;
}
void setContextPhaseParams(executor::ContextPhaseParams contextPhaseParams)
{
mContextPhaseParams = std::move(contextPhaseParams);
}
/// @brief Get the state params of the context
/// @return The state params of the context
[[nodiscard]] executor::DataTransceiverState const& getDataTransceiverState() const
{
TLLM_CHECK(mContextPhaseParams.has_value());
return *static_cast<executor::DataTransceiverState const*>(mContextPhaseParams.value().getState());
}
[[nodiscard]] std::shared_ptr<ContextProgress> const& getContextProgress() const noexcept
{
return mContextProgress;
}
void setContextProgress(std::shared_ptr<ContextProgress> const& progress)
{
mContextProgress = progress;
}
/// @brief Get total number of tokens for this req (prompt + generated)
/// @param beam The beam index
/// @return The number of tokens
[[nodiscard]] SizeType32 getNumTokens(SizeType32 beam) const
{
return mTokens.at(beam).size() - mNumPreDecodedTokens[beam];
}
/// @brief Get the number of subrequests, the expected number of responses under non-streaming mode. In sampling
/// mode, it will be equal to mSamplingConfig.numReturnSequences, while it will be equal to 1 in beam search.
/// @return The number of subrequests in total request size.
[[nodiscard]] SizeType32 getNumSubRequests() const
{
return mSamplingConfig.beamWidth == 1 ? mSamplingConfig.numReturnSequences.value_or(1) : 1;
}
/// @brief Get child requests spawned by this req.
/// @return A vector of child requests.
[[nodiscard]] std::vector<RequestPtr> const& getChildRequests() const
{
return mChildRequests;
}
/// @brief Get max number of tokens across all beams
/// @return The number of tokens
[[nodiscard]] SizeType32 getMaxBeamNumTokens() const
{
SizeType32 maxTokens = 0;
for (SizeType32 beam = 0; beam < mSamplingConfig.beamWidth; ++beam)
{
maxTokens = std::max(maxTokens, getNumTokens(beam));
}
return maxTokens;
}
/// @brief Get a token at a given position and beam index
/// @param beam The beam index
/// @param pos The position of the token relative to beginning of the prompt
/// @return The token index
[[nodiscard]] TokenIdType getToken(SizeType32 beam, SizeType32 pos) const
{
return mTokens.at(beam).at(pos);
}
/// @brief Get the tokens at a given beam index
/// @param beam The beam index
/// @return A vector of tokens for this beam index, includes the prompt
[[nodiscard]] VecTokens const& getTokens(SizeType32 beam) const
{
return mTokens.at(beam);
}
/// @brief Get mutable reference to tokens for a specific beam
/// @param beam The beam index
/// @return Mutable reference to the tokens vector
[[nodiscard]] VecTokens& getTokensMutable(SizeType32 beam)
{
return mTokens.at(beam);
}
/// @brief Get all tokens (input+output) for all beams
/// @return A vector of vector of tokens.
[[nodiscard]] BeamTokens const& getTokens() const
{
return mTokens;
}
/// @brief Get the unique tokens at a given beam index
/// @param beam The beam index
/// @return A vector of UniqueTokens for this beam index, includes the prompt
[[nodiscard]] VecUniqueTokens const& getUniqueTokens(SizeType32 beam) const
{
return mUniqueTokens.at(beam);
}
/// @brief Get all unique tokens (input+output) for all beams
/// @return A vector of vector of UniqueTokens.
[[nodiscard]] BeamUniqueTokens const& getUniqueTokens() const
{
return mUniqueTokens;
}
/// @brief Get all extra input token ids
/// @return A optional shared pointer to a vector of extra ids.
[[nodiscard]] std::optional<std::shared_ptr<VecTokenExtraIds>> const& getInputTokensExtraIds() const
{
return mInputTokenExtraIds;
}
/// @brief Get input tokens to encoder
/// @return A vector of tokens.
[[nodiscard]] std::optional<std::shared_ptr<VecTokens>> const& getEncoderTokens() const
{
return mEncoderTokens;
}
/// @brief Get the unique tokens to encoder
/// @return A vector of UniqueTokens for encoder
[[nodiscard]] std::optional<std::shared_ptr<VecUniqueTokens>> const& getEncoderUniqueTokens() const
{
return mEncoderUniqueTokens;
}
/// @brief Get length of encoder input (could be tokens or features length)
/// @return An integer.
[[nodiscard]] SizeType32 getEncoderInputLen() const
{
if (mEncoderInputFeatures.has_value())
{
return getEncoderInputFeatures()->getShape().d[0];
}
if (getEncoderTokens().has_value())
{
return getEncoderTokens().value()->size();
}
TLLM_THROW("GenericLlmRequest::getEncoderInputLen - Do not have encoder length!");
}
/// @brief Get length of encoder output. Fall back to encoder input length if not present
/// @return An integer.
[[nodiscard]] SizeType32 getEncoderOutputLen() const
{
if (mEncoderOutputLength.has_value())
{
return mEncoderOutputLength.value();
}
return getEncoderInputLen();
}
[[nodiscard]] std::optional<std::shared_ptr<std::vector<SizeType32>>> getPositionIds() const
{
return mPositionIds;
}
/// @brief Get the draft tokens
/// @return shared_ptr to vector of draft tokens
[[nodiscard]] std::shared_ptr<VecTokens> const& getDraftTokens() const
{
return mDraftTokens;
}
/// @brief Get the logits for the draft tokens
/// @return Tensor of draft logits
[[nodiscard]] std::optional<TensorPtr> getDraftLogits() const
{
return mDraftLogits;
}
/// @brief Returns true if request has draft tokens
/// @return flag
[[nodiscard]] bool hasDraftTokens() const
{
return mDraftTokens && !mDraftTokens->empty();
}
/// @brief Get the maximum number of generated tokens among all rays in beam
/// @return The number of generated tokens (doesn't include the prompt tokens)
[[nodiscard]] SizeType32 getMaxNumGeneratedTokens() const
{
return getMaxBeamNumTokens() - mPromptLen;
}
/// @brief Returns true if request reaches max number of tokens in the next iteration.
[[nodiscard]] bool willCompleteNextIteration() const
{
return getMaxNumGeneratedTokens() + mNumTokensPerIteration >= mMaxNewTokens;
}
[[nodiscard]] LlmRequestType getLlmRequestType() const
{
return mLlmRequestType;
}
/// @brief Add new generated tokens to the vector of tokens and set mLastTokens
/// @param token The token to add
/// @param beam The beam to which to add the new token
/// @return The number of tokens after the new token is added
SizeType32 addNewToken(TokenIdType token, SizeType32 beam)
{
mLastTokens[beam] = token;
mTokens.at(beam).push_back(token);
// New token's extra id is 0
mUniqueTokens.at(beam).push_back({token, 0});
return getNumTokens(beam);
}
/// @brief Add new generated tokens to the vector of tokens and set mLastTokens
/// @param beamTokens A vector containing the tokens to add for each beam index
/// beamTokens is expected to be of size beamWidth
void addNewTokens(VecTokens const& beamTokens)
{
assert(static_cast<size_t>(mSamplingConfig.beamWidth) == beamTokens.size());
mLastTokens = beamTokens;
for (std::size_t beam = 0; beam < beamTokens.size(); ++beam)
{
auto const outputId = beamTokens[beam];
mTokens.at(beam).push_back(outputId);
// New token's extra id is 0
mUniqueTokens.at(beam).push_back({outputId, 0});
}
}
/// @brief Set the number of pre-decoded tokens
/// @param num_tokens The number of pre-decoded tokens
/// @param beam The beam to which to set the number of pre-decoded tokens
void setNumPreDecodedTokens(SizeType32 num_tokens, SizeType32 beam)
{
mNumPreDecodedTokens[beam] = num_tokens;
}
/// @brief Erases all previous generated tokens, only leaving the prompt.
void clearGeneratedTokens()
{
TLLM_LOG_DEBUG("Clearing generated tokens for request %ld with promptlen %d", mRequestId, mPromptLen);
for (auto& beam : mTokens)
{
beam.resize(mPromptLen);
}
}
/// @brief Sets the generated tokens for all beams after gatherTree. Erases all previous generated tokens.
/// @param generatedBeamTokens The generated tokens for all beams (vector of vector of tokens)
void setGeneratedTokens(BeamTokens const& generatedBeamTokens)
{
TLLM_LOG_DEBUG("Setting generated tokens for request %ld", mRequestId);
assert(generatedBeamTokens.size() == static_cast<size_t>(mSamplingConfig.beamWidth));
for (size_t beamId = 0; beamId < generatedBeamTokens.size(); ++beamId)
{
auto& beamTokens = mTokens[beamId];
beamTokens.resize(mPromptLen);
beamTokens.insert(beamTokens.end(), generatedBeamTokens[beamId].begin(), generatedBeamTokens[beamId].end());
auto& beamUniqueTokens = mUniqueTokens[beamId];
beamUniqueTokens.resize(mPromptLen);
for (auto const token : generatedBeamTokens[beamId])
{
beamUniqueTokens.push_back({token, 0});
}
}
}
/// @brief Sets the number of return sequences.
/// @param numReturnSequences The number of return sequences.
void setNumReturnSequences(SizeType32 const& numReturnSequences)
{
TLLM_CHECK_WITH_INFO(!isChild(), "A child request cannot change numReturnSequences.");
TLLM_CHECK_WITH_INFO(
numReturnSequences > 0, "numReturnSequences should be a positive integer, got %d.", numReturnSequences);
TLLM_CHECK_WITH_INFO(mChildRequests.size() <= static_cast<size_t>(numReturnSequences),
"Cannot set numReturnSequences %d smaller than the number %ld of child requests that have already created.",
numReturnSequences, mChildRequests.size());
mSamplingConfig.numReturnSequences = numReturnSequences;
mSequenceFinalVec->resize(numReturnSequences);
}
[[nodiscard]] bool constexpr isChild() const noexcept
{
return mSequenceIndex > 0;
}
[[nodiscard]] RequestIdType getParentRequestId() const
{
return mParentRequestId;
}
/// @brief Return a vector of the last-generated tokens of shape [num_beams]
[[nodiscard]] VecTokens const& getLastTokens()
{
return mLastTokens;
}
/// @brief Return the last-generated token of from a particular beam
[[nodiscard]] TokenIdType const& getLastTokens(SizeType32 beam)
{
return mLastTokens[beam];
}
/// @brief Pause a request by moving the generated tokens to the prompt
/// @param maxInputLen The maximum prompt len.
void pause(SizeType32 maxInputLen)
{
// TODO: For beamWidth > 1, we would need to support swapping to avoid
// recomputing from the start
// As a temporary solution, we currently reset the tokens to the prompt
if (mSamplingConfig.beamWidth > 1)
{
for (std::size_t beam = 0; beam < mTokens.size(); ++beam)
{
auto& beamTokens = mTokens.at(beam);
beamTokens.resize(mPromptLen);
auto& beamUniqueTokens = mUniqueTokens.at(beam);
beamUniqueTokens.resize(mPromptLen);
if (returnLogProbs())
{
mLogProbs.at(beam).clear();
}
}
}
else
{
SizeType32 newPromptLen = std::min(maxInputLen, mPromptLen + getMaxNumGeneratedTokens());
for (std::size_t beam = 0; beam < mTokens.size(); ++beam)
{
auto& beamTokens = mTokens.at(beam);
beamTokens.resize(newPromptLen);
auto& beamUniqueTokens = mUniqueTokens.at(beam);
beamUniqueTokens.resize(newPromptLen);
if (returnLogProbs())
{
auto& logProb = mLogProbs.at(beam);
logProb.resize(newPromptLen - mPromptLen);
}
}
mMaxNewTokens -= (newPromptLen - mPromptLen);
mPromptLen = newPromptLen;
}
// for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase
mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT
: LlmRequestState::kCONTEXT_INIT;
mContextCurrentPositionTarget = 0;
mContextCurrentPositionDraft = 0;
mPrepopulatedPromptLenTarget = 0;
mPrepopulatedPromptLenDraft = 0;
mContextChunkSize = mPromptLen;
mSeqSlot.reset();
}
/// @brief Get the maximum length of tokens returned to the client. Use to ensure we don't return to
/// client duplicated tokens.
/// @return The maximum length of the tokens sent to the client.
[[nodiscard]] SizeType32 getMaxSentTokenLen() const
{
return mMaxSentTokenLen;
}
/// @brief Sets the maximum length of tokens returned to the client. Use to ensure we don't return to
/// client duplicated tokens.
/// @param maxSentLength The new maximum length.
void setMaxSentTokenLen(SizeType32 maxSentLength)
{
mMaxSentTokenLen = maxSentLength;
}
[[nodiscard]] std::optional<TensorPtr> getPromptEmbeddingTable() const
{
return mPromptEmbeddingTable;
}
[[nodiscard]] std::optional<TensorPtr>& getPromptEmbeddingTableMutable()
{
return mPromptEmbeddingTable;
}
[[nodiscard]] std::optional<SizeType32> getPromptVocabSize() const
{
return mPromptVocabSize;
}
[[nodiscard]] std::optional<std::shared_ptr<std::vector<std::vector<SizeType32>>>> getMultimodalHashes() const
{
return mMultimodalHashes;
}
[[nodiscard]] std::optional<std::shared_ptr<std::vector<SizeType32>>> getMultimodalPositions() const
{
return mMultimodalPositions;
}
[[nodiscard]] std::optional<std::shared_ptr<std::vector<SizeType32>>> getMultimodalLengths() const
{
return mMultimodalLengths;
}
[[nodiscard]] std::optional<TensorPtr> getMultimodalEmbedding() const
{
return mMultimodalEmbedding;
}
[[nodiscard]] std::optional<TensorPtr> getMropeRotaryCosSin() const
{
return mMropeRotaryCosSin;
}
[[nodiscard]] std::optional<SizeType32> getMropePositionDeltas() const
{
return mMropePositionDeltas;
}
[[nodiscard]] std::optional<LoraTaskIdType> getLoraTaskId() const
{
return mLoraTaskId;
}
void setLoraTaskId(LoraTaskIdType taskId)
{
mLoraTaskId = taskId;
}
[[nodiscard]] std::optional<TensorPtr> getLoraWeights() const
{
return mLoraWeights;
}
void setLoraWeights(TensorPtr weights)
{
mLoraWeights = weights;
}
void setPromptVocabSize(SizeType32 size)
{
mPromptVocabSize = size;
}
void clearLoraWeights()
{
mLoraWeights = std::nullopt;
}
[[nodiscard]] std::optional<TensorPtr> getLoraConfig() const
{
return mLoraConfig;
}
void setLoraConfig(TensorPtr config)
{
mLoraConfig = config;
}
void clearLoraConfig()
{
mLoraConfig = std::nullopt;
}
[[nodiscard]] std::optional<executor::LookaheadDecodingConfig> getLookaheadConfig() const
{
return mLookaheadConfig;
}
void setLookaheadConfig(executor::LookaheadDecodingConfig config)
{
mLookaheadConfig = config;
}
void clearLookaheadConfig()
{
mLookaheadConfig = std::nullopt;
}
[[nodiscard]] std::optional<executor::KvCacheRetentionConfig> getKvCacheRetentionConfig() const
{
return mKvCacheRetentionConfig;
}
void setKvCacheRetentionConfig(executor::KvCacheRetentionConfig config)
{
mKvCacheRetentionConfig = config;
}
[[nodiscard]] std::optional<executor::EagleConfig> getEagleConfig() const
{
return mEagleConfig;
}
void setEagleConfig(executor::EagleConfig config)
{
mEagleConfig = config;
}
[[nodiscard]] std::optional<executor::GuidedDecodingParams> getGuidedDecodingParams() const
{
return mGuidedDecodingParams;
}
void setGuidedDecodingParams(executor::GuidedDecodingParams guidedDecodingParams)
{
mGuidedDecodingParams = guidedDecodingParams;
}
[[nodiscard]] std::optional<TensorPtr> getEmbeddingBias() const
{
return mEmbeddingBias;
}
[[nodiscard]] std::optional<TensorPtr> getBadWordsList() const
{
return mBadWordsList;
}
[[nodiscard]] std::optional<TensorPtr> getStopWordsList() const
{
return mStopWordsList;
}
[[nodiscard]] bool returnLogProbs() const
{
return mSamplingConfig.outputLogProbs.has_value() ? mSamplingConfig.outputLogProbs->at(0) : false;
}
void setReturnLogProbs(bool returnLogProbs)
{
mSamplingConfig.outputLogProbs = {{returnLogProbs}};
mSamplingConfig.cumLogProbs = {{returnLogProbs}};
}
[[nodiscard]] std::vector<VecLogProbs> const& getLogProbs() const
{
return mLogProbs;
}
[[nodiscard]] VecLogProbs const& getLogProbs(SizeType32 beam) const
{
return mLogProbs.at(beam);
}
void setLogProbs(VecLogProbs const& logProbs, SizeType32 beam)
{
mLogProbs.at(beam).resize(mPromptLen - mOrigPromptLen);
mLogProbs.at(beam).insert(mLogProbs.at(beam).end(), logProbs.begin(), logProbs.end());
}
[[nodiscard]] VecLogProbs const& getCumLogProbs() const
{
return mCumLogProbs;
}
void setCumLogProb(float cumLogProb, SizeType32 beam)
{
mCumLogProbs.at(beam) = cumLogProb;
}
[[nodiscard]] SizeType32 getOrigPromptLen() const
{
return mOrigPromptLen;
}
[[nodiscard]] SizeType32 getPromptLen() const
{
return mPromptLen;
}
[[nodiscard]] SizeType32 getPrepopulatedPromptLen() const
{
return mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget;
}
void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock)
{
// Add debug log for prepopulatedPromptLen
TLLM_LOG_DEBUG("Setting pre-populated prompt length for request %lu to %i (promptLen=%i).", mRequestId,
prepopulatedPromptLen, getPromptLen());
auto const promptLen = getPromptLen();
// This check is make sure prepopulated prompt length (tokens already cached in KV cache) is less than prompt
// length (total tokens in the prompt)
TLLM_CHECK_WITH_INFO(prepopulatedPromptLen < promptLen,
"Invalid state: prepopulatedPromptLen (%d) >= promptLen (%d) for request %lu", prepopulatedPromptLen,
promptLen, mRequestId);
auto& prePromptLen = mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget;
auto& contextCurrentPosition = mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget;
prePromptLen = prepopulatedPromptLen;
if (prepopulatedPromptLen > 0)
{
// Currently, the runtime process is to apply for cache first and then determine prepopulation.
// Use the prepopulated length to advance the context position and decrease chunk size if necessary.
auto chunkSize = getContextChunkSize();
if (prepopulatedPromptLen + chunkSize < promptLen)
{
// make sure to end at block boundary after current chunk
auto const flooredEndPosition
= (prepopulatedPromptLen + chunkSize) / kvTokensPerBlock * kvTokensPerBlock;
chunkSize = flooredEndPosition - prepopulatedPromptLen;
TLLM_CHECK(chunkSize <= getContextChunkSize());
}
contextCurrentPosition = prepopulatedPromptLen;
setContextChunkSize(chunkSize);
if (!isLastContextChunk())
{
TLLM_CHECK_WITH_INFO((getContextCurrentPosition() + getContextChunkSize()) % kvTokensPerBlock == 0,
"To prevent cache fragmentation, the context position after current chunk should be divisible "
"by the number of tokens per block, except for the last chunk.");
}
}
}
void setDraftTokens(std::shared_ptr<VecTokens> const& draftTokens)
{
mDraftTokens = draftTokens;
}
void setDraftLogits(std::optional<TensorPtr> const& draftLogits)
{
mDraftLogits = draftLogits;
}
[[nodiscard]] SizeType32 getNumDraftTokens() const noexcept
{
return hasDraftTokens() ? static_cast<SizeType32>(mDraftTokens->size()) : 0;
}
void discardDraftTokens(SizeType32 numTokensToDiscard)
{
TLLM_CHECK_WITH_INFO(
numTokensToDiscard > 0, "Can only discard a positive amount of draft tokens, got %d", numTokensToDiscard);
TLLM_CHECK_WITH_INFO(numTokensToDiscard <= getNumDraftTokens(),
"Can't discard more draft tokens (%d) than exists (%d).", numTokensToDiscard, getNumDraftTokens());
mDraftTokens->resize(getNumDraftTokens() - numTokensToDiscard);
if (mDraftLogits)
{
auto shape = mDraftLogits.value()->getShape();
shape.d[0] = getNumDraftTokens();
mDraftLogits.value()->reshape(shape);
}
}
void updateNumTokensPerIteration(SizeType32 numTokensPerIteration, runtime::ModelConfig const& modelConfig)
{
mNumTokensPerIteration = std::max(1, numTokensPerIteration);
if (modelConfig.hasSpeculativeDecodingModule() && getReturnPerfMetrics() && hasDraftTokens())
{
auto& specDecMetrics = mPerfMetrics.speculativeDecoding;
specDecMetrics.totalAcceptedDraftTokens += mNumTokensPerIteration - 1;
auto const maxAcceptedDraftTokens = modelConfig.getSpeculativeDecodingModule().getMaxDraftPathLen();
specDecMetrics.totalDraftTokens += std::min(getNumDraftTokens(), maxAcceptedDraftTokens);
}
}
[[nodiscard]] SizeType32 getNumTokensPerIteration() const
{
return mNumTokensPerIteration;
}
void setReturnEncoderOutput(bool const returnEncoderOutput)
{
mReturnEncoderOutput = returnEncoderOutput;
}
[[nodiscard]] bool getReturnEncoderOutput() const
{
return mReturnEncoderOutput;
}
[[nodiscard]] TensorPtr const& getEncoderOutputHost() const
{
return mEncoderOutputHost;
}
[[nodiscard]] TensorPtr getEncoderInputFeatures() const
{
return mEncoderInputFeatures.value_or(nullptr);
}
void setEncoderOutputHost(TensorPtr encoderOutputHost)
{
mEncoderOutputHost = std::move(encoderOutputHost);
}
void setEncoderOutput(TensorPtr encoderOutput)
{
mEncoderOutput = std::move(encoderOutput);
}
void allocEncoderOutputHost(SizeType32 encoderHiddenSize, nvinfer1::DataType dataType)
{
mEncoderOutputHost = runtime::BufferManager::pinned(
runtime::ITensor::makeShape({getEncoderOutputLen(), encoderHiddenSize}), dataType);
}
[[nodiscard]] TensorPtr const& getEncoderOutput() const noexcept
{
return mEncoderOutput;
}
[[nodiscard]] TensorPtr const& getEncoderHiddenStates() const noexcept
{
return mEncoderHiddenStates;
}
void allocEncoderOutput(runtime::BufferManager const& manager, nvinfer1::DataType dataType)
{
// unique_ptr --> shared_ptr ownership move
mEncoderOutput = std::move(manager.emptyTensor(runtime::MemoryType::kGPU, dataType));
}
void allocEncoderHiddenStates(runtime::BufferManager const& manager, nvinfer1::DataType dataType)
{
// unique_ptr --> shared_ptr ownership move
mEncoderHiddenStates = std::move(manager.emptyTensor(runtime::MemoryType::kGPU, dataType));
}
void freeEncoderOutputBuffers()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_DEBUG(
"Encoder output buffers use count: %u, %u", mEncoderOutput.use_count(), mEncoderHiddenStates.use_count());
// TODO: better ways to free shared_ptr buffers
mEncoderOutput.reset();
mEncoderHiddenStates.reset();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
[[nodiscard]] TensorPtr getCrossAttentionMask() const
{
return mCrossAttentionMask.value_or(nullptr);
}
[[nodiscard]] TensorPtr getSkipCrossAttnBlocks() const
{
return mSkipCrossAttnBlocks.value_or(nullptr);
}
[[nodiscard]] bool constexpr getReturnPerfMetrics() const noexcept
{
return mReturnPerfMetrics;
}
void constexpr setReturnPerfMetrics(bool returnPerfMetrics) noexcept
{
mReturnPerfMetrics = returnPerfMetrics;
}
[[nodiscard]] executor::RequestPerfMetrics const& getPerfMetrics() const noexcept
{
return mPerfMetrics;
}
void setFirstScheduledTime()
{
if (mPerfMetrics.timingMetrics.firstScheduledTime == executor::RequestPerfMetrics::TimePoint{})
{
mPerfMetrics.timingMetrics.firstScheduledTime = getSteadyClockNow();
}
}
[[nodiscard]] bool constexpr isStreaming() const noexcept
{
return mIsStreaming;
}
void constexpr setStreaming(bool isStreaming) noexcept
{
mIsStreaming = isStreaming;
}
void setPriority(executor::PriorityType priority) noexcept
{
mPriority = priority;
}
void setReturnAllGeneratedTokens(bool const returnAllGeneratedTokens)
{
TLLM_CHECK_WITH_INFO(!mIsStreaming || mSamplingConfig.beamWidth == 1 || returnAllGeneratedTokens,
"returnAllGeneratedTokens must be true if streaming AND beam search are used.");
mReturnAllGeneratedTokens = returnAllGeneratedTokens;
}
[[nodiscard]] bool getReturnAllGeneratedTokens()
{
return mReturnAllGeneratedTokens;
}
void setAllottedTimeMs(MillisecondsType allottedTimeMs)
{
mAllottedTimeMs = allottedTimeMs;
}
void setReturnContextLogits(bool const returnContextLogits)
{
mReturnContextLogits = returnContextLogits;
}
[[nodiscard]] bool getReturnContextLogits() const
{
return mReturnContextLogits;
}
void setReturnGenerationLogits(bool const returnGenerationLogits)
{
TLLM_CHECK_WITH_INFO(!(mIsStreaming && mSamplingConfig.beamWidth > 1 && returnGenerationLogits),
"returnGenerationLogits must be false if streaming AND beam search are used.");
mReturnGenerationLogits = returnGenerationLogits;
}
[[nodiscard]] bool getReturnGenerationLogits() const
{
return mReturnGenerationLogits;
}
[[nodiscard]] TensorPtr const& getContextLogitsHost() const
{
return mContextLogitsHost;
}
/// @param contextLogitsHost Expected shape [promtLen, vocabSizePadded]
void setContextLogitsHost(TensorPtr contextLogitsHost)
{
mContextLogitsHost = std::move(contextLogitsHost);
}
void allocContextLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType)
{
mContextLogitsHost = runtime::BufferManager::pinnedPool(
runtime::ITensor::makeShape({mPromptLen, vocabSizePadded}), logitsDataType);
}
[[nodiscard]] TensorPtr const& getGenerationLogitsHost() const
{
return mGenerationLogitsHost;
}
/// @param generationLogitsHost Expected shape
/// * [beamWidth, maxNewTokens, vocabSizePadded] for non-speculative decoding
/// * [1, numDraftTokens + 1, vocabSizePadded] for speculative decoding
void setGenerationLogitsHost(TensorPtr generationLogitsHost)
{
mGenerationLogitsHost = std::move(generationLogitsHost);
}
void allocGenerationLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType)
{
if (mIsStreaming)
{
// If streaming mode, the complete generation logits shape will be [1, beamWidth, vocabSizePadded],
// or [allGeneratedTokens, beamWidth, vocabSizePadded] if mReturnAllGeneratedTokens is True.
// This could reduce unnecessary format conversions and allows the data to be returned directly.
mGenerationLogitsHost = runtime::BufferManager::pinnedPool(
runtime::ITensor::makeShape({mMaxNewTokens, mSamplingConfig.beamWidth, vocabSizePadded}),
logitsDataType);
}
else
{
mGenerationLogitsHost = runtime::BufferManager::pinnedPool(
runtime::ITensor::makeShape({mSamplingConfig.beamWidth, mMaxNewTokens, vocabSizePadded}),
logitsDataType);
}
}
void allocTargetModelAcceptedTokenLogitsHost(SizeType32 vocabSizePadded, nvinfer1::DataType logitsDataType)
{
mGenerationLogitsHost = runtime::BufferManager::pinnedPool(
runtime::ITensor::makeShape({1, getNumDraftTokens() + 1, vocabSizePadded}), logitsDataType);
}
[[nodiscard]] std::vector<TensorPtr> const& getGenerationLogitsFragments() const
{
return mGenerationLogitsFragments;
}
void addGenerationLogitsFragment(TensorPtr& genLogits)
{
mGenerationLogitsFragments.push_back(genLogits);
}
[[nodiscard]] SizeType32 getGenerationLogitsFragmentsSize() const noexcept
{
return static_cast<SizeType32>(mGenerationLogitsFragments.size());
}
void clearGenerationLogitsFragments() noexcept
{
mGenerationLogitsFragments.clear();
}
[[nodiscard]] bool hasAdditionalOutputs() const noexcept
{
return !mAdditionalContextOutputTensors.empty() || !mAdditionalGenerationOutputTensors.empty();
}
[[nodiscard]] TensorMap const& getAdditionalContextOutputs() const
{
return mAdditionalContextOutputTensors;
}
[[nodiscard]] TensorMap const& getAdditionalGenerationOutputs() const
{
return mAdditionalGenerationOutputTensors;
}
template <typename TypeFunc, typename ShapeFunc>
void allocAdditionalOutputs(TypeFunc getTensorDataType, ShapeFunc getTensorShape)
{
for (auto& outputTensor : mAdditionalContextOutputTensors)
{
auto const& outputTensorName = outputTensor.first;
auto const dataType = getTensorDataType(outputTensorName);
auto shape = getTensorShape(outputTensorName);
TLLM_CHECK_WITH_INFO(shape.d[0] == -1, "First dimension of additional output tensor '%s' must be dynamic",
outputTensorName.c_str());
shape.d[0] = mPromptLen;
auto tensor = runtime::BufferManager::pinnedPool(shape, dataType);
outputTensor.second = std::move(tensor);
}
for (auto& outputTensor : mAdditionalGenerationOutputTensors)
{
auto const& outputTensorName = outputTensor.first;
auto const dataType = getTensorDataType(outputTensorName);
auto shape = getTensorShape(outputTensorName);
TLLM_CHECK_WITH_INFO(shape.d[0] == -1, "First dimension of additional output tensor '%s' must be dynamic",
outputTensorName.c_str());
shape.d[0] = mMaxNewTokens;
shape = runtime::ITensor::unsqueeze(shape, 0);
shape.d[0] = mSamplingConfig.beamWidth;
auto tensor = runtime::BufferManager::pinnedPool(shape, dataType);
outputTensor.second = std::move(tensor);
}
}
void setState(LlmRequestState state)
{
TLLM_LOG_DEBUG("Set request %lu from state %d to %d", mRequestId, mState, state);
mState = state;
}
[[nodiscard]] LlmRequestState getState() const noexcept
{
return mState;
}
[[nodiscard]] bool hasReachedState(LlmRequestState state) const noexcept
{
return mState >= state;
}
[[nodiscard]] bool isEncoderInitState() const noexcept
{
return mState == LlmRequestState::kENCODER_INIT;
}
[[nodiscard]] bool isContextInitState() const noexcept
{
return mState == LlmRequestState::kCONTEXT_INIT || mState == LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS;
}
[[nodiscard]] bool isContextFinished() const noexcept
{
return isGenerationInProgressState() || mState == LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS;
}
[[nodiscard]] bool isGenerationInProgressState() const noexcept
{
return mState == LlmRequestState::kGENERATION_IN_PROGRESS || mState == LlmRequestState::kGENERATION_TO_COMPLETE
|| mState == LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE;
}
[[nodiscard]] bool isGenerationToCompleteState() const noexcept
{
return mState == LlmRequestState::kGENERATION_TO_COMPLETE;
}
[[nodiscard]] bool isGenerationCompleteState() const noexcept
{
return mState == LlmRequestState::kGENERATION_COMPLETE;
}
[[nodiscard]] bool isDisaggGenerationInitState() const noexcept
{
return mState == LlmRequestState::kDISAGG_GENERATION_INIT;
}
[[nodiscard]] bool isDisaggGenerationTransmissionComplete() const noexcept
{
return mState == LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE;
}
[[nodiscard]] bool isDisaggGenerationTransmissionInProgress() const noexcept
{
return mState == LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS;
}
[[nodiscard]] bool isDisaggContextTransmissionState() const noexcept
{
return mState == LlmRequestState::kDISAGG_CONTEXT_TRANS_IN_PROGRESS
|| mState == LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS;
}
[[nodiscard]] bool isDisaggContextCompleteState() const noexcept
{
return mState == LlmRequestState::kDISAGG_CONTEXT_COMPLETE;
}
[[nodiscard]] executor::RequestStage getRequestStage() const
{
switch (mState)
{
case batch_manager::LlmRequestState::kENCODER_INIT: return executor::RequestStage::kENCODER_IN_PROGRESS; break;
case batch_manager::LlmRequestState::kCONTEXT_INIT: return executor::RequestStage::kCONTEXT_IN_PROGRESS; break;
case batch_manager::LlmRequestState::kGENERATION_IN_PROGRESS:
case batch_manager::LlmRequestState::kGENERATION_TO_COMPLETE:
case batch_manager::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE:
case batch_manager::LlmRequestState::kDISAGG_GENERATION_INIT:
case batch_manager::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS:
return executor::RequestStage::kGENERATION_IN_PROGRESS;
break;
default: TLLM_LOG_ERROR("Unexpected request state."); return executor::RequestStage::kGENERATION_COMPLETE;
}
}
[[nodiscard]] bool isContextOnlyRequest() const noexcept
{
return mLlmRequestType == LlmRequestType::LLMREQUEST_TYPE_CONTEXT_ONLY;
}
[[nodiscard]] bool isGenerationOnlyRequest() const noexcept
{
return mLlmRequestType == LlmRequestType::LLMREQUEST_TYPE_GENERATION_ONLY;
}
void setContextCurrentPosition(SizeType32 contextCurrentPosition)
{
mContextCurrentPositionDraft = contextCurrentPosition;
mContextCurrentPositionTarget = contextCurrentPosition;
}
/// When chunked, the position of the current chunk is returned. Otherwise, only the beginning
/// or end of the context is returned.
[[nodiscard]] SizeType32 getContextCurrentPosition() const noexcept
{
return mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget;
}
/// Return the length of the context that has not yet been processed.
[[nodiscard]] SizeType32 getContextRemainingLength() const noexcept
{
return mPromptLen - getContextCurrentPosition();
}
[[nodiscard]] SizeType32 getContextChunkSize() const
{
TLLM_CHECK_WITH_INFO(
isContextInitState() || isDisaggGenerationInitState() || isDisaggGenerationTransmissionComplete(),
"getContextChunkSize is only possible during the context phase or generation init phase.");
return mContextChunkSize;
}
/// To set the context chunk size, throw an exception when the chunk size is negative. If the chunk
/// size is greater than the remaining length of the context, the size will be reduced to fit the
/// remaining length.
void setContextChunkSize(SizeType32 size)
{
TLLM_CHECK_WITH_INFO(
isContextInitState() || isDisaggGenerationInitState() || isDisaggGenerationTransmissionComplete(),
"setContextChunkSize is only possible during the context phase or generation init phase.");
TLLM_CHECK_WITH_INFO(size >= 0, "The chunk size of context (%d) can't be negative.", size);
mContextChunkSize = std::min(size, getContextRemainingLength());
}
/// Determines whether the current position is only one chunk away from the end of the context.
[[nodiscard]] bool isLastContextChunk() const
{
return isDisaggGenerationInitState() || isDisaggGenerationTransmissionComplete()
|| getContextCurrentPosition() + getContextChunkSize() == mPromptLen;
}
/// Returns whether the position is at the beginning of the context.
[[nodiscard]] bool isFirstContextChunk() const noexcept
{
// The number of cached token is encountered in mContextCurrentPosition,
// so the start position of the context is mPrepopulatedPromptLen.
return getContextCurrentPosition() == getPrepopulatedPromptLen();
}
/// Move the cursor forward one chunk. When not chunked, move forward to the end of the context.
void moveToNextContextChunk()
{
TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase.");
mContextCurrentPositionDraft += getContextChunkSize();
mContextCurrentPositionTarget += getContextChunkSize();
setContextChunkSize(0);
}
[[nodiscard]] executor::PriorityType priority() const noexcept
{
return mPriority;
}
/// Get the counter of decoding iterations.
SizeType32 getDecodingIter()
{
return mDecodingIter;
}
/// Increment the counter of decoding iterations.
void advanceDecodingIter()
{
mDecodingIter++;
}
/// @brief Return the average number of decoded tokens per iteration. For standard model it is 1.
/// For speculative decoding model >= 1 -- number of draft tokens accepted per step + 1.
[[nodiscard]] float getAvgDecodedTokensPerIter() const noexcept
{
if (mDecodingIter == 0)
{
return 0.F;
}
return static_cast<float>(getMaxNumGeneratedTokens()) / mDecodingIter;
}
/// @brief Get the beam width of the current decoding step.
/// @details Return `mSamplingConfig.beamWidth` in decoding modes beside Variable-Beam-Width-Search (VBWS).
/// Or returns a scalar value from `mSamplingConfig.beamWidthArray` indexing by `mDecodingIter` in VBWS.
///
/// Calling in context phase, it returns the beam width of the first generation step, which is used for copying
/// logits (function `copyGenerationLogits` as example).
///
/// Calling in generation phase, it returns the beam width of the input tokens in the current generation step, which
/// is used for computing I/O tensor shapes for TRT engine (function `RuntimeBuffers::setBufferSizes` as example).
///
/// For example, we have a request with beamWidthArray = [2,3,4], the generation process can be:
///
/// input_ids[1,inputLength] --->
/// ---> [Forward, step == 0] ---> logits[1, 1, vocabSize] ---> [BeamSearchDecoder] ---> tokens[1, 2]
/// Context Phase, getBeamWidthByIter() returns 2 for copying logits
/// Decoder uses beamWidthIn=2, beamWidthOut=2 to get top 2 tokens
/// ---> [Forward, step == 1] ---> logits[1, 2, vocabSize] ---> [BeamSearchDecoder] ---> tokens[1, 3]
/// Generation phase, getBeamWidthByIter() returns 2 for computing tensor shapes
/// Decoder uses beamWidthIn=2, beamWidthOut=3 to get top 3 tokens
/// ---> [Forward, step == 2] ---> logits[1, 3, vocabSize] ---> [BeamSearchDecoder] ---> tokens[1, 4]
/// Generation phase, getBeamWidthByIter() returns 3 for computing tensor shapes
/// Decoder uses beamWidthIn=3, beamWidthOut=4 to get top 4 tokens
/// ---> [Forward, step == 3] ---> logits[1, 4, vocabSize] ---> [BeamSearchDecoder] ---> tokens[1, 4]
/// Generation phase, getBeamWidthByIter() returns 4 for computing tensor shapes
/// Decoder uses beamWidthIn=4, beamWidthOut=4 to get top 4 tokens
/// i.e. the same as normal Beam Search of `beamWidth==4`
/// @param: forNextIteration: get beam width for next step rather than current beam width.
[[nodiscard]] SizeType32 getBeamWidthByIter(bool forNextIteration = false);
[[nodiscard]] bool isFinished() const noexcept
{
return isGenerationCompleteState() || mState == LlmRequestState::kDISAGG_CONTEXT_TRANS_IN_PROGRESS;
}
/// Returns true if finished_reason is length for all beams
[[nodiscard]] bool isFinishedDueToLength() const noexcept
{
return std::all_of(mFinishReasons.begin(), mFinishReasons.end(),
[](auto reason) { return reason == executor::FinishReason::kLENGTH; });
}
[[nodiscard]] bool isFinishedDueToCancellation() const noexcept
{
return std::all_of(mFinishReasons.begin(), mFinishReasons.end(),
[](auto reason) { return reason == executor::FinishReason::kCANCELLED; });
}
[[nodiscard]] bool isTimedOut() const
{
if (!mAllottedTimeMs.has_value())
{
return false;
}
auto const currentTime = std::chrono::steady_clock::now();
auto const elapsed = (std::chrono::duration_cast<std::chrono::milliseconds>(currentTime - mStartTime));
TLLM_LOG_DEBUG("Checked timeOut for request %ld with allotted Time %ld after time %ld and got %d", mRequestId,
mAllottedTimeMs->count(), elapsed.count(), (elapsed >= mAllottedTimeMs));
return elapsed >= *mAllottedTimeMs;
}
void setFinishedReason(executor::FinishReason reason, SizeType32 beam)
{
mFinishReasons.at(beam) = reason;
}
void setDecodingIter(SizeType32 iter)
{
mDecodingIter = iter;
}
void setKvCacheTransferStart(TimePoint time) const
{
mPerfMetrics.timingMetrics.kvCacheTransferStart = maybeToGlobalSteadyClock(time);
}
void setKvCacheTransferEnd(TimePoint time) const
{
mPerfMetrics.timingMetrics.kvCacheTransferEnd = maybeToGlobalSteadyClock(time);
}
TimePoint getKvCacheTransferStart() const
{
return mPerfMetrics.timingMetrics.kvCacheTransferStart;
}
TimePoint getKvCacheTransferEnd() const
{
return mPerfMetrics.timingMetrics.kvCacheTransferEnd;
}
[[nodiscard]] double getKvCacheTransferTimeMS() const
{
// get max with 0 in case this function is called while end time is not recorded
return std::max(0.0,
std::chrono::duration<double, std::milli>(
mPerfMetrics.timingMetrics.kvCacheTransferEnd - mPerfMetrics.timingMetrics.kvCacheTransferStart)
.count());
}
void updateKvCacheSize(size_t targetBufferSize) const
{
mPerfMetrics.timingMetrics.kvCacheSize += targetBufferSize;
}
void setKvCacheSize(size_t targetBufferSize) const
{
mPerfMetrics.timingMetrics.kvCacheSize = targetBufferSize;
}
[[nodiscard]] size_t getKvCacheSize() const
{
return mPerfMetrics.timingMetrics.kvCacheSize;
}
void updateAllocTotalBlocksPerRequest(SizeType32 allocTotalBlocksPerRequest)
{
mPerfMetrics.kvCacheMetrics.numTotalAllocatedBlocks += allocTotalBlocksPerRequest;
}
[[nodiscard]] SizeType32 getAllocTotalBlocksPerRequest() const
{
return mPerfMetrics.kvCacheMetrics.numTotalAllocatedBlocks;
}
void updateAllocNewBlocksPerRequest(SizeType32 allocNewBlocksPerRequest)
{
mPerfMetrics.kvCacheMetrics.numNewAllocatedBlocks += allocNewBlocksPerRequest;
}
[[nodiscard]] SizeType32 getAllocNewBlocksPerRequest() const
{
return mPerfMetrics.kvCacheMetrics.numNewAllocatedBlocks;
}
void updateReusedBlocksPerRequest(SizeType32 reusedBlocksPerRequest)
{
mPerfMetrics.kvCacheMetrics.numReusedBlocks += reusedBlocksPerRequest;
}
[[nodiscard]] SizeType32 getReusedBlocksPerRequest() const
{
return mPerfMetrics.kvCacheMetrics.numReusedBlocks;
}
[[nodiscard]] std::optional<SizeType32> getLanguageAdapterUid() const
{
return mLanguageAdapterUid;
}
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const
{
return mCacheSaltID;
}
std::vector<SizeType32> getLanguageAdapterRouting(
SizeType32 const reqNumLanguages, SizeType32 const inputLength) const
{
auto const reqLanguageAdapterUid = getLanguageAdapterUid().value();
TLLM_CHECK_WITH_INFO(reqLanguageAdapterUid < reqNumLanguages, "Language adapter uid is out of range.\n");
// Copy the same routing info for all the tokens in this request
return std::vector<SizeType32>(inputLength, reqLanguageAdapterUid);
}
/// @brief mark all beams as finished by the given reason. Marks only unfinished beams.
void finishByReason(executor::FinishReason finishReason)
{
if (finishReason == executor::FinishReason::kTIMED_OUT)
{
TLLM_LOG_DEBUG("Request %ld finished by timeout after %f sec", mRequestId,
std::chrono::duration<float>(getSteadyClockNow() - mStartTime).count());
}
if (finishReason == executor::FinishReason::kCANCELLED)
{
TLLM_LOG_DEBUG("Request %ld finished by cancel", mRequestId);
}
for (int beam = 0; beam < mSamplingConfig.beamWidth; ++beam)
{
if (mFinishReasons.at(beam) == executor::FinishReason::kNOT_FINISHED)
{
setFinishedReason(finishReason, beam);
}
}
mState = LlmRequestState::kGENERATION_COMPLETE;
}
void updateMissedBlocksPerRequest(SizeType32 missedBlocksPerRequest)
{
mPerfMetrics.kvCacheMetrics.numMissedBlocks += missedBlocksPerRequest;
}
[[nodiscard]] SizeType32 getMissedBlocksPerRequest() const
{
return mPerfMetrics.kvCacheMetrics.numMissedBlocks;
}
[[nodiscard]] float getKVCacheHitRatePerRequest() const
{
return mPerfMetrics.kvCacheMetrics.numReusedBlocks == 0
? 0
: static_cast<float>(mPerfMetrics.kvCacheMetrics.numReusedBlocks)
/ (static_cast<float>(
mPerfMetrics.kvCacheMetrics.numReusedBlocks + mPerfMetrics.kvCacheMetrics.numMissedBlocks));
}
void updatePerfMetrics(executor::IterationType iter)
{
auto const currentTokenTime = getSteadyClockNow();
if (!mPerfMetrics.firstIter)
{
mPerfMetrics.firstIter = iter;
mPerfMetrics.timingMetrics.firstTokenTime = currentTokenTime;
}
mPerfMetrics.iter = iter;
if (isFinished())
{
mPerfMetrics.lastIter = iter;
mPerfMetrics.timingMetrics.lastTokenTime = currentTokenTime;
}
}
void setIsDummyRequest(bool isDummyRequest)
{
mIsDummyRequest = isDummyRequest;
}
[[nodiscard]] bool isDummyRequest() const
{
return mIsDummyRequest;
}
void setUseDraftModel(bool useDraftModel)
{
mUseDraftModel = useDraftModel;
}
[[nodiscard]] bool useDraftModel() const
{
return mUseDraftModel;
}
// If sGlobalSteadyClockOffset is set, return a global steady clock time point, otherwise return local steady clock
// time point
[[nodiscard]] static TimePoint getSteadyClockNow()
{
return maybeToGlobalSteadyClock(std::chrono::steady_clock::now());
}
RequestIdType mRequestId;
SizeType32 mPromptLen;
SizeType32 mMaxNewTokens;
runtime::SamplingConfig mSamplingConfig;
std::optional<TokenIdType> mEndId{std::nullopt};
std::optional<TokenIdType> mPadId{std::nullopt};
std::optional<SizeType32> mSeqSlot{std::nullopt};
std::optional<LogitsPostProcessor> mLogitsPostProcessor{std::nullopt};
bool mApplyLogitsPostProcessorBatched{false};
std::optional<RequestIdType> mClientId{std::nullopt};
// Position of mask token in GLM model inputs
SizeType32 mMaskPosition{0};
LlmRequestState mState{LlmRequestState::kCONTEXT_INIT};
// current position of the prompt tuning table (only used in chunked prefill mode)
SizeType32 mPtableCurrentPosition{0};
// The offset between local steady clock and global steady clock (at rank 0)
inline static std::optional<Duration> sGlobalSteadyClockOffset{std::nullopt};
protected:
bool mIsStreaming;
// List of tokens generated at the current step, used as the input tokens to the next step, [beamSize]
// `mLastTokens[beam]` is not equal to `mTokens.back()[beam]` in "Streaming + Beam Search" mode
// since `mTokens` will be overwritten by the gathered tokens.
VecTokens mLastTokens;
// List of tokens including input prompt and generated part, [beamSize, mPromptLen + getMaxNumGeneratedTokens()]
BeamTokens mTokens;
// Length of input prompt tokens, never changes during generation process.
SizeType32 mOrigPromptLen;
// List of numbers of pre-deocded tokens on the last PP rank when using pipeline parallelism.
// It is introduced as a WAR to solve the hanging problem caused by overestimating the used KV cache on the last PP
// rank (because new tokens are decoded earlier). By excluding the numbers of pre-decoded tokens, the used KV cache
// can be estimated correctly.
std::vector<SizeType32> mNumPreDecodedTokens;
// Number of tokens already in KV cache before context phase.
// A value > 0 indicates cached KV cache blocks were reused.
// Up to inputLen - 1 tokens can be reused.
SizeType32 mPrepopulatedPromptLenTarget{0};
SizeType32 mPrepopulatedPromptLenDraft{0};
SizeType32 mMaxSentTokenLen;
std::optional<TensorPtr> mEmbeddingBias{std::nullopt};
std::optional<TensorPtr> mBadWordsList{std::nullopt};
std::optional<TensorPtr> mStopWordsList{std::nullopt};
std::optional<std::shared_ptr<std::vector<SizeType32>>> mPositionIds{std::nullopt};
std::optional<TensorPtr> mPromptEmbeddingTable{std::nullopt};
std::optional<SizeType32> mPromptVocabSize{std::nullopt};
std::optional<std::shared_ptr<std::vector<std::vector<SizeType32>>>> mMultimodalHashes{std::nullopt};
std::optional<std::shared_ptr<std::vector<SizeType32>>> mMultimodalPositions{std::nullopt};
std::optional<std::shared_ptr<std::vector<SizeType32>>> mMultimodalLengths{std::nullopt};
std::optional<TensorPtr> mMultimodalEmbedding{std::nullopt};
std::optional<TensorPtr> mMropeRotaryCosSin{std::nullopt};
std::optional<SizeType32> mMropePositionDeltas{std::nullopt};
std::optional<LoraTaskIdType> mLoraTaskId{std::nullopt};
std::optional<TensorPtr> mLoraWeights{std::nullopt};
std::optional<TensorPtr> mLoraConfig{std::nullopt};
std::optional<executor::LookaheadDecodingConfig> mLookaheadConfig{std::nullopt};
std::optional<executor::KvCacheRetentionConfig> mKvCacheRetentionConfig{std::nullopt};
// Paged-KV-Cache must be enabled while enabling Chunked-Context.
// The size of the context chunk must be multiple of the KV-Cache block size except the last one.
// Value `0` means Chunked-Context is disabled.
SizeType32 mContextChunkSize{0};
SizeType32 mContextCurrentPositionTarget{0};
SizeType32 mContextCurrentPositionDraft{0};
std::vector<VecLogProbs> mLogProbs; // [beamSize, seqLen]
VecLogProbs mCumLogProbs; // [beamSize]
std::shared_ptr<VecTokens> mDraftTokens{nullptr};
std::optional<TensorPtr> mDraftLogits{std::nullopt};
SizeType32 mNumTokensPerIteration{1};
// whether to return the full beams on each iteration. True when doing streaming + beamsearch
bool mReturnAllGeneratedTokens;
// Save logits
bool mReturnContextLogits;
bool mReturnGenerationLogits;
bool mReturnLogProbs;
TensorPtr mContextLogitsHost; // [mPromptLen, vocabSizePadded]
TensorPtr mGenerationLogitsHost; // [beamSize, mMaxNewTokens, vocabSizePadded]
std::vector<TensorPtr> mGenerationLogitsFragments;
bool mExcludeInputFromOutput;
// Encoder-only and Encoder-Decoder models
// Encoder input tokens
std::optional<std::shared_ptr<VecTokens>> mEncoderTokens{std::nullopt};
bool mReturnEncoderOutput;
// Encoder output, used to compute cross attention KV-Cache.
TensorPtr mEncoderOutput; // [numTokens, hidden_size]
TensorPtr mEncoderHiddenStates; // [numTokens, hiddenSize] for for Pipeline-Parallelism
TensorPtr mEncoderOutputHost; // [mEncoderOutputLength, encoderHiddenSize]
SizeType32 mDecodingIter{0};
executor::PriorityType mPriority;
std::vector<executor::FinishReason> mFinishReasons;
// Input features of encoder for multimodal models.
std::optional<TensorPtr> mEncoderInputFeatures{std::nullopt};
// Setting buffer sizes correctly for models like Whisper,
// which encoder output shape cannot be inferred from encoder input shape due to downsampling.
std::optional<SizeType32> mEncoderOutputLength{std::nullopt};
// Input cross attention mask.
std::optional<TensorPtr> mCrossAttentionMask{std::nullopt};
LlmRequestType mLlmRequestType;
std::optional<executor::ContextPhaseParams> mContextPhaseParams{std::nullopt};
std::shared_ptr<ContextProgress> mContextProgress{nullptr};
std::optional<std::shared_ptr<VecTokenExtraIds>> mInputTokenExtraIds{std::nullopt};
BeamUniqueTokens mUniqueTokens;
// TODO: add real extra id for encoder tokens.
std::optional<std::shared_ptr<VecUniqueTokens>> mEncoderUniqueTokens{std::nullopt};
SizeType32 mNumReturnSequences{1};
// Config for Eagle speculative decoding.
std::optional<executor::EagleConfig> mEagleConfig{std::nullopt};
SizeType32 mSequenceIndex{0};
std::vector<RequestPtr> mChildRequests;
RequestIdType mParentRequestId;
// Indicators whether each sibling completes generation.
std::shared_ptr<std::vector<bool>> mSequenceFinalVec;
std::optional<TensorPtr> mSkipCrossAttnBlocks{std::nullopt};
// Performance metrics. Should be updatable even from a const LlmRequest reference.
bool mReturnPerfMetrics{false};
mutable executor::RequestPerfMetrics mPerfMetrics;
// Guided decoding params.
std::optional<executor::GuidedDecodingParams> mGuidedDecodingParams{std::nullopt};
std::optional<SizeType32> mLanguageAdapterUid{std::nullopt};
// Timepoint at which the request started. Used for tracking the timeout
std::chrono::steady_clock::time_point mStartTime;
// Time in milliseconds after which the model is finished with a `timeout` finishReason.
std::optional<MillisecondsType> mAllottedTimeMs{std::nullopt};
// Tensors containing the additional context output.
TensorMap mAdditionalContextOutputTensors;
// Tensors containing the additional generation output.
TensorMap mAdditionalGenerationOutputTensors;
bool mIsDummyRequest{false};
bool mUseDraftModel{false};
// Cache salt id for each request.
std::optional<CacheSaltIDType> mCacheSaltID{std::nullopt};
private:
void initialize(
VecTokens const& inputTokens, bool outputLogProbs, std::optional<TimePoint> arrivalTime = std::nullopt)
{
if (mLlmRequestType == LlmRequestType::LLMREQUEST_TYPE_GENERATION_ONLY)
{
mState = LlmRequestState::kDISAGG_GENERATION_INIT;
}
// Scatter the input tokens to other beam
mTokens = BeamTokens(mSamplingConfig.beamWidth, inputTokens);
mLastTokens = VecTokens(mSamplingConfig.beamWidth, inputTokens.back());
// Init mUniqueTokens
VecUniqueTokens uniqueTokens{inputTokens.size()};
if (mInputTokenExtraIds.has_value() && mInputTokenExtraIds.value())
{
if (mInputTokenExtraIds.value()->size() != inputTokens.size())
{
TLLM_THROW("inputTokenExtraIds vector size (%lu) must be the same as input token vector size (%lu).",
mInputTokenExtraIds.value()->size(), inputTokens.size());
}
std::transform(inputTokens.cbegin(), inputTokens.cend(), mInputTokenExtraIds.value()->cbegin(),
uniqueTokens.begin(),
[](auto const inputToken, auto const tokenExtraId) {
return UniqueToken{inputToken, tokenExtraId};
});
}
else
{
// Default extra id is 0
std::transform(inputTokens.cbegin(), inputTokens.cend(), uniqueTokens.begin(),
[](auto const inputToken) {
return UniqueToken{inputToken, 0};
});
}
mUniqueTokens = BeamUniqueTokens(mSamplingConfig.beamWidth, uniqueTokens);
// Init mEncoderUniqueTokens
// TODO: use real extra id instead of default zero value
if (mEncoderTokens.has_value() && mEncoderTokens.value())
{
auto const& encoderTokens = *(mEncoderTokens.value());
auto encoderUniqueTokens = std::make_shared<VecUniqueTokens>(encoderTokens.size());
std::transform(encoderTokens.cbegin(), encoderTokens.cend(), encoderUniqueTokens->begin(),
[](auto const encoderToken) {
return UniqueToken{encoderToken, 0};
});
mEncoderUniqueTokens = encoderUniqueTokens;
}
if ((mPromptEmbeddingTable.has_value() && !mPromptVocabSize.has_value())
|| (!mPromptEmbeddingTable.has_value() && mPromptVocabSize.has_value()))
{
std::string errStr
= "Prompt embedding table and prompt vocab size tensors must both be provided for requests with "
"prompt "
"tuning enabled.";
TLLM_THROW(errStr);
}
if (mDraftLogits.has_value() && mDraftTokens->empty())
{
TLLM_THROW("Draft tokens must be specified when draft logits are given.");
}
setReturnLogProbs(outputLogProbs);
// Handling the backward compatibility of numReturnSequences.
if (mNumReturnSequences > 1)
{
if (!mSamplingConfig.numReturnSequences)
{
TLLM_LOG_WARNING(
"In the Executor class, mNumReturnSequences is deprecated. Please set numReturnSequences in "
"SamplingConfig directly.");
}
else if (mSamplingConfig.numReturnSequences
&& mSamplingConfig.numReturnSequences.value() != mNumReturnSequences)
{
TLLM_THROW(
"In the Executor class, both mSamplingConfig.numReturnSequences (%d) and mNumReturnSequences (%d) "
"are provided but unmatched. Please use numReturnSequences in SamplingConfig directly.",
mSamplingConfig.numReturnSequences.value(), mNumReturnSequences);
}
mSamplingConfig.numReturnSequences = mNumReturnSequences;
}
if (!isChild())
{
// Initialize result states unless it is a child and a child request should share parent's one.
mSequenceFinalVec = std::make_shared<std::vector<bool>>(getNumSubRequests(), false);
}
if (mReturnPerfMetrics)
{
// arrivalTime is assumed to be recorded at the rank 0, so no need to convert it to global clock
mPerfMetrics.timingMetrics.arrivalTime = arrivalTime.value_or(getSteadyClockNow());
}
mStartTime = std::chrono::steady_clock::now();
}
TensorPtr createListTensor(std::list<VecTokens> const& wordsList)
{
std::vector<SizeType32> offsets;
VecTokens words;
SizeType32 offsetCnt = 0;
for (auto const& tokens : wordsList)
{
offsetCnt += tokens.size();
offsets.push_back(offsetCnt);
words.insert(words.end(), tokens.begin(), tokens.end());
}
offsets.resize(words.size(), -1);
auto const numWords = static_cast<SizeType32>(words.size());
auto const shape = runtime::ITensor::makeShape({2, numWords});
auto tensor = runtime::BufferManager::pinnedPool(shape, nvinfer1::DataType::kINT32);
auto* data = runtime::bufferCast<int32_t>(*tensor);
std::memcpy(data, words.data(), numWords * sizeof(int32_t));
std::memcpy(data + numWords, offsets.data(), numWords * sizeof(int32_t));
// Add leading dim of 1
tensor->unsqueeze(0);
return tensor;
}
static TimePoint maybeToGlobalSteadyClock(TimePoint const& time_point)
{
if (sGlobalSteadyClockOffset.has_value())
{
return time_point + *sGlobalSteadyClockOffset;
}
return time_point;
}
};
class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
{
friend class LlmRequestBindings;
public:
using Base = GenericLlmRequest<runtime::ITensor::SharedPtr>;
using TensorPtr = Base::TensorPtr;
using SizeType32 = Base::SizeType32;
using TokenIdType = Base::TokenIdType;
using RequestIdType = Base::RequestIdType;
using VecLogProbs = Base::VecLogProbs;
using BeamTokens = Base::BeamTokens;
using VecTokens = Base::VecTokens;
using LoraTaskIdType = Base::LoraTaskIdType;
using TokenExtraIdType = Base::TokenExtraIdType;
using VecTokenExtraIds = Base::VecTokenExtraIds;
// inherit constructors
using Base::Base;
LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::vector<TokenIdType> inputTokens,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<std::vector<SizeType32>> positionIds = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType32> promptVocabSize = std::nullopt,
std::optional<std::vector<std::vector<SizeType32>>> multimodalHashes = std::nullopt,
std::optional<std::vector<SizeType32>> multimodalPositions = std::nullopt,
std::optional<std::vector<SizeType32>> multimodalLengths = std::nullopt,
std::optional<TensorPtr> multimodalEmbedding = std::nullopt,
std::optional<TensorPtr> mropeRotaryCosSin = std::nullopt,
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
std::optional<TensorPtr> loraConfig = std::nullopt,
std::optional<executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
std::optional<executor::KvCacheRetentionConfig> kvCacheRetentionConfig = std::nullopt,
bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false,
std::optional<VecTokens> draftTokens = std::nullopt, std::optional<TensorPtr> draftLogits = std::nullopt,
bool excludeInputFromOutput = false, std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt,
bool applyLogitsPostProcessorBatched = false, std::optional<VecTokens> encoderInputTokens = std::nullopt,
bool returnEncoderOutput = false, std::optional<RequestIdType> clientId = std::nullopt,
executor::PriorityType priority = executor::Request::kDefaultPriority,
std::optional<TensorPtr> encoderInputFeatures = std::nullopt,
std::optional<SizeType32> encoderOutputLength = std::nullopt,
std::optional<TensorPtr> crossAttentionMask = std::nullopt,
LlmRequestType llmRequestType = LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION,
std::optional<VecTokenExtraIds> inputTokenExtraIds = std::nullopt, SizeType32 numReturnSequences = 1,
std::optional<executor::EagleConfig> eagleConfig = std::nullopt,
std::optional<TensorPtr> skipCrossAttnBlocks = std::nullopt, bool returnPerfMetrics = false,
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt, std::optional<TimePoint> arrivalTime = std::nullopt)
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList),
std::move(stopWordsList),
positionIds.has_value() ? std::make_shared<std::vector<SizeType32>>(std::move(positionIds.value()))
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt),
std::move(promptEmbeddingTable), promptVocabSize,
multimodalHashes.has_value()
? std::make_shared<std::vector<std::vector<SizeType32>>>(std::move(multimodalHashes.value()))
: std::optional<std::shared_ptr<std::vector<std::vector<SizeType32>>>>(std::nullopt),
multimodalPositions.has_value()
? std::make_shared<std::vector<SizeType32>>(std::move(multimodalPositions.value()))
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt),
multimodalLengths.has_value()
? std::make_shared<std::vector<SizeType32>>(std::move(multimodalLengths.value()))
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt),
std::move(multimodalEmbedding), std::move(mropeRotaryCosSin), mropePositionDeltas, loraTaskId,
std::move(loraWeights), std::move(loraConfig), lookaheadConfig, std::move(kvCacheRetentionConfig),
returnLogProbs, returnContextLogits, returnGenerationLogits,
draftTokens.has_value() ? std::make_shared<VecTokens>(std::move(draftTokens.value()))
: std::make_shared<VecTokens>(),
std::move(draftLogits), excludeInputFromOutput, std::move(logitsPostProcessor),
applyLogitsPostProcessorBatched,
encoderInputTokens ? std::make_optional(std::make_shared<VecTokens>(std::move(*encoderInputTokens)))
: std::optional<std::shared_ptr<VecTokens>>(std::nullopt),
returnEncoderOutput, clientId, priority, std::move(encoderInputFeatures), encoderOutputLength,
std::move(crossAttentionMask), llmRequestType,
inputTokenExtraIds ? std::make_optional(std::make_shared<VecTokenExtraIds>(std::move(*inputTokenExtraIds)))
: std::optional<std::shared_ptr<VecTokenExtraIds>>(std::nullopt),
numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics,
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, cacheSaltID,
arrivalTime)
{
}
LlmRequest(RequestIdType requestId, executor::Request const& request,
std::optional<Base::LogitsPostProcessor> logitsPostProcessor = std::nullopt,
bool applyLogitsPostProcessorBatched = false)
: Base(requestId, request)
{
mLogitsPostProcessor = std::move(logitsPostProcessor);
mApplyLogitsPostProcessorBatched = applyLogitsPostProcessorBatched;
mLookaheadConfig = request.getLookaheadConfig();
mKvCacheRetentionConfig = request.getKvCacheRetentionConfig();
}
LlmRequest(LlmRequest&& request) = default;
LlmRequest(LlmRequest const& request) = default;
/// @brief Create a Response from the current state of the request
/// @details Note that there is some dependency on the order of operations in this method. Modify with care!
/// @return An optional Response
std::optional<executor::Response> createResponse(bool useFastLogits = false, int32_t mpiWorldRank = 0);
std::optional<executor::Result> createResult(bool useFastLogits = false, int32_t mpiWorldRank = 0);
void createSerializedResult(
std::vector<char>& serializedResult, bool& isFinal, bool useFastLogits = false, int32_t mpiWorldRank = 0);
/// @brief Check if the (user-provided) tokens fall within the vocabulary range.
/// @details Currently only supports invocation before context phase is completed.
/// @return True if tokens are within range.
bool checkTokenIdRange(SizeType32 vocabSize);
void validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen, SizeType32 vocabSizePadded,
std::optional<SizeType32> maxEncoderInputLen = std::nullopt, bool enableKVCacheReuse = false);
std::shared_ptr<LlmRequest> createChildRequest(RequestIdType requestId);
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager);
void moveLoraWeightsToGpu(runtime::BufferManager const& manager);
// Remove LoRA weights and LoRA config tensors
void removeLoraTensors();
};
} // namespace tensorrt_llm::batch_manager