mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* refactor: Update ExecutorConfig to use AdditionalModelOutput type - Changed function signatures and member variables across multiple files to replace std::optional<std::vector<std::string>> with std::optional<std::vector<executor::AdditionalModelOutput>> to include gatherContext flag for each additional output. - Updated related serialization and deserialization methods to accommodate the new type. - Adjusted tests to reflect the changes in the output handling structure. This refactor enhances the flexibility and maintainability of the output configuration in the executor and batch manager components. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Remove equality operator from TrtGptModelOptionalParams - Deleted the operator== implementation from TrtGptModelOptionalParams to simplify the class. - Updated the pybind11 bindings to remove the exposure of the equality operator to Python. This change streamlines the class definition and reduces unnecessary complexity in the bindings. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Enhance copyAdditionalOutputs to utilize AdditionalModelOutput - Updated the copyAdditionalOutputs function to accept a vector of AdditionalModelOutput, allowing for the inclusion of the gatherContext flag. - Adjusted the logic to handle context and non-context outputs separately, improving the output handling mechanism. - Modified related unit tests to incorporate the new gatherContext parameter, ensuring comprehensive testing of the updated functionality. This refactor improves the flexibility and clarity of output management in the batch processing workflow. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Introduce findOutputTensor utility function for output tensor retrieval - Added a new utility function, findOutputTensor, to encapsulate the logic for finding output tensors and checking their validity. - Refactored copyAdditionalOutputs to utilize findOutputTensor, reducing code duplication and improving clarity. - Enhanced error checking for additional context and generation output tensors. This change streamlines the output tensor retrieval process, enhancing maintainability and readability in the batch processing workflow. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Check final indices of additional output tensors and update tests - Added checks to verify the final indices of additional output tensors for context and generation outputs. - Updated unit tests to verify the changes. - Add lastTokenIds input tensor to test engines. - Logits output depends on gatherContextLogits parameter. - Removed gatherContextOutputs parameter from the validate method in LlmRequest. - Context outputs do not depend on computeContextLogits parameter. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * fixup! refactor: Check final indices of additional output tensors and update tests Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * fixup! refactor: Update ExecutorConfig to use AdditionalModelOutput type Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * fixup! refactor: Remove equality operator from TrtGptModelOptionalParams Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * docs: Update executor.md Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Clean up includes Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --------- Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
295 lines
15 KiB
C++
295 lines
15 KiB
C++
/*
|
|
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "tensorrt_llm/executor/executor.h"
|
|
#include "tensorrt_llm/executor/tensor.h"
|
|
#include "tensorrt_llm/executor/types.h"
|
|
#include <istream>
|
|
#include <ostream>
|
|
|
|
namespace tensorrt_llm::executor
|
|
{
|
|
|
|
namespace kv_cache
|
|
{
|
|
class CommState;
|
|
class CacheState;
|
|
struct SocketState;
|
|
} // namespace kv_cache
|
|
|
|
class Serialization
|
|
{
|
|
public:
|
|
// TimePoint
|
|
[[nodiscard]] static RequestPerfMetrics::TimePoint deserializeTimePoint(std::istream& is);
|
|
static void serialize(RequestPerfMetrics::TimePoint const& tp, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(RequestPerfMetrics::TimePoint const&);
|
|
|
|
// RequestPerfMetrics
|
|
[[nodiscard]] static RequestPerfMetrics deserializeRequestPerfMetrics(std::istream& is);
|
|
static void serialize(RequestPerfMetrics const& metrics, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(RequestPerfMetrics const& metrics);
|
|
|
|
// SamplingConfig
|
|
[[nodiscard]] static SamplingConfig deserializeSamplingConfig(std::istream& is);
|
|
static void serialize(SamplingConfig const& config, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(SamplingConfig const& config);
|
|
|
|
// OutputConfig
|
|
[[nodiscard]] static OutputConfig deserializeOutputConfig(std::istream& is);
|
|
static void serialize(OutputConfig const& config, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(OutputConfig const& config);
|
|
|
|
// OutputConfig::AdditionalModelOutput
|
|
[[nodiscard]] static AdditionalModelOutput deserializeAdditionalModelOutput(std::istream& is);
|
|
static void serialize(AdditionalModelOutput const& additionalModelOutput, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(AdditionalModelOutput const& additionalModelOutput);
|
|
|
|
// ExternalDraftTokensConfig
|
|
[[nodiscard]] static ExternalDraftTokensConfig deserializeExternalDraftTokensConfig(std::istream& is);
|
|
static void serialize(ExternalDraftTokensConfig const& config, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(ExternalDraftTokensConfig const& config);
|
|
|
|
// PromptTuningConfig
|
|
[[nodiscard]] static PromptTuningConfig deserializePromptTuningConfig(std::istream& is);
|
|
static void serialize(PromptTuningConfig const& config, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(PromptTuningConfig const& config);
|
|
|
|
// MropeConfig
|
|
[[nodiscard]] static MropeConfig deserializeMropeConfig(std::istream& is);
|
|
static void serialize(MropeConfig const& config, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(MropeConfig const& config);
|
|
|
|
// LoraConfig
|
|
[[nodiscard]] static LoraConfig deserializeLoraConfig(std::istream& is);
|
|
static void serialize(LoraConfig const& config, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(LoraConfig const& config);
|
|
|
|
// CommState
|
|
[[nodiscard]] static kv_cache::CommState deserializeCommState(std::istream& is);
|
|
static void serialize(kv_cache::CommState const& state, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(kv_cache::CommState const& state);
|
|
|
|
// SocketState
|
|
[[nodiscard]] static kv_cache::SocketState deserializeSocketState(std::istream& is);
|
|
static void serialize(kv_cache::SocketState const& state, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(kv_cache::SocketState const& state);
|
|
|
|
// CacheState
|
|
[[nodiscard]] static kv_cache::CacheState deserializeCacheState(std::istream& is);
|
|
static void serialize(kv_cache::CacheState const& state, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(kv_cache::CacheState const& state);
|
|
|
|
// DataTransceiverState
|
|
[[nodiscard]] static DataTransceiverState deserializeDataTransceiverState(std::istream& is);
|
|
[[nodiscard]] static DataTransceiverState deserializeDataTransceiverState(std::vector<char>& buffer);
|
|
static void serialize(DataTransceiverState const& dataTransceiverState, std::ostream& os);
|
|
static std::vector<char> serialize(DataTransceiverState const& dataTransceiverState);
|
|
[[nodiscard]] static size_t serializedSize(DataTransceiverState const& dataTransceiverState);
|
|
|
|
// ContextPhaseParams
|
|
[[nodiscard]] static ContextPhaseParams deserializeContextPhaseParams(std::istream& is);
|
|
static void serialize(ContextPhaseParams const& contextPhaseParams, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(ContextPhaseParams const& contextPhaseParams);
|
|
|
|
// Request
|
|
[[nodiscard]] static Request deserializeRequest(std::istream& is);
|
|
static void serialize(Request const& request, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(Request const& request);
|
|
|
|
// Tensor
|
|
[[nodiscard]] static Tensor deserializeTensor(std::istream& is);
|
|
static void serialize(Tensor const& tensor, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(Tensor const& tensor);
|
|
|
|
// SpeculativeDecodingFastLogitsInfo
|
|
[[nodiscard]] static SpeculativeDecodingFastLogitsInfo deserializeSpecDecFastLogitsInfo(std::istream& is);
|
|
static void serialize(SpeculativeDecodingFastLogitsInfo const& info, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(SpeculativeDecodingFastLogitsInfo const& info);
|
|
|
|
// Result
|
|
[[nodiscard]] static Result deserializeResult(std::istream& is);
|
|
static void serialize(Result const& result, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(Result const& result);
|
|
|
|
// AdditionalOutput
|
|
[[nodiscard]] static AdditionalOutput deserializeAdditionalOutput(std::istream& is);
|
|
static void serialize(AdditionalOutput const& additionalOutput, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(AdditionalOutput const& additionalOutput);
|
|
|
|
// Response
|
|
[[nodiscard]] static Response deserializeResponse(std::istream& is);
|
|
static void serialize(Response const& response, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(Response const& response);
|
|
|
|
// Vector of responses
|
|
static std::vector<Response> deserializeResponses(std::vector<char>& buffer);
|
|
static std::vector<char> serialize(std::vector<Response> const& responses);
|
|
|
|
// KvCacheConfig
|
|
static KvCacheConfig deserializeKvCacheConfig(std::istream& is);
|
|
static void serialize(KvCacheConfig const& kvCacheConfig, std::ostream& os);
|
|
static size_t serializedSize(KvCacheConfig const& kvCacheConfig);
|
|
|
|
// DynamicBatchConfig
|
|
static DynamicBatchConfig deserializeDynamicBatchConfig(std::istream& is);
|
|
static void serialize(DynamicBatchConfig const& dynamicBatchConfig, std::ostream& os);
|
|
static size_t serializedSize(DynamicBatchConfig const& dynamicBatchConfig);
|
|
|
|
// SchedulerConfig
|
|
static SchedulerConfig deserializeSchedulerConfig(std::istream& is);
|
|
static void serialize(SchedulerConfig const& schedulerConfig, std::ostream& os);
|
|
static size_t serializedSize(SchedulerConfig const& schedulerConfig);
|
|
|
|
// ExtendedRuntimePerfKnobConfig
|
|
static ExtendedRuntimePerfKnobConfig deserializeExtendedRuntimePerfKnobConfig(std::istream& is);
|
|
static void serialize(ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig, std::ostream& os);
|
|
static size_t serializedSize(ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig);
|
|
|
|
// ParallelConfig
|
|
static ParallelConfig deserializeParallelConfig(std::istream& is);
|
|
static void serialize(ParallelConfig const& parallelConfig, std::ostream& os);
|
|
static size_t serializedSize(ParallelConfig const& parallelConfig);
|
|
|
|
// PeftCacheConfig
|
|
static PeftCacheConfig deserializePeftCacheConfig(std::istream& is);
|
|
static void serialize(PeftCacheConfig const& peftCacheConfig, std::ostream& os);
|
|
static size_t serializedSize(PeftCacheConfig const& peftCacheConfig);
|
|
|
|
// OrchestratorConfig
|
|
static OrchestratorConfig deserializeOrchestratorConfig(std::istream& is);
|
|
static void serialize(OrchestratorConfig const& orchestratorConfig, std::ostream& os);
|
|
static size_t serializedSize(OrchestratorConfig const& orchestratorConfig);
|
|
|
|
// DecodingMode
|
|
static DecodingMode deserializeDecodingMode(std::istream& is);
|
|
static void serialize(DecodingMode const& decodingMode, std::ostream& os);
|
|
static size_t serializedSize(DecodingMode const& decodingMode);
|
|
|
|
// LookaheadDecodingConfig
|
|
static LookaheadDecodingConfig deserializeLookaheadDecodingConfig(std::istream& is);
|
|
static void serialize(LookaheadDecodingConfig const& lookaheadDecodingConfig, std::ostream& os);
|
|
static size_t serializedSize(LookaheadDecodingConfig const& lookaheadDecodingConfig);
|
|
|
|
// EagleConfig
|
|
static EagleConfig deserializeEagleConfig(std::istream& is);
|
|
static void serialize(EagleConfig const& eagleConfig, std::ostream& os);
|
|
static size_t serializedSize(EagleConfig const& eagleConfig);
|
|
|
|
// SpeculativeDecodingConfig
|
|
static SpeculativeDecodingConfig deserializeSpeculativeDecodingConfig(std::istream& is);
|
|
static void serialize(SpeculativeDecodingConfig const& specDecConfig, std::ostream& os);
|
|
static size_t serializedSize(SpeculativeDecodingConfig const& specDecConfig);
|
|
|
|
// GuidedDecodingConfig
|
|
static GuidedDecodingConfig deserializeGuidedDecodingConfig(std::istream& is);
|
|
static void serialize(GuidedDecodingConfig const& guidedDecodingConfig, std::ostream& os);
|
|
static size_t serializedSize(GuidedDecodingConfig const& guidedDecodingConfig);
|
|
|
|
// GuidedDecodingParams
|
|
static GuidedDecodingParams deserializeGuidedDecodingParams(std::istream& is);
|
|
static void serialize(GuidedDecodingParams const& guidedDecodingParams, std::ostream& os);
|
|
static size_t serializedSize(GuidedDecodingParams const& guidedDecodingParams);
|
|
|
|
// KvCacheRetentionConfig
|
|
static KvCacheRetentionConfig deserializeKvCacheRetentionConfig(std::istream& is);
|
|
static void serialize(KvCacheRetentionConfig const& kvCacheRetentionConfig, std::ostream& os);
|
|
static size_t serializedSize(KvCacheRetentionConfig const& kvCacheRetentionConfig);
|
|
|
|
// TokenRangeRetentionConfig
|
|
static KvCacheRetentionConfig::TokenRangeRetentionConfig deserializeTokenRangeRetentionConfig(std::istream& is);
|
|
static void serialize(
|
|
KvCacheRetentionConfig::TokenRangeRetentionConfig const& tokenRangeRetentionConfig, std::ostream& os);
|
|
static size_t serializedSize(KvCacheRetentionConfig::TokenRangeRetentionConfig const& tokenRangeRetentionConfig);
|
|
|
|
// DecodingConfig
|
|
static DecodingConfig deserializeDecodingConfig(std::istream& is);
|
|
static void serialize(DecodingConfig const& decodingConfig, std::ostream& os);
|
|
static size_t serializedSize(DecodingConfig const& decodingConfig);
|
|
|
|
// DebugConfig
|
|
static DebugConfig deserializeDebugConfig(std::istream& is);
|
|
static void serialize(DebugConfig const& debugConfig, std::ostream& os);
|
|
static size_t serializedSize(DebugConfig const& debugConfig);
|
|
|
|
// ExecutorConfig
|
|
static ExecutorConfig deserializeExecutorConfig(std::istream& is);
|
|
static void serialize(ExecutorConfig const& executorConfig, std::ostream& os);
|
|
static size_t serializedSize(ExecutorConfig const& executorConfig);
|
|
|
|
// KvCacheStats
|
|
static KvCacheStats deserializeKvCacheStats(std::istream& is);
|
|
static void serialize(KvCacheStats const& kvCacheStats, std::ostream& os);
|
|
static size_t serializedSize(KvCacheStats const& kvCacheStats);
|
|
|
|
// StaticBatchingStats
|
|
static StaticBatchingStats deserializeStaticBatchingStats(std::istream& is);
|
|
static void serialize(StaticBatchingStats const& staticBatchingStats, std::ostream& os);
|
|
static size_t serializedSize(StaticBatchingStats const& staticBatchingStats);
|
|
|
|
// InflightBatchingStats
|
|
static InflightBatchingStats deserializeInflightBatchingStats(std::istream& is);
|
|
static void serialize(InflightBatchingStats const& inflightBatchingStats, std::ostream& os);
|
|
static size_t serializedSize(InflightBatchingStats const& inflightBatchingStats);
|
|
|
|
// IterationStats
|
|
static IterationStats deserializeIterationStats(std::vector<char>& buffer);
|
|
static IterationStats deserializeIterationStats(std::istream& is);
|
|
static void serialize(IterationStats const& iterStats, std::ostream& os);
|
|
static std::vector<char> serialize(IterationStats const& iterStats);
|
|
static size_t serializedSize(IterationStats const& iterStats);
|
|
static std::vector<char> serialize(std::vector<IterationStats> const& iterStatsVec);
|
|
static std::vector<IterationStats> deserializeIterationStatsVec(std::vector<char>& buffer);
|
|
|
|
// DisServingStats
|
|
[[nodiscard]] static DisServingRequestStats deserializeDisServingRequestStats(std::istream& is);
|
|
static void serialize(DisServingRequestStats const& stats, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(DisServingRequestStats const& disServingRequestStats);
|
|
|
|
// RequestStage
|
|
[[nodiscard]] static RequestStage deserializeRequestStage(std::istream& is);
|
|
static void serialize(RequestStage const& requestStage, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(RequestStage const& requestStage);
|
|
|
|
// RequestStats
|
|
[[nodiscard]] static RequestStats deserializeRequestStats(std::istream& is);
|
|
static void serialize(RequestStats const& state, std::ostream& os);
|
|
[[nodiscard]] static size_t serializedSize(RequestStats const& state);
|
|
|
|
// RequestStatsPerIteration
|
|
[[nodiscard]] static RequestStatsPerIteration deserializeRequestStatsPerIteration(std::istream& is);
|
|
[[nodiscard]] static RequestStatsPerIteration deserializeRequestStatsPerIteration(std::vector<char>& buffer);
|
|
static void serialize(RequestStatsPerIteration const& state, std::ostream& os);
|
|
[[nodiscard]] static std::vector<char> serialize(RequestStatsPerIteration const& state);
|
|
[[nodiscard]] static size_t serializedSize(RequestStatsPerIteration const& state);
|
|
[[nodiscard]] static std::vector<char> serialize(std::vector<RequestStatsPerIteration> const& requestStatsVec);
|
|
[[nodiscard]] static std::vector<RequestStatsPerIteration> deserializeRequestStatsPerIterationVec(
|
|
std::vector<char>& buffer);
|
|
|
|
// String
|
|
static std::string deserializeString(std::istream& is);
|
|
|
|
// Bool
|
|
static bool deserializeBool(std::istream& is);
|
|
|
|
// ModelType
|
|
static ModelType deserializeModelType(std::istream& is);
|
|
};
|
|
|
|
} // namespace tensorrt_llm::executor
|