TensorRT-LLMs/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp
Kanghwan 41e5870a70
[#8476][chore] Update license (#8807)
Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
2025-11-19 15:05:25 -08:00

1299 lines
55 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/executor/serializeUtils.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
#include "tensorrt_llm/executor/dataTransceiverState.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <gtest/gtest.h>
#include <optional>
#include <type_traits>
#include <variant>
namespace su = tensorrt_llm::executor::serialize_utils;
namespace texec = tensorrt_llm::executor;
using VecTokens = texec::VecTokens;
void compareRequestPerfMetrics(texec::RequestPerfMetrics const& lh, texec::RequestPerfMetrics const& rh)
{
EXPECT_EQ(lh.timingMetrics.arrivalTime, rh.timingMetrics.arrivalTime);
EXPECT_EQ(lh.timingMetrics.firstScheduledTime, rh.timingMetrics.firstScheduledTime);
EXPECT_EQ(lh.timingMetrics.firstTokenTime, rh.timingMetrics.firstTokenTime);
EXPECT_EQ(lh.timingMetrics.lastTokenTime, rh.timingMetrics.lastTokenTime);
EXPECT_EQ(lh.timingMetrics.kvCacheTransferStart, rh.timingMetrics.kvCacheTransferStart);
EXPECT_EQ(lh.timingMetrics.kvCacheTransferEnd, rh.timingMetrics.kvCacheTransferEnd);
EXPECT_EQ(lh.kvCacheMetrics.numTotalAllocatedBlocks, rh.kvCacheMetrics.numTotalAllocatedBlocks);
EXPECT_EQ(lh.kvCacheMetrics.numNewAllocatedBlocks, rh.kvCacheMetrics.numNewAllocatedBlocks);
EXPECT_EQ(lh.kvCacheMetrics.numReusedBlocks, rh.kvCacheMetrics.numReusedBlocks);
EXPECT_EQ(lh.kvCacheMetrics.numMissedBlocks, rh.kvCacheMetrics.numMissedBlocks);
EXPECT_EQ(lh.kvCacheMetrics.kvCacheHitRate, rh.kvCacheMetrics.kvCacheHitRate);
EXPECT_EQ(lh.speculativeDecoding.acceptanceRate, rh.speculativeDecoding.acceptanceRate);
EXPECT_EQ(lh.speculativeDecoding.totalAcceptedDraftTokens, rh.speculativeDecoding.totalAcceptedDraftTokens);
EXPECT_EQ(lh.speculativeDecoding.totalDraftTokens, rh.speculativeDecoding.totalDraftTokens);
EXPECT_EQ(lh.firstIter, rh.firstIter);
EXPECT_EQ(lh.lastIter, rh.lastIter);
EXPECT_EQ(lh.iter, rh.iter);
}
void compareKvCacheStats(texec::KvCacheStats const& lh, texec::KvCacheStats const& rh)
{
EXPECT_TRUE(std::make_tuple(lh.maxNumBlocks, lh.freeNumBlocks, lh.usedNumBlocks, lh.tokensPerBlock)
== std::make_tuple(rh.maxNumBlocks, rh.freeNumBlocks, rh.usedNumBlocks, rh.tokensPerBlock));
}
void compareStaticBatchingStats(texec::StaticBatchingStats const& lh, texec::StaticBatchingStats const& rh)
{
EXPECT_TRUE(std::make_tuple(
lh.numScheduledRequests, lh.numContextRequests, lh.numCtxTokens, lh.numGenTokens, lh.emptyGenSlots)
== std::make_tuple(
rh.numScheduledRequests, rh.numContextRequests, rh.numCtxTokens, rh.numGenTokens, rh.emptyGenSlots));
}
void compareInflightBatchingStats(texec::InflightBatchingStats const& lh, texec::InflightBatchingStats const& rh)
{
EXPECT_TRUE(std::make_tuple(lh.numScheduledRequests, lh.numContextRequests, lh.numGenRequests, lh.numPausedRequests,
lh.numCtxTokens, lh.microBatchId)
== std::make_tuple(rh.numScheduledRequests, rh.numContextRequests, rh.numGenRequests, rh.numPausedRequests,
rh.numCtxTokens, rh.microBatchId));
}
void compareIterationStats(texec::IterationStats const& lh, texec::IterationStats const& rh)
{
EXPECT_EQ(lh.timestamp, rh.timestamp);
EXPECT_EQ(lh.iter, rh.iter);
EXPECT_EQ(lh.numActiveRequests, rh.numActiveRequests);
EXPECT_EQ(lh.maxNumActiveRequests, rh.maxNumActiveRequests);
EXPECT_EQ(lh.gpuMemUsage, rh.gpuMemUsage);
EXPECT_EQ(lh.cpuMemUsage, rh.cpuMemUsage);
EXPECT_EQ(lh.pinnedMemUsage, rh.pinnedMemUsage);
EXPECT_EQ(lh.kvCacheStats.has_value(), rh.kvCacheStats.has_value());
if (lh.kvCacheStats.has_value())
{
compareKvCacheStats(lh.kvCacheStats.value(), rh.kvCacheStats.value());
}
EXPECT_EQ(lh.staticBatchingStats.has_value(), rh.staticBatchingStats.has_value());
if (lh.staticBatchingStats.has_value())
{
compareStaticBatchingStats(lh.staticBatchingStats.value(), rh.staticBatchingStats.value());
}
EXPECT_EQ(lh.inflightBatchingStats.has_value(), rh.inflightBatchingStats.has_value());
if (lh.inflightBatchingStats.has_value())
{
compareInflightBatchingStats(lh.inflightBatchingStats.value(), rh.inflightBatchingStats.value());
}
}
void compareDisServingRequestStats(texec::DisServingRequestStats const& lh, texec::DisServingRequestStats const& rh)
{
EXPECT_EQ(lh.kvCacheTransferMS, rh.kvCacheTransferMS);
EXPECT_EQ(lh.kvCacheSize, rh.kvCacheSize);
}
void compareRequestStats(texec::RequestStats const& lh, texec::RequestStats const& rh)
{
EXPECT_EQ(lh.id, rh.id);
EXPECT_EQ(lh.stage, rh.stage);
EXPECT_EQ(lh.contextPrefillPosition, rh.contextPrefillPosition);
EXPECT_EQ(lh.numGeneratedTokens, rh.numGeneratedTokens);
EXPECT_EQ(lh.avgNumDecodedTokensPerIter, rh.avgNumDecodedTokensPerIter);
EXPECT_EQ(lh.scheduled, rh.scheduled);
EXPECT_EQ(lh.paused, rh.paused);
EXPECT_EQ(lh.disServingStats.has_value(), rh.disServingStats.has_value());
if (lh.disServingStats.has_value())
{
compareDisServingRequestStats(lh.disServingStats.value(), rh.disServingStats.value());
}
EXPECT_EQ(lh.allocTotalBlocksPerRequest, rh.allocTotalBlocksPerRequest);
EXPECT_EQ(lh.allocNewBlocksPerRequest, rh.allocNewBlocksPerRequest);
EXPECT_EQ(lh.reusedBlocksPerRequest, rh.reusedBlocksPerRequest);
EXPECT_EQ(lh.missedBlocksPerRequest, rh.missedBlocksPerRequest);
EXPECT_EQ(lh.kvCacheHitRatePerRequest, rh.kvCacheHitRatePerRequest);
}
void compareRequestStatsPerIteration(
texec::RequestStatsPerIteration const& lh, texec::RequestStatsPerIteration const& rh)
{
EXPECT_EQ(lh.iter, rh.iter);
EXPECT_EQ(lh.requestStats.size(), rh.requestStats.size());
for (size_t i = 0; i < lh.requestStats.size(); i++)
{
compareRequestStats(lh.requestStats.at(i), rh.requestStats.at(i));
}
}
void compareResult(texec::Result res, texec::Result res2)
{
EXPECT_EQ(res.isFinal, res2.isFinal);
EXPECT_EQ(res.outputTokenIds, res2.outputTokenIds);
EXPECT_EQ(res.cumLogProbs, res2.cumLogProbs);
EXPECT_EQ(res.logProbs, res2.logProbs);
EXPECT_EQ(res.finishReasons, res2.finishReasons);
EXPECT_EQ(res.decodingIter, res2.decodingIter);
EXPECT_EQ(res.sequenceIndex, res2.sequenceIndex);
EXPECT_EQ(res.isSequenceFinal, res2.isSequenceFinal);
}
void compareResponse(texec::Response res, texec::Response res2)
{
EXPECT_EQ(res.hasError(), res2.hasError());
EXPECT_EQ(res.getRequestId(), res2.getRequestId());
if (res.hasError())
{
EXPECT_EQ(res.getErrorMsg(), res2.getErrorMsg());
}
else
{
compareResult(res.getResult(), res2.getResult());
}
EXPECT_EQ(res.getClientId(), res2.getClientId());
}
template <typename T>
T serializeDeserialize(T val)
{
auto size = su::serializedSize(val);
std::ostringstream oss;
su::serialize(val, oss);
EXPECT_EQ(oss.str().size(), size);
std::istringstream iss(oss.str());
return su::deserialize<T>(iss);
}
template <typename T>
void testSerializeDeserialize(T val)
{
auto val2 = serializeDeserialize(val);
if constexpr (std::is_same<T, texec::Result>::value)
{
compareResult(val, val2);
}
else if constexpr (std::is_same<T, texec::Response>::value)
{
compareResponse(val, val2);
}
else if constexpr (std::is_same<T, texec::KvCacheStats>::value)
{
compareKvCacheStats(val, val2);
}
else if constexpr (std::is_same<T, texec::StaticBatchingStats>::value)
{
compareStaticBatching(val, val2);
}
else if constexpr (std::is_same<T, texec::InflightBatchingStats>::value)
{
compareInflightBatchingStats(val, val2);
}
else if constexpr (std::is_same<T, texec::IterationStats>::value)
{
compareIterationStats(val, val2);
}
else if constexpr (std::is_same<T, texec::RequestStatsPerIteration>::value)
{
compareRequestStatsPerIteration(val, val2);
}
else if constexpr (std::is_same<T, texec::RequestPerfMetrics>::value)
{
compareRequestPerfMetrics(val, val2);
}
else
{
EXPECT_EQ(val2, val) << typeid(T).name();
}
}
template <typename T, typename T2>
void testSerializeDeserializeVariant(T val)
{
auto val2 = serializeDeserialize(val);
EXPECT_TRUE(std::holds_alternative<T2>(val2));
if constexpr (std::is_same<T2, texec::Result>::value)
{
compareResult(std::get<T2>(val), std::get<T2>(val2));
}
else
{
EXPECT_EQ(std::get<T2>(val), std::get<T2>(val2));
}
}
TEST(SerializeUtilsTest, FundamentalTypes)
{
testSerializeDeserialize(int32_t(99));
testSerializeDeserialize(int64_t(99));
testSerializeDeserialize(uint32_t(99));
testSerializeDeserialize(uint64_t(99));
testSerializeDeserialize(float(99.f));
testSerializeDeserialize(double(99.));
testSerializeDeserialize(char('c'));
}
TEST(SerializeUtilsTest, Vector)
{
{
std::vector<int32_t> vec{1, 2, 3, 4};
testSerializeDeserialize(vec);
}
{
std::vector<char> vec{'a', 'b', 'c', 'd'};
testSerializeDeserialize(vec);
}
{
std::vector<float> vec{1.f, 2.f, 3.f, 4.f};
testSerializeDeserialize(vec);
}
}
TEST(SerializeUtilsTest, List)
{
{
std::list<int32_t> list{1, 2, 3, 4};
testSerializeDeserialize(list);
}
{
std::list<char> list{'a', 'b', 'c', 'd'};
testSerializeDeserialize(list);
}
{
std::list<float> list{9.0f, 3.333f};
testSerializeDeserialize(list);
}
}
TEST(SerializeUtilsTest, String)
{
{
std::string str{"abcdefg"};
testSerializeDeserialize(str);
}
}
TEST(SerializeUtilsTest, Optional)
{
{
std::optional<int32_t> opt = std::nullopt;
testSerializeDeserialize(opt);
}
{
std::optional<int32_t> opt = 1;
testSerializeDeserialize(opt);
}
{
std::optional<char> opt = std::nullopt;
testSerializeDeserialize(opt);
}
{
std::optional<char> opt = 'c';
testSerializeDeserialize(opt);
}
{
std::optional<float> opt = std::nullopt;
testSerializeDeserialize(opt);
}
{
std::optional<float> opt = 1.f;
testSerializeDeserialize(opt);
}
}
TEST(SerializeUtilsTest, Variant)
{
{
std::variant<bool, int32_t> val = int32_t(10);
testSerializeDeserializeVariant<std::variant<bool, int32_t>, int32_t>(val);
}
{
std::variant<bool, texec::Result> val = texec::Result{false, {{1, 2, 3}}};
testSerializeDeserializeVariant<std::variant<bool, texec::Result>, texec::Result>(val);
}
{
std::variant<bool, texec::Result> val = true;
testSerializeDeserializeVariant<std::variant<bool, texec::Result>, bool>(val);
}
}
static texec::RequestPerfMetrics::TimePoint generateTimePoint(uint64_t microseconds)
{
std::chrono::microseconds duration(microseconds);
return std::chrono::steady_clock::time_point(duration);
}
TEST(SerializeUtilsTest, RequestPerfMetrics)
{
uint64_t x = 42;
texec::RequestPerfMetrics::TimingMetrics timingMetrics{
generateTimePoint(x++),
generateTimePoint(x++),
generateTimePoint(x++),
generateTimePoint(x++),
generateTimePoint(x++),
generateTimePoint(x++),
};
texec::RequestPerfMetrics::KvCacheMetrics kvCacheMetrics{1, 2, 3, 4, 5};
texec::RequestPerfMetrics val{timingMetrics, kvCacheMetrics, 1000, 1001, 1002, 0.6f};
testSerializeDeserialize(val);
}
TEST(SerializeUtilsTest, SamplingConfig)
{
{
texec::SamplingConfig val(2);
testSerializeDeserialize(val);
}
{
texec::SamplingConfig val(4);
val.setNumReturnSequences(3);
testSerializeDeserialize(val);
}
{
texec::SamplingConfig val(1);
val.setNumReturnSequences(3);
testSerializeDeserialize(val);
}
}
TEST(SerializeUtilsTest, Nested)
{
{
std::optional<std::vector<int32_t>> val = std::nullopt;
testSerializeDeserialize(val);
}
{
std::optional<std::vector<int32_t>> val = std::vector<int32_t>{1, 2, 3, 5};
testSerializeDeserialize(val);
}
{
std::list<std::vector<int32_t>> val = {{2, 3}, {5, 6, 7}};
testSerializeDeserialize(val);
}
{
std::list<std::vector<std::optional<float>>> val = {{2.f, 3.f}, {5.f, 6.f, 7.f}, {std::nullopt, 3.f}};
testSerializeDeserialize(val);
}
// Unsupported won't build
//{
// std::map<int, int> val;
// val[1] = 1;
// testSerializeDeserialize(val);
//}
{
auto const val = std::make_optional(std::vector<texec::SamplingConfig>{
texec::SamplingConfig{1, 1, 0.05, 0.2}, texec::SamplingConfig{2, std::nullopt}});
testSerializeDeserialize(val);
}
{
auto const val = std::make_optional(texec::ExternalDraftTokensConfig({1, 1}));
auto const size = su::serializedSize(val);
std::ostringstream oss;
su::serialize(val, oss);
EXPECT_EQ(oss.str().size(), size);
std::istringstream iss(oss.str());
auto const val2 = su::deserialize<std::optional<texec::ExternalDraftTokensConfig>>(iss);
EXPECT_EQ(val2.value().getTokens(), val.value().getTokens());
}
}
TEST(SerializeUtilsTest, ResultResponse)
{
texec::Result res = texec::Result{false, {{1, 2, 3}}, texec::VecLogProbs{1.0, 2.0},
std::vector<texec::VecLogProbs>{{1.1, 2.2}, {3.3, 4.4}}, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
std::vector<texec::FinishReason>{texec::FinishReason::kLENGTH},
texec::ContextPhaseParams({9, 37}, 0, VecTokens{1, 2}), 3, 2, true};
{
testSerializeDeserialize(res);
}
{
auto val = texec::Response(1, res);
testSerializeDeserialize(val);
}
{
auto val = texec::Response(1, "my error msg");
testSerializeDeserialize(val);
}
{
auto val = texec::Response(1, "my error msg", 2);
testSerializeDeserialize(val);
}
}
TEST(SerializeUtilsTest, VectorResponses)
{
int numResponses = 15;
std::vector<texec::Response> responsesIn;
for (int i = 0; i < numResponses; ++i)
{
if (i < 5)
{
texec::Result res = texec::Result{false, {{i + 1, i + 2, i + 3}}, texec::VecLogProbs{1.0, 2.0},
std::vector<texec::VecLogProbs>{{1.1, 2.2}, {3.3, 4.4}}, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::vector<texec::FinishReason>{texec::FinishReason::kEND_ID}};
responsesIn.emplace_back(i, res);
}
else if (i < 10)
{
std::string errMsg = "my_err_msg" + std::to_string(i);
responsesIn.emplace_back(i, errMsg);
}
else
{
std::string errMsg = "my_err_msg" + std::to_string(i);
responsesIn.emplace_back(i, errMsg, i + 1);
}
}
auto buffer = texec::Serialization::serialize(responsesIn);
auto responsesOut = texec::Serialization::deserializeResponses(buffer);
EXPECT_EQ(responsesIn.size(), responsesOut.size());
for (int i = 0; i < numResponses; ++i)
{
compareResponse(responsesIn.at(i), responsesOut.at(i));
}
}
TEST(SerializeUtilsTest, KvCacheConfig)
{
texec::KvCacheConfig kvCacheConfig(
true, 10, std::vector(1, 100), 2, 0.1, 10000, false, 0.5, 50, 1024, false, false, true, 77);
auto kvCacheConfig2 = serializeDeserialize(kvCacheConfig);
EXPECT_EQ(kvCacheConfig.getEnableBlockReuse(), kvCacheConfig2.getEnableBlockReuse());
EXPECT_EQ(kvCacheConfig.getEnablePartialReuse(), kvCacheConfig2.getEnablePartialReuse());
EXPECT_EQ(kvCacheConfig.getCopyOnPartialReuse(), kvCacheConfig2.getCopyOnPartialReuse());
EXPECT_EQ(kvCacheConfig.getMaxTokens(), kvCacheConfig2.getMaxTokens());
EXPECT_EQ(kvCacheConfig.getMaxAttentionWindowVec(), kvCacheConfig2.getMaxAttentionWindowVec());
EXPECT_EQ(kvCacheConfig.getSinkTokenLength(), kvCacheConfig2.getSinkTokenLength());
EXPECT_EQ(kvCacheConfig.getFreeGpuMemoryFraction(), kvCacheConfig2.getFreeGpuMemoryFraction());
EXPECT_EQ(kvCacheConfig.getHostCacheSize(), kvCacheConfig2.getHostCacheSize());
EXPECT_EQ(kvCacheConfig.getOnboardBlocks(), kvCacheConfig2.getOnboardBlocks());
EXPECT_EQ(kvCacheConfig.getCrossKvCacheFraction(), kvCacheConfig2.getCrossKvCacheFraction());
EXPECT_EQ(kvCacheConfig.getSecondaryOffloadMinPriority(), kvCacheConfig2.getSecondaryOffloadMinPriority());
EXPECT_EQ(kvCacheConfig.getEventBufferMaxSize(), kvCacheConfig2.getEventBufferMaxSize());
EXPECT_EQ(kvCacheConfig.getUseUvm(), kvCacheConfig2.getUseUvm());
EXPECT_EQ(kvCacheConfig.getAttentionDpEventsGatherPeriodMs(), kvCacheConfig2.getAttentionDpEventsGatherPeriodMs());
}
TEST(SerializeUtilsTest, SchedulerConfig)
{
texec::SchedulerConfig schedulerConfig(
texec::CapacitySchedulerPolicy::kMAX_UTILIZATION, texec::ContextChunkingPolicy::kFIRST_COME_FIRST_SERVED);
auto schedulerConfig2 = serializeDeserialize(schedulerConfig);
EXPECT_EQ(schedulerConfig.getCapacitySchedulerPolicy(), schedulerConfig2.getCapacitySchedulerPolicy());
EXPECT_EQ(schedulerConfig.getContextChunkingPolicy(), schedulerConfig2.getContextChunkingPolicy());
}
TEST(SerializeUtilsTest, ParallelConfig)
{
texec::ParallelConfig parallelConfig(texec::CommunicationType::kMPI, texec::CommunicationMode::kLEADER,
std::vector<texec::SizeType32>{1, 2, 7}, std::vector<texec::SizeType32>{0, 1, 4});
auto parallelConfig2 = serializeDeserialize(parallelConfig);
EXPECT_EQ(parallelConfig.getCommunicationType(), parallelConfig2.getCommunicationType());
EXPECT_EQ(parallelConfig.getCommunicationMode(), parallelConfig2.getCommunicationMode());
EXPECT_EQ(parallelConfig.getDeviceIds(), parallelConfig2.getDeviceIds());
EXPECT_EQ(parallelConfig.getParticipantIds(), parallelConfig2.getParticipantIds());
}
TEST(SerializeUtilsTest, PeftCacheConfig)
{
auto peftCacheConfig = texec::PeftCacheConfig(10, 9, 8, 7, 6, 5, 4, 3, 2, 0.9, 1000);
testSerializeDeserialize(peftCacheConfig);
}
TEST(SerializeUtilsTest, LookaheadDecodingConfig)
{
auto lookaheadDecodingConfig = texec::LookaheadDecodingConfig(3, 5, 7);
auto lookaheadDecodingConfig2 = serializeDeserialize(lookaheadDecodingConfig);
EXPECT_EQ(lookaheadDecodingConfig.getNgramSize(), lookaheadDecodingConfig2.getNgramSize());
EXPECT_EQ(lookaheadDecodingConfig.getWindowSize(), lookaheadDecodingConfig2.getWindowSize());
EXPECT_EQ(lookaheadDecodingConfig.getVerificationSetSize(), lookaheadDecodingConfig2.getVerificationSetSize());
}
TEST(SerializeUtilsTest, EagleConfig)
{
texec::EagleChoices eagleChoices{{{0, 1, 2}}};
auto eagleConfig = texec::EagleConfig(eagleChoices);
auto eagleConfig2 = serializeDeserialize(eagleConfig);
EXPECT_EQ(eagleConfig.getEagleChoices(), eagleConfig2.getEagleChoices());
}
TEST(SerializeUtilsTest, KvCacheRetentionConfig)
{
using namespace std::chrono_literals;
auto kvCacheRetentionConfig = texec::KvCacheRetentionConfig();
auto kvCacheRetentionConfig2 = serializeDeserialize(kvCacheRetentionConfig);
EXPECT_EQ(kvCacheRetentionConfig.getTokenRangeRetentionConfigs(),
kvCacheRetentionConfig2.getTokenRangeRetentionConfigs());
EXPECT_EQ(
kvCacheRetentionConfig.getDecodeRetentionPriority(), kvCacheRetentionConfig2.getDecodeRetentionPriority());
EXPECT_EQ(kvCacheRetentionConfig.getDecodeDurationMs(), kvCacheRetentionConfig2.getDecodeDurationMs());
kvCacheRetentionConfig
= texec::KvCacheRetentionConfig({texec::KvCacheRetentionConfig::TokenRangeRetentionConfig(0, 1, 80),
texec::KvCacheRetentionConfig::TokenRangeRetentionConfig(1, 2, 50),
texec::KvCacheRetentionConfig::TokenRangeRetentionConfig(2, 3, 30)},
5, 30s);
kvCacheRetentionConfig2 = serializeDeserialize(kvCacheRetentionConfig);
EXPECT_EQ(kvCacheRetentionConfig.getTokenRangeRetentionConfigs(),
kvCacheRetentionConfig2.getTokenRangeRetentionConfigs());
EXPECT_EQ(
kvCacheRetentionConfig.getDecodeRetentionPriority(), kvCacheRetentionConfig2.getDecodeRetentionPriority());
EXPECT_EQ(kvCacheRetentionConfig.getDecodeDurationMs(), kvCacheRetentionConfig2.getDecodeDurationMs());
}
TEST(SerializeUtilsTest, DecodingConfig)
{
{
texec::DecodingMode decodingMode{texec::DecodingMode::Lookahead()};
texec::LookaheadDecodingConfig laConfig{3, 5, 7};
auto specDecodingConfig = texec::DecodingConfig(decodingMode, laConfig);
auto specDecodingConfig2 = serializeDeserialize(specDecodingConfig);
EXPECT_EQ(specDecodingConfig.getDecodingMode(), specDecodingConfig2.getDecodingMode());
EXPECT_EQ(specDecodingConfig.getLookaheadDecodingConfig(), specDecodingConfig2.getLookaheadDecodingConfig());
}
{
texec::DecodingMode decodingMode{texec::DecodingMode::Medusa()};
texec::MedusaChoices medusaChoices{{{0, 1, 2}}};
auto specDecodingConfig = texec::DecodingConfig(decodingMode, std::nullopt, medusaChoices);
auto specDecodingConfig2 = serializeDeserialize(specDecodingConfig);
EXPECT_EQ(specDecodingConfig.getDecodingMode(), specDecodingConfig2.getDecodingMode());
EXPECT_EQ(specDecodingConfig.getMedusaChoices(), specDecodingConfig2.getMedusaChoices());
}
{
texec::DecodingMode decodingMode{texec::DecodingMode::Eagle()};
texec::EagleChoices eagleChoices{{{0, 1, 2}}};
texec::EagleConfig eagleConfig{eagleChoices};
auto specDecodingConfig = texec::DecodingConfig(decodingMode, std::nullopt, std::nullopt, eagleConfig);
auto specDecodingConfig2 = serializeDeserialize(specDecodingConfig);
EXPECT_EQ(specDecodingConfig.getDecodingMode(), specDecodingConfig2.getDecodingMode());
EXPECT_EQ(specDecodingConfig.getEagleConfig()->getEagleChoices(),
specDecodingConfig2.getEagleConfig()->getEagleChoices());
}
}
TEST(SerializeUtilsTest, DebugConfig)
{
texec::DebugConfig debugConfig(true, true, {"test"}, 3);
auto debugConfig2 = serializeDeserialize(debugConfig);
EXPECT_EQ(debugConfig.getDebugInputTensors(), debugConfig2.getDebugInputTensors());
EXPECT_EQ(debugConfig.getDebugOutputTensors(), debugConfig2.getDebugOutputTensors());
EXPECT_EQ(debugConfig.getDebugTensorNames(), debugConfig2.getDebugTensorNames());
EXPECT_EQ(debugConfig.getDebugTensorsMaxIterations(), debugConfig2.getDebugTensorsMaxIterations());
}
TEST(SerializeUtilsTest, OrchestratorConfig)
{
auto orchConfig = texec::OrchestratorConfig(false, std::filesystem::current_path().string(), nullptr, false);
auto orchConfig2 = serializeDeserialize(orchConfig);
EXPECT_EQ(orchConfig.getIsOrchestrator(), orchConfig2.getIsOrchestrator());
EXPECT_EQ(orchConfig.getWorkerExecutablePath(), orchConfig2.getWorkerExecutablePath());
EXPECT_EQ(orchConfig.getSpawnProcesses(), orchConfig2.getSpawnProcesses());
}
TEST(SerializeUtilsTest, KvCacheStats)
{
auto stats = texec::KvCacheStats{10, 20, 30, 40, 50, 60, 70};
auto stats2 = serializeDeserialize(stats);
compareKvCacheStats(stats, stats2);
}
TEST(SerializeUtilsTest, StaticBatchingStats)
{
auto stats = texec::StaticBatchingStats{10, 20, 30, 40, 50};
auto stats2 = serializeDeserialize(stats);
compareStaticBatchingStats(stats, stats2);
}
TEST(SerializeUtilsTest, InflightBatchingStats)
{
auto stats = texec::InflightBatchingStats{10, 20, 30, 40, 50, 60};
auto stats2 = serializeDeserialize(stats);
compareInflightBatchingStats(stats, stats2);
}
TEST(SerializeUtilsTest, IterationStats)
{
auto timestamp = std::string{"05:01:00"};
auto iter = texec::IterationType{10};
auto iterLatencyMS = double{100};
auto newActiveRequestsQueueLatencyMS = double{1000};
auto numNewActiveRequests = texec::SizeType32{10};
auto numActiveRequests = texec::SizeType32{20};
auto numQueuedRequests = texec::SizeType32{30};
auto numCompletedRequests = texec::SizeType32{10};
auto maxNumActiveRequests = texec::SizeType32{30};
auto maxBatchSizeStatic = texec::SizeType32{100};
auto maxBatchSizeTunerRecommended = texec::SizeType32{50};
auto maxBatchSizeRuntime = texec::SizeType32{50};
auto maxNumTokensStatic = texec::SizeType32{100};
auto maxNumTokensTunerRecommended = texec::SizeType32{50};
auto maxNumTokensRuntime = texec::SizeType32{50};
auto gpuMemUsage = size_t{1024};
auto cpuMemUsage = size_t{2048};
auto pinnedMemUsage = size_t{4096};
auto kvCacheStats = texec::KvCacheStats{10, 20, 30, 40, 50, 60, 70};
auto staticBatchingStats = texec::StaticBatchingStats{10, 20, 30, 40, 50};
auto ifbBatchingStats = texec::InflightBatchingStats{10, 20, 30, 40, 50, 60};
{
{
auto stats = texec::IterationStats{timestamp, iter, iterLatencyMS, newActiveRequestsQueueLatencyMS,
numNewActiveRequests, numActiveRequests, numQueuedRequests, numCompletedRequests, maxNumActiveRequests,
maxBatchSizeStatic, maxBatchSizeTunerRecommended, maxBatchSizeRuntime, maxNumTokensStatic,
maxNumTokensTunerRecommended, maxNumTokensRuntime, gpuMemUsage, cpuMemUsage, pinnedMemUsage,
kvCacheStats, kvCacheStats, staticBatchingStats, ifbBatchingStats};
// serialize and deserialize using std::vector<char>
{
auto buffer = texec::Serialization::serialize(stats);
auto stats2 = texec::Serialization::deserializeIterationStats(buffer);
compareIterationStats(stats, stats2);
}
// serialize deserialize using is, os
{
auto stats2 = serializeDeserialize(stats);
compareIterationStats(stats, stats2);
}
}
}
for (auto kvStats : std::vector<std::optional<texec::KvCacheStats>>{std::nullopt, kvCacheStats})
{
for (auto staticBatchStats :
std::vector<std::optional<texec::StaticBatchingStats>>{std::nullopt, staticBatchingStats})
{
for (auto ifbBatchStats :
std::vector<std::optional<texec::InflightBatchingStats>>{std::nullopt, ifbBatchingStats})
{
auto stats = texec::IterationStats{timestamp, iter, iterLatencyMS, newActiveRequestsQueueLatencyMS,
numNewActiveRequests, numActiveRequests, numQueuedRequests, numCompletedRequests,
maxNumActiveRequests, maxBatchSizeStatic, maxBatchSizeTunerRecommended, maxBatchSizeRuntime,
maxNumTokensStatic, maxNumTokensTunerRecommended, maxNumTokensRuntime, gpuMemUsage, cpuMemUsage,
pinnedMemUsage, kvStats, kvStats, staticBatchStats, ifbBatchStats};
{
auto buffer = texec::Serialization::serialize(stats);
auto stats2 = texec::Serialization::deserializeIterationStats(buffer);
compareIterationStats(stats, stats2);
}
{
auto stats2 = serializeDeserialize(stats);
compareIterationStats(stats, stats2);
}
}
}
}
}
TEST(SerializeUtilsTest, ContextPhaseParams)
{
{
auto stats = texec::ContextPhaseParams({1}, 0, std::nullopt);
auto stats2 = serializeDeserialize(stats);
EXPECT_EQ(stats, stats2);
}
{
auto state = std::make_unique<texec::DataTransceiverState>();
state->setCommState(texec::kv_cache::CommState{{10, 20}});
auto stats = texec::ContextPhaseParams({10, 20, 30, 40, 50, 60}, 1, state.release(), VecTokens{10, 20});
auto stats2 = serializeDeserialize(stats);
EXPECT_EQ(stats, stats2);
}
{
auto state = std::make_unique<texec::DataTransceiverState>();
state->setCommState(texec::kv_cache::CommState{12, "127.0.0.1"});
state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 8, {4}, nvinfer1::DataType::kFLOAT});
auto stats = texec::ContextPhaseParams({10, 20, 30, 40, 50, 60}, 0, state.release(), VecTokens{10, 20});
auto stats2 = serializeDeserialize(stats);
EXPECT_EQ(stats, stats2);
}
{
auto state = std::make_unique<texec::DataTransceiverState>();
state->setCommState(texec::kv_cache::CommState{{10, 20}});
auto state2 = *state;
auto contextPhaseParams
= texec::ContextPhaseParams({10, 20, 30, 40, 50, 60}, 1, state.release(), VecTokens{10, 20});
auto serializedState = contextPhaseParams.getSerializedState();
auto stateCopy = texec::Serialization::deserializeDataTransceiverState(serializedState);
EXPECT_EQ(state2, stateCopy);
}
}
TEST(SerializeUtilsTest, SpeculativeDecodingFastLogitsInfo)
{
auto logitsInfo = texec::SpeculativeDecodingFastLogitsInfo{10, 20};
auto logitsInfo2 = serializeDeserialize(logitsInfo);
EXPECT_EQ(logitsInfo.draftRequestId, logitsInfo2.draftRequestId);
EXPECT_EQ(logitsInfo.draftParticipantId, logitsInfo2.draftParticipantId);
}
TEST(SerializeUtilsTest, GuidedDecodingConfig)
{
std::vector<std::string> encodedVocab{"eos", "a", "b", "c", "d"};
std::vector<texec::TokenIdType> stopTokenIds{0};
texec::GuidedDecodingConfig guidedDecodingConfig(
texec::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR, encodedVocab, std::nullopt, stopTokenIds);
auto guidedDecodingConfig2 = serializeDeserialize(guidedDecodingConfig);
EXPECT_EQ(guidedDecodingConfig.getBackend(), guidedDecodingConfig2.getBackend());
EXPECT_EQ(guidedDecodingConfig.getEncodedVocab(), guidedDecodingConfig2.getEncodedVocab());
EXPECT_EQ(guidedDecodingConfig.getTokenizerStr(), guidedDecodingConfig2.getTokenizerStr());
EXPECT_EQ(guidedDecodingConfig.getStopTokenIds(), guidedDecodingConfig2.getStopTokenIds());
}
TEST(SerializeUtilsTest, GuidedDecodingParams)
{
texec::GuidedDecodingParams guidedDecodingParams(texec::GuidedDecodingParams::GuideType::kREGEX, R"(\d+)");
auto guidedDecodingParams2 = serializeDeserialize(guidedDecodingParams);
EXPECT_EQ(guidedDecodingParams.getGuideType(), guidedDecodingParams2.getGuideType());
EXPECT_EQ(guidedDecodingParams.getGuide(), guidedDecodingParams2.getGuide());
}
TEST(SerializeUtilsTest, ExecutorConfig)
{
texec::ExecutorConfig executorConfig(2, texec::SchedulerConfig(texec::CapacitySchedulerPolicy::kMAX_UTILIZATION),
texec::KvCacheConfig(true), true, false, 500, 200, texec::BatchingType::kSTATIC, 128, 64,
texec::ParallelConfig(texec::CommunicationType::kMPI, texec::CommunicationMode::kORCHESTRATOR),
texec::PeftCacheConfig(10), std::nullopt,
texec::DecodingConfig(texec::DecodingMode::Lookahead(), texec::LookaheadDecodingConfig(3, 5, 7)), false, 0.5f,
8, texec::ExtendedRuntimePerfKnobConfig(true), texec::DebugConfig(true), 60000000, 180000000,
texec::SpeculativeDecodingConfig(true),
texec::GuidedDecodingConfig(
texec::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR, std::initializer_list<std::string>{"eos"}),
std::vector{tensorrt_llm::executor::AdditionalModelOutput{"output_name"}},
texec::CacheTransceiverConfig(std::nullopt, 1024, 100, 1000), true, true, true);
auto executorConfig2 = serializeDeserialize(executorConfig);
EXPECT_EQ(executorConfig.getMaxBeamWidth(), executorConfig2.getMaxBeamWidth());
EXPECT_EQ(executorConfig.getSchedulerConfig(), executorConfig2.getSchedulerConfig());
EXPECT_EQ(executorConfig.getKvCacheConfig().getEnableBlockReuse(),
executorConfig2.getKvCacheConfig().getEnableBlockReuse());
EXPECT_EQ(executorConfig.getEnableChunkedContext(), executorConfig2.getEnableChunkedContext());
EXPECT_EQ(executorConfig.getNormalizeLogProbs(), executorConfig2.getNormalizeLogProbs());
EXPECT_EQ(executorConfig.getIterStatsMaxIterations(), executorConfig2.getIterStatsMaxIterations());
EXPECT_EQ(executorConfig.getRequestStatsMaxIterations(), executorConfig2.getRequestStatsMaxIterations());
EXPECT_EQ(executorConfig.getBatchingType(), executorConfig2.getBatchingType());
EXPECT_EQ(executorConfig.getMaxBatchSize(), executorConfig2.getMaxBatchSize());
EXPECT_EQ(executorConfig.getMaxNumTokens(), executorConfig2.getMaxNumTokens());
EXPECT_EQ(executorConfig.getParallelConfig().value().getCommunicationMode(),
executorConfig2.getParallelConfig().value().getCommunicationMode());
EXPECT_EQ(executorConfig.getPeftCacheConfig(), executorConfig2.getPeftCacheConfig());
EXPECT_EQ(executorConfig.getDecodingConfig(), executorConfig2.getDecodingConfig());
EXPECT_EQ(executorConfig.getUseGpuDirectStorage(), executorConfig2.getUseGpuDirectStorage());
EXPECT_EQ(executorConfig.getGpuWeightsPercent(), executorConfig2.getGpuWeightsPercent());
EXPECT_EQ(executorConfig.getMaxQueueSize(), executorConfig2.getMaxQueueSize());
EXPECT_EQ(executorConfig.getExtendedRuntimePerfKnobConfig(), executorConfig2.getExtendedRuntimePerfKnobConfig());
EXPECT_EQ(executorConfig.getDebugConfig(), executorConfig2.getDebugConfig());
EXPECT_EQ(executorConfig.getRecvPollPeriodMs(), executorConfig2.getRecvPollPeriodMs());
EXPECT_EQ(executorConfig.getMaxSeqIdleMicroseconds(), executorConfig2.getMaxSeqIdleMicroseconds());
EXPECT_EQ(executorConfig.getSpecDecConfig(), executorConfig2.getSpecDecConfig());
EXPECT_EQ(executorConfig.getGuidedDecodingConfig(), executorConfig2.getGuidedDecodingConfig());
EXPECT_EQ(executorConfig.getAdditionalModelOutputs(), executorConfig2.getAdditionalModelOutputs());
EXPECT_EQ(executorConfig.getCacheTransceiverConfig(), executorConfig2.getCacheTransceiverConfig());
EXPECT_EQ(executorConfig.getGatherGenerationLogits(), executorConfig2.getGatherGenerationLogits());
EXPECT_EQ(executorConfig.getPromptTableOffloading(), executorConfig2.getPromptTableOffloading());
EXPECT_EQ(executorConfig.getEnableTrtOverlap(), executorConfig2.getEnableTrtOverlap());
}
TEST(SerializeUtilsTest, RequestStats)
{
tensorrt_llm::executor::DisServingRequestStats disServingRequestStats{0.56222, 78643200};
texec::RequestStats requestStats{123, tensorrt_llm::executor::RequestStage::kQUEUED, 56, 25, 135, true, false,
disServingRequestStats, 33, 22, 6, 1, 8};
auto requestStats2 = serializeDeserialize(requestStats);
compareRequestStats(requestStats, requestStats2);
requestStats.disServingStats = std::nullopt;
requestStats2 = serializeDeserialize(requestStats);
compareRequestStats(requestStats, requestStats2);
}
TEST(SerializeUtilsTest, RequestStatsPerIteration)
{
tensorrt_llm::executor::DisServingRequestStats disServingRequestStats{0.56222, 78643200};
texec::RequestStats requestStats1{123, tensorrt_llm::executor::RequestStage::kQUEUED, 56, 25, 135, true, false,
disServingRequestStats, 33, 22, 6, 1, 8};
texec::RequestStats requestStats2{
899, texec::RequestStage::kGENERATION_IN_PROGRESS, 98, 78, 77, false, true, std::nullopt, 7, 14, 65, 61, 78};
texec::RequestStatsPerIteration requestStatsPerIteration{0, {requestStats1, requestStats2}};
auto requestStatsPerIteration2 = serializeDeserialize(requestStatsPerIteration);
compareRequestStatsPerIteration(requestStatsPerIteration, requestStatsPerIteration2);
}
void compareKvCacheEvents(texec::KVCacheEvent const& kvCacheEvent, texec::KVCacheEvent const& kvCacheEvent2)
{
EXPECT_EQ(kvCacheEvent.eventId, kvCacheEvent2.eventId);
EXPECT_EQ(kvCacheEvent.windowSize, kvCacheEvent2.windowSize);
EXPECT_EQ(kvCacheEvent.attentionDpRank, kvCacheEvent2.attentionDpRank);
if (std::holds_alternative<texec::KVCacheCreatedData>(kvCacheEvent.data))
{
EXPECT_TRUE(std::holds_alternative<texec::KVCacheCreatedData>(kvCacheEvent2.data));
auto data = std::get<texec::KVCacheCreatedData>(kvCacheEvent.data);
auto data2 = std::get<texec::KVCacheCreatedData>(kvCacheEvent2.data);
EXPECT_EQ(data.numBlocksPerCacheLevel, data2.numBlocksPerCacheLevel);
}
else if (std::holds_alternative<texec::KVCacheRemovedData>(kvCacheEvent.data))
{
EXPECT_TRUE(std::holds_alternative<texec::KVCacheRemovedData>(kvCacheEvent2.data));
auto data = std::get<texec::KVCacheRemovedData>(kvCacheEvent.data);
auto data2 = std::get<texec::KVCacheRemovedData>(kvCacheEvent2.data);
EXPECT_EQ(data.blockHashes, data2.blockHashes);
}
else if (std::holds_alternative<texec::KVCacheStoredData>(kvCacheEvent.data))
{
EXPECT_TRUE(std::holds_alternative<texec::KVCacheStoredData>(kvCacheEvent2.data));
auto data = std::get<texec::KVCacheStoredData>(kvCacheEvent.data);
auto data2 = std::get<texec::KVCacheStoredData>(kvCacheEvent2.data);
EXPECT_EQ(data.parentHash, data2.parentHash);
EXPECT_EQ(data.blocks.size(), data2.blocks.size());
for (size_t i = 0; i < data.blocks.size(); ++i)
{
auto blockData = data.blocks[i];
auto blockData2 = data2.blocks[i];
EXPECT_EQ(blockData.blockHash, blockData2.blockHash);
EXPECT_EQ(blockData.loraId, blockData2.loraId);
EXPECT_EQ(blockData.cacheLevel, blockData2.cacheLevel);
EXPECT_EQ(blockData.priority, blockData2.priority);
EXPECT_EQ(blockData.tokens.size(), blockData2.tokens.size());
for (size_t j = 0; j < blockData.tokens.size(); ++j)
{
EXPECT_EQ(blockData.tokens[j].tokenId, blockData2.tokens[j].tokenId);
EXPECT_EQ(blockData.tokens[j].tokenExtraId, blockData2.tokens[j].tokenExtraId);
}
}
}
else if (std::holds_alternative<texec::KVCacheUpdatedData>(kvCacheEvent.data))
{
EXPECT_TRUE(std::holds_alternative<texec::KVCacheUpdatedData>(kvCacheEvent2.data));
auto data = std::get<texec::KVCacheUpdatedData>(kvCacheEvent.data);
auto data2 = std::get<texec::KVCacheUpdatedData>(kvCacheEvent2.data);
EXPECT_EQ(data.blockHash, data2.blockHash);
if (data.cacheLevel)
{
EXPECT_TRUE(data2.cacheLevel);
EXPECT_EQ(data.cacheLevel.value().oldValue, data2.cacheLevel.value().oldValue);
EXPECT_EQ(data.cacheLevel.value().newValue, data2.cacheLevel.value().newValue);
}
if (data.priority)
{
EXPECT_TRUE(data2.priority);
EXPECT_EQ(data.priority.value().oldValue, data2.priority.value().oldValue);
EXPECT_EQ(data.priority.value().newValue, data2.priority.value().newValue);
}
}
else
{
FAIL() << "Unknown KVCacheEvent data type";
}
}
TEST(SerializeUtilsTest, KvCacheEventsDeque)
{
// Created event
texec::KVCacheCreatedData kvCacheCreatedData{{1, 2}};
texec::KVCacheEvent kvCacheCreatedEvent(1, kvCacheCreatedData, 32);
// Removed event
texec::KVCacheEvent kvCacheRemovedEvent(1, texec::KVCacheRemovedData{{3, 4}}, 32);
// Stored event
auto storedBlockData1 = texec::KVCacheStoredBlockData(77, {{1, 2}, {3, 4}, {5, 6}}, 88, 0, 99);
auto storedBlockData2 = texec::KVCacheStoredBlockData(99, {{11, 12}, {3, 4}, {15, 6}}, 77, 1, 101);
texec::KVCacheStoredData kvCacheStoredData{177, {storedBlockData1, storedBlockData2}};
texec::KVCacheEvent kvCacheStoredEvent(1, kvCacheStoredData, 32);
// Updated event
texec::KVCacheEventDiff<texec::SizeType32> diff{0, 1};
texec::KVCacheEventDiff<texec::SizeType32> diff2{90, 99};
texec::KVCacheUpdatedData kvCacheUpdatedData(999, diff, diff2);
texec::KVCacheEvent kvCacheEvent(1, kvCacheUpdatedData, 32);
std::deque<texec::KVCacheEvent> kvCacheEvents{
kvCacheCreatedEvent, kvCacheRemovedEvent, kvCacheStoredEvent, kvCacheEvent};
auto serializedEvents = texec::Serialization::serialize(kvCacheEvents);
auto kvCacheEvents2 = texec::Serialization::deserializeKVCacheEvents(serializedEvents);
EXPECT_EQ(kvCacheEvents.size(), kvCacheEvents2.size());
for (size_t i = 0; i < kvCacheEvents.size(); ++i)
{
compareKvCacheEvents(kvCacheEvents[i], kvCacheEvents2[i]);
}
}
// Test for KVCacheEvent with KVCacheCreatedData
TEST(SerializeUtilsTest, KVCacheCreatedEvent)
{
texec::KVCacheCreatedData kvCacheCreatedData{{1, 2}};
texec::KVCacheEvent kvCacheEvent(1, kvCacheCreatedData, 32);
auto kvCacheEvent2 = serializeDeserialize(kvCacheEvent);
compareKvCacheEvents(kvCacheEvent, kvCacheEvent2);
}
// Test for KVCacheEvent with KVCacheRemovedData
TEST(SerializeUtilsTest, KVCacheRemovedEvents)
{
texec::KVCacheEvent kvCacheEvent(1, texec::KVCacheRemovedData{{3, 4}}, 32);
auto kvCacheEvent2 = serializeDeserialize(kvCacheEvent);
compareKvCacheEvents(kvCacheEvent, kvCacheEvent2);
}
// Test for KVCacheEvent with KVCacheStoredData
TEST(SerializeUtilsTest, KVCacheStoredEvent)
{
auto storedBlockData1 = texec::KVCacheStoredBlockData(77, {{1, 2}, {3, 4}, {5, 6}}, 88, 0, 99);
auto storedBlockData2 = texec::KVCacheStoredBlockData(99, {{11, 12}, {3, 4}, {15, 6}}, 77, 1, 101);
texec::KVCacheStoredData kvCacheStoredData{177, {storedBlockData1, storedBlockData2}};
texec::KVCacheEvent kvCacheEvent(1, kvCacheStoredData, 32);
auto kvCacheEvent2 = serializeDeserialize(kvCacheEvent);
compareKvCacheEvents(kvCacheEvent, kvCacheEvent2);
}
// Test for KVCacheEvent with KVCacheUpdatedData
TEST(SerializeUtilsTest, KVCacheUpdatedEvent)
{
texec::KVCacheEventDiff<texec::SizeType32> diff{0, 1};
texec::KVCacheEventDiff<texec::SizeType32> diff2{90, 99};
texec::KVCacheUpdatedData kvCacheUpdatedData(999, diff, diff2);
texec::KVCacheEvent kvCacheEvent(1, kvCacheUpdatedData, 32);
auto kvCacheEvent2 = serializeDeserialize(kvCacheEvent);
compareKvCacheEvents(kvCacheEvent, kvCacheEvent2);
}
TEST(SerializeUtilsTest, UniqueToken)
{
tensorrt_llm::runtime::UniqueToken token{1, 2};
auto token2 = serializeDeserialize(token);
EXPECT_EQ(token.tokenId, token2.tokenId);
EXPECT_EQ(token.tokenExtraId, token2.tokenExtraId);
}
TEST(SerializeUtilsTest, UniqueTokenVector)
{
std::vector<tensorrt_llm::runtime::UniqueToken> tokens{{1, 2}, {3, 4}, {5, 6}};
auto tokens2 = serializeDeserialize(tokens);
EXPECT_EQ(tokens.size(), tokens2.size());
for (size_t i = 0; i < tokens.size(); ++i)
{
EXPECT_EQ(tokens[i].tokenId, tokens2[i].tokenId);
EXPECT_EQ(tokens[i].tokenExtraId, tokens2[i].tokenExtraId);
}
}
TEST(SerializeUtilsTest, MethodReturnType)
{
struct S
{
void foo() const;
[[nodiscard]] int bar() const;
[[nodiscard]] float const& baz() const;
};
static_assert(std::is_same_v<void, su::method_return_type_t<decltype(&S::foo)>>);
static_assert(std::is_same_v<int, su::method_return_type_t<decltype(&S::bar)>>);
static_assert(std::is_same_v<float const&, su::method_return_type_t<decltype(&S::baz)>>);
}
TEST(SerializeUtilsTest, CacheTransceiverConfig)
{
texec::CacheTransceiverConfig cacheTransceiverConfig(
tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, 1024, 100, 1000);
auto cacheTransceiverConfig2 = serializeDeserialize(cacheTransceiverConfig);
EXPECT_EQ(cacheTransceiverConfig.getBackendType(), cacheTransceiverConfig2.getBackendType());
EXPECT_EQ(cacheTransceiverConfig.getMaxTokensInBuffer(), cacheTransceiverConfig2.getMaxTokensInBuffer());
EXPECT_EQ(cacheTransceiverConfig.getKvTransferTimeoutMs(), cacheTransceiverConfig2.getKvTransferTimeoutMs());
EXPECT_EQ(cacheTransceiverConfig.getKvTransferSenderFutureTimeoutMs(),
cacheTransceiverConfig2.getKvTransferSenderFutureTimeoutMs());
}
TEST(SerializeUtilsTest, BlockKeyBasic)
{
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
VecUniqueTokens uniqueTokens{UniqueToken{1, 0}, UniqueToken{2, 0}, UniqueToken{3, 0}};
BlockKey key(false, std::nullopt, uniqueTokens, {});
testSerializeDeserialize(key);
}
TEST(SerializeUtilsTest, BlockKeyWithExtras)
{
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
// Prepare multimodal extra keys
std::array<uint8_t, 32> h1{};
std::array<uint8_t, 32> h2{};
for (size_t i = 0; i < h1.size(); ++i)
{
h1[i] = static_cast<uint8_t>(i);
h2[i] = static_cast<uint8_t>(255 - i);
}
std::vector<MmKey> extraKeys{{h1, SizeType32{0}}, {h2, SizeType32{5}}};
VecUniqueTokens uniqueTokens{UniqueToken{10, 100}, UniqueToken{20, 200}};
std::optional<LoraTaskIdType> loraTaskId = LoraTaskIdType{42};
// Note: cacheSaltID is intentionally not set since it is not serialized
BlockKey key(true, loraTaskId, uniqueTokens, extraKeys);
testSerializeDeserialize(key);
}
// Connection notification tests
namespace kv_cache = tensorrt_llm::executor::kv_cache;
template <typename T>
T serializeDeserializeNotification(T const& val)
{
auto size = T::serializedSize(val);
std::ostringstream oss;
T::serialize(val, oss);
EXPECT_EQ(oss.str().size(), size);
std::istringstream iss(oss.str());
return T::deserialize(iss);
}
TEST(SerializeUtilsTest, RequestAndBufferInfo)
{
// Test with all fields populated
{
kv_cache::RequestAndBufferInfo original{"testAgent", "127.0.0.1:8080",
tensorrt_llm::batch_manager::RequestInfo{},
std::vector<kv_cache::MemoryDesc>{kv_cache::MemoryDesc{nullptr, 1024, 0}},
std::make_optional<std::string>("metadata"), 1};
auto deserialized = serializeDeserializeNotification(original);
EXPECT_EQ(original.mAgentName, deserialized.mAgentName);
EXPECT_EQ(original.mAddress, deserialized.mAddress);
EXPECT_EQ(original.mRequestInfo.getRequestId(), deserialized.mRequestInfo.getRequestId());
ASSERT_EQ(original.mBufferDescs.size(), deserialized.mBufferDescs.size());
EXPECT_EQ(original.mBufferDescs[0].getAddr(), deserialized.mBufferDescs[0].getAddr());
EXPECT_EQ(original.mBufferDescs[0].getLen(), deserialized.mBufferDescs[0].getLen());
EXPECT_EQ(original.mBufferDescs[0].getDeviceId(), deserialized.mBufferDescs[0].getDeviceId());
EXPECT_EQ(original.mMetadata, deserialized.mMetadata);
EXPECT_EQ(original.mValidConnectionIdx, deserialized.mValidConnectionIdx);
}
// Test with nullopt metadata
{
kv_cache::RequestAndBufferInfo original{"testAgent2", "192.168.1.1:9090",
tensorrt_llm::batch_manager::RequestInfo{},
std::vector<kv_cache::MemoryDesc>{kv_cache::MemoryDesc{nullptr, 512, 0}}, std::nullopt, 2};
auto deserialized = serializeDeserializeNotification(original);
EXPECT_EQ(original.mAgentName, deserialized.mAgentName);
EXPECT_EQ(original.mAddress, deserialized.mAddress);
EXPECT_EQ(original.mRequestInfo.getRequestId(), deserialized.mRequestInfo.getRequestId());
ASSERT_EQ(original.mBufferDescs.size(), deserialized.mBufferDescs.size());
EXPECT_EQ(original.mBufferDescs[0].getAddr(), deserialized.mBufferDescs[0].getAddr());
EXPECT_EQ(original.mBufferDescs[0].getLen(), deserialized.mBufferDescs[0].getLen());
EXPECT_EQ(original.mBufferDescs[0].getDeviceId(), deserialized.mBufferDescs[0].getDeviceId());
EXPECT_EQ(original.mMetadata, deserialized.mMetadata);
EXPECT_EQ(original.mValidConnectionIdx, deserialized.mValidConnectionIdx);
}
}
TEST(SerializeUtilsTest, ReadySignalInfo)
{
// Test with isReady = true
{
kv_cache::ReadySignalInfo original{"agent1", kv_cache::DataContext{12345}, true};
auto deserialized = serializeDeserializeNotification(original);
EXPECT_EQ(original.mAgentName, deserialized.mAgentName);
EXPECT_EQ(original.mContext.getTag(), deserialized.mContext.getTag());
EXPECT_EQ(original.mIsReady, deserialized.mIsReady);
}
// Test with isReady = false
{
kv_cache::ReadySignalInfo original{"agent2", kv_cache::DataContext{67890}, false};
auto deserialized = serializeDeserializeNotification(original);
EXPECT_EQ(original.mAgentName, deserialized.mAgentName);
EXPECT_EQ(original.mContext.getTag(), deserialized.mContext.getTag());
EXPECT_EQ(original.mIsReady, deserialized.mIsReady);
}
// Test with different context tags
{
kv_cache::ReadySignalInfo original{"agent3", kv_cache::DataContext{0}, true};
auto deserialized = serializeDeserializeNotification(original);
EXPECT_EQ(original.mAgentName, deserialized.mAgentName);
EXPECT_EQ(original.mContext.getTag(), deserialized.mContext.getTag());
EXPECT_EQ(original.mIsReady, deserialized.mIsReady);
}
}
TEST(SerializeUtilsTest, NotificationSyncInfo)
{
// Test basic functionality
{
kv_cache::NotificationSyncInfo original{"syncAgent", kv_cache::DataContext{54321}};
auto deserialized = serializeDeserializeNotification(original);
EXPECT_EQ(original.mAgentName, deserialized.mAgentName);
EXPECT_EQ(original.mContext.getTag(), deserialized.mContext.getTag());
}
// Test with different agent names and context tags
{
kv_cache::NotificationSyncInfo original{"anotherAgent", kv_cache::DataContext{98765}};
auto deserialized = serializeDeserializeNotification(original);
EXPECT_EQ(original.mAgentName, deserialized.mAgentName);
EXPECT_EQ(original.mContext.getTag(), deserialized.mContext.getTag());
}
// Test with empty agent name
{
kv_cache::NotificationSyncInfo original{"", kv_cache::DataContext{11111}};
auto deserialized = serializeDeserializeNotification(original);
EXPECT_EQ(original.mAgentName, deserialized.mAgentName);
EXPECT_EQ(original.mContext.getTag(), deserialized.mContext.getTag());
}
}
TEST(SerializeUtilsTest, NotificationInfo)
{
// Test with RequestAndBufferInfo variant
{
kv_cache::RequestAndBufferInfo requestInfo{"testAgent", "127.0.0.1:8080",
tensorrt_llm::batch_manager::RequestInfo{},
std::vector<kv_cache::MemoryDesc>{kv_cache::MemoryDesc{nullptr, 1024, 0}},
std::make_optional<std::string>("test_metadata"), 1};
kv_cache::NotificationInfo original{requestInfo};
auto deserialized = serializeDeserializeNotification(original);
EXPECT_TRUE(std::holds_alternative<kv_cache::RequestAndBufferInfo>(deserialized.mInfo));
auto deserializedRequestInfo = std::get<kv_cache::RequestAndBufferInfo>(deserialized.mInfo);
EXPECT_EQ(requestInfo.mAgentName, deserializedRequestInfo.mAgentName);
EXPECT_EQ(requestInfo.mAddress, deserializedRequestInfo.mAddress);
EXPECT_EQ(requestInfo.mRequestInfo.getRequestId(), deserializedRequestInfo.mRequestInfo.getRequestId());
EXPECT_EQ(requestInfo.mMetadata, deserializedRequestInfo.mMetadata);
EXPECT_EQ(requestInfo.mValidConnectionIdx, deserializedRequestInfo.mValidConnectionIdx);
}
// Test with NotificationSyncInfo variant
{
kv_cache::NotificationSyncInfo syncInfo{"syncAgent", kv_cache::DataContext{12345}};
kv_cache::NotificationInfo original{syncInfo};
auto deserialized = serializeDeserializeNotification(original);
EXPECT_TRUE(std::holds_alternative<kv_cache::NotificationSyncInfo>(deserialized.mInfo));
auto deserializedSyncInfo = std::get<kv_cache::NotificationSyncInfo>(deserialized.mInfo);
EXPECT_EQ(syncInfo.mAgentName, deserializedSyncInfo.mAgentName);
EXPECT_EQ(syncInfo.mContext.getTag(), deserializedSyncInfo.mContext.getTag());
}
// Test with ReadySignalInfo variant
{
kv_cache::ReadySignalInfo readyInfo{"readyAgent", kv_cache::DataContext{67890}, true};
kv_cache::NotificationInfo original{readyInfo};
auto deserialized = serializeDeserializeNotification(original);
EXPECT_TRUE(std::holds_alternative<kv_cache::ReadySignalInfo>(deserialized.mInfo));
auto deserializedReadyInfo = std::get<kv_cache::ReadySignalInfo>(deserialized.mInfo);
EXPECT_EQ(readyInfo.mAgentName, deserializedReadyInfo.mAgentName);
EXPECT_EQ(readyInfo.mContext.getTag(), deserializedReadyInfo.mContext.getTag());
EXPECT_EQ(readyInfo.mIsReady, deserializedReadyInfo.mIsReady);
}
}
TEST(SerializeUtilsTest, CacheStateIndexerKCache)
{
using texec::kv_cache::CacheState;
// Minimal realistic model/parallel config
std::vector<texec::SizeType32> nbKvHeadsPerLayer{8, 8};
texec::SizeType32 sizePerHead = 128;
texec::SizeType32 tokensPerBlock = 16;
texec::SizeType32 tp = 1;
texec::SizeType32 pp = 1;
texec::SizeType32 cp = 1;
std::vector<texec::SizeType32> attentionLayerNumPerPP{static_cast<texec::SizeType32>(nbKvHeadsPerLayer.size())};
auto dataType = nvinfer1::DataType::kFLOAT;
auto attentionType = CacheState::AttentionType::kDEFAULT;
int kvFactor = 2;
bool enableAttentionDP = false;
int dpRank = 0;
int dpSize = 1;
bool enableBlockReuse = true;
bool hasIndexerKCache = true;
texec::SizeType32 indexerDimPerHead = 96;
texec::SizeType32 indexerKCacheQuantBlockSize = 128;
CacheState state{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tp, pp, cp, attentionLayerNumPerPP, dataType,
attentionType, kvFactor, enableAttentionDP, dpRank, dpSize, enableBlockReuse, hasIndexerKCache,
indexerDimPerHead, indexerKCacheQuantBlockSize};
std::ostringstream oss;
texec::Serialization::serialize(state, oss);
std::istringstream iss(oss.str());
auto state2 = texec::Serialization::deserializeCacheState(iss);
EXPECT_EQ(state.getModelConfig().mNbKvHeadsPerLayer, state2.getModelConfig().mNbKvHeadsPerLayer);
EXPECT_EQ(state.getModelConfig().mSizePerHead, state2.getModelConfig().mSizePerHead);
EXPECT_EQ(state.getModelConfig().mTokensPerBlock, state2.getModelConfig().mTokensPerBlock);
EXPECT_EQ(state.getParallelConfig().mTensorParallelism, state2.getParallelConfig().mTensorParallelism);
EXPECT_EQ(state.getParallelConfig().mPipelineParallelism, state2.getParallelConfig().mPipelineParallelism);
EXPECT_EQ(state.getParallelConfig().mContextParallelism, state2.getParallelConfig().mContextParallelism);
EXPECT_EQ(state.getParallelConfig().mAttentionLayerNumPerPP, state2.getParallelConfig().mAttentionLayerNumPerPP);
EXPECT_EQ(state.getDataType(), state2.getDataType());
EXPECT_EQ(state.getAttentionConfig().mAttentionType, state2.getAttentionConfig().mAttentionType);
EXPECT_EQ(state.getAttentionConfig().mKvFactor, state2.getAttentionConfig().mKvFactor);
EXPECT_EQ(state.getEnableBlockReuse(), state2.getEnableBlockReuse());
EXPECT_EQ(state.getHasIndexerKCache(), state2.getHasIndexerKCache());
EXPECT_EQ(state.getIndexerDimPerHead(), state2.getIndexerDimPerHead());
EXPECT_EQ(state.getIndexerKCacheQuantBlockSize(), state2.getIndexerKCacheQuantBlockSize());
}