TensorRT-LLMs/cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h
Robin Kobus 1bd84c6d8c
feat: Allow individual gatherContext for each additional output (#3374)
* refactor: Update ExecutorConfig to use AdditionalModelOutput type

- Changed function signatures and member variables across multiple files to replace std::optional<std::vector<std::string>> with std::optional<std::vector<executor::AdditionalModelOutput>> to include gatherContext flag for each additional output.
- Updated related serialization and deserialization methods to accommodate the new type.
- Adjusted tests to reflect the changes in the output handling structure.

This refactor enhances the flexibility and maintainability of the output configuration in the executor and batch manager components.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Remove equality operator from TrtGptModelOptionalParams

- Deleted the operator== implementation from TrtGptModelOptionalParams to simplify the class.
- Updated the pybind11 bindings to remove the exposure of the equality operator to Python.

This change streamlines the class definition and reduces unnecessary complexity in the bindings.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Enhance copyAdditionalOutputs to utilize AdditionalModelOutput

- Updated the copyAdditionalOutputs function to accept a vector of AdditionalModelOutput, allowing for the inclusion of the gatherContext flag.
- Adjusted the logic to handle context and non-context outputs separately, improving the output handling mechanism.
- Modified related unit tests to incorporate the new gatherContext parameter, ensuring comprehensive testing of the updated functionality.

This refactor improves the flexibility and clarity of output management in the batch processing workflow.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Introduce findOutputTensor utility function for output tensor retrieval

- Added a new utility function, findOutputTensor, to encapsulate the logic for finding output tensors and checking their validity.
- Refactored copyAdditionalOutputs to utilize findOutputTensor, reducing code duplication and improving clarity.
- Enhanced error checking for additional context and generation output tensors.

This change streamlines the output tensor retrieval process, enhancing maintainability and readability in the batch processing workflow.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Check final indices of additional output tensors and update tests

- Added checks to verify the final indices of additional output tensors for context and generation outputs.
- Updated unit tests to verify the changes.
  - Add lastTokenIds input tensor to test engines.
  - Logits output depends on gatherContextLogits parameter.
- Removed gatherContextOutputs parameter from the validate method in LlmRequest.
  - Context outputs do not depend on computeContextLogits parameter.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fixup! refactor: Check final indices of additional output tensors and update tests

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fixup! refactor: Update ExecutorConfig to use AdditionalModelOutput type

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fixup! refactor: Remove equality operator from TrtGptModelOptionalParams

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* docs: Update executor.md

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* chore: Clean up includes

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-04-12 17:00:36 +08:00

317 lines
12 KiB
C++

/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
#include "tensorrt_llm/runtime/eagleBuffers.h"
#include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/lookaheadBuffers.h"
#include "tensorrt_llm/runtime/loraManager.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <cstddef>
#include <memory>
#include <optional>
#include <vector>
namespace tensorrt_llm::runtime
{
class TllmRuntime;
} // namespace tensorrt_llm::runtime
namespace tensorrt_llm::batch_manager
{
namespace kv_cache_manager
{
class BaseKVCacheManager;
} // namespace kv_cache_manager
class LlmRequest;
class EncoderBuffers;
class LoraBuffers;
class MedusaBuffers;
class PromptTuningBuffers;
class RnnStateBuffers;
class TransformerBuffers;
class RuntimeBuffers
{
public:
static constexpr auto kLogitsTensorName = "logits";
static constexpr auto kHiddenStatesOutputTensorName = "hidden_states_output";
static constexpr auto kHiddenStatesInputTensorName = "hidden_states_input";
static constexpr auto kInputIdsTensorName = "input_ids";
static constexpr auto kLastTokenIdsTensorName = "last_token_ids";
static constexpr auto kHostRequestTypesTensorName = "host_request_types";
static constexpr auto kContextLengthsTensorName = "context_lengths";
static constexpr auto kHostContextLengthsTensorName = "host_context_lengths";
static constexpr auto kSequenceLengthsTensorName = "sequence_length";
static constexpr auto kPromptEmbeddingTableTensorName = "prompt_embedding_table";
static constexpr auto kTasksTensorName = "tasks";
static constexpr auto kPromptVocabSizeTensorName = "prompt_vocab_size";
static constexpr auto kMRopeRotaryCosSinTensorName = "mrope_rotary_cos_sin";
static constexpr auto kMRopePositionDeltasTensorName = "mrope_position_deltas";
using SizeType32 = runtime::SizeType32;
using TensorPtr = runtime::ITensor::SharedPtr;
using TensorMap = runtime::ITensor::TensorMap;
using PeftTable = runtime::LoraManager::PeftTable;
[[nodiscard]] SizeType32 constexpr getContextIndex() const noexcept
{
return contextIndex;
};
void constexpr setContextIndex(SizeType32 index) noexcept
{
contextIndex = index;
};
[[nodiscard]] SizeType32 constexpr getNumContextTokens() const noexcept
{
return numContextTokens;
};
[[nodiscard]] BatchState getBatchState() const noexcept
{
return {numContextRequests, numGenRequests, getNumTokens(), maxKvCacheLengthRounded};
};
private:
[[nodiscard]] SizeType32 constexpr getNumRequests() const noexcept
{
return numContextRequests + numGenRequests;
};
[[nodiscard]] SizeType32 constexpr getNumSequences() const noexcept
{
return numContextRequests + numGenSequences;
};
[[nodiscard]] SizeType32 constexpr getNumTokens() const noexcept
{
return numContextTokens + numGenTokens;
};
//! Sizes
SizeType32 numContextRequests{};
SizeType32 numGenRequests{};
SizeType32 numGenSequences{};
SizeType32 numContextTokens{};
SizeType32 numGenTokens{};
SizeType32 numLogits{};
SizeType32 maxKvCacheLengthRounded{};
//! General
TensorPtr inputsIds;
TensorPtr contextLengthsHost;
TensorPtr contextLengthsDevice;
TensorPtr sequenceLengthsHost;
//! Index of selected runtime context.
SizeType32 contextIndex{};
SizeType32 maxContextLength{};
public:
TensorPtr sequenceLengthsDevice;
private:
//! Runtime
//! Type of host tensor: 0 for context, 1 for generation
TensorPtr requestTypes;
TensorPtr lastTokenIdsHost;
TensorPtr lastTokenIdsDevice;
TensorPtr logitsIdsHost;
TensorPtr logitsIdsDevice;
//! Pipeline-Parallelism
TensorPtr hiddenStates;
//! Prompt-Tuning
std::unique_ptr<PromptTuningBuffers> promptTuningBuffers;
//! Mrope
TensorPtr mropeRotaryCosSin;
TensorPtr mropePositionDeltas;
//! LoRA
std::unique_ptr<LoraBuffers> loraBuffers;
public:
//! Additional buffers depending on model type
std::unique_ptr<TransformerBuffers> transformerBuffers;
std::unique_ptr<RnnStateBuffers> rnnStateBuffers;
//! Encoder-Decoder
std::unique_ptr<EncoderBuffers> encoderBuffers;
//! Medusa
std::unique_ptr<MedusaBuffers> medusaBuffers;
//! Lookahead decoding
std::optional<runtime::LookaheadRuntimeBuffers> lookaheadBuffers;
//! Explicit draft tokens decoding
std::optional<runtime::ExplicitDraftTokensBuffers> explicitDraftTokensBuffers;
//! Eagle decoding
std::optional<runtime::EagleBuffers> eagleBuffers;
//! Language adapter routing information if language adapter is presented.
TensorPtr languageAdapterRoutings; // [numTokens, numLanguages]
TensorPtr cacheIndirDecoderIOBatchedCopySrcOffsets;
TensorPtr cacheIndirDecoderIOBatchedCopyDstOffsets;
TensorPtr cacheIndirDecoderIOBatchedCopySizes;
//! Logits
std::vector<SizeType32> numContextLogits;
TensorPtr logits;
//! Helper cache for store generation logits
struct GenerationLogitsCache
{
static constexpr auto kCACHE_LENGTH = 8;
//! Buffer for logits between steps to prevent from being overwritten
//! [kCACHE_LENGTH, maxBatchSize * maxBeamWidth, vocabSizePadded]
TensorPtr logits;
//! Record the usage offset of the cacheGenerationLogits buffer
SizeType32 offset{0};
//! Temporarily store the transposed results of multiple fragment logits, [maxBeamWidth, kCACHE_LENGTH]
TensorPtr transposedLogits;
//! Temporarily store logits buffer address during the transposing, [kCACHE_LENGTH]
TensorPtr fragmentPointerDevice;
//! Temporarily store logits buffer address during the transposing, [maxBatchSize, kCACHE_LENGTH]
TensorPtr fragmentPointerHost;
//! Cycling index for workspace
size_t workIdx{0};
void cycleWorkIdx()
{
workIdx = (workIdx + 1) % (fragmentPointerHost->getShape().d[0]);
}
[[nodiscard]] TensorPtr getFragmentPointerHost()
{
TensorPtr slice = runtime::ITensor::slice(fragmentPointerHost, workIdx, 1);
cycleWorkIdx();
return slice;
};
};
GenerationLogitsCache generationLogitsCache;
//! Helper for KV cache rewind
TensorPtr seqSlots;
TensorPtr seqSlotsDevice;
TensorPtr sortedSeqSlots;
//! TODO: move into decoderBuffers.DraftBuffers
TensorPtr seqSlotRemappingHost; // [numSequences]
TensorPtr seqSlotRemappingDevice; // [numSequences]
//! Explicitly device-copy src offsets to reduce warp stalls in copy batch kernel invocation
//! [mMaxNumRequests], on gpu
TensorPtr mCacheIndirDecoderIOBatchedCopySrcOffsetsSliceDevice;
//! Explicitly device-copy dst offsets to reduce warp stalls in copy batch kernel invocation
//! [mMaxNumRequests], on gpu
TensorPtr mCacheIndirDecoderIOBatchedCopyDstOffsetsSliceDevice;
//! Explicitly device-copy size to reduce warp stalls in copy batch kernel invocation
//! [mMaxNumRequests], on gpu
TensorPtr mCacheIndirDecoderIOBatchedCopyCopySizesDevice;
private:
//! Re-capture cuda graph when max kv cache len of the batch has changed on kKV_CACHE_LEN_CUDA_GRAPH_ROUND_SIZE.
static SizeType32 constexpr kKV_CACHE_LEN_CUDA_GRAPH_ROUND_SIZE{256};
TensorMap mAdditionalOutputTensors; // Tensors storing additional output tensors.
//! Engine I/O
TensorMap inputMap;
TensorMap outputMap;
public:
RuntimeBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLen,
runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig, executor::DecodingConfig const& decodingConfig,
bool gatherGenerationLogits, std::optional<SizeType32> maxNumTokens = std::nullopt,
std::optional<std::vector<executor::AdditionalModelOutput>> const& additionalModelOutputs = std::nullopt);
RuntimeBuffers(RuntimeBuffers const& other) = delete;
RuntimeBuffers& operator=(RuntimeBuffers const& other) = delete;
RuntimeBuffers(RuntimeBuffers&& other) = delete;
RuntimeBuffers& operator=(RuntimeBuffers&& other) = delete;
~RuntimeBuffers();
std::tuple<SizeType32, TensorMap const&, TensorMap&> prepareStep(RequestVector const& contextRequests,
RequestVector const& genRequests, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
DecoderBuffers& decoderBuffers, kv_cache_manager::BaseKVCacheManager* kvCacheManager,
kv_cache_manager::BaseKVCacheManager* crossKvCacheManager, rnn_state_manager::RnnStateManager* rnnStateManager,
PeftTable const& peftTable, runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig, bool gatherGenerationLogits);
void prepareBuffersForCudaGraph(SizeType32 maxSequenceLength);
void prepareExplicitDraftTokenBuffers(DecoderBuffers& decoderBuffers, runtime::TllmRuntime const& runtime,
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig);
void prepareEagleBuffers(RequestVector const& contextRequests, RequestVector const& genRequests,
DecoderBuffers& decoderBuffers, runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig);
private:
void create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
SizeType32 maxAttentionWindow, SizeType32 sinkTokenLen, runtime::TllmRuntime const& runtime,
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
executor::DecodingConfig const& decodingConfig, bool gatherGenerationLogits,
std::optional<std::vector<executor::AdditionalModelOutput>> const& additionalModelOutputs = std::nullopt);
void reshape(runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig, bool gatherGenerationLogits);
//! @brief set max sizes for pre-allocation
void setMaxBufferSizes(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, runtime::ModelConfig const& modelConfig,
std::optional<SizeType32> maxNumRuntimeTokens);
//! @brief set sizes depending on scheduled requests
void setBufferSizes(RequestVector const& contextRequests, RequestVector const& genRequests);
void setFromInputs(RequestVector const& contextRequests, RequestVector const& genRequests, SizeType32 maxBeamWidth,
SizeType32 maxAttentionWindow, DecoderBuffers& decoderBuffers,
kv_cache_manager::BaseKVCacheManager* kvCacheManagerPtr,
kv_cache_manager::BaseKVCacheManager* crossKvCacheManagerPtr,
rnn_state_manager::RnnStateManager* rnnStateManagerPtr, PeftTable const& peftTable,
runtime::TllmRuntime const& runtime, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig);
void fillIOMaps(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig);
};
} // namespace tensorrt_llm::batch_manager