mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* refactor: Restructure DecoderBuffers and DecoderStepAsyncSend - Move communication logic from `DecoderBuffers` to `DecoderStepAsyncSend`. - Updated `DecoderStepAsyncSend` constructor to utilize the `DecoderBuffers`, enhancing clarity and reducing parameter complexity. - Refactored related methods to align with the new class structure, improving maintainability and readability of the code. These changes streamline the handling of decoding buffers and improve the overall architecture of the batch manager. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Restructure SlotDecoderBuffers and DecoderSlotAsyncSend - Move communication logic from `SlotDecoderBuffers` to `DecoderSlotAsyncSend`. - Updated `DecoderSlotAsyncSend` constructor to utilize the `SlotDecoderBuffers`, enhancing clarity and reducing parameter complexity. - Refactored related methods to align with the new class structure, improving maintainability and readability of the code. These changes enhance the structure and readability of the batch manager's decoding process. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Log DecodingMode Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Introduce DecoderOutputBuffers and update related classes - Moved buffers from `DecoderBuffers` to `DecoderOutputBuffers` to better reflect its purpose. - Updated the `DecoderStepAsyncSend` class to utilize `DecoderOutputBuffers`, enhancing clarity in the communication logic. - Refactored the constructor and methods in `DecoderBuffers` to accommodate the new structure, improving maintainability. - Added Python bindings for `DecoderOutputBuffers` to ensure compatibility with existing interfaces. These changes streamline the handling of output buffers in the decoding process, improving the overall architecture of the batch manager. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Update MPI communicator handling - Changed the `commSession` parameter type from `std::shared_ptr<mpi::MpiComm>` to `mpi::MpiComm` in `DecoderStepAsyncSend` and `DecoderSlotAsyncSend` classes for improved clarity and reduced complexity. - Updated related methods and constructors to reflect the new parameter type, enhancing maintainability. - Refactored the `TrtGptModelInflightBatching` class to accommodate these changes, ensuring consistent usage of `MpiComm`. These modifications streamline the communication logic in the decoding process, improving the overall architecture of the batch manager. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Replace shared_ptr with unique_ptr for buffer management - Updated the `TrtGptModelInflightBatching` class to use `std::unique_ptr` instead of `std::shared_ptr` for various buffer types, including `AllReduceBuffers`, `RuntimeBuffers`, `DecoderBuffers`, and `SlotDecoderBuffers`. - This change enhances memory management and ownership semantics, reducing overhead and improving performance. These modifications contribute to a cleaner and more efficient architecture in the batch manager. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --------- Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
580 lines
25 KiB
C++
580 lines
25 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "tensorrt_llm/batch_manager/common.h"
|
|
#include "tensorrt_llm/batch_manager/sequenceSlotManager.h"
|
|
#include "tensorrt_llm/executor/executor.h"
|
|
#include "tensorrt_llm/executor/types.h"
|
|
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
|
|
#include "tensorrt_llm/runtime/modelConfig.h"
|
|
#include "tensorrt_llm/runtime/rawEngine.h"
|
|
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
|
#include "tensorrt_llm/runtime/worldConfig.h"
|
|
#include "trtGptModel.h"
|
|
|
|
#include <NvInferRuntime.h>
|
|
|
|
namespace tensorrt_llm::runtime
|
|
{
|
|
class TllmRuntime;
|
|
class GptDecoderBatched;
|
|
class AllReduceBuffers;
|
|
class NcclCommunicator;
|
|
class SpeculativeDecodingMode;
|
|
} // namespace tensorrt_llm::runtime
|
|
|
|
namespace tensorrt_llm::mpi
|
|
{
|
|
class MpiWaitThread;
|
|
} // namespace tensorrt_llm::mpi
|
|
|
|
namespace tensorrt_llm::batch_manager
|
|
{
|
|
|
|
namespace kv_cache_manager
|
|
{
|
|
class KVCacheManager;
|
|
struct OffsetTableDimensions;
|
|
} // namespace kv_cache_manager
|
|
|
|
namespace rnn_state_manager
|
|
{
|
|
class RnnStateManager;
|
|
} // namespace rnn_state_manager
|
|
class SequenceSlotManager;
|
|
class DecoderStepAsyncSend;
|
|
class DecoderSlotAsyncSend;
|
|
class DecoderInputBuffers;
|
|
class DecoderOutputBuffers;
|
|
class DecoderBuffers;
|
|
class SlotDecoderBuffers;
|
|
class LlmRequest;
|
|
class RuntimeBuffers;
|
|
class BasePeftCacheManager;
|
|
class GuidedDecoder;
|
|
class BaseCacheTransceiver;
|
|
|
|
// Algorithms
|
|
class CapacityScheduler;
|
|
class MicroBatchScheduler;
|
|
class PauseRequests;
|
|
class AssignReqSeqSlots;
|
|
class AllocateKvCache;
|
|
class HandleContextLogits;
|
|
class HandleGenerationLogits;
|
|
class GenerateRequestOptions;
|
|
class LogitsPostProcessor;
|
|
class MakeDecodingBatchInputOutput;
|
|
class CreateNewDecoderRequests;
|
|
|
|
namespace utils
|
|
{
|
|
class CudaGraphExecutorCache;
|
|
} // namespace utils
|
|
|
|
struct RewindInputs
|
|
{
|
|
SizeType32 maxBlocksPerSeq;
|
|
bool isUseOneMoreBlock;
|
|
SizeType32 numKvHeads;
|
|
};
|
|
|
|
class TrtGptModelInflightBatching : public TrtGptModel
|
|
{
|
|
using BaseKVCacheManager = kv_cache_manager::BaseKVCacheManager;
|
|
using OffsetTableDimensions = kv_cache_manager::OffsetTableDimensions;
|
|
using KVCacheManager = kv_cache_manager::KVCacheManager;
|
|
using KvCacheType = kv_cache_manager::CacheType;
|
|
using KvCacheConfig = kv_cache_manager::KvCacheConfig;
|
|
using RnnStateManager = rnn_state_manager::RnnStateManager;
|
|
using LlmRequestPtr = std::shared_ptr<batch_manager::LlmRequest>;
|
|
|
|
public:
|
|
class IterationStatsIFB
|
|
{
|
|
public:
|
|
explicit IterationStatsIFB(SizeType32 microBatchId)
|
|
: microBatchId{microBatchId}
|
|
{
|
|
}
|
|
|
|
SizeType32 microBatchId;
|
|
SizeType32 numCtxRequests{};
|
|
SizeType32 numGenRequests{};
|
|
SizeType32 numCtxTokens{};
|
|
float avgNumDecodedTokensPerIter{};
|
|
ReqIdsSet scheduledRequests;
|
|
ReqIdsSet pausedRequests;
|
|
};
|
|
|
|
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
|
using TokenIdType = tensorrt_llm::runtime::TokenIdType;
|
|
using BufferManager = tensorrt_llm::runtime::BufferManager;
|
|
using PeftTable = PeftCacheManager::PeftTable;
|
|
using TensorMap = runtime::StringPtrMap<runtime::ITensor>;
|
|
using TensorPtr = runtime::ITensor::SharedPtr;
|
|
|
|
TrtGptModelInflightBatching(std::shared_ptr<nvinfer1::ILogger> logger, runtime::ModelConfig const& modelConfig,
|
|
runtime::WorldConfig const& worldConfig, runtime::RawEngine const& rawEngine, bool ctxGenFusion,
|
|
TrtGptModelOptionalParams const& optionalParams = TrtGptModelOptionalParams());
|
|
|
|
~TrtGptModelInflightBatching() override;
|
|
|
|
void terminateRequest(LlmRequestPtr const& llmRequest, bool pause = false) override;
|
|
|
|
/// @brief Function that waits for the decoding of requests in flight.
|
|
/// When the requests have finished or using speculative decoding, the state of requests
|
|
/// will become LlmRequestState::kGENERATION_COMPLETE. Else, it will be set to
|
|
/// LlmRequestState::kGENERATION_IN_PROGRESS.
|
|
void forwardSync() override;
|
|
|
|
/// @brief Function that tries to advance the active requests.
|
|
/// Depending on resources available, it's possible that not all requests will get advanced.
|
|
/// Requests that may be in state LlmRequestState::kCONTEXT_INIT become
|
|
/// LlmRequestState::kGENERATION_IN_PROGRESS or LlmRequestState::kGENERATION_TO_COMPLETE.
|
|
/// @param activeRequests The list of request to try to advance.
|
|
void forwardAsync(RequestList const& activeRequests) override;
|
|
|
|
/// @brief Override the runtime batch size for the model
|
|
void setRuntimeBatchSize(SizeType32 runtimeMaxBatchSize) override;
|
|
|
|
/// @brief Get the runtime batch size for the model
|
|
[[nodiscard]] SizeType32 getRuntimeBatchSize() const override;
|
|
|
|
/// @brief Override the runtime max num tokens for the model
|
|
void setRuntimeMaxNumTokens(SizeType32 runtimeMaxNumTokens) override;
|
|
|
|
void updatePeftCache(std::shared_ptr<LlmRequest> const& llmRequest) override;
|
|
|
|
[[nodiscard]] IterationStatsIFB getLastIterationStats() const
|
|
{
|
|
return mLastIterationStatsIFB;
|
|
}
|
|
|
|
[[nodiscard]] TrtGptModelType getModelType() const override
|
|
{
|
|
return mCtxGenFusion ? TrtGptModelType::InflightFusedBatching : TrtGptModelType::InflightBatching;
|
|
};
|
|
|
|
[[nodiscard]] runtime::BufferManager const& getBufferManager() const override;
|
|
[[nodiscard]] runtime::BufferManager::CudaStreamPtr getRuntimeStreamPtr() const override;
|
|
|
|
void getCurrentIterationStats(executor::IterationStats& stats) const override;
|
|
void getCurrentRequestStats(executor::RequestStatsPerIteration& stats) const override;
|
|
[[nodiscard]] executor::DebugTensorsPerIteration getCurrentDebugTensors() const override;
|
|
|
|
[[nodiscard]] executor::IterationType getIterCounter() const noexcept override
|
|
{
|
|
return mIterCounter;
|
|
}
|
|
|
|
[[nodiscard]] static bool optionalParamsAreValid(
|
|
runtime::ModelConfig const& modelConfig, TrtGptModelOptionalParams const& optionalParams);
|
|
[[nodiscard]] static TrtGptModelOptionalParams fixOptionalParams(
|
|
runtime::ModelConfig const& modelConfig, TrtGptModelOptionalParams const& optionalParams);
|
|
|
|
void prepareDisaggGenInitRequests(RequestList const& activeRequests, RequestVector& newGenReques);
|
|
void checkDisaggGenTransferStatus(RequestList const& activeRequests);
|
|
void prepareDistGenBufferAndDecoder(RequestVector const& generationRequests);
|
|
|
|
void resetIterationStats() override;
|
|
|
|
runtime::SpeculativeDecodingMode getSpeculativeDecodingMode() const noexcept
|
|
{
|
|
return mModelConfig.getSpeculativeDecodingMode();
|
|
}
|
|
|
|
private:
|
|
[[nodiscard]] SizeType32 getContextBufferId() const
|
|
{
|
|
return mMicroBatchId;
|
|
}
|
|
|
|
[[nodiscard]] SizeType32 getGenerationBufferId() const
|
|
{
|
|
return mNumMicroBatches + mMicroBatchId;
|
|
}
|
|
|
|
[[nodiscard]] SizeType32 getFusedBufferId() const
|
|
{
|
|
return mMicroBatchId;
|
|
}
|
|
|
|
[[nodiscard]] SizeType32 getNextMicroBatchId(SizeType32 bufferId) const
|
|
{
|
|
return (bufferId + 1) % mNumMicroBatches;
|
|
}
|
|
|
|
[[nodiscard]] SizeType32 getPrevMicroBatchId(SizeType32 bufferId) const
|
|
{
|
|
return (bufferId + mNumMicroBatches - 1) % mNumMicroBatches;
|
|
}
|
|
|
|
//! @brief Store full kv cache blocks contributed by req.
|
|
//! These blocks become reusable from next step.
|
|
void storeContextBlocks(std::shared_ptr<LlmRequest> const& req);
|
|
|
|
//! @brief Set LayerProfiler to collect performance per layer.
|
|
void setLayerProfiler() override;
|
|
|
|
//! @brief Print profile information per layer.
|
|
std::string getLayerProfileInfo() const override;
|
|
|
|
std::tuple<SizeType32, TensorMap const&, TensorMap&> prepareBuffers(
|
|
RequestVector const& contextRequests, RequestVector const& generationRequests, SizeType32 bufferId);
|
|
|
|
//! @brief Capture graph of current batch state during engine execution.
|
|
//! This is based on the assumptions that
|
|
//! a) We can hide CPU graph capture behind the GPU engine execution.
|
|
//! b) Batch size in the next iterations won't change and we can reuse the graph multiple times.
|
|
void prepareGraph(SizeType32 bufferId, SizeType32 optProfileId);
|
|
|
|
void executeContext(SizeType32 runtimeContextId, SizeType32 bufferId);
|
|
void executeBatch(ScheduledRequests const& scheduledRequests);
|
|
void executeStep(
|
|
RequestVector const& contextRequests, RequestVector const& generationRequests, SizeType32 bufferId);
|
|
|
|
void debugIOTensors(RequestVector const& contextRequests, RequestVector const& generationRequests,
|
|
TensorMap const& inputMap, TensorMap const& outputMap);
|
|
|
|
void createRuntimeContexts();
|
|
void createDecoder(std::optional<executor::DecodingMode> const& decodingModeOpt);
|
|
void createBuffers(executor::DecodingConfig const& decodingConfig,
|
|
std::optional<std::vector<executor::AdditionalModelOutput>> const& additionalModelOutputs);
|
|
std::shared_ptr<KVCacheManager> createKvCacheManager(KvCacheConfig const& kvCacheConfig,
|
|
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, KvCacheType kvCacheType = KvCacheType::kSELF);
|
|
void createRnnStateManager();
|
|
void createCustomAllReduceWorkspace();
|
|
void createRuntimePerfKnobsTensor(executor::ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig);
|
|
|
|
/// @brief Verify draft token length and beam width of all active requests.
|
|
/// May change operating beam width if all requests agree on same beam width.
|
|
void verifyRequests(RequestList const& activeRequests);
|
|
|
|
/// @brief Change the operating beam width.
|
|
/// Only possible if no requests are currently in-flight.
|
|
/// @param beamWidth New operating beam width. Must be smaller than initial maxBeamWidth.
|
|
void changeBeamWidth(SizeType32 beamWidth);
|
|
|
|
SizeType32 getOperatingBeamWidth() const override
|
|
{
|
|
return mOperatingBeamWidth;
|
|
}
|
|
|
|
/// @details Should be called after setting up the current batch in executeBatch to get the correct number of
|
|
/// context tokens.
|
|
IterationStatsIFB fillIterationStats(
|
|
ScheduledRequests const& scheduledRequests, RequestVector const& requestsToPause);
|
|
|
|
/// @brief Function that sets up the TensorRT execution context that is going to be used for execution. If multiple
|
|
/// TensorRT optimization profiles are built in the engine, it selects the corresponding context that is going to be
|
|
/// used, and prepares the input and output tensors so that both buffers and the context is ready for the execution.
|
|
/// @return The TensorRT execution context index that has been setup.
|
|
void setupContext(
|
|
RequestVector const& contextRequests, RequestVector const& generationRequests, SizeType32 bufferId);
|
|
|
|
void setupDecoderStep(
|
|
RequestVector const& contextRequests, RuntimeBuffers const& buffers, DecoderInputBuffers const& inputBuffers);
|
|
runtime::CudaEvent decoderStepAsync(ScheduledRequests const& scheduledRequests);
|
|
std::vector<std::unique_ptr<DecoderStepAsyncSend>> decoderSync(
|
|
ScheduledRequests const& scheduledRequests, std::optional<runtime::CudaEvent> const& decoderFinishEvent);
|
|
|
|
runtime::CudaEvent updateDecoderBuffers(bool returnLogProbs, runtime::CudaEvent decoderFinishEvent);
|
|
std::vector<std::unique_ptr<DecoderStepAsyncSend>> communicateDecoderBuffers(bool returnLogProbs);
|
|
void updateRequests(ScheduledRequests const& scheduledRequests);
|
|
|
|
/// @brief It gathers the logits if they need to be returned, calls getDecoderSlotHostOutputs,
|
|
/// and overwrites the llmRequest tokens buffer.
|
|
/// Called either on request finishing, or at every step when doing beam search and streaming.
|
|
void postProcessRequest(LlmRequest& llmReq, std::vector<SizeType32> const& numDroppedTokens);
|
|
/// @brief Calls gatherTree (via finalize) and transmits the received data across ranks if PP>1
|
|
void getDecoderSlotHostOutputs(
|
|
SizeType32 seqSlot, bool returnLogProbs, runtime::SamplingConfig const& samplingConfig, bool streaming);
|
|
void rewindKVCacheBlocks(SizeType32 numSequences);
|
|
void setupSpeculativeDecodingModule(executor::DecodingConfig const& decodingConfig);
|
|
|
|
/// @brief Copies the content of the cache indirection outputs to the cache indirection inputs.
|
|
/// @param[in] scheduledRequests The requests to copy the cache indirections for.
|
|
/// @param[in] genBufferId The id of the generation buffers for those requests.
|
|
void copyCacheIndirectionFromOutputsToInputs(ScheduledRequests const& scheduledRequests, SizeType32 genBufferId);
|
|
|
|
[[nodiscard]] bool getGatherGenerationLogits() const override
|
|
{
|
|
return getModelConfig().computeGenerationLogits() || mGatherGenerationLogits;
|
|
}
|
|
|
|
[[nodiscard]] runtime::ModelConfig const& getModelConfig() const override
|
|
{
|
|
return mModelConfig;
|
|
}
|
|
|
|
[[nodiscard]] runtime::WorldConfig const& getWorldConfig() const override
|
|
{
|
|
return mWorldConfig;
|
|
}
|
|
|
|
[[nodiscard]] SizeType32 getNumMicroBatches() const override
|
|
{
|
|
return mNumMicroBatches;
|
|
}
|
|
|
|
[[nodiscard]] nvinfer1::DataType getLogitDataType() const override;
|
|
|
|
[[nodiscard]] nvinfer1::DataType getTensorDataType(std::string const& name) const override;
|
|
|
|
[[nodiscard]] nvinfer1::Dims getTensorShape(std::string const& name) const override;
|
|
|
|
void reshapeKvTensors(OffsetTableDimensions const& dims);
|
|
|
|
[[nodiscard]] bool hasSpeculativeDecodingFastLogits() const noexcept override
|
|
{
|
|
return mSpeculativeDecodingFastLogits;
|
|
}
|
|
|
|
[[nodiscard]] bool hasGuidedDecoder() const noexcept override
|
|
{
|
|
return static_cast<bool>(mGuidedDecoder);
|
|
}
|
|
|
|
/// @brief Based on the KV-cache manager's capacity and configuration, we adjust the maximum supported attention
|
|
/// window.
|
|
///
|
|
/// @param numPrimaryBlocks The number of blocks in the kv-cache's primary pool.
|
|
/// @param numTokensPerBlock The number of tokens per kv-cache block.
|
|
void adjustMaxAttentionWindow(SizeType32 numPrimaryBlocks, SizeType32 numTokensPerBlock);
|
|
|
|
/// @brief Change the speculative decoding mode.
|
|
void changeSpecDecMode(ScheduledRequests const& scheduledRequests);
|
|
|
|
void prefetchNextPromptTableChunk(RequestVector const& contextRequests, bool isFirstChunk, SizeType32 bufferId);
|
|
|
|
void remapInputTokensForPromptTable(
|
|
std::shared_ptr<LlmRequest> const& llmReq, bool isCurrentChunk, SizeType32 bufferId, SizeType32 contextId);
|
|
|
|
void copyPromptTableToGpuInChunk(std::shared_ptr<LlmRequest> const& llmReq,
|
|
std::vector<int32_t> const& outOfVocabTokens, bool useCurrentBuffer, SizeType32 bufferId, SizeType32 contextId);
|
|
|
|
protected:
|
|
std::shared_ptr<BaseKVCacheManager> getKVCacheManager() override
|
|
{
|
|
return mKvCacheManager;
|
|
}
|
|
|
|
[[nodiscard]] std::shared_ptr<BaseKVCacheManager const> getKVCacheManager() const override
|
|
{
|
|
return mKvCacheManager;
|
|
}
|
|
|
|
std::shared_ptr<BaseKVCacheManager> getCrossKVCacheManager()
|
|
{
|
|
return mCrossKvCacheManager;
|
|
}
|
|
|
|
[[nodiscard]] std::shared_ptr<BaseKVCacheManager const> getCrossKVCacheManager() const
|
|
{
|
|
return mCrossKvCacheManager;
|
|
}
|
|
|
|
[[nodiscard]] std::shared_ptr<BasePeftCacheManager> getPeftCacheManager() override
|
|
{
|
|
return mPeftCacheManager;
|
|
}
|
|
|
|
[[nodiscard]] std::shared_ptr<BasePeftCacheManager const> getPeftCacheManager() const override
|
|
{
|
|
return mPeftCacheManager;
|
|
}
|
|
|
|
void setLogitsPostProcessorBatched(std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched) override
|
|
{
|
|
mLogitsPostProcessorBatched = logitsPostProcessorBatched;
|
|
}
|
|
|
|
void setReplicateLogitsPostProcessor(bool replicateLogitsPostProcessor) override
|
|
{
|
|
mReplicateLogitsPostProcessor = replicateLogitsPostProcessor;
|
|
}
|
|
|
|
[[nodiscard]] bool getReplicateLogitsPostProcessor() const override
|
|
{
|
|
return mReplicateLogitsPostProcessor;
|
|
}
|
|
|
|
SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const override;
|
|
|
|
private:
|
|
/******************** Configs ********************/
|
|
// Parameters of the model (TRT engine)
|
|
runtime::ModelConfig mModelConfig;
|
|
// Parameters of the execution environment
|
|
runtime::WorldConfig mWorldConfig;
|
|
// Device ID of this instance
|
|
int mDevice{-1};
|
|
// Config for (speculative) decoding
|
|
executor::DecodingConfig mDecodingConfig;
|
|
// Performance knobs for the engine.
|
|
executor::ExtendedRuntimePerfKnobConfig mExtendedRuntimePerfKnobConfig;
|
|
TensorPtr mExtendedRuntimePerfKnobsHost;
|
|
// Config for debugging output
|
|
std::optional<executor::DebugConfig> mDebugConfig;
|
|
// List of additional outputs for each request
|
|
std::optional<std::vector<executor::AdditionalModelOutput>> mAdditionalModelOutputs;
|
|
|
|
/******************** Components ********************/
|
|
std::shared_ptr<nvinfer1::ILogger> mLogger;
|
|
// Runner for the TRT engine. The engine produces logits.
|
|
std::shared_ptr<runtime::TllmRuntime> mRuntime;
|
|
// Decoder that generates new tokens from the logits.
|
|
std::shared_ptr<runtime::GptDecoderBatched> mDecoder;
|
|
// Synchronization handles for decoder
|
|
std::vector<std::optional<runtime::CudaEvent>> mDecoderFinishedEvents;
|
|
|
|
// Manager that maps requests to slots
|
|
std::shared_ptr<SequenceSlotManager> mSeqSlotManager;
|
|
// KV cache manager for attention layers (optional)
|
|
std::shared_ptr<BaseKVCacheManager> mKvCacheManager;
|
|
// KV cache manager for cross attention in enc-dec models (optional)
|
|
std::shared_ptr<BaseKVCacheManager> mCrossKvCacheManager = nullptr;
|
|
// RNN state manager for recurrent layers (optional)
|
|
std::shared_ptr<RnnStateManager> mRnnStateManager;
|
|
// PEFT cache manager for LoRA tasks (optional)
|
|
std::shared_ptr<BasePeftCacheManager> mPeftCacheManager;
|
|
// BufferManager using a separate stream for async copy operations.
|
|
runtime::BufferManager mCopyBufferManager;
|
|
// Event for async data transfers
|
|
runtime::CudaEvent mPtableCopyDoneEvent;
|
|
|
|
/******************** Logits Post-Processor ********************/
|
|
std::optional<LogitsPostProcessorBatched> mLogitsPostProcessorBatched;
|
|
bool mReplicateLogitsPostProcessor{true};
|
|
// Set if any request invoked a logits processor in current step
|
|
bool mLogitsPostProcessorIsApplied{false};
|
|
|
|
constexpr bool broadcastPostDecoder()
|
|
{
|
|
return mWorldConfig.isTensorParallel() && !mReplicateLogitsPostProcessor && mLogitsPostProcessorIsApplied;
|
|
}
|
|
|
|
std::unique_ptr<tensorrt_llm::batch_manager::GuidedDecoder> mGuidedDecoder;
|
|
|
|
/******************** Pipeline parallelism ********************/
|
|
std::unique_ptr<tensorrt_llm::mpi::MpiComm> mMpiCommPipelinePara;
|
|
std::vector<std::unique_ptr<DecoderStepAsyncSend>> mDecStepAsyncSndHdls;
|
|
std::vector<std::unique_ptr<DecoderSlotAsyncSend>> mDecSlotAsyncSndHdls;
|
|
std::unique_ptr<tensorrt_llm::mpi::MpiWaitThread> mAsyncSendWaitThread;
|
|
|
|
/******************** Tensor parallelism ********************/
|
|
std::unique_ptr<tensorrt_llm::mpi::MpiComm> mMpiCommTensorPara;
|
|
std::unique_ptr<runtime::AllReduceBuffers> mAllReduceBuffers;
|
|
|
|
/******************** Runtime parameters ********************/
|
|
// Flag to select fused or unfused context+generation execution
|
|
bool mCtxGenFusion;
|
|
// ID of current micro batch, changes after each iteration
|
|
SizeType32 mMicroBatchId{0};
|
|
// Number of micro batches. Multiple batches are used for overlapping setup and execution,
|
|
// and in pipeline parallelism.
|
|
SizeType32 mNumMicroBatches;
|
|
// Number of buffers to be added to mBuffers.
|
|
SizeType32 mNumBuffers;
|
|
// Current operating beam width. Can be changed with changeBeamWidth function.
|
|
SizeType32 mOperatingBeamWidth;
|
|
// Runtime batch size optimized during execution for microBatchScheduler:
|
|
/// The max batch size recommended by the dynamic tuner
|
|
SizeType32 mMaxBatchSizeTunerRecommended;
|
|
/// The min of mMaxBatchSize and mMaxBatchSizeTunerRecommended
|
|
SizeType32 mMaxBatchSizeRuntime;
|
|
// Runtime max num tokens optimized during execution for microBatchScheduler:
|
|
/// Build time max num tokens
|
|
std::optional<SizeType32> mMaxNumTokensStatic;
|
|
/// The max num tokens recommended by the dynamic tuner
|
|
SizeType32 mMaxNumTokensTunerRecommended;
|
|
/// The min of mMaxNumTokens and mMaxNumTokensTunerRecommended
|
|
std::optional<SizeType32> mMaxNumTokensRuntime;
|
|
// Controls if generation logits should be gathered, so that returnGenerationLogits can be requested.
|
|
bool mGatherGenerationLogits{false};
|
|
// offloading and prefetching the prompt tuning table (only effective in chunked prefill mode)
|
|
bool mPromptTableOffloading;
|
|
|
|
/******************** Buffers ********************/
|
|
// Buffers for each micro batch. Unfused path (mCtxGenFusion==false) uses two times the buffers.
|
|
std::vector<std::unique_ptr<RuntimeBuffers>> mBuffers;
|
|
// Decoder input buffers for each micro batch.
|
|
std::vector<DecoderInputBuffers> mDecoderInputBuffers;
|
|
// Decoder output buffers for each micro batch.
|
|
std::vector<DecoderOutputBuffers> mDecoderOutputBuffers;
|
|
// Global buffer to interface with decoder. Slots in this buffer are selected by mSeqSlotManager.
|
|
std::unique_ptr<DecoderBuffers> mDecoderBuffers;
|
|
// Buffers for each slot in the decoder
|
|
std::vector<std::unique_ptr<SlotDecoderBuffers>> mSlotDecoderBuffers;
|
|
// PEFT table for each micro batch
|
|
std::vector<PeftTable> mPeftTables;
|
|
// Decoder input for each micro batch.
|
|
std::vector<std::unique_ptr<runtime::decoder_batch::Input>> mDecodingInputs;
|
|
std::unique_ptr<runtime::decoder_batch::Output> mDecodingOutput;
|
|
|
|
/******************** Book keeping ********************/
|
|
// List of requests in each micro batch
|
|
std::vector<ScheduledRequests> mMicroBatchScheduledRequests;
|
|
// Set of in-flight requests of *all* micro batches
|
|
ReqIdsSet mInflightReqIds;
|
|
// Requests that the scheduler selected to be paused
|
|
ReqIdsSet mReqIdsToPause;
|
|
// Stats collected in last iteration
|
|
IterationStatsIFB mLastIterationStatsIFB{-1};
|
|
// Iteration counter used to distinguish debug output
|
|
executor::IterationType mIterCounter{0};
|
|
// Debug tensors of last itreation
|
|
TensorMap mLastIterationDebugTensors;
|
|
// Cuda graph instances for each microbatch.
|
|
std::vector<utils::CudaGraphExecutorCache> mCudaGraphExecutorCaches;
|
|
|
|
/******************** Cache transceiver ********************/
|
|
std::unique_ptr<BaseCacheTransceiver> mCacheTransceiver;
|
|
|
|
/******************** Spec dec ***********************/
|
|
std::unique_ptr<std::thread> mDraftModelSendLogitsThread;
|
|
bool mSpeculativeDecodingFastLogits;
|
|
std::atomic<bool> mDraftModelThreadShouldExit{false};
|
|
bool mIsLeaderInOrchMode{false};
|
|
// List of completed draft requests which logits will need to be sent to the target model.
|
|
RequestVector mDraftRequestsWaitingToSendLogits;
|
|
SizeType32 mSeamlessLADMaxDraftLen{0};
|
|
bool mUseSeamlessLookahead{false};
|
|
RewindInputs mRewindInputs;
|
|
|
|
/******************** Algorithms ********************/
|
|
// Algorithms are reentrant, they are assigned a state at
|
|
// construction time and it is not modified through execution, hence they are const.
|
|
// Schedulers that select which requests to run in each iteration
|
|
std::unique_ptr<tensorrt_llm::batch_manager::CapacityScheduler const> mCapacityScheduler;
|
|
std::unique_ptr<tensorrt_llm::batch_manager::MicroBatchScheduler const> mMicroBatchScheduler;
|
|
std::unique_ptr<tensorrt_llm::batch_manager::PauseRequests const> mPauseRequests;
|
|
std::unique_ptr<tensorrt_llm::batch_manager::AssignReqSeqSlots const> mAssignReqSeqSlots;
|
|
std::unique_ptr<tensorrt_llm::batch_manager::AllocateKvCache const> mAllocateKvCache;
|
|
std::unique_ptr<tensorrt_llm::batch_manager::HandleContextLogits const> mHandleContextLogits;
|
|
std::unique_ptr<tensorrt_llm::batch_manager::HandleGenerationLogits const> mHandleGenerationLogits;
|
|
std::unique_ptr<tensorrt_llm::batch_manager::GenerateRequestOptions const> mGenerateRequestOptions;
|
|
std::unique_ptr<tensorrt_llm::batch_manager::LogitsPostProcessor const> mLogitsPostProcessor;
|
|
std::unique_ptr<tensorrt_llm::batch_manager::MakeDecodingBatchInputOutput const> mMakeDecodingBatchInputOutput;
|
|
std::unique_ptr<tensorrt_llm::batch_manager::CreateNewDecoderRequests const> mCreateNewDecoderRequests;
|
|
};
|
|
|
|
} // namespace tensorrt_llm::batch_manager
|