/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/executor/dataTransceiverState.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/serialization.h" #include "tensorrt_llm/executor/types.h" #include #include #include #include #include #include #include #include namespace tensorrt_llm::executor::serialize_utils { template > class VectorWrapBuf : public std::basic_streambuf { public: explicit VectorWrapBuf(std::vector& vec) { std::streambuf::setg(vec.data(), vec.data(), vec.data() + vec.size()); } }; template struct ValueType { using type = void; }; template struct ValueType> { using type = typename T::value_type; }; template struct ValueType, void> { using type = T; }; template struct is_variant : std::false_type { }; template struct is_variant> : std::true_type { }; template constexpr bool is_variant_v = is_variant::value; // SerializedSize template bool constexpr hasSerializedSize(...) { return false; } template bool constexpr hasSerializedSize(decltype(Serialization::serializedSize(std::declval()))) { return true; } static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(hasSerializedSize(size_t())); static_assert(!hasSerializedSize(size_t())); static_assert(!hasSerializedSize>(size_t())); static_assert(hasSerializedSize(size_t())); template size_t serializedSize(T const& data) { // Fundamental types if constexpr (std::is_fundamental_v) { return sizeof(T); } else if constexpr (hasSerializedSize(size_t())) { return Serialization::serializedSize(data); } // Enum class else if constexpr (std::is_enum_v) { using UnderlyingType = std::underlying_type_t; auto value = static_cast(data); return serializedSize(value); } // Vectors, lists and strings else if constexpr (std::is_same_v::type>> || std::is_same_v::type>> || std::is_same_v) { size_t size = sizeof(size_t); for (auto const& elem : data) { size += serializedSize(elem); } return size; } // Optional else if constexpr (std::is_same_v::type>>) { return sizeof(bool) + (data.has_value() ? serializedSize(data.value()) : 0); } else if constexpr (is_variant_v) { size_t index = data.index(); size_t size = sizeof(index); std::visit([&size](auto const& value) { size += serializedSize(value); }, data); return size; } else { static_assert(std::is_same_v, "Unsupported type for serialization"); } } // Serialize template bool constexpr hasSerialize(...) { return false; } template bool constexpr hasSerialize( decltype(Serialization::serialize(std::declval(), std::declval()))*) { return true; } static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(hasSerialize(nullptr)); static_assert(!hasSerialize(nullptr)); static_assert(!hasSerialize>(nullptr)); static_assert(hasSerialize(nullptr)); template void serialize(T const& data, std::ostream& os) { // Fundamental types if constexpr (std::is_fundamental_v) { os.write(reinterpret_cast(&data), sizeof(data)); } else if constexpr (hasSerialize(nullptr)) { return Serialization::serialize(data, os); } // Enum class else if constexpr (std::is_enum_v) { using UnderlyingType = std::underlying_type_t; auto value = static_cast(data); os.write(reinterpret_cast(&value), sizeof(value)); } // Vectors, lists and strings else if constexpr (std::is_same_v::type>> || std::is_same_v::type>> || std::is_same_v) { size_t size = data.size(); os.write(reinterpret_cast(&size), sizeof(size)); for (auto const& element : data) { serialize(element, os); } } // Optional else if constexpr (std::is_same_v::type>>) { // Serialize a boolean indicating whether optional has a value bool hasValue = data.has_value(); os.write(reinterpret_cast(&hasValue), sizeof(hasValue)); // Serialize the value if it exists if (hasValue) { serialize(data.value(), os); } } // std::variant else if constexpr (is_variant_v) { // Store the index of the active variant size_t index = data.index(); os.write(reinterpret_cast(&index), sizeof(index)); // Serialize the held value based on the index std::visit([&os](auto const& value) { serialize(value, os); }, data); } else { static_assert(std::is_same_v, "Unsupported type for serialization"); } } template using variant_alternative_t = typename std::variant_alternative::type; template struct get_variant_alternative_type { static variant_alternative_t get(T const& variant) { return std::get(variant); } }; // Deserialize template T deserialize(std::istream& is) { // Fundamental types if constexpr (std::is_fundamental_v) { T data; is.read(reinterpret_cast(&data), sizeof(data)); return data; } // Enum class else if constexpr (std::is_enum_v) { using UnderlyingType = std::underlying_type_t; UnderlyingType value; is.read(reinterpret_cast(&value), sizeof(value)); return static_cast(value); } // deserialize from serialization class else if constexpr (std::is_same_v) { return Serialization::deserializeTimePoint(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeRequestPerfMetrics(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeSamplingConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeOutputConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeAdditionalModelOutput(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeExternalDraftTokensConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializePromptTuningConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeMropeConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeLoraConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeCommState(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeSocketState(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeAgentState(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeCacheState(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeDataTransceiverState(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeContextPhaseParams(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeRequest(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeTensor(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeSpecDecFastLogitsInfo(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeResult(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeResponse(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeKvCacheConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeDynamicBatchConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeSchedulerConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeExtendedRuntimePerfKnobConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeParallelConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializePeftCacheConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeOrchestratorConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeDecodingMode(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeLookaheadDecodingConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeEagleConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeSpeculativeDecodingConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeGuidedDecodingConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeGuidedDecodingParams(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeKvCacheRetentionConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeTokenRangeRetentionConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeDecodingConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeDebugConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeKvCacheStats(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeStaticBatchingStats(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeInflightBatchingStats(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeSpecDecodingStats(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeIterationStats(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeExecutorConfig(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeDisServingRequestStats(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeRequestStage(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeRequestStats(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeRequestStatsPerIteration(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeAdditionalOutput(is); } else if constexpr (std::is_same_v) { return Serialization::deserializeCacheTransceiverConfig(is); } // Optional else if constexpr (std::is_same_v::type>>) { bool hasValue = false; is.read(reinterpret_cast(&hasValue), sizeof(hasValue)); if (hasValue) { auto value = deserialize::type>(is); return std::optional::type>(std::move(value)); } return std::nullopt; } // Vectors, lists and strings else if constexpr (std::is_same_v::type>> || std::is_same_v::type>> || std::is_same_v) { size_t size = 0; is.read(reinterpret_cast(&size), sizeof(size)); T container; for (size_t i = 0; i < size; ++i) { auto element = deserialize::type>(is); container.push_back(std::move(element)); } return container; } // std::variant else if constexpr (is_variant_v) { // Get the index of the active type std::size_t index = 0; is.read(reinterpret_cast(&index), sizeof(index)); // TODO: Is there a better way to implement this? T data; if (index == 0) { using U = std::variant_alternative_t<0, T>; data = deserialize(is); } else if (index == 1) { using U = std::variant_alternative_t<1, T>; data = deserialize(is); } else { TLLM_THROW("Serialization of variant of size > 2 is not supported."); } return data; } else { static_assert(std::is_same_v, "Unsupported type for deserialization"); return T(); } } // https://stackoverflow.com/a/75741832 template struct method_return_type; template struct method_return_type { using type = ReturnT; }; template using method_return_type_t = typename method_return_type::type; template auto deserializeWithGetterType(std::istream& is) { return deserialize>(is); } } // namespace tensorrt_llm::executor::serialize_utils