mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* refactor: Update ExecutorConfig to use AdditionalModelOutput type - Changed function signatures and member variables across multiple files to replace std::optional<std::vector<std::string>> with std::optional<std::vector<executor::AdditionalModelOutput>> to include gatherContext flag for each additional output. - Updated related serialization and deserialization methods to accommodate the new type. - Adjusted tests to reflect the changes in the output handling structure. This refactor enhances the flexibility and maintainability of the output configuration in the executor and batch manager components. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Remove equality operator from TrtGptModelOptionalParams - Deleted the operator== implementation from TrtGptModelOptionalParams to simplify the class. - Updated the pybind11 bindings to remove the exposure of the equality operator to Python. This change streamlines the class definition and reduces unnecessary complexity in the bindings. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Enhance copyAdditionalOutputs to utilize AdditionalModelOutput - Updated the copyAdditionalOutputs function to accept a vector of AdditionalModelOutput, allowing for the inclusion of the gatherContext flag. - Adjusted the logic to handle context and non-context outputs separately, improving the output handling mechanism. - Modified related unit tests to incorporate the new gatherContext parameter, ensuring comprehensive testing of the updated functionality. This refactor improves the flexibility and clarity of output management in the batch processing workflow. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Introduce findOutputTensor utility function for output tensor retrieval - Added a new utility function, findOutputTensor, to encapsulate the logic for finding output tensors and checking their validity. - Refactored copyAdditionalOutputs to utilize findOutputTensor, reducing code duplication and improving clarity. - Enhanced error checking for additional context and generation output tensors. This change streamlines the output tensor retrieval process, enhancing maintainability and readability in the batch processing workflow. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * refactor: Check final indices of additional output tensors and update tests - Added checks to verify the final indices of additional output tensors for context and generation outputs. - Updated unit tests to verify the changes. - Add lastTokenIds input tensor to test engines. - Logits output depends on gatherContextLogits parameter. - Removed gatherContextOutputs parameter from the validate method in LlmRequest. - Context outputs do not depend on computeContextLogits parameter. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * fixup! refactor: Check final indices of additional output tensors and update tests Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * fixup! refactor: Update ExecutorConfig to use AdditionalModelOutput type Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * fixup! refactor: Remove equality operator from TrtGptModelOptionalParams Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * docs: Update executor.md Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * chore: Clean up includes Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --------- Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2040 lines
87 KiB
C++
2040 lines
87 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "tensorrt_llm/executor/serialization.h"
|
|
#include "tensorrt_llm/executor/dataTransceiverState.h"
|
|
#include "tensorrt_llm/executor/executor.h"
|
|
#include "tensorrt_llm/executor/requestImpl.h"
|
|
#include "tensorrt_llm/executor/responseImpl.h"
|
|
#include "tensorrt_llm/executor/serializeUtils.h"
|
|
#include "tensorrt_llm/executor/types.h"
|
|
#include "tensorrt_llm/runtime/cudaStream.h"
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <type_traits>
|
|
|
|
namespace su = tensorrt_llm::executor::serialize_utils;
|
|
|
|
namespace tensorrt_llm::executor
|
|
{
|
|
|
|
// TimePoint
|
|
RequestPerfMetrics::TimePoint Serialization::deserializeTimePoint(std::istream& is)
|
|
{
|
|
auto serialized_time = su::deserialize<uint64_t>(is);
|
|
std::chrono::microseconds duration(serialized_time);
|
|
return std::chrono::steady_clock::time_point(duration);
|
|
}
|
|
|
|
void Serialization::serialize(RequestPerfMetrics::TimePoint const& tp, std::ostream& os)
|
|
{
|
|
su::serialize(std::chrono::duration_cast<std::chrono::microseconds>(tp.time_since_epoch()).count(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(RequestPerfMetrics::TimePoint const& /*unused*/)
|
|
{
|
|
return sizeof(RequestPerfMetrics::TimePoint);
|
|
}
|
|
|
|
// RequestPerfMetrics
|
|
RequestPerfMetrics Serialization::deserializeRequestPerfMetrics(std::istream& is)
|
|
{
|
|
auto arrivalTime = su::deserialize<RequestPerfMetrics::TimePoint>(is);
|
|
auto firstScheduledTime = su::deserialize<RequestPerfMetrics::TimePoint>(is);
|
|
auto firstTokenTime = su::deserialize<RequestPerfMetrics::TimePoint>(is);
|
|
auto lastTokenTime = su::deserialize<RequestPerfMetrics::TimePoint>(is);
|
|
auto kvCacheTransferStart = su::deserialize<RequestPerfMetrics::TimePoint>(is);
|
|
auto kvCacheTransferEnd = su::deserialize<RequestPerfMetrics::TimePoint>(is);
|
|
|
|
auto kvCacheSize = su::deserialize<size_t>(is);
|
|
|
|
auto numTotalAllocatedBlocks = su::deserialize<SizeType32>(is);
|
|
auto numNewAllocatedBlocks = su::deserialize<SizeType32>(is);
|
|
auto numReusedBlocks = su::deserialize<SizeType32>(is);
|
|
auto numMissedBlocks = su::deserialize<SizeType32>(is);
|
|
auto kvCacheHitRate = su::deserialize<FloatType>(is);
|
|
|
|
auto acceptanceRate = su::deserialize<FloatType>(is);
|
|
auto totalAcceptedDraftTokens = su::deserialize<SizeType32>(is);
|
|
auto totalDraftTokens = su::deserialize<SizeType32>(is);
|
|
|
|
auto firstIter = su::deserialize<std::optional<IterationType>>(is);
|
|
auto lastIter = su::deserialize<std::optional<IterationType>>(is);
|
|
auto iter = su::deserialize<std::optional<IterationType>>(is);
|
|
|
|
RequestPerfMetrics::TimingMetrics timingMetrics{arrivalTime, firstScheduledTime, firstTokenTime, lastTokenTime,
|
|
kvCacheTransferStart, kvCacheTransferEnd, kvCacheSize};
|
|
RequestPerfMetrics::KvCacheMetrics kvCacheMetrics{
|
|
numTotalAllocatedBlocks, numNewAllocatedBlocks, numReusedBlocks, numMissedBlocks, kvCacheHitRate};
|
|
RequestPerfMetrics::SpeculativeDecodingMetrics specDecMetrics{
|
|
acceptanceRate, totalAcceptedDraftTokens, totalDraftTokens};
|
|
return RequestPerfMetrics{timingMetrics, kvCacheMetrics, specDecMetrics, firstIter, lastIter, iter};
|
|
}
|
|
|
|
void Serialization::serialize(RequestPerfMetrics const& metrics, std::ostream& os)
|
|
{
|
|
su::serialize(metrics.timingMetrics.arrivalTime, os);
|
|
su::serialize(metrics.timingMetrics.firstScheduledTime, os);
|
|
su::serialize(metrics.timingMetrics.firstTokenTime, os);
|
|
su::serialize(metrics.timingMetrics.lastTokenTime, os);
|
|
su::serialize(metrics.timingMetrics.kvCacheTransferStart, os);
|
|
su::serialize(metrics.timingMetrics.kvCacheTransferEnd, os);
|
|
|
|
su::serialize(metrics.timingMetrics.kvCacheSize, os);
|
|
|
|
su::serialize(metrics.kvCacheMetrics.numTotalAllocatedBlocks, os);
|
|
su::serialize(metrics.kvCacheMetrics.numNewAllocatedBlocks, os);
|
|
su::serialize(metrics.kvCacheMetrics.numReusedBlocks, os);
|
|
su::serialize(metrics.kvCacheMetrics.numMissedBlocks, os);
|
|
su::serialize(metrics.kvCacheMetrics.kvCacheHitRate, os);
|
|
|
|
su::serialize(metrics.speculativeDecoding.acceptanceRate, os);
|
|
su::serialize(metrics.speculativeDecoding.totalAcceptedDraftTokens, os);
|
|
su::serialize(metrics.speculativeDecoding.totalDraftTokens, os);
|
|
|
|
su::serialize(metrics.firstIter, os);
|
|
su::serialize(metrics.lastIter, os);
|
|
su::serialize(metrics.iter, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(RequestPerfMetrics const& metrics)
|
|
{
|
|
size_t totalSize = 0;
|
|
|
|
totalSize += su::serializedSize(metrics.timingMetrics.arrivalTime);
|
|
totalSize += su::serializedSize(metrics.timingMetrics.firstScheduledTime);
|
|
totalSize += su::serializedSize(metrics.timingMetrics.firstTokenTime);
|
|
totalSize += su::serializedSize(metrics.timingMetrics.lastTokenTime);
|
|
totalSize += su::serializedSize(metrics.timingMetrics.kvCacheTransferStart);
|
|
totalSize += su::serializedSize(metrics.timingMetrics.kvCacheTransferEnd);
|
|
|
|
totalSize += su::serializedSize(metrics.timingMetrics.kvCacheSize);
|
|
|
|
totalSize += su::serializedSize(metrics.kvCacheMetrics.numTotalAllocatedBlocks);
|
|
totalSize += su::serializedSize(metrics.kvCacheMetrics.numNewAllocatedBlocks);
|
|
totalSize += su::serializedSize(metrics.kvCacheMetrics.numReusedBlocks);
|
|
totalSize += su::serializedSize(metrics.kvCacheMetrics.numMissedBlocks);
|
|
totalSize += su::serializedSize(metrics.kvCacheMetrics.kvCacheHitRate);
|
|
|
|
totalSize += su::serializedSize(metrics.speculativeDecoding.acceptanceRate);
|
|
totalSize += su::serializedSize(metrics.speculativeDecoding.totalAcceptedDraftTokens);
|
|
totalSize += su::serializedSize(metrics.speculativeDecoding.totalDraftTokens);
|
|
|
|
totalSize += su::serializedSize(metrics.firstIter);
|
|
totalSize += su::serializedSize(metrics.lastIter);
|
|
totalSize += su::serializedSize(metrics.iter);
|
|
|
|
return totalSize;
|
|
}
|
|
|
|
// SamplingConfig
|
|
SamplingConfig Serialization::deserializeSamplingConfig(std::istream& is)
|
|
{
|
|
auto beamWidth = su::deserialize<SizeType32>(is);
|
|
auto topK = su::deserialize<std::optional<SizeType32>>(is);
|
|
auto topP = su::deserialize<std::optional<FloatType>>(is);
|
|
auto topPMin = su::deserialize<std::optional<FloatType>>(is);
|
|
auto topPResetIds = su::deserialize<std::optional<TokenIdType>>(is);
|
|
auto topPDecay = su::deserialize<std::optional<FloatType>>(is);
|
|
auto randomSeed = su::deserialize<std::optional<RandomSeedType>>(is);
|
|
auto temperature = su::deserialize<std::optional<FloatType>>(is);
|
|
auto minLength = su::deserialize<std::optional<SizeType32>>(is);
|
|
auto beamSearchDiversityRate = su::deserialize<std::optional<FloatType>>(is);
|
|
auto repetitionPenalty = su::deserialize<std::optional<FloatType>>(is);
|
|
auto presencePenalty = su::deserialize<std::optional<FloatType>>(is);
|
|
auto frequencyPenalty = su::deserialize<std::optional<FloatType>>(is);
|
|
auto lengthPenalty = su::deserialize<std::optional<FloatType>>(is);
|
|
auto earlyStopping = su::deserialize<std::optional<SizeType32>>(is);
|
|
auto noRepeatNgramSize = su::deserialize<std::optional<SizeType32>>(is);
|
|
auto numReturnSequences = su::deserialize<std::optional<SizeType32>>(is);
|
|
auto minP = su::deserialize<std::optional<FloatType>>(is);
|
|
auto beamWidthArray = su::deserialize<std::optional<std::vector<SizeType32>>>(is);
|
|
|
|
return SamplingConfig{beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, randomSeed, temperature, minLength,
|
|
beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, lengthPenalty, earlyStopping,
|
|
noRepeatNgramSize, numReturnSequences, minP, beamWidthArray};
|
|
}
|
|
|
|
void Serialization::serialize(SamplingConfig const& config, std::ostream& os)
|
|
{
|
|
su::serialize(config.mBeamWidth, os);
|
|
su::serialize(config.mTopK, os);
|
|
su::serialize(config.mTopP, os);
|
|
su::serialize(config.mTopPMin, os);
|
|
su::serialize(config.mTopPResetIds, os);
|
|
su::serialize(config.mTopPDecay, os);
|
|
su::serialize(config.mSeed, os);
|
|
su::serialize(config.mTemperature, os);
|
|
su::serialize(config.mMinTokens, os);
|
|
su::serialize(config.mBeamSearchDiversityRate, os);
|
|
su::serialize(config.mRepetitionPenalty, os);
|
|
su::serialize(config.mPresencePenalty, os);
|
|
su::serialize(config.mFrequencyPenalty, os);
|
|
su::serialize(config.mLengthPenalty, os);
|
|
su::serialize(config.mEarlyStopping, os);
|
|
su::serialize(config.mNoRepeatNgramSize, os);
|
|
su::serialize(config.mNumReturnSequences, os);
|
|
su::serialize(config.mMinP, os);
|
|
su::serialize(config.mBeamWidthArray, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(SamplingConfig const& config)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(config.mBeamWidth);
|
|
totalSize += su::serializedSize(config.mTopK);
|
|
totalSize += su::serializedSize(config.mTopP);
|
|
totalSize += su::serializedSize(config.mTopPMin);
|
|
totalSize += su::serializedSize(config.mTopPResetIds);
|
|
totalSize += su::serializedSize(config.mTopPDecay);
|
|
totalSize += su::serializedSize(config.mSeed);
|
|
totalSize += su::serializedSize(config.mTemperature);
|
|
totalSize += su::serializedSize(config.mMinTokens);
|
|
totalSize += su::serializedSize(config.mBeamSearchDiversityRate);
|
|
totalSize += su::serializedSize(config.mRepetitionPenalty);
|
|
totalSize += su::serializedSize(config.mPresencePenalty);
|
|
totalSize += su::serializedSize(config.mFrequencyPenalty);
|
|
totalSize += su::serializedSize(config.mLengthPenalty);
|
|
totalSize += su::serializedSize(config.mEarlyStopping);
|
|
totalSize += su::serializedSize(config.mNoRepeatNgramSize);
|
|
totalSize += su::serializedSize(config.mNumReturnSequences);
|
|
totalSize += su::serializedSize(config.mMinP);
|
|
totalSize += su::serializedSize(config.mBeamWidthArray);
|
|
return totalSize;
|
|
}
|
|
|
|
// OutputConfig
|
|
OutputConfig Serialization::deserializeOutputConfig(std::istream& is)
|
|
{
|
|
auto returnLogProbs = su::deserialize<bool>(is);
|
|
auto returnContextLogits = su::deserialize<bool>(is);
|
|
auto returnGenerationLogits = su::deserialize<bool>(is);
|
|
auto excludeInputFromOutput = su::deserialize<bool>(is);
|
|
auto returnEncoderOutput = su::deserialize<bool>(is);
|
|
auto returnPerfMetrics = su::deserialize<bool>(is);
|
|
auto additionalOutputs = su::deserialize<std::optional<std::vector<AdditionalModelOutput>>>(is);
|
|
return OutputConfig{returnLogProbs, returnContextLogits, returnGenerationLogits, excludeInputFromOutput,
|
|
returnEncoderOutput, returnPerfMetrics, additionalOutputs};
|
|
}
|
|
|
|
void Serialization::serialize(OutputConfig const& config, std::ostream& os)
|
|
{
|
|
su::serialize(config.returnLogProbs, os);
|
|
su::serialize(config.returnContextLogits, os);
|
|
su::serialize(config.returnGenerationLogits, os);
|
|
su::serialize(config.excludeInputFromOutput, os);
|
|
su::serialize(config.returnEncoderOutput, os);
|
|
su::serialize(config.returnPerfMetrics, os);
|
|
su::serialize(config.additionalModelOutputs, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(OutputConfig const& config)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(config.returnLogProbs);
|
|
totalSize += su::serializedSize(config.returnContextLogits);
|
|
totalSize += su::serializedSize(config.returnGenerationLogits);
|
|
totalSize += su::serializedSize(config.excludeInputFromOutput);
|
|
totalSize += su::serializedSize(config.returnEncoderOutput);
|
|
totalSize += su::serializedSize(config.returnPerfMetrics);
|
|
totalSize += su::serializedSize(config.additionalModelOutputs);
|
|
return totalSize;
|
|
}
|
|
|
|
// AdditionalModelOutput
|
|
AdditionalModelOutput Serialization::deserializeAdditionalModelOutput(std::istream& is)
|
|
{
|
|
auto name = su::deserialize<std::string>(is);
|
|
auto gatherContext = su::deserialize<bool>(is);
|
|
return AdditionalModelOutput{name, gatherContext};
|
|
}
|
|
|
|
void Serialization::serialize(AdditionalModelOutput const& additionalModelOutput, std::ostream& os)
|
|
{
|
|
su::serialize(additionalModelOutput.name, os);
|
|
su::serialize(additionalModelOutput.gatherContext, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(AdditionalModelOutput const& additionalModelOutput)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(additionalModelOutput.name);
|
|
totalSize += su::serializedSize(additionalModelOutput.gatherContext);
|
|
return totalSize;
|
|
}
|
|
|
|
// ExternalDraftTokensConfig
|
|
ExternalDraftTokensConfig Serialization::deserializeExternalDraftTokensConfig(std::istream& is)
|
|
{
|
|
auto tokens = su::deserialize<VecTokens>(is);
|
|
auto logits = su::deserialize<std::optional<Tensor>>(is);
|
|
auto acceptanceThreshold = su::deserialize<std::optional<FloatType>>(is);
|
|
auto fastLogits = su::deserialize<std::optional<bool>>(is);
|
|
|
|
return ExternalDraftTokensConfig{std::move(tokens), std::move(logits), acceptanceThreshold, fastLogits};
|
|
}
|
|
|
|
void Serialization::serialize(ExternalDraftTokensConfig const& config, std::ostream& os)
|
|
{
|
|
su::serialize(config.mTokens, os);
|
|
su::serialize(config.mLogits, os);
|
|
su::serialize(config.mAcceptanceThreshold, os);
|
|
su::serialize(config.mFastLogits, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(ExternalDraftTokensConfig const& config)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(config.mTokens);
|
|
totalSize += su::serializedSize(config.mLogits);
|
|
totalSize += su::serializedSize(config.mAcceptanceThreshold);
|
|
totalSize += su::serializedSize(config.mFastLogits);
|
|
return totalSize;
|
|
}
|
|
|
|
// PromptTuningConfig
|
|
PromptTuningConfig Serialization::deserializePromptTuningConfig(std::istream& is)
|
|
{
|
|
auto tensor = su::deserialize<Tensor>(is);
|
|
auto inputTokenExtraIds = su::deserialize<std::optional<VecTokenExtraIds>>(is);
|
|
return PromptTuningConfig{std::move(tensor), std::move(inputTokenExtraIds)};
|
|
}
|
|
|
|
void Serialization::serialize(PromptTuningConfig const& config, std::ostream& os)
|
|
{
|
|
su::serialize(config.mEmbeddingTable, os);
|
|
su::serialize(config.mInputTokenExtraIds, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(PromptTuningConfig const& config)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(config.mEmbeddingTable);
|
|
totalSize += su::serializedSize(config.mInputTokenExtraIds);
|
|
return totalSize;
|
|
}
|
|
|
|
// MropeConfig
|
|
MropeConfig Serialization::deserializeMropeConfig(std::istream& is)
|
|
{
|
|
auto mropeRotaryCosSin = su::deserialize<Tensor>(is);
|
|
auto mropePositionDeltas = su::deserialize<SizeType32>(is);
|
|
return MropeConfig{std::move(mropeRotaryCosSin), std::move(mropePositionDeltas)};
|
|
}
|
|
|
|
void Serialization::serialize(MropeConfig const& config, std::ostream& os)
|
|
{
|
|
su::serialize(config.mMRopeRotaryCosSin, os);
|
|
su::serialize(config.mMRopePositionDeltas, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(MropeConfig const& config)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(config.mMRopeRotaryCosSin);
|
|
totalSize += su::serializedSize(config.mMRopePositionDeltas);
|
|
|
|
return totalSize;
|
|
}
|
|
|
|
// LoraConfig
|
|
LoraConfig Serialization::deserializeLoraConfig(std::istream& is)
|
|
{
|
|
auto taskId = su::deserialize<IdType>(is);
|
|
auto weights = su::deserialize<std::optional<Tensor>>(is);
|
|
auto config = su::deserialize<std::optional<Tensor>>(is);
|
|
return LoraConfig{taskId, std::move(weights), std::move(config)};
|
|
}
|
|
|
|
void Serialization::serialize(LoraConfig const& config, std::ostream& os)
|
|
{
|
|
su::serialize(config.mTaskId, os);
|
|
su::serialize(config.mWeights, os);
|
|
su::serialize(config.mConfig, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(LoraConfig const& config)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(config.mTaskId);
|
|
totalSize += su::serializedSize(config.mWeights);
|
|
totalSize += su::serializedSize(config.mConfig);
|
|
return totalSize;
|
|
}
|
|
|
|
// SocketState
|
|
kv_cache::SocketState Serialization::deserializeSocketState(std::istream& is)
|
|
{
|
|
auto port = su::deserialize<decltype(kv_cache::SocketState::mPort)>(is);
|
|
auto ip = su::deserialize<decltype(kv_cache::SocketState::mIp)>(is);
|
|
return kv_cache::SocketState{port, std::move(ip)};
|
|
}
|
|
|
|
void Serialization::serialize(kv_cache::SocketState const& state, std::ostream& os)
|
|
{
|
|
su::serialize(state.mPort, os);
|
|
su::serialize(state.mIp, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(kv_cache::SocketState const& state)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(state.mPort);
|
|
totalSize += su::serializedSize(state.mIp);
|
|
return totalSize;
|
|
}
|
|
|
|
// CommState
|
|
kv_cache::CommState Serialization::deserializeCommState(std::istream& is)
|
|
{
|
|
auto selfIdx = su::deserialize<decltype(kv_cache::CommState::mSelfIdx)>(is);
|
|
auto variantIdx = su::deserialize<std::size_t>(is);
|
|
constexpr std::size_t mpiIdx{1};
|
|
constexpr std::size_t socketIdx{2};
|
|
static_assert(
|
|
std::is_same_v<kv_cache::MpiState, std::variant_alternative_t<mpiIdx, decltype(kv_cache::CommState::mState)>>);
|
|
static_assert(std::is_same_v<std::vector<kv_cache::SocketState>,
|
|
std::variant_alternative_t<socketIdx, decltype(kv_cache::CommState::mState)>>);
|
|
if (variantIdx == mpiIdx)
|
|
{
|
|
auto ranks = su::deserialize<decltype(kv_cache::MpiState::mRanks)>(is);
|
|
return kv_cache::CommState{std::move(ranks), selfIdx};
|
|
}
|
|
if (variantIdx == socketIdx)
|
|
{
|
|
auto state = su::deserialize<std::vector<kv_cache::SocketState>>(is);
|
|
return kv_cache::CommState{std::move(state), selfIdx};
|
|
}
|
|
return {};
|
|
}
|
|
|
|
void Serialization::serialize(kv_cache::CommState const& state, std::ostream& os)
|
|
{
|
|
su::serialize(state.mSelfIdx, os);
|
|
su::serialize(state.mState.index(), os);
|
|
if (state.isMpiState())
|
|
{
|
|
su::serialize(state.getMpiState().mRanks, os);
|
|
}
|
|
else if (state.isSocketState())
|
|
{
|
|
su::serialize(state.getSocketState(), os);
|
|
}
|
|
else
|
|
{
|
|
TLLM_THROW("Unknown context phase state communication type.");
|
|
}
|
|
}
|
|
|
|
size_t Serialization::serializedSize(kv_cache::CommState const& state)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(state.mSelfIdx);
|
|
totalSize += su::serializedSize(state.mState.index());
|
|
if (state.isMpiState())
|
|
{
|
|
totalSize += su::serializedSize(state.getMpiState().mRanks);
|
|
}
|
|
else if (state.isSocketState())
|
|
{
|
|
totalSize += su::serializedSize(state.getSocketState());
|
|
}
|
|
else
|
|
{
|
|
TLLM_THROW("Unknown context phase state communication type.");
|
|
}
|
|
return totalSize;
|
|
}
|
|
|
|
// CacheState
|
|
kv_cache::CacheState Serialization::deserializeCacheState(std::istream& is)
|
|
{
|
|
using CacheState = kv_cache::CacheState;
|
|
auto nbKvHeadsPerLayer = su::deserialize<decltype(CacheState::ModelConfig::mNbKvHeadsPerLayer)>(is);
|
|
auto sizePerHead = su::deserialize<decltype(CacheState::ModelConfig::mSizePerHead)>(is);
|
|
auto tokensPerBlock = su::deserialize<decltype(CacheState::ModelConfig::mTokensPerBlock)>(is);
|
|
auto tensorParallelism = su::deserialize<decltype(CacheState::ParallelConfig::mTensorParallelism)>(is);
|
|
auto pipelineParallelism = su::deserialize<decltype(CacheState::ParallelConfig::mPipelineParallelism)>(is);
|
|
auto enableAttentionDP = su::deserialize<decltype(CacheState::ParallelConfig::mEnableAttenionDP)>(is);
|
|
auto DPrank = su::deserialize<decltype(CacheState::ParallelConfig::mDPrank)>(is);
|
|
auto DPsize = su::deserialize<decltype(CacheState::ParallelConfig::mDPsize)>(is);
|
|
auto dataType = su::deserialize<decltype(CacheState::mDataType)>(is);
|
|
auto attentionType = su::deserialize<decltype(CacheState::AttentionConfig::mAttentionType)>(is);
|
|
auto kvFactor = su::deserialize<decltype(CacheState::AttentionConfig::mKvFactor)>(is);
|
|
return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism, dataType,
|
|
attentionType, kvFactor, enableAttentionDP, DPrank, DPsize};
|
|
}
|
|
|
|
void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& os)
|
|
{
|
|
su::serialize(state.mModelConfig.mNbKvHeadsPerLayer, os);
|
|
su::serialize(state.mModelConfig.mSizePerHead, os);
|
|
su::serialize(state.mModelConfig.mTokensPerBlock, os);
|
|
su::serialize(state.mParallelConfig.mTensorParallelism, os);
|
|
su::serialize(state.mParallelConfig.mPipelineParallelism, os);
|
|
su::serialize(state.mParallelConfig.mEnableAttenionDP, os);
|
|
su::serialize(state.mParallelConfig.mDPrank, os);
|
|
su::serialize(state.mParallelConfig.mDPsize, os);
|
|
su::serialize(state.mDataType, os);
|
|
su::serialize(state.mAttentionConfig.mAttentionType, os);
|
|
su::serialize(state.mAttentionConfig.mKvFactor, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(kv_cache::CacheState const& state)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(state.mModelConfig.mNbKvHeadsPerLayer);
|
|
totalSize += su::serializedSize(state.mModelConfig.mSizePerHead);
|
|
totalSize += su::serializedSize(state.mModelConfig.mTokensPerBlock);
|
|
totalSize += su::serializedSize(state.mParallelConfig.mTensorParallelism);
|
|
totalSize += su::serializedSize(state.mParallelConfig.mPipelineParallelism);
|
|
totalSize += su::serializedSize(state.mParallelConfig.mEnableAttenionDP);
|
|
totalSize += su::serializedSize(state.mParallelConfig.mDPrank);
|
|
totalSize += su::serializedSize(state.mParallelConfig.mDPsize);
|
|
totalSize += su::serializedSize(state.mDataType);
|
|
totalSize += su::serializedSize(state.mAttentionConfig.mAttentionType);
|
|
totalSize += su::serializedSize(state.mAttentionConfig.mKvFactor);
|
|
return totalSize;
|
|
}
|
|
|
|
// DataTransceiverState
|
|
|
|
DataTransceiverState Serialization::deserializeDataTransceiverState(std::vector<char>& buffer)
|
|
{
|
|
su::VectorWrapBuf<char> strbuf(buffer);
|
|
std::istream is(&strbuf);
|
|
return deserializeDataTransceiverState(is);
|
|
}
|
|
|
|
DataTransceiverState Serialization::deserializeDataTransceiverState(std::istream& is)
|
|
{
|
|
DataTransceiverState state;
|
|
auto commState = su::deserialize<decltype(DataTransceiverState::mCommState)>(is);
|
|
if (commState)
|
|
{
|
|
state.setCommState(std::move(commState).value());
|
|
}
|
|
auto cacheState = su::deserialize<decltype(DataTransceiverState::mCacheState)>(is);
|
|
if (cacheState)
|
|
{
|
|
state.setCacheState(std::move(cacheState).value());
|
|
}
|
|
return state;
|
|
}
|
|
|
|
void Serialization::serialize(DataTransceiverState const& state, std::ostream& os)
|
|
{
|
|
su::serialize(state.mCommState, os);
|
|
su::serialize(state.mCacheState, os);
|
|
}
|
|
|
|
std::vector<char> Serialization::serialize(DataTransceiverState const& state)
|
|
{
|
|
std::vector<char> buffer(serializedSize(state));
|
|
|
|
std::stringbuf strbuf(std::ios_base::out | std::ios_base::in);
|
|
strbuf.pubsetbuf(buffer.data(), buffer.size());
|
|
std::ostream os(&strbuf);
|
|
serialize(state, os);
|
|
return buffer;
|
|
}
|
|
|
|
size_t Serialization::serializedSize(DataTransceiverState const& state)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(state.mCommState);
|
|
totalSize += su::serializedSize(state.mCacheState);
|
|
return totalSize;
|
|
}
|
|
|
|
// ContextPhaseParams
|
|
ContextPhaseParams Serialization::deserializeContextPhaseParams(std::istream& is)
|
|
{
|
|
auto reqId = su::deserialize<decltype(ContextPhaseParams::mReqId)>(is);
|
|
auto firstGenTokens = su::deserialize<decltype(ContextPhaseParams::mFirstGenTokens)>(is);
|
|
auto draftTokens = su::deserialize<decltype(ContextPhaseParams::mDraftTokens)>(is);
|
|
auto hasState = su::deserialize<bool>(is);
|
|
if (hasState)
|
|
{
|
|
auto state = std::make_unique<DataTransceiverState>();
|
|
*state = deserializeDataTransceiverState(is);
|
|
return ContextPhaseParams{std::move(firstGenTokens), reqId, state.release(), std::move(draftTokens)};
|
|
}
|
|
return ContextPhaseParams{std::move(firstGenTokens), reqId, nullptr, std::move(draftTokens)};
|
|
}
|
|
|
|
void Serialization::serialize(ContextPhaseParams const& contextPhaseParams, std::ostream& os)
|
|
{
|
|
su::serialize(contextPhaseParams.mReqId, os);
|
|
su::serialize(contextPhaseParams.mFirstGenTokens, os);
|
|
su::serialize(contextPhaseParams.mDraftTokens, os);
|
|
su::serialize(static_cast<bool>(contextPhaseParams.mState), os);
|
|
if (contextPhaseParams.mState)
|
|
{
|
|
serialize(*static_cast<DataTransceiverState*>(contextPhaseParams.mState.get()), os);
|
|
}
|
|
}
|
|
|
|
size_t Serialization::serializedSize(ContextPhaseParams const& contextPhaseParams)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(contextPhaseParams.mReqId);
|
|
totalSize += su::serializedSize(contextPhaseParams.mFirstGenTokens);
|
|
totalSize += su::serializedSize(contextPhaseParams.mDraftTokens);
|
|
totalSize += su::serializedSize(bool{});
|
|
if (contextPhaseParams.mState)
|
|
{
|
|
totalSize += serializedSize(*static_cast<DataTransceiverState*>(contextPhaseParams.mState.get()));
|
|
}
|
|
return totalSize;
|
|
}
|
|
|
|
// Request
|
|
Request Serialization::deserializeRequest(std::istream& is)
|
|
{
|
|
// Serialization of Request with logitsPostProcessor is currently not supported.
|
|
// Dynamic logitsPostProcessor only supported with replicate=false or no tensor parallelism.
|
|
auto inputTokenIds = su::deserialize<VecTokens>(is);
|
|
auto maxNewTokens = su::deserialize<SizeType32>(is);
|
|
auto streaming = su::deserialize<bool>(is);
|
|
auto samplingConfig = su::deserialize<SamplingConfig>(is);
|
|
auto outputConfig = su::deserialize<OutputConfig>(is);
|
|
auto endId = su::deserialize<std::optional<SizeType32>>(is);
|
|
auto padId = su::deserialize<std::optional<SizeType32>>(is);
|
|
auto positionIds = su::deserialize<std::optional<std::vector<SizeType32>>>(is);
|
|
auto badWords = su::deserialize<std::optional<std::list<VecTokens>>>(is);
|
|
auto stopWords = su::deserialize<std::optional<std::list<VecTokens>>>(is);
|
|
auto embeddingBias = su::deserialize<std::optional<Tensor>>(is);
|
|
auto externalDraftTokensConfig = su::deserialize<std::optional<ExternalDraftTokensConfig>>(is);
|
|
auto pTuningConfig = su::deserialize<std::optional<PromptTuningConfig>>(is);
|
|
auto mRopeConfig = su::deserialize<std::optional<MropeConfig>>(is);
|
|
auto loraConfig = su::deserialize<std::optional<LoraConfig>>(is);
|
|
auto lookaheadConfig = su::deserialize<std::optional<LookaheadDecodingConfig>>(is);
|
|
auto kvCacheRetentionConfig = su::deserialize<std::optional<KvCacheRetentionConfig>>(is);
|
|
auto logitsPostProcessorName = su::deserialize<std::optional<std::string>>(is);
|
|
auto encoderInputTokenIds = su::deserialize<std::optional<VecTokens>>(is);
|
|
auto clientId = su::deserialize<std::optional<IdType>>(is);
|
|
auto returnAllGeneratedTokens = su::deserialize<bool>(is);
|
|
auto priority = su::deserialize<executor::PriorityType>(is);
|
|
auto requestType = su::deserialize<executor::RequestType>(is);
|
|
auto contextPhaseParams = su::deserialize<std::optional<ContextPhaseParams>>(is);
|
|
auto encoderInputFeatures = su::deserialize<std::optional<Tensor>>(is);
|
|
auto encoderOutputLength = su::deserialize<std::optional<SizeType32>>(is);
|
|
auto crossAttentionMask = su::deserialize<std::optional<Tensor>>(is);
|
|
auto numReturnSequences = su::deserialize<SizeType32>(is);
|
|
auto eagleConfig = su::deserialize<std::optional<EagleConfig>>(is);
|
|
auto skipCrossAttnBlocks = su::deserialize<std::optional<Tensor>>(is);
|
|
auto guidedDecodingParams = su::deserialize<std::optional<GuidedDecodingParams>>(is);
|
|
auto languageAdapterUid = su::deserialize<std::optional<SizeType32>>(is);
|
|
auto allottedTimeInt = su::deserialize<std::optional<std::chrono::milliseconds::rep>>(is);
|
|
auto allottedTimeMs = allottedTimeInt
|
|
? std::optional<std::chrono::milliseconds>(std::chrono::milliseconds(*allottedTimeInt))
|
|
: std::nullopt;
|
|
|
|
// 33 parameters
|
|
return Request(std::move(inputTokenIds), maxNewTokens, streaming, samplingConfig, outputConfig, endId, padId,
|
|
std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias),
|
|
std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(mRopeConfig), std::move(loraConfig),
|
|
lookaheadConfig, std::move(kvCacheRetentionConfig), std::move(logitsPostProcessorName), std::nullopt,
|
|
std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, requestType,
|
|
std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength,
|
|
std::move(crossAttentionMask), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks),
|
|
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs);
|
|
}
|
|
|
|
void Serialization::serialize(Request const& request, std::ostream& os)
|
|
{
|
|
request.mImpl->serialize(os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(Request const& request)
|
|
{
|
|
return request.mImpl->serializedSize();
|
|
}
|
|
|
|
// Tensor
|
|
Tensor Serialization::deserializeTensor(std::istream& is)
|
|
{
|
|
// DataType
|
|
DataType dataType{};
|
|
is.read(reinterpret_cast<char*>(&dataType), sizeof(dataType));
|
|
|
|
// Shape
|
|
size_t shapeSize{0};
|
|
is.read(reinterpret_cast<char*>(&shapeSize), sizeof(size_t));
|
|
static constexpr int32_t MAX_DIMS{8};
|
|
TLLM_CHECK(shapeSize < MAX_DIMS);
|
|
|
|
Shape::DimType64 dims[MAX_DIMS];
|
|
is.read(reinterpret_cast<char*>(&dims[0]), shapeSize * sizeof(Shape::DimType64));
|
|
Shape shape(&dims[0], shapeSize);
|
|
|
|
// Memory Type
|
|
MemoryType memoryType{};
|
|
is.read(reinterpret_cast<char*>(&memoryType), sizeof(memoryType));
|
|
|
|
// Size in bytes
|
|
size_t sizeInBytes{0};
|
|
is.read(reinterpret_cast<char*>(&sizeInBytes), sizeof(size_t));
|
|
|
|
Tensor tensor;
|
|
switch (memoryType)
|
|
{
|
|
case MemoryType::kCPU:
|
|
{
|
|
tensor = Tensor::cpu(dataType, shape);
|
|
is.read(reinterpret_cast<char*>(tensor.getData()), static_cast<std::streamsize>(sizeInBytes));
|
|
break;
|
|
}
|
|
case MemoryType::kCPU_PINNED:
|
|
{
|
|
tensor = Tensor::pinned(dataType, shape);
|
|
is.read(reinterpret_cast<char*>(tensor.getData()), static_cast<std::streamsize>(sizeInBytes));
|
|
break;
|
|
}
|
|
case MemoryType::kUVM:
|
|
{
|
|
tensor = Tensor::managed(dataType, shape);
|
|
is.read(reinterpret_cast<char*>(tensor.getData()), static_cast<std::streamsize>(sizeInBytes));
|
|
break;
|
|
}
|
|
case MemoryType::kGPU:
|
|
{
|
|
// TODO: Eventually we might want to support serialization/deserialization in GPU memory
|
|
// Until then created Pinned tensor and move to GPU
|
|
auto pinnedTensor = Tensor::pinned(dataType, shape);
|
|
is.read(reinterpret_cast<char*>(pinnedTensor.getData()), static_cast<std::streamsize>(sizeInBytes));
|
|
auto stream = std::make_shared<tensorrt_llm::runtime::CudaStream>();
|
|
tensor = pinnedTensor.copyToGpu(stream);
|
|
stream->synchronize();
|
|
break;
|
|
}
|
|
case MemoryType::kUNKNOWN:
|
|
{
|
|
TLLM_THROW("Cannot deserialize tensor with UNKNOWN type.");
|
|
break;
|
|
}
|
|
default:
|
|
{
|
|
TLLM_THROW("Memory type not supported in deserializeTensor.");
|
|
break;
|
|
}
|
|
}
|
|
|
|
return tensor;
|
|
}
|
|
|
|
void Serialization::serialize(Tensor const& tensor, std::ostream& os)
|
|
{
|
|
auto dataType = tensor.getDataType();
|
|
os.write(reinterpret_cast<char const*>(&dataType), sizeof(dataType));
|
|
auto shape = tensor.getShape();
|
|
auto shapeSize = shape.size();
|
|
os.write(reinterpret_cast<char const*>(&shapeSize), sizeof(shapeSize));
|
|
os.write(reinterpret_cast<char const*>(&shape[0]), shapeSize * sizeof(Shape::DimType64));
|
|
|
|
// Memory Type
|
|
auto memoryType = tensor.getMemoryType();
|
|
os.write(reinterpret_cast<char const*>(&memoryType), sizeof(memoryType));
|
|
|
|
std::size_t sizeInBytes = tensor.getSizeInBytes();
|
|
os.write(reinterpret_cast<char const*>(&sizeInBytes), sizeof(sizeInBytes));
|
|
|
|
if (memoryType == MemoryType::kCPU || memoryType == MemoryType::kCPU_PINNED || memoryType == MemoryType::kUVM)
|
|
{
|
|
void const* data = tensor.getData();
|
|
os.write(reinterpret_cast<char const*>(data), std::streamsize(sizeInBytes));
|
|
}
|
|
// Need special treatment for GPU type
|
|
else if (memoryType == MemoryType::kGPU)
|
|
{
|
|
auto stream = std::make_shared<tensorrt_llm::runtime::CudaStream>();
|
|
auto pinnedTensor = tensor.copyToPinned(stream);
|
|
stream->synchronize();
|
|
void const* data = pinnedTensor.getData();
|
|
os.write(reinterpret_cast<char const*>(data), std::streamsize(sizeInBytes));
|
|
}
|
|
else if (memoryType == MemoryType::kUNKNOWN)
|
|
{
|
|
TLLM_THROW("Memory type unknown when serializing tensor");
|
|
}
|
|
}
|
|
|
|
size_t Serialization::serializedSize(Tensor const& tensor)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += sizeof(tensor.getDataType()); // datatype
|
|
auto const shape = tensor.getShape();
|
|
auto const shapeSize = shape.size();
|
|
totalSize += sizeof(decltype(shapeSize)); // number of dims
|
|
TLLM_CHECK(shapeSize > 0);
|
|
totalSize += shapeSize * sizeof(decltype(shape[0]));
|
|
|
|
auto memoryType = tensor.getMemoryType();
|
|
totalSize += sizeof(memoryType); // memory type
|
|
|
|
totalSize += sizeof(size_t); // Size in bytes
|
|
totalSize += tensor.getSizeInBytes();
|
|
return totalSize;
|
|
}
|
|
|
|
// SpeculativeDecodingFastLogitsInfo
|
|
SpeculativeDecodingFastLogitsInfo Serialization::deserializeSpecDecFastLogitsInfo(std::istream& is)
|
|
{
|
|
auto draftRequestId = su::deserialize<uint64_t>(is);
|
|
auto draftParticipantId = su::deserialize<int32_t>(is);
|
|
|
|
return SpeculativeDecodingFastLogitsInfo{draftRequestId, draftParticipantId};
|
|
}
|
|
|
|
void Serialization::serialize(SpeculativeDecodingFastLogitsInfo const& info, std::ostream& os)
|
|
{
|
|
su::serialize(info.draftRequestId, os);
|
|
su::serialize(info.draftParticipantId, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(SpeculativeDecodingFastLogitsInfo const& info)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(info.draftRequestId);
|
|
totalSize += su::serializedSize(info.draftParticipantId);
|
|
return totalSize;
|
|
}
|
|
|
|
// Result
|
|
Result Serialization::deserializeResult(std::istream& is)
|
|
{
|
|
Result result{};
|
|
result.isFinal = su::deserialize<bool>(is);
|
|
result.outputTokenIds = su::deserialize<BeamTokens>(is);
|
|
result.cumLogProbs = su::deserialize<std::optional<VecLogProbs>>(is);
|
|
result.logProbs = su::deserialize<std::optional<std::vector<VecLogProbs>>>(is);
|
|
result.contextLogits = su::deserialize<std::optional<Tensor>>(is);
|
|
result.generationLogits = su::deserialize<std::optional<Tensor>>(is);
|
|
result.specDecFastLogitsInfo = su::deserialize<std::optional<SpeculativeDecodingFastLogitsInfo>>(is);
|
|
result.encoderOutput = su::deserialize<std::optional<Tensor>>(is);
|
|
result.finishReasons = su::deserialize<std::vector<FinishReason>>(is);
|
|
result.contextPhaseParams = su::deserialize<std::optional<ContextPhaseParams>>(is);
|
|
result.decodingIter = su::deserialize<SizeType32>(is);
|
|
result.sequenceIndex = su::deserialize<SizeType32>(is);
|
|
result.isSequenceFinal = su::deserialize<bool>(is);
|
|
result.requestPerfMetrics = su::deserialize<std::optional<RequestPerfMetrics>>(is);
|
|
result.additionalOutputs = su::deserialize<std::vector<AdditionalOutput>>(is);
|
|
return result;
|
|
}
|
|
|
|
void Serialization::serialize(Result const& result, std::ostream& os)
|
|
{
|
|
su::serialize(result.isFinal, os);
|
|
su::serialize(result.outputTokenIds, os);
|
|
su::serialize(result.cumLogProbs, os);
|
|
su::serialize(result.logProbs, os);
|
|
su::serialize(result.contextLogits, os);
|
|
su::serialize(result.generationLogits, os);
|
|
su::serialize(result.specDecFastLogitsInfo, os);
|
|
su::serialize(result.encoderOutput, os);
|
|
su::serialize(result.finishReasons, os);
|
|
su::serialize(result.contextPhaseParams, os);
|
|
su::serialize(result.decodingIter, os);
|
|
su::serialize(result.sequenceIndex, os);
|
|
su::serialize(result.isSequenceFinal, os);
|
|
su::serialize(result.requestPerfMetrics, os);
|
|
su::serialize(result.additionalOutputs, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(Result const& result)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(result.isFinal);
|
|
totalSize += su::serializedSize(result.outputTokenIds);
|
|
totalSize += su::serializedSize(result.cumLogProbs);
|
|
totalSize += su::serializedSize(result.logProbs);
|
|
totalSize += su::serializedSize(result.contextLogits);
|
|
totalSize += su::serializedSize(result.specDecFastLogitsInfo);
|
|
totalSize += su::serializedSize(result.generationLogits);
|
|
totalSize += su::serializedSize(result.encoderOutput);
|
|
totalSize += su::serializedSize(result.finishReasons);
|
|
totalSize += su::serializedSize(result.contextPhaseParams);
|
|
totalSize += su::serializedSize(result.decodingIter);
|
|
totalSize += su::serializedSize(result.sequenceIndex);
|
|
totalSize += su::serializedSize(result.isSequenceFinal);
|
|
totalSize += su::serializedSize(result.requestPerfMetrics);
|
|
totalSize += su::serializedSize(result.additionalOutputs);
|
|
return totalSize;
|
|
}
|
|
|
|
// Response
|
|
Response Serialization::deserializeResponse(std::istream& is)
|
|
{
|
|
auto requestId = su::deserialize<IdType>(is);
|
|
auto errOrResult = su::deserialize<std::variant<std::string, Result>>(is);
|
|
|
|
return std::holds_alternative<std::string>(errOrResult) ? Response{requestId, std::get<std::string>(errOrResult)}
|
|
: Response{requestId, std::get<Result>(errOrResult)};
|
|
}
|
|
|
|
void Serialization::serialize(Response const& response, std::ostream& os)
|
|
{
|
|
su::serialize(response.mImpl->mRequestId, os);
|
|
su::serialize(response.mImpl->mErrOrResult, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(Response const& response)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(response.mImpl->mRequestId);
|
|
totalSize += su::serializedSize(response.mImpl->mErrOrResult);
|
|
return totalSize;
|
|
}
|
|
|
|
// Vector of responses
|
|
std::vector<Response> Serialization::deserializeResponses(std::vector<char>& buffer)
|
|
{
|
|
std::vector<Response> responses;
|
|
su::VectorWrapBuf<char> strbuf(buffer);
|
|
std::istream is(&strbuf);
|
|
auto numResponses = su::deserialize<std::size_t>(is);
|
|
for (std::size_t resp = 0; resp < numResponses; ++resp)
|
|
{
|
|
responses.emplace_back(Serialization::deserializeResponse(is));
|
|
}
|
|
return responses;
|
|
}
|
|
|
|
std::vector<char> Serialization::serialize(std::vector<Response> const& responses)
|
|
{
|
|
// Compute the size of serialized buffer
|
|
size_t totalSize = 0;
|
|
totalSize += sizeof(size_t);
|
|
for (auto const& response : responses)
|
|
{
|
|
totalSize += su::serializedSize(response);
|
|
}
|
|
|
|
std::vector<char> buffer(totalSize);
|
|
std::stringbuf strbuf(std::ios_base::out | std::ios_base::in);
|
|
strbuf.pubsetbuf(buffer.data(), buffer.size());
|
|
std::ostream os(&strbuf);
|
|
|
|
su::serialize(responses.size(), os);
|
|
for (auto const& response : responses)
|
|
{
|
|
su::serialize(response, os);
|
|
}
|
|
return buffer;
|
|
}
|
|
|
|
AdditionalOutput Serialization::deserializeAdditionalOutput(std::istream& is)
|
|
{
|
|
return {
|
|
Serialization::deserializeString(is),
|
|
Serialization::deserializeTensor(is),
|
|
};
|
|
}
|
|
|
|
void Serialization::serialize(AdditionalOutput const& additionalOutput, std::ostream& os)
|
|
{
|
|
su::serialize(additionalOutput.name, os);
|
|
su::serialize(additionalOutput.output, os);
|
|
}
|
|
|
|
size_t serializedSize(AdditionalOutput const& additionalOutput)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(additionalOutput.name);
|
|
totalSize += su::serializedSize(additionalOutput.output);
|
|
return totalSize;
|
|
}
|
|
|
|
// ExecutorConfig
|
|
ExecutorConfig Serialization::deserializeExecutorConfig(std::istream& is)
|
|
{
|
|
auto maxBeamWidth = su::deserializeWithGetterType<decltype(&ExecutorConfig::getMaxBeamWidth)>(is);
|
|
auto maxBatchSize = su::deserializeWithGetterType<decltype(&ExecutorConfig::getMaxBatchSize)>(is);
|
|
auto maxNumTokens = su::deserializeWithGetterType<decltype(&ExecutorConfig::getMaxNumTokens)>(is);
|
|
auto schedulerConfig = su::deserializeWithGetterType<decltype(&ExecutorConfig::getSchedulerConfig)>(is);
|
|
auto kvCacheConfig = su::deserializeWithGetterType<decltype(&ExecutorConfig::getKvCacheConfig)>(is);
|
|
auto enableChunkedContext = su::deserializeWithGetterType<decltype(&ExecutorConfig::getEnableChunkedContext)>(is);
|
|
auto normalizeLogProbs = su::deserializeWithGetterType<decltype(&ExecutorConfig::getNormalizeLogProbs)>(is);
|
|
auto iterStatsMaxIterations
|
|
= su::deserializeWithGetterType<decltype(&ExecutorConfig::getIterStatsMaxIterations)>(is);
|
|
auto requestStatsMaxIterations
|
|
= su::deserializeWithGetterType<decltype(&ExecutorConfig::getRequestStatsMaxIterations)>(is);
|
|
auto batchingType = su::deserializeWithGetterType<decltype(&ExecutorConfig::getBatchingType)>(is);
|
|
auto parallelConfig = su::deserializeWithGetterType<decltype(&ExecutorConfig::getParallelConfig)>(is);
|
|
auto peftCacheConfig = su::deserializeWithGetterType<decltype(&ExecutorConfig::getPeftCacheConfig)>(is);
|
|
auto decodingConfig = su::deserializeWithGetterType<decltype(&ExecutorConfig::getDecodingConfig)>(is);
|
|
auto gpuWeightsPercent = su::deserializeWithGetterType<decltype(&ExecutorConfig::getGpuWeightsPercent)>(is);
|
|
auto maxQueueSize = su::deserializeWithGetterType<decltype(&ExecutorConfig::getMaxQueueSize)>(is);
|
|
auto extendedRuntimePerfKnobConfig
|
|
= su::deserializeWithGetterType<decltype(&ExecutorConfig::getExtendedRuntimePerfKnobConfig)>(is);
|
|
auto debugConfig = su::deserializeWithGetterType<decltype(&ExecutorConfig::getDebugConfig)>(is);
|
|
auto recvPollPeriodMs = su::deserializeWithGetterType<decltype(&ExecutorConfig::getRecvPollPeriodMs)>(is);
|
|
auto maxSeqIdleMicroseconds
|
|
= su::deserializeWithGetterType<decltype(&ExecutorConfig::getMaxSeqIdleMicroseconds)>(is);
|
|
auto specDecConfig = su::deserializeWithGetterType<decltype(&ExecutorConfig::getSpecDecConfig)>(is);
|
|
auto guidedDecodingConfig = su::deserializeWithGetterType<decltype(&ExecutorConfig::getGuidedDecodingConfig)>(is);
|
|
auto additionalModelOutputs
|
|
= su::deserializeWithGetterType<decltype(&ExecutorConfig::getAdditionalModelOutputs)>(is);
|
|
auto gatherGenerationLogits
|
|
= su::deserializeWithGetterType<decltype(&ExecutorConfig::getGatherGenerationLogits)>(is);
|
|
|
|
return ExecutorConfig{maxBeamWidth, schedulerConfig, kvCacheConfig, enableChunkedContext, normalizeLogProbs,
|
|
iterStatsMaxIterations, requestStatsMaxIterations, batchingType, maxBatchSize, maxNumTokens, parallelConfig,
|
|
peftCacheConfig, std::nullopt, decodingConfig, gpuWeightsPercent, maxQueueSize, extendedRuntimePerfKnobConfig,
|
|
debugConfig, recvPollPeriodMs, maxSeqIdleMicroseconds, specDecConfig, guidedDecodingConfig,
|
|
additionalModelOutputs, gatherGenerationLogits};
|
|
}
|
|
|
|
size_t Serialization::serializedSize(ExecutorConfig const& executorConfig)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(!executorConfig.getLogitsPostProcessorConfig().has_value(),
|
|
"Serialization of executorConfig with logitsPostProcessor is currently not supported.");
|
|
|
|
// Compute the size of serialized buffer
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(executorConfig.getMaxBeamWidth());
|
|
totalSize += su::serializedSize(executorConfig.getMaxBatchSize());
|
|
totalSize += su::serializedSize(executorConfig.getMaxNumTokens());
|
|
totalSize += su::serializedSize(executorConfig.getSchedulerConfig());
|
|
totalSize += su::serializedSize(executorConfig.getKvCacheConfig());
|
|
totalSize += su::serializedSize(executorConfig.getEnableChunkedContext());
|
|
totalSize += su::serializedSize(executorConfig.getNormalizeLogProbs());
|
|
totalSize += su::serializedSize(executorConfig.getIterStatsMaxIterations());
|
|
totalSize += su::serializedSize(executorConfig.getRequestStatsMaxIterations());
|
|
totalSize += su::serializedSize(executorConfig.getBatchingType());
|
|
totalSize += su::serializedSize(executorConfig.getParallelConfig());
|
|
totalSize += su::serializedSize(executorConfig.getPeftCacheConfig());
|
|
totalSize += su::serializedSize(executorConfig.getDecodingConfig());
|
|
totalSize += su::serializedSize(executorConfig.getGpuWeightsPercent());
|
|
totalSize += su::serializedSize(executorConfig.getMaxQueueSize());
|
|
totalSize += su::serializedSize(executorConfig.getExtendedRuntimePerfKnobConfig());
|
|
totalSize += su::serializedSize(executorConfig.getDebugConfig());
|
|
totalSize += su::serializedSize(executorConfig.getRecvPollPeriodMs());
|
|
totalSize += su::serializedSize(executorConfig.getMaxSeqIdleMicroseconds());
|
|
totalSize += su::serializedSize(executorConfig.getSpecDecConfig());
|
|
totalSize += su::serializedSize(executorConfig.getGuidedDecodingConfig());
|
|
totalSize += su::serializedSize(executorConfig.getAdditionalModelOutputs());
|
|
totalSize += su::serializedSize(executorConfig.getGatherGenerationLogits());
|
|
|
|
return totalSize;
|
|
}
|
|
|
|
void Serialization::serialize(ExecutorConfig const& executorConfig, std::ostream& os)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(!executorConfig.getLogitsPostProcessorConfig().has_value(),
|
|
"Serialization of executorConfig with logitsPostProcessor is currently not supported.");
|
|
|
|
su::serialize(executorConfig.getMaxBeamWidth(), os);
|
|
su::serialize(executorConfig.getMaxBatchSize(), os);
|
|
su::serialize(executorConfig.getMaxNumTokens(), os);
|
|
su::serialize(executorConfig.getSchedulerConfig(), os);
|
|
su::serialize(executorConfig.getKvCacheConfig(), os);
|
|
su::serialize(executorConfig.getEnableChunkedContext(), os);
|
|
su::serialize(executorConfig.getNormalizeLogProbs(), os);
|
|
su::serialize(executorConfig.getIterStatsMaxIterations(), os);
|
|
su::serialize(executorConfig.getRequestStatsMaxIterations(), os);
|
|
su::serialize(executorConfig.getBatchingType(), os);
|
|
su::serialize(executorConfig.getParallelConfig(), os);
|
|
su::serialize(executorConfig.getPeftCacheConfig(), os);
|
|
su::serialize(executorConfig.getDecodingConfig(), os);
|
|
su::serialize(executorConfig.getGpuWeightsPercent(), os);
|
|
su::serialize(executorConfig.getMaxQueueSize(), os);
|
|
su::serialize(executorConfig.getExtendedRuntimePerfKnobConfig(), os);
|
|
su::serialize(executorConfig.getDebugConfig(), os);
|
|
su::serialize(executorConfig.getRecvPollPeriodMs(), os);
|
|
su::serialize(executorConfig.getMaxSeqIdleMicroseconds(), os);
|
|
su::serialize(executorConfig.getSpecDecConfig(), os);
|
|
su::serialize(executorConfig.getGuidedDecodingConfig(), os);
|
|
su::serialize(executorConfig.getAdditionalModelOutputs(), os);
|
|
su::serialize(executorConfig.getGatherGenerationLogits(), os);
|
|
}
|
|
|
|
// KvCacheConfig
|
|
KvCacheConfig Serialization::deserializeKvCacheConfig(std::istream& is)
|
|
{
|
|
auto enableBlockReuse = su::deserialize<bool>(is);
|
|
auto maxTokens = su::deserialize<std::optional<SizeType32>>(is);
|
|
auto maxAttentionWindowVec = su::deserialize<std::optional<std::vector<SizeType32>>>(is);
|
|
auto sinkTokenLength = su::deserialize<std::optional<SizeType32>>(is);
|
|
auto freeGpuMemoryFraction = su::deserialize<std::optional<FloatType>>(is);
|
|
auto hostCacheSize = su::deserialize<std::optional<size_t>>(is);
|
|
auto onboardBlocks = su::deserialize<bool>(is);
|
|
auto crossKvCacheFraction = su::deserialize<std::optional<FloatType>>(is);
|
|
auto secondaryOffloadMinPriority = su::deserialize<std::optional<executor::RetentionPriority>>(is);
|
|
auto eventBufferMaxSize = su::deserialize<size_t>(is);
|
|
|
|
return KvCacheConfig{enableBlockReuse, maxTokens, maxAttentionWindowVec, sinkTokenLength, freeGpuMemoryFraction,
|
|
hostCacheSize, onboardBlocks, crossKvCacheFraction, secondaryOffloadMinPriority, eventBufferMaxSize};
|
|
}
|
|
|
|
void Serialization::serialize(KvCacheConfig const& kvCacheConfig, std::ostream& os)
|
|
{
|
|
su::serialize(kvCacheConfig.getEnableBlockReuse(), os);
|
|
su::serialize(kvCacheConfig.getMaxTokens(), os);
|
|
su::serialize(kvCacheConfig.getMaxAttentionWindowVec(), os);
|
|
su::serialize(kvCacheConfig.getSinkTokenLength(), os);
|
|
su::serialize(kvCacheConfig.getFreeGpuMemoryFraction(), os);
|
|
su::serialize(kvCacheConfig.getHostCacheSize(), os);
|
|
su::serialize(kvCacheConfig.getOnboardBlocks(), os);
|
|
su::serialize(kvCacheConfig.getCrossKvCacheFraction(), os);
|
|
su::serialize(kvCacheConfig.getSecondaryOffloadMinPriority(), os);
|
|
su::serialize(kvCacheConfig.getEventBufferMaxSize(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig)
|
|
{
|
|
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(kvCacheConfig.getEnableBlockReuse());
|
|
totalSize += su::serializedSize(kvCacheConfig.getMaxTokens());
|
|
totalSize += su::serializedSize(kvCacheConfig.getMaxAttentionWindowVec());
|
|
totalSize += su::serializedSize(kvCacheConfig.getSinkTokenLength());
|
|
totalSize += su::serializedSize(kvCacheConfig.getFreeGpuMemoryFraction());
|
|
totalSize += su::serializedSize(kvCacheConfig.getHostCacheSize());
|
|
totalSize += su::serializedSize(kvCacheConfig.getOnboardBlocks());
|
|
totalSize += su::serializedSize(kvCacheConfig.getCrossKvCacheFraction());
|
|
totalSize += su::serializedSize(kvCacheConfig.getSecondaryOffloadMinPriority());
|
|
totalSize += su::serializedSize(kvCacheConfig.getEventBufferMaxSize());
|
|
return totalSize;
|
|
}
|
|
|
|
// DynamicBatchConfig
|
|
DynamicBatchConfig Serialization::deserializeDynamicBatchConfig(std::istream& is)
|
|
{
|
|
auto enableBatchSizeTuning = su::deserialize<bool>(is);
|
|
auto enableMaxNumTokensTuning = su::deserialize<bool>(is);
|
|
auto dynamicBatchMovingAvgWindow = su::deserialize<SizeType32>(is);
|
|
return DynamicBatchConfig{enableBatchSizeTuning, enableMaxNumTokensTuning, dynamicBatchMovingAvgWindow};
|
|
}
|
|
|
|
void Serialization::serialize(DynamicBatchConfig const& DynamicBatchConfig, std::ostream& os)
|
|
{
|
|
su::serialize(DynamicBatchConfig.getEnableBatchSizeTuning(), os);
|
|
su::serialize(DynamicBatchConfig.getEnableMaxNumTokensTuning(), os);
|
|
su::serialize(DynamicBatchConfig.getDynamicBatchMovingAverageWindow(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(DynamicBatchConfig const& DynamicBatchConfig)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(DynamicBatchConfig.getEnableBatchSizeTuning());
|
|
totalSize += su::serializedSize(DynamicBatchConfig.getEnableMaxNumTokensTuning());
|
|
totalSize += su::serializedSize(DynamicBatchConfig.getDynamicBatchMovingAverageWindow());
|
|
return totalSize;
|
|
}
|
|
|
|
// SchedulerConfig
|
|
SchedulerConfig Serialization::deserializeSchedulerConfig(std::istream& is)
|
|
{
|
|
auto capacitySchedulerPolicy = su::deserialize<CapacitySchedulerPolicy>(is);
|
|
auto contextChunkingPolicy = su::deserialize<std::optional<ContextChunkingPolicy>>(is);
|
|
auto dynamicBatchConfig = su::deserialize<std::optional<DynamicBatchConfig>>(is);
|
|
return SchedulerConfig{capacitySchedulerPolicy, contextChunkingPolicy, dynamicBatchConfig};
|
|
}
|
|
|
|
void Serialization::serialize(SchedulerConfig const& schedulerConfig, std::ostream& os)
|
|
{
|
|
su::serialize(schedulerConfig.getCapacitySchedulerPolicy(), os);
|
|
su::serialize(schedulerConfig.getContextChunkingPolicy(), os);
|
|
su::serialize(schedulerConfig.getDynamicBatchConfig(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(SchedulerConfig const& schedulerConfig)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(schedulerConfig.getCapacitySchedulerPolicy());
|
|
totalSize += su::serializedSize(schedulerConfig.getContextChunkingPolicy());
|
|
totalSize += su::serializedSize(schedulerConfig.getDynamicBatchConfig());
|
|
return totalSize;
|
|
}
|
|
|
|
// ExtendedRuntimePerfKnobConfig
|
|
ExtendedRuntimePerfKnobConfig Serialization::deserializeExtendedRuntimePerfKnobConfig(std::istream& is)
|
|
{
|
|
auto multiBlockMode = su::deserialize<bool>(is);
|
|
auto enableContextFMHAFP32Acc = su::deserialize<bool>(is);
|
|
auto cudaGraphMode = su::deserialize<bool>(is);
|
|
auto cudaGraphCacheSize = su::deserialize<SizeType32>(is);
|
|
return ExtendedRuntimePerfKnobConfig{multiBlockMode, enableContextFMHAFP32Acc, cudaGraphMode, cudaGraphCacheSize};
|
|
}
|
|
|
|
void Serialization::serialize(ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig, std::ostream& os)
|
|
{
|
|
su::serialize(extendedRuntimePerfKnobConfig.getMultiBlockMode(), os);
|
|
su::serialize(extendedRuntimePerfKnobConfig.getEnableContextFMHAFP32Acc(), os);
|
|
su::serialize(extendedRuntimePerfKnobConfig.getCudaGraphMode(), os);
|
|
su::serialize(extendedRuntimePerfKnobConfig.getCudaGraphCacheSize(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(extendedRuntimePerfKnobConfig.getMultiBlockMode());
|
|
totalSize += su::serializedSize(extendedRuntimePerfKnobConfig.getEnableContextFMHAFP32Acc());
|
|
totalSize += su::serializedSize(extendedRuntimePerfKnobConfig.getCudaGraphMode());
|
|
totalSize += su::serializedSize(extendedRuntimePerfKnobConfig.getCudaGraphCacheSize());
|
|
return totalSize;
|
|
}
|
|
|
|
// ParallelConfig
|
|
ParallelConfig Serialization::deserializeParallelConfig(std::istream& is)
|
|
{
|
|
auto commType = su::deserialize<CommunicationType>(is);
|
|
auto commMode = su::deserialize<CommunicationMode>(is);
|
|
auto deviceIds = su::deserialize<std::optional<std::vector<SizeType32>>>(is);
|
|
auto participantids = su::deserialize<std::optional<std::vector<SizeType32>>>(is);
|
|
auto orchestratorConfig = su::deserialize<std::optional<OrchestratorConfig>>(is);
|
|
|
|
return ParallelConfig{commType, commMode, deviceIds, participantids, orchestratorConfig};
|
|
}
|
|
|
|
void Serialization::serialize(ParallelConfig const& parallelConfig, std::ostream& os)
|
|
{
|
|
su::serialize(parallelConfig.getCommunicationType(), os);
|
|
su::serialize(parallelConfig.getCommunicationMode(), os);
|
|
su::serialize(parallelConfig.getDeviceIds(), os);
|
|
su::serialize(parallelConfig.getParticipantIds(), os);
|
|
su::serialize(parallelConfig.getOrchestratorConfig(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(ParallelConfig const& parallelConfig)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(parallelConfig.getCommunicationType());
|
|
totalSize += su::serializedSize(parallelConfig.getCommunicationMode());
|
|
totalSize += su::serializedSize(parallelConfig.getDeviceIds());
|
|
totalSize += su::serializedSize(parallelConfig.getParticipantIds());
|
|
totalSize += su::serializedSize(parallelConfig.getOrchestratorConfig());
|
|
return totalSize;
|
|
}
|
|
|
|
// PeftCacheConfig
|
|
PeftCacheConfig Serialization::deserializePeftCacheConfig(std::istream& is)
|
|
{
|
|
auto numHostModuleLayer = su::deserialize<SizeType32>(is);
|
|
auto numDeviceModuleLayer = su::deserialize<SizeType32>(is);
|
|
auto optimalAdapterSize = su::deserialize<SizeType32>(is);
|
|
auto maxAdapterSize = su::deserialize<SizeType32>(is);
|
|
auto numPutWorkers = su::deserialize<SizeType32>(is);
|
|
auto numEnsureWorkers = su::deserialize<SizeType32>(is);
|
|
auto numCopyStreams = su::deserialize<SizeType32>(is);
|
|
auto maxPagesPerBlockHost = su::deserialize<SizeType32>(is);
|
|
auto maxPagesPerBlockDevice = su::deserialize<SizeType32>(is);
|
|
auto deviceCachePercent = su::deserialize<std::optional<FloatType>>(is);
|
|
auto hostCacheSize = su::deserialize<std::optional<size_t>>(is);
|
|
|
|
return PeftCacheConfig{numHostModuleLayer, numDeviceModuleLayer, optimalAdapterSize, maxAdapterSize, numPutWorkers,
|
|
numEnsureWorkers, numCopyStreams, maxPagesPerBlockHost, maxPagesPerBlockDevice, deviceCachePercent,
|
|
hostCacheSize};
|
|
}
|
|
|
|
void Serialization::serialize(PeftCacheConfig const& peftCacheConfig, std::ostream& os)
|
|
{
|
|
su::serialize(peftCacheConfig.getNumHostModuleLayer(), os);
|
|
su::serialize(peftCacheConfig.getNumDeviceModuleLayer(), os);
|
|
su::serialize(peftCacheConfig.getOptimalAdapterSize(), os);
|
|
su::serialize(peftCacheConfig.getMaxAdapterSize(), os);
|
|
su::serialize(peftCacheConfig.getNumPutWorkers(), os);
|
|
su::serialize(peftCacheConfig.getNumEnsureWorkers(), os);
|
|
su::serialize(peftCacheConfig.getNumCopyStreams(), os);
|
|
su::serialize(peftCacheConfig.getMaxPagesPerBlockHost(), os);
|
|
su::serialize(peftCacheConfig.getMaxPagesPerBlockDevice(), os);
|
|
su::serialize(peftCacheConfig.getDeviceCachePercent(), os);
|
|
su::serialize(peftCacheConfig.getHostCacheSize(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(PeftCacheConfig const& peftCacheConfig)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(peftCacheConfig.getNumHostModuleLayer());
|
|
totalSize += su::serializedSize(peftCacheConfig.getNumDeviceModuleLayer());
|
|
totalSize += su::serializedSize(peftCacheConfig.getOptimalAdapterSize());
|
|
totalSize += su::serializedSize(peftCacheConfig.getMaxAdapterSize());
|
|
totalSize += su::serializedSize(peftCacheConfig.getNumPutWorkers());
|
|
totalSize += su::serializedSize(peftCacheConfig.getNumEnsureWorkers());
|
|
totalSize += su::serializedSize(peftCacheConfig.getNumCopyStreams());
|
|
totalSize += su::serializedSize(peftCacheConfig.getMaxPagesPerBlockHost());
|
|
totalSize += su::serializedSize(peftCacheConfig.getMaxPagesPerBlockDevice());
|
|
totalSize += su::serializedSize(peftCacheConfig.getDeviceCachePercent());
|
|
totalSize += su::serializedSize(peftCacheConfig.getHostCacheSize());
|
|
return totalSize;
|
|
}
|
|
|
|
// OrchestratorConfig
|
|
OrchestratorConfig Serialization::deserializeOrchestratorConfig(std::istream& is)
|
|
{
|
|
auto isOrchestrator = su::deserialize<bool>(is);
|
|
auto path = su::deserialize<std::string>(is);
|
|
auto spawnProcesses = su::deserialize<bool>(is);
|
|
// Note we ignore mpiComm since we don't need to exchange it
|
|
return OrchestratorConfig{isOrchestrator, path, nullptr, spawnProcesses};
|
|
}
|
|
|
|
void Serialization::serialize(OrchestratorConfig const& orchestratorConfig, std::ostream& os)
|
|
{
|
|
su::serialize(orchestratorConfig.getIsOrchestrator(), os);
|
|
su::serialize(orchestratorConfig.getWorkerExecutablePath(), os);
|
|
su::serialize(orchestratorConfig.getSpawnProcesses(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(OrchestratorConfig const& orchestratorConfig)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(orchestratorConfig.getIsOrchestrator());
|
|
totalSize += su::serializedSize(orchestratorConfig.getWorkerExecutablePath());
|
|
totalSize += su::serializedSize(orchestratorConfig.getSpawnProcesses());
|
|
return totalSize;
|
|
}
|
|
|
|
// DecodingMode
|
|
DecodingMode Serialization::deserializeDecodingMode(std::istream& is)
|
|
{
|
|
auto mode = su::deserialize<DecodingMode::UnderlyingType>(is);
|
|
|
|
return DecodingMode{mode};
|
|
}
|
|
|
|
void Serialization::serialize(DecodingMode const& decodingMode, std::ostream& os)
|
|
{
|
|
su::serialize(decodingMode.getState(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(DecodingMode const& decodingMode)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(decodingMode.getState());
|
|
return totalSize;
|
|
}
|
|
|
|
// LookaheadDecodingConfig
|
|
LookaheadDecodingConfig Serialization::deserializeLookaheadDecodingConfig(std::istream& is)
|
|
{
|
|
auto ngramSize = su::deserialize<SizeType32>(is);
|
|
auto windowSize = su::deserialize<SizeType32>(is);
|
|
auto verificationSetSize = su::deserialize<SizeType32>(is);
|
|
|
|
return LookaheadDecodingConfig{windowSize, ngramSize, verificationSetSize};
|
|
}
|
|
|
|
void Serialization::serialize(LookaheadDecodingConfig const& lookaheadDecodingConfig, std::ostream& os)
|
|
{
|
|
su::serialize(lookaheadDecodingConfig.getNgramSize(), os);
|
|
su::serialize(lookaheadDecodingConfig.getWindowSize(), os);
|
|
su::serialize(lookaheadDecodingConfig.getVerificationSetSize(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(LookaheadDecodingConfig const& lookaheadDecodingConfig)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(lookaheadDecodingConfig.getNgramSize());
|
|
totalSize += su::serializedSize(lookaheadDecodingConfig.getWindowSize());
|
|
totalSize += su::serializedSize(lookaheadDecodingConfig.getVerificationSetSize());
|
|
return totalSize;
|
|
}
|
|
|
|
// EagleConfig
|
|
EagleConfig Serialization::deserializeEagleConfig(std::istream& is)
|
|
{
|
|
auto eagleChoices = su::deserialize<std::optional<EagleChoices>>(is);
|
|
auto isGreedySampling = su::deserialize<bool>(is);
|
|
auto posteriorThreshold = su::deserialize<std::optional<float>>(is);
|
|
auto useDynamicTree = su::deserialize<bool>(is);
|
|
auto dynamicTreeMaxTopK = su::deserialize<std::optional<SizeType32>>(is);
|
|
|
|
return EagleConfig{eagleChoices, isGreedySampling, posteriorThreshold, useDynamicTree, dynamicTreeMaxTopK};
|
|
}
|
|
|
|
void Serialization::serialize(EagleConfig const& eagleConfig, std::ostream& os)
|
|
{
|
|
su::serialize(eagleConfig.getEagleChoices(), os);
|
|
su::serialize(eagleConfig.isGreedySampling(), os);
|
|
su::serialize(eagleConfig.getPosteriorThreshold(), os);
|
|
su::serialize(eagleConfig.useDynamicTree(), os);
|
|
su::serialize(eagleConfig.getDynamicTreeMaxTopK(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(EagleConfig const& eagleConfig)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(eagleConfig.getEagleChoices());
|
|
totalSize += su::serializedSize(eagleConfig.isGreedySampling());
|
|
totalSize += su::serializedSize(eagleConfig.getPosteriorThreshold());
|
|
totalSize += su::serializedSize(eagleConfig.useDynamicTree());
|
|
totalSize += su::serializedSize(eagleConfig.getDynamicTreeMaxTopK());
|
|
|
|
return totalSize;
|
|
}
|
|
|
|
// SpeculativeDecodingConfig
|
|
SpeculativeDecodingConfig Serialization::deserializeSpeculativeDecodingConfig(std::istream& is)
|
|
{
|
|
auto fastLogits = su::deserialize<decltype(SpeculativeDecodingConfig::fastLogits)>(is);
|
|
return SpeculativeDecodingConfig(fastLogits);
|
|
}
|
|
|
|
void Serialization::serialize(SpeculativeDecodingConfig const& specDecConfig, std::ostream& os)
|
|
{
|
|
su::serialize(specDecConfig.fastLogits, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(SpeculativeDecodingConfig const& specDecConfig)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(specDecConfig.fastLogits);
|
|
return totalSize;
|
|
}
|
|
|
|
// GuidedDecodingConfig
|
|
GuidedDecodingConfig Serialization::deserializeGuidedDecodingConfig(std::istream& is)
|
|
{
|
|
auto backend = su::deserializeWithGetterType<decltype(&GuidedDecodingConfig::getBackend)>(is);
|
|
auto encodedVocab = su::deserializeWithGetterType<decltype(&GuidedDecodingConfig::getEncodedVocab)>(is);
|
|
auto tokenizerStr = su::deserializeWithGetterType<decltype(&GuidedDecodingConfig::getTokenizerStr)>(is);
|
|
auto stopTokenIds = su::deserializeWithGetterType<decltype(&GuidedDecodingConfig::getStopTokenIds)>(is);
|
|
return GuidedDecodingConfig(backend, encodedVocab, tokenizerStr, stopTokenIds);
|
|
}
|
|
|
|
void Serialization::serialize(GuidedDecodingConfig const& guidedDecodingConfig, std::ostream& os)
|
|
{
|
|
su::serialize(guidedDecodingConfig.getBackend(), os);
|
|
su::serialize(guidedDecodingConfig.getEncodedVocab(), os);
|
|
su::serialize(guidedDecodingConfig.getTokenizerStr(), os);
|
|
su::serialize(guidedDecodingConfig.getStopTokenIds(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(GuidedDecodingConfig const& guidedDecodingConfig)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(guidedDecodingConfig.getBackend());
|
|
totalSize += su::serializedSize(guidedDecodingConfig.getEncodedVocab());
|
|
totalSize += su::serializedSize(guidedDecodingConfig.getTokenizerStr());
|
|
totalSize += su::serializedSize(guidedDecodingConfig.getStopTokenIds());
|
|
return totalSize;
|
|
}
|
|
|
|
// GuidedDecodingParams
|
|
GuidedDecodingParams Serialization::deserializeGuidedDecodingParams(std::istream& is)
|
|
{
|
|
auto guideType = su::deserializeWithGetterType<decltype(&GuidedDecodingParams::getGuideType)>(is);
|
|
auto guide = su::deserializeWithGetterType<decltype(&GuidedDecodingParams::getGuide)>(is);
|
|
return GuidedDecodingParams(guideType, guide);
|
|
}
|
|
|
|
void Serialization::serialize(GuidedDecodingParams const& guidedDecodingParams, std::ostream& os)
|
|
{
|
|
su::serialize(guidedDecodingParams.getGuideType(), os);
|
|
su::serialize(guidedDecodingParams.getGuide(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(GuidedDecodingParams const& guidedDecodingParams)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(guidedDecodingParams.getGuideType());
|
|
totalSize += su::serializedSize(guidedDecodingParams.getGuide());
|
|
return totalSize;
|
|
}
|
|
|
|
// KV Cache Retention Config
|
|
std::optional<size_t> durationToInt(std::optional<std::chrono::milliseconds> const& duration)
|
|
{
|
|
return duration.has_value() ? std::optional(static_cast<size_t>(duration->count())) : std::nullopt;
|
|
}
|
|
|
|
std::optional<std::chrono::milliseconds> intToDuration(std::optional<size_t> const& duration)
|
|
{
|
|
return duration.has_value() ? std::optional(std::chrono::milliseconds(*duration)) : std::nullopt;
|
|
}
|
|
|
|
KvCacheRetentionConfig Serialization::deserializeKvCacheRetentionConfig(std::istream& is)
|
|
{
|
|
auto tokenRangeRetentionPriorities
|
|
= su::deserialize<std::vector<executor::KvCacheRetentionConfig::TokenRangeRetentionConfig>>(is);
|
|
auto decodePriority = su::deserialize<executor::RetentionPriority>(is);
|
|
auto decodeDurationMs = intToDuration(su::deserialize<std::optional<size_t>>(is));
|
|
|
|
return KvCacheRetentionConfig{tokenRangeRetentionPriorities, decodePriority, decodeDurationMs};
|
|
}
|
|
|
|
void Serialization::serialize(KvCacheRetentionConfig const& kvCacheRetentionConfig, std::ostream& os)
|
|
{
|
|
su::serialize(kvCacheRetentionConfig.getTokenRangeRetentionConfigs(), os);
|
|
su::serialize(kvCacheRetentionConfig.getDecodeRetentionPriority(), os);
|
|
su::serialize(durationToInt(kvCacheRetentionConfig.getDecodeDurationMs()), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(KvCacheRetentionConfig const& config)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(config.getTokenRangeRetentionConfigs());
|
|
totalSize += su::serializedSize(config.getDecodeRetentionPriority());
|
|
totalSize += su::serializedSize(durationToInt(config.getDecodeDurationMs()));
|
|
return totalSize;
|
|
}
|
|
|
|
// TokenRangeRetentionConfig
|
|
KvCacheRetentionConfig::TokenRangeRetentionConfig Serialization::deserializeTokenRangeRetentionConfig(std::istream& is)
|
|
{
|
|
auto tokenStart = su::deserialize<SizeType32>(is);
|
|
auto tokenEnd = su::deserialize<std::optional<SizeType32>>(is);
|
|
auto priority = static_cast<RetentionPriority>(su::deserialize<RetentionPriority>(is));
|
|
auto durationMs = intToDuration(su::deserialize<std::optional<size_t>>(is));
|
|
|
|
return KvCacheRetentionConfig::TokenRangeRetentionConfig(tokenStart, tokenEnd, priority, durationMs);
|
|
}
|
|
|
|
void Serialization::serialize(
|
|
KvCacheRetentionConfig::TokenRangeRetentionConfig const& tokenRangeRetentionConfig, std::ostream& os)
|
|
{
|
|
su::serialize(tokenRangeRetentionConfig.tokenStart, os);
|
|
su::serialize(tokenRangeRetentionConfig.tokenEnd, os);
|
|
su::serialize(tokenRangeRetentionConfig.priority, os);
|
|
su::serialize(durationToInt(tokenRangeRetentionConfig.durationMs), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(KvCacheRetentionConfig::TokenRangeRetentionConfig const& tokenRangeRetentionConfig)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(tokenRangeRetentionConfig.tokenStart);
|
|
totalSize += su::serializedSize(tokenRangeRetentionConfig.tokenEnd);
|
|
totalSize += su::serializedSize(static_cast<RetentionPriority>(tokenRangeRetentionConfig.priority));
|
|
totalSize += su::serializedSize(durationToInt(tokenRangeRetentionConfig.durationMs));
|
|
return totalSize;
|
|
}
|
|
|
|
// DecodingConfig
|
|
DecodingConfig Serialization::deserializeDecodingConfig(std::istream& is)
|
|
{
|
|
auto decodingMode = su::deserialize<std::optional<DecodingMode>>(is);
|
|
auto lookaheadDecodingConfig = su::deserialize<std::optional<LookaheadDecodingConfig>>(is);
|
|
auto medusaChoices = su::deserialize<std::optional<MedusaChoices>>(is);
|
|
auto eagleConfig = su::deserialize<std::optional<EagleConfig>>(is);
|
|
|
|
return DecodingConfig{decodingMode, lookaheadDecodingConfig, medusaChoices, eagleConfig};
|
|
}
|
|
|
|
void Serialization::serialize(DecodingConfig const& decodingConfig, std::ostream& os)
|
|
{
|
|
su::serialize(decodingConfig.getDecodingMode(), os);
|
|
su::serialize(decodingConfig.getLookaheadDecodingConfig(), os);
|
|
su::serialize(decodingConfig.getMedusaChoices(), os);
|
|
su::serialize(decodingConfig.getEagleConfig(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(DecodingConfig const& decodingConfig)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(decodingConfig.getDecodingMode());
|
|
totalSize += su::serializedSize(decodingConfig.getLookaheadDecodingConfig());
|
|
totalSize += su::serializedSize(decodingConfig.getMedusaChoices());
|
|
totalSize += su::serializedSize(decodingConfig.getEagleConfig());
|
|
return totalSize;
|
|
}
|
|
|
|
// DebugConfig
|
|
DebugConfig Serialization::deserializeDebugConfig(std::istream& is)
|
|
{
|
|
auto debugInputTensors = su::deserializeWithGetterType<decltype(&DebugConfig::getDebugInputTensors)>(is);
|
|
auto debugOutputTensors = su::deserializeWithGetterType<decltype(&DebugConfig::getDebugOutputTensors)>(is);
|
|
auto debugTensorNames = su::deserialize<std::remove_cv_t<
|
|
std::remove_reference_t<std::invoke_result_t<decltype(&DebugConfig::getDebugTensorNames), DebugConfig>>>>(is);
|
|
auto debugTensorsMaxIterations
|
|
= su::deserializeWithGetterType<decltype(&DebugConfig::getDebugTensorsMaxIterations)>(is);
|
|
|
|
return DebugConfig{debugInputTensors, debugOutputTensors, debugTensorNames, debugTensorsMaxIterations};
|
|
}
|
|
|
|
void Serialization::serialize(DebugConfig const& debugConfig, std::ostream& os)
|
|
{
|
|
su::serialize(debugConfig.getDebugInputTensors(), os);
|
|
su::serialize(debugConfig.getDebugOutputTensors(), os);
|
|
su::serialize(debugConfig.getDebugTensorNames(), os);
|
|
su::serialize(debugConfig.getDebugTensorsMaxIterations(), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(DebugConfig const& debugConfig)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(debugConfig.getDebugInputTensors());
|
|
totalSize += su::serializedSize(debugConfig.getDebugOutputTensors());
|
|
totalSize += su::serializedSize(debugConfig.getDebugTensorNames());
|
|
totalSize += su::serializedSize(debugConfig.getDebugTensorsMaxIterations());
|
|
return totalSize;
|
|
}
|
|
|
|
// KvCacheStats
|
|
KvCacheStats Serialization::deserializeKvCacheStats(std::istream& is)
|
|
{
|
|
auto maxNumBlocks = su::deserialize<SizeType32>(is);
|
|
auto freeNumBlocks = su::deserialize<SizeType32>(is);
|
|
auto usedNumBlocks = su::deserialize<SizeType32>(is);
|
|
auto tokensPerBlock = su::deserialize<SizeType32>(is);
|
|
auto allocTotalBlocks = su::deserialize<SizeType32>(is);
|
|
auto allocNewBlocks = su::deserialize<SizeType32>(is);
|
|
auto reusedBlocks = su::deserialize<SizeType32>(is);
|
|
auto missedBlocks = su::deserialize<SizeType32>(is);
|
|
auto cacheHitRate = su::deserialize<float>(is);
|
|
|
|
return KvCacheStats{maxNumBlocks, freeNumBlocks, usedNumBlocks, tokensPerBlock, allocTotalBlocks, allocNewBlocks,
|
|
reusedBlocks, missedBlocks, cacheHitRate};
|
|
}
|
|
|
|
void Serialization::serialize(KvCacheStats const& kvCacheStats, std::ostream& os)
|
|
{
|
|
su::serialize(kvCacheStats.maxNumBlocks, os);
|
|
su::serialize(kvCacheStats.freeNumBlocks, os);
|
|
su::serialize(kvCacheStats.usedNumBlocks, os);
|
|
su::serialize(kvCacheStats.tokensPerBlock, os);
|
|
su::serialize(kvCacheStats.allocTotalBlocks, os);
|
|
su::serialize(kvCacheStats.allocNewBlocks, os);
|
|
su::serialize(kvCacheStats.reusedBlocks, os);
|
|
su::serialize(kvCacheStats.missedBlocks, os);
|
|
su::serialize(kvCacheStats.cacheHitRate, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(KvCacheStats const& kvCacheStats)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(kvCacheStats.maxNumBlocks);
|
|
totalSize += su::serializedSize(kvCacheStats.freeNumBlocks);
|
|
totalSize += su::serializedSize(kvCacheStats.usedNumBlocks);
|
|
totalSize += su::serializedSize(kvCacheStats.tokensPerBlock);
|
|
totalSize += su::serializedSize(kvCacheStats.allocTotalBlocks);
|
|
totalSize += su::serializedSize(kvCacheStats.allocNewBlocks);
|
|
totalSize += su::serializedSize(kvCacheStats.reusedBlocks);
|
|
totalSize += su::serializedSize(kvCacheStats.missedBlocks);
|
|
totalSize += su::serializedSize(kvCacheStats.cacheHitRate);
|
|
return totalSize;
|
|
}
|
|
|
|
// StaticBatchingStats
|
|
StaticBatchingStats Serialization::deserializeStaticBatchingStats(std::istream& is)
|
|
{
|
|
auto numScheduledRequests = su::deserialize<SizeType32>(is);
|
|
auto numContextRequests = su::deserialize<SizeType32>(is);
|
|
auto numCtxTokens = su::deserialize<SizeType32>(is);
|
|
auto numGenTokens = su::deserialize<SizeType32>(is);
|
|
auto emptyGenSlots = su::deserialize<SizeType32>(is);
|
|
return StaticBatchingStats{numScheduledRequests, numContextRequests, numCtxTokens, numGenTokens, emptyGenSlots};
|
|
}
|
|
|
|
void Serialization::serialize(StaticBatchingStats const& staticBatchingStats, std::ostream& os)
|
|
{
|
|
su::serialize(staticBatchingStats.numScheduledRequests, os);
|
|
su::serialize(staticBatchingStats.numContextRequests, os);
|
|
su::serialize(staticBatchingStats.numCtxTokens, os);
|
|
su::serialize(staticBatchingStats.numGenTokens, os);
|
|
su::serialize(staticBatchingStats.emptyGenSlots, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(StaticBatchingStats const& staticBatchingStats)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(staticBatchingStats.numScheduledRequests);
|
|
totalSize += su::serializedSize(staticBatchingStats.numContextRequests);
|
|
totalSize += su::serializedSize(staticBatchingStats.numCtxTokens);
|
|
totalSize += su::serializedSize(staticBatchingStats.numGenTokens);
|
|
totalSize += su::serializedSize(staticBatchingStats.emptyGenSlots);
|
|
return totalSize;
|
|
}
|
|
|
|
// InflightBatchingStats
|
|
InflightBatchingStats Serialization::deserializeInflightBatchingStats(std::istream& is)
|
|
{
|
|
auto numScheduledRequests = su::deserialize<SizeType32>(is);
|
|
auto numContextRequests = su::deserialize<SizeType32>(is);
|
|
auto numGenRequests = su::deserialize<SizeType32>(is);
|
|
auto numPausedRequests = su::deserialize<SizeType32>(is);
|
|
auto numCtxTokens = su::deserialize<SizeType32>(is);
|
|
auto microBatchId = su::deserialize<SizeType32>(is);
|
|
auto avgNumDecodedTokensPerIter = su::deserialize<float>(is);
|
|
return InflightBatchingStats{numScheduledRequests, numContextRequests, numGenRequests, numPausedRequests,
|
|
numCtxTokens, microBatchId, avgNumDecodedTokensPerIter};
|
|
}
|
|
|
|
void Serialization::serialize(InflightBatchingStats const& inflightBatchingStats, std::ostream& os)
|
|
{
|
|
su::serialize(inflightBatchingStats.numScheduledRequests, os);
|
|
su::serialize(inflightBatchingStats.numContextRequests, os);
|
|
su::serialize(inflightBatchingStats.numGenRequests, os);
|
|
su::serialize(inflightBatchingStats.numPausedRequests, os);
|
|
su::serialize(inflightBatchingStats.numCtxTokens, os);
|
|
su::serialize(inflightBatchingStats.microBatchId, os);
|
|
su::serialize(inflightBatchingStats.avgNumDecodedTokensPerIter, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(InflightBatchingStats const& inflightBatchingStats)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(inflightBatchingStats.numScheduledRequests);
|
|
totalSize += su::serializedSize(inflightBatchingStats.numContextRequests);
|
|
totalSize += su::serializedSize(inflightBatchingStats.numGenRequests);
|
|
totalSize += su::serializedSize(inflightBatchingStats.numPausedRequests);
|
|
totalSize += su::serializedSize(inflightBatchingStats.numCtxTokens);
|
|
totalSize += su::serializedSize(inflightBatchingStats.microBatchId);
|
|
totalSize += su::serializedSize(inflightBatchingStats.avgNumDecodedTokensPerIter);
|
|
return totalSize;
|
|
}
|
|
|
|
// IterationStats
|
|
|
|
IterationStats Serialization::deserializeIterationStats(std::istream& is)
|
|
{
|
|
IterationStats iterStats;
|
|
auto timestamp = su::deserialize<std::string>(is);
|
|
auto iter = su::deserialize<IterationType>(is);
|
|
auto iterLatencyMS = su::deserialize<double>(is);
|
|
auto newActiveRequestsQueueLatencyMS = su::deserialize<double>(is);
|
|
auto numNewActiveRequests = su::deserialize<SizeType32>(is);
|
|
auto numActiveRequests = su::deserialize<SizeType32>(is);
|
|
auto numQueuedRequests = su::deserialize<SizeType32>(is);
|
|
auto numCompletedRequests = su::deserialize<SizeType32>(is);
|
|
auto maxNumActiveRequests = su::deserialize<SizeType32>(is);
|
|
auto maxBatchSizeStatic = su::deserialize<SizeType32>(is);
|
|
auto maxBatchSizeTunerRecommended = su::deserialize<SizeType32>(is);
|
|
auto maxBatchSizeRuntime = su::deserialize<SizeType32>(is);
|
|
auto maxNumTokensStatic = su::deserialize<SizeType32>(is);
|
|
auto maxNumTokensTunerRecommended = su::deserialize<SizeType32>(is);
|
|
auto maxNumTokensRuntime = su::deserialize<SizeType32>(is);
|
|
auto gpuMemUsage = su::deserialize<size_t>(is);
|
|
auto cpuMemUsage = su::deserialize<size_t>(is);
|
|
auto pinnedMemUsage = su::deserialize<size_t>(is);
|
|
auto kvCacheStats = su::deserialize<std::optional<KvCacheStats>>(is);
|
|
auto crossKvCacheStats = su::deserialize<std::optional<KvCacheStats>>(is);
|
|
auto staticBatchingStats = su::deserialize<std::optional<StaticBatchingStats>>(is);
|
|
auto inflightBatchingStats = su::deserialize<std::optional<InflightBatchingStats>>(is);
|
|
|
|
return IterationStats{timestamp, iter, iterLatencyMS, newActiveRequestsQueueLatencyMS, numNewActiveRequests,
|
|
numActiveRequests, numQueuedRequests, numCompletedRequests, maxNumActiveRequests, maxBatchSizeStatic,
|
|
maxBatchSizeTunerRecommended, maxBatchSizeRuntime, maxNumTokensStatic, maxNumTokensTunerRecommended,
|
|
maxNumTokensRuntime, gpuMemUsage, cpuMemUsage, pinnedMemUsage, kvCacheStats, crossKvCacheStats,
|
|
staticBatchingStats, inflightBatchingStats};
|
|
}
|
|
|
|
IterationStats Serialization::deserializeIterationStats(std::vector<char>& buffer)
|
|
{
|
|
su::VectorWrapBuf<char> strbuf(buffer);
|
|
std::istream is(&strbuf);
|
|
|
|
return Serialization::deserializeIterationStats(is);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(IterationStats const& iterStats)
|
|
{
|
|
// Compute the size of serialized buffer
|
|
size_t totalSize = 0;
|
|
|
|
totalSize += su::serializedSize(iterStats.timestamp);
|
|
totalSize += su::serializedSize(iterStats.iter);
|
|
totalSize += su::serializedSize(iterStats.iterLatencyMS);
|
|
totalSize += su::serializedSize(iterStats.newActiveRequestsQueueLatencyMS);
|
|
totalSize += su::serializedSize(iterStats.numNewActiveRequests);
|
|
totalSize += su::serializedSize(iterStats.numActiveRequests);
|
|
totalSize += su::serializedSize(iterStats.numQueuedRequests);
|
|
totalSize += su::serializedSize(iterStats.numCompletedRequests);
|
|
totalSize += su::serializedSize(iterStats.maxNumActiveRequests);
|
|
totalSize += su::serializedSize(iterStats.maxBatchSizeStatic);
|
|
totalSize += su::serializedSize(iterStats.maxBatchSizeTunerRecommended);
|
|
totalSize += su::serializedSize(iterStats.maxBatchSizeRuntime);
|
|
totalSize += su::serializedSize(iterStats.maxNumTokensStatic);
|
|
totalSize += su::serializedSize(iterStats.maxNumTokensTunerRecommended);
|
|
totalSize += su::serializedSize(iterStats.maxNumTokensRuntime);
|
|
totalSize += su::serializedSize(iterStats.gpuMemUsage);
|
|
totalSize += su::serializedSize(iterStats.cpuMemUsage);
|
|
totalSize += su::serializedSize(iterStats.pinnedMemUsage);
|
|
totalSize += su::serializedSize(iterStats.kvCacheStats);
|
|
totalSize += su::serializedSize(iterStats.crossKvCacheStats);
|
|
totalSize += su::serializedSize(iterStats.staticBatchingStats);
|
|
totalSize += su::serializedSize(iterStats.inflightBatchingStats);
|
|
|
|
return totalSize;
|
|
}
|
|
|
|
void Serialization::serialize(IterationStats const& iterStats, std::ostream& os)
|
|
{
|
|
su::serialize(iterStats.timestamp, os);
|
|
su::serialize(iterStats.iter, os);
|
|
su::serialize(iterStats.iterLatencyMS, os);
|
|
su::serialize(iterStats.newActiveRequestsQueueLatencyMS, os);
|
|
su::serialize(iterStats.numNewActiveRequests, os);
|
|
su::serialize(iterStats.numActiveRequests, os);
|
|
su::serialize(iterStats.numQueuedRequests, os);
|
|
su::serialize(iterStats.numCompletedRequests, os);
|
|
su::serialize(iterStats.maxNumActiveRequests, os);
|
|
su::serialize(iterStats.maxBatchSizeStatic, os);
|
|
su::serialize(iterStats.maxBatchSizeTunerRecommended, os);
|
|
su::serialize(iterStats.maxBatchSizeRuntime, os);
|
|
su::serialize(iterStats.maxNumTokensStatic, os);
|
|
su::serialize(iterStats.maxNumTokensTunerRecommended, os);
|
|
su::serialize(iterStats.maxNumTokensRuntime, os);
|
|
su::serialize(iterStats.gpuMemUsage, os);
|
|
su::serialize(iterStats.cpuMemUsage, os);
|
|
su::serialize(iterStats.pinnedMemUsage, os);
|
|
su::serialize(iterStats.kvCacheStats, os);
|
|
su::serialize(iterStats.crossKvCacheStats, os);
|
|
su::serialize(iterStats.staticBatchingStats, os);
|
|
su::serialize(iterStats.inflightBatchingStats, os);
|
|
}
|
|
|
|
std::vector<char> Serialization::serialize(IterationStats const& iterStats)
|
|
{
|
|
auto totalSize = Serialization::serializedSize(iterStats);
|
|
std::vector<char> buffer(totalSize);
|
|
|
|
std::stringbuf strbuf(std::ios_base::out | std::ios_base::in);
|
|
strbuf.pubsetbuf(buffer.data(), buffer.size());
|
|
std::ostream os(&strbuf);
|
|
|
|
Serialization::serialize(iterStats, os);
|
|
|
|
return buffer;
|
|
}
|
|
|
|
std::vector<IterationStats> Serialization::deserializeIterationStatsVec(std::vector<char>& buffer)
|
|
{
|
|
std::vector<IterationStats> iterStatsVec;
|
|
su::VectorWrapBuf<char> strbuf(buffer);
|
|
std::istream is(&strbuf);
|
|
auto numIterStats = su::deserialize<std::size_t>(is);
|
|
for (std::size_t iterStats = 0; iterStats < numIterStats; ++iterStats)
|
|
{
|
|
iterStatsVec.emplace_back(Serialization::deserializeIterationStats(is));
|
|
}
|
|
return iterStatsVec;
|
|
}
|
|
|
|
std::vector<char> Serialization::serialize(std::vector<IterationStats> const& iterStatsVec)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += sizeof(size_t);
|
|
for (auto const& iterStats : iterStatsVec)
|
|
{
|
|
totalSize += su::serializedSize(iterStats);
|
|
}
|
|
std::vector<char> buffer(totalSize);
|
|
std::stringbuf strbuf(std::ios_base::out | std::ios_base::in);
|
|
strbuf.pubsetbuf(buffer.data(), buffer.size());
|
|
std::ostream os(&strbuf);
|
|
su::serialize(iterStatsVec.size(), os);
|
|
for (auto const& iterStats : iterStatsVec)
|
|
{
|
|
su::serialize(iterStats, os);
|
|
}
|
|
return buffer;
|
|
}
|
|
|
|
// DistServinigState
|
|
|
|
DisServingRequestStats Serialization::deserializeDisServingRequestStats(std::istream& is)
|
|
{
|
|
auto kvCacheTransferMs = su::deserialize<double>(is);
|
|
auto kvCacheSize = su::deserialize<size_t>(is);
|
|
return DisServingRequestStats{kvCacheTransferMs, kvCacheSize};
|
|
}
|
|
|
|
void Serialization::serialize(DisServingRequestStats const& stats, std::ostream& os)
|
|
{
|
|
su::serialize(stats.kvCacheTransferMS, os);
|
|
su::serialize(stats.kvCacheSize, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(DisServingRequestStats const& disServingRequestStats)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(disServingRequestStats.kvCacheTransferMS);
|
|
totalSize += su::serializedSize(disServingRequestStats.kvCacheSize);
|
|
return totalSize;
|
|
}
|
|
|
|
// ReqeuestStage
|
|
|
|
RequestStage Serialization::deserializeRequestStage(std::istream& is)
|
|
{
|
|
int stage = su::deserialize<int>(is);
|
|
return RequestStage{stage};
|
|
}
|
|
|
|
void Serialization::serialize(RequestStage const& requestStage, std::ostream& os)
|
|
{
|
|
su::serialize(static_cast<int>(requestStage), os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(RequestStage const& requestStage)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(static_cast<int>(requestStage));
|
|
return totalSize;
|
|
}
|
|
|
|
// RequestStats
|
|
RequestStats Serialization::deserializeRequestStats(std::istream& is)
|
|
{
|
|
auto id = su::deserialize<IdType>(is);
|
|
auto stage = su::deserialize<RequestStage>(is);
|
|
auto contextPrefillPosition = su::deserialize<SizeType32>(is);
|
|
auto numGeneratedTokens = su::deserialize<SizeType32>(is);
|
|
auto avgNumDecodedTokensPerIter = su::deserialize<float>(is);
|
|
auto scheduled = su::deserialize<bool>(is);
|
|
auto paused = su::deserialize<bool>(is);
|
|
auto disServingStats = su::deserialize<std::optional<DisServingRequestStats>>(is);
|
|
auto allocTotalBlocksPerRequest = su::deserialize<SizeType32>(is);
|
|
auto allocNewBlocksPerRequest = su::deserialize<SizeType32>(is);
|
|
auto reusedBlocksPerRequest = su::deserialize<SizeType32>(is);
|
|
auto missedBlocksPerRequest = su::deserialize<SizeType32>(is);
|
|
auto kvCacheHitRatePerRequest = su::deserialize<FloatType>(is);
|
|
|
|
return RequestStats{id, stage, contextPrefillPosition, numGeneratedTokens, avgNumDecodedTokensPerIter, scheduled,
|
|
paused, disServingStats, allocTotalBlocksPerRequest, allocNewBlocksPerRequest, reusedBlocksPerRequest,
|
|
missedBlocksPerRequest, kvCacheHitRatePerRequest};
|
|
}
|
|
|
|
void Serialization::serialize(RequestStats const& state, std::ostream& os)
|
|
{
|
|
su::serialize(state.id, os);
|
|
su::serialize(state.stage, os);
|
|
su::serialize(state.contextPrefillPosition, os);
|
|
su::serialize(state.numGeneratedTokens, os);
|
|
su::serialize(state.avgNumDecodedTokensPerIter, os);
|
|
su::serialize(state.scheduled, os);
|
|
su::serialize(state.paused, os);
|
|
su::serialize(state.disServingStats, os);
|
|
su::serialize(state.allocTotalBlocksPerRequest, os);
|
|
su::serialize(state.allocNewBlocksPerRequest, os);
|
|
su::serialize(state.reusedBlocksPerRequest, os);
|
|
su::serialize(state.missedBlocksPerRequest, os);
|
|
su::serialize(state.kvCacheHitRatePerRequest, os);
|
|
}
|
|
|
|
size_t Serialization::serializedSize(RequestStats const& state)
|
|
{
|
|
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(state.id);
|
|
totalSize += su::serializedSize((state.stage));
|
|
totalSize += su::serializedSize(state.contextPrefillPosition);
|
|
totalSize += su::serializedSize(state.numGeneratedTokens);
|
|
totalSize += su::serializedSize(state.avgNumDecodedTokensPerIter);
|
|
totalSize += su::serializedSize(state.scheduled);
|
|
totalSize += su::serializedSize(state.paused);
|
|
totalSize += su::serializedSize(state.disServingStats);
|
|
totalSize += su::serializedSize(state.allocTotalBlocksPerRequest);
|
|
totalSize += su::serializedSize(state.allocNewBlocksPerRequest);
|
|
totalSize += su::serializedSize(state.reusedBlocksPerRequest);
|
|
totalSize += su::serializedSize(state.missedBlocksPerRequest);
|
|
totalSize += su::serializedSize(state.kvCacheHitRatePerRequest);
|
|
return totalSize;
|
|
}
|
|
|
|
// RequestStatsPerIteration
|
|
RequestStatsPerIteration Serialization::deserializeRequestStatsPerIteration(std::istream& is)
|
|
{
|
|
auto iter = su::deserialize<IterationType>(is);
|
|
auto requestStats = su::deserialize<std::vector<RequestStats>>(is);
|
|
|
|
return RequestStatsPerIteration{iter, requestStats};
|
|
}
|
|
|
|
RequestStatsPerIteration Serialization::deserializeRequestStatsPerIteration(std::vector<char>& buffer)
|
|
{
|
|
su::VectorWrapBuf<char> strbuf(buffer);
|
|
std::istream is(&strbuf);
|
|
|
|
return Serialization::deserializeRequestStatsPerIteration(is);
|
|
}
|
|
|
|
void Serialization::serialize(RequestStatsPerIteration const& state, std::ostream& os)
|
|
{
|
|
su::serialize(state.iter, os);
|
|
su::serialize(state.requestStats, os);
|
|
}
|
|
|
|
std::vector<char> Serialization::serialize(RequestStatsPerIteration const& state)
|
|
{
|
|
auto totalSize = Serialization::serializedSize(state);
|
|
std::vector<char> buffer(totalSize);
|
|
|
|
std::stringbuf strbuf(std::ios_base::out | std::ios_base::in);
|
|
strbuf.pubsetbuf(buffer.data(), buffer.size());
|
|
std::ostream os(&strbuf);
|
|
|
|
Serialization::serialize(state, os);
|
|
|
|
return buffer;
|
|
}
|
|
|
|
size_t Serialization::serializedSize(RequestStatsPerIteration const& state)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += su::serializedSize(state.iter);
|
|
totalSize += su::serializedSize(state.requestStats);
|
|
return totalSize;
|
|
}
|
|
|
|
std::vector<char> Serialization::serialize(std::vector<RequestStatsPerIteration> const& requestStatsVec)
|
|
{
|
|
size_t totalSize = 0;
|
|
totalSize += sizeof(size_t);
|
|
for (auto const& requestStat : requestStatsVec)
|
|
{
|
|
totalSize += su::serializedSize(requestStat);
|
|
}
|
|
std::vector<char> buffer(totalSize);
|
|
std::stringbuf strbuf(std::ios_base::out | std::ios_base::in);
|
|
strbuf.pubsetbuf(buffer.data(), buffer.size());
|
|
std::ostream os(&strbuf);
|
|
su::serialize(requestStatsVec.size(), os);
|
|
for (auto const& requestStat : requestStatsVec)
|
|
{
|
|
su::serialize(requestStat, os);
|
|
}
|
|
return buffer;
|
|
}
|
|
|
|
std::vector<RequestStatsPerIteration> Serialization::deserializeRequestStatsPerIterationVec(std::vector<char>& buffer)
|
|
{
|
|
|
|
std::vector<RequestStatsPerIteration> iterRequestStatsVec;
|
|
su::VectorWrapBuf<char> strbuf(buffer);
|
|
std::istream is(&strbuf);
|
|
auto numIterStats = su::deserialize<std::size_t>(is);
|
|
for (std::size_t iterStats = 0; iterStats < numIterStats; ++iterStats)
|
|
{
|
|
iterRequestStatsVec.emplace_back(Serialization::deserializeRequestStatsPerIteration(is));
|
|
}
|
|
return iterRequestStatsVec;
|
|
}
|
|
|
|
// String
|
|
std::string Serialization::deserializeString(std::istream& is)
|
|
{
|
|
return su::deserialize<std::string>(is);
|
|
}
|
|
|
|
// Bool
|
|
bool Serialization::deserializeBool(std::istream& is)
|
|
{
|
|
return su::deserialize<bool>(is);
|
|
}
|
|
|
|
// ModelType
|
|
ModelType Serialization::deserializeModelType(std::istream& is)
|
|
{
|
|
return su::deserialize<ModelType>(is);
|
|
}
|
|
|
|
} // namespace tensorrt_llm::executor
|