refactor: DisaggExecutorTest (#4398)

* chore: Improve formatting of DisaggExecutorTest

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Typed InstanceRole param in DisaggExecutorTest

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* refactor: Skip DisaggExecutorTest based on device count

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
Robin Kobus 2025-05-21 12:01:45 +02:00 committed by GitHub
parent 4018806742
commit cd0c826417
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 336 additions and 194 deletions

View File

@ -16,30 +16,79 @@
#include "tensorrt_llm/runtime/utils/numpyUtils.h"
#include "tests/utils/common.h"
#include <cstddef>
#include <unordered_set>
namespace tr = tensorrt_llm::runtime;
using namespace tensorrt_llm::testing;
using DisaggParamsType = std::tuple<int, std::vector<std::string>, std::vector<std::vector<int>>,
std::vector<std::vector<int>>, std::vector<int>, int>;
using CondDisaggParamsType = std::tuple<std::string>;
namespace
{
auto constexpr LLAMA_INPUT_FILE = "input_tokens_llama.npy";
auto constexpr LLAMA_VOCAB_SIZE_PADDED = 128256;
auto constexpr LLAMA_END_ID = 128001;
auto constexpr LLAMA_PAD_ID = 128001;
} // namespace
using CondDisaggParamsType = std::tuple<std::string>; // modelName
enum InstanceRole : int
enum class InstanceRole : int
{
CONTEXT = 1,
GENERATION = 0,
MIXED = 2
kCONTEXT = 1,
kGENERATION = 0,
kMIXED = 2
};
using DisaggParamsType = std::tuple< //
int, // processNum
std::vector<std::string>, // modelNames
std::vector<std::vector<int>>, // participantIdsEachInstance
std::vector<std::vector<int>>, // participantDeviceIdsEachInstance
std::vector<InstanceRole>, // instanceRoles
int // controllerRank
>;
std::string convertToString(std::vector<std::vector<int>> const& vec)
{
std::ostringstream oss;
oss << "XX";
for (size_t i = 0; i < vec.size(); ++i)
{
for (size_t j = 0; j < vec[i].size(); ++j)
{
oss << vec[i][j];
if (j < vec[i].size() - 1)
{
oss << "_";
}
}
if (i < vec.size() - 1)
{
oss << "X_X";
}
}
oss << "XX";
return oss.str();
};
std::string convertToString(std::vector<InstanceRole> const& vec)
{
std::ostringstream oss;
oss << "XX";
for (size_t j = 0; j < vec.size(); ++j)
{
oss << static_cast<int>(vec[j]);
if (j < vec.size() - 1)
{
oss << "_";
}
}
oss << "XX";
return oss.str();
};
std::string generateTestNameDisaggParams(testing::TestParamInfo<DisaggParamsType> const& info)
@ -51,30 +100,6 @@ std::string generateTestNameDisaggParams(testing::TestParamInfo<DisaggParamsType
auto const instanceRoles = std::get<4>(info.param); // std::vector<int> ; //1 is context , 0 is generation
auto const controllerRank = std::get<5>(info.param);
auto convertToString = [](std::vector<std::vector<int>> const& vec)
{
std::ostringstream oss;
oss << "XX";
for (size_t i = 0; i < vec.size(); ++i)
{
for (size_t j = 0; j < vec[i].size(); ++j)
{
oss << vec[i][j];
if (j < vec[i].size() - 1)
{
oss << "_";
}
}
if (i < vec.size() - 1)
{
oss << "X_X";
}
}
oss << "XX";
return oss.str();
};
std::string name = "DisaggExecutorTest_";
name.append("ProcessNum_" + std::to_string(processNum));
@ -89,7 +114,7 @@ std::string generateTestNameDisaggParams(testing::TestParamInfo<DisaggParamsType
name.append("_ranks_").append(convertToString(participantIdsEachInstance));
name.append("_devices_").append(convertToString(participantDeviceIdsEachInstance));
name.append("_roles_").append(convertToString({instanceRoles}));
name.append("_roles_").append(convertToString(instanceRoles));
name.append("_controllerRank_" + std::to_string(controllerRank));
return name;
@ -113,8 +138,6 @@ class ConditionalDisaggParamsTest : public GptExecutorTest, public ::testing::Wi
{
};
namespace
{
void verifyGenerateDistStats(std::deque<RequestStatsPerIteration> const& iterationStats)
{
for (auto const& iteration : iterationStats)
@ -466,6 +489,21 @@ TEST_P(DisaggParamsTest, DisaggTokenComparison)
SizeType32 instanceNum = participantIdsEachInstance.size();
ASSERT_EQ(instanceNum, instanceRoles.size());
ASSERT_EQ(instanceNum, modelNames.size());
std::unordered_set<int> deviceIdsSet;
for (auto const& ids : participantDeviceIdsEachInstance)
{
for (auto const& id : ids)
{
deviceIdsSet.insert(id);
}
}
if (mDeviceCount < deviceIdsSet.size())
{
GTEST_SKIP() << " need " << deviceIdsSet.size() << " devices but got " << mDeviceCount
<< " devices, skip test.";
}
ASSERT_GE(controllerRank, 0);
ASSERT_LT(controllerRank, commSize);
int ranksNum = 0;
@ -500,8 +538,9 @@ TEST_P(DisaggParamsTest, DisaggTokenComparison)
{
participatntIds = ranksThisInstance;
deviceIds = devicesThisInstance;
isContext = instanceRoles[i] == InstanceRole::CONTEXT || instanceRoles[i] == InstanceRole::MIXED;
isGeneration = instanceRoles[i] == InstanceRole::GENERATION || instanceRoles[i] == InstanceRole::MIXED;
isContext = instanceRoles[i] == InstanceRole::kCONTEXT || instanceRoles[i] == InstanceRole::kMIXED;
isGeneration
= instanceRoles[i] == InstanceRole::kGENERATION || instanceRoles[i] == InstanceRole::kMIXED;
// modelName = isContext ? contextModel : genModel;
modelName = modelNames[i];
}
@ -606,19 +645,9 @@ TEST_P(DisaggParamsTest, DisaggTokenComparison)
if (modelName == "llama_tp4_pp1_cp1" || modelName == "llama_tp1_pp4_cp1" || modelName == "llama_tp2_pp2_cp1"
|| modelName == "llama_tp1_pp2_cp1" || modelName == "llama_tp2_pp1_cp1")
{
// For llama model, only run for multiple GPUs
// This is detected by setting an env variable when running the test
char const* val = getenv("RUN_LLAMA_MULTI_GPU");
if (val == NULL)
if (outConfig.returnLogProbs || outConfig.returnContextLogits || outConfig.returnGenerationLogits)
{
GTEST_SKIP() << "Skipping Llama test";
}
else
{
if (outConfig.returnLogProbs || outConfig.returnContextLogits || outConfig.returnGenerationLogits)
{
GTEST_SKIP() << "Skipping logits and log probs tests for mpi runs";
}
GTEST_SKIP() << "Skipping logits and log probs tests for mpi runs";
}
}
@ -680,13 +709,12 @@ TEST_P(DisaggOrchestratorParamsTest, DisaggTokenComparison)
{
GTEST_SKIP() << " need " << processNum << " processes but got " << commSize << " mpi processes, skip test.";
}
bool spawnProcess = false;
if (commSize == 1)
{
spawnProcess = true;
int deviceCount = -1;
TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
if (deviceCount < 4)
if (mDeviceCount < 4)
{
GTEST_SKIP() << "DisaggExecutorTest requires at least 4 GPUs";
}
@ -697,6 +725,21 @@ TEST_P(DisaggOrchestratorParamsTest, DisaggTokenComparison)
SizeType32 instanceNum = participantIdsEachInstance.size();
ASSERT_EQ(instanceNum, instanceRoles.size());
ASSERT_EQ(instanceNum, modelNames.size());
std::unordered_set<int> deviceIdsSet;
for (auto const& ids : participantDeviceIdsEachInstance)
{
for (auto const& id : ids)
{
deviceIdsSet.insert(id);
}
}
if (mDeviceCount < deviceIdsSet.size())
{
GTEST_SKIP() << " need " << deviceIdsSet.size() << " devices but got " << mDeviceCount
<< " devices, skip test.";
}
ASSERT_GE(controllerRank, 0);
ASSERT_LT(controllerRank, commSize);
std::string modelName = modelNames[0];
@ -735,7 +778,7 @@ TEST_P(DisaggOrchestratorParamsTest, DisaggTokenComparison)
};
for (SizeType32 i = 0; i < instanceNum; i++)
{
if (instanceRoles[i] == 1)
if (instanceRoles[i] == InstanceRole::kCONTEXT)
{
contextModels.push_back(getModelPath(modelNames[i]));
}
@ -814,19 +857,9 @@ TEST_P(DisaggOrchestratorParamsTest, DisaggTokenComparison)
if (modelName == "llama_tp4_pp1" || modelName == "llama_tp1_pp4" || modelName == "llama_tp2_pp2"
|| modelName == "llama_tp1_pp2" || modelName == "llama_tp2_pp1")
{
// For llama model, only run for multiple GPUs
// This is detected by setting an env variable when running the test
char const* val = getenv("RUN_LLAMA_MULTI_GPU");
if (val == NULL)
if (outConfig.returnLogProbs || outConfig.returnContextLogits || outConfig.returnGenerationLogits)
{
GTEST_SKIP() << "Skipping Llama test";
}
else
{
if (outConfig.returnLogProbs || outConfig.returnContextLogits || outConfig.returnGenerationLogits)
{
GTEST_SKIP() << "Skipping logits and log probs tests for mpi runs";
}
GTEST_SKIP() << "Skipping logits and log probs tests for mpi runs";
}
}
@ -891,6 +924,7 @@ TEST_P(ConditionalDisaggParamsTest, DisaggTokenComparison)
setenv("UCX_TLS", "^cuda_ipc", 1); // disable cuda_ipc for testing for mpi
}
auto constexpr processNum = 2;
auto constexpr deviceNum = 2;
auto const& modelName = std::get<0>(GetParam());
auto constexpr controllerRank = 0;
@ -902,6 +936,10 @@ TEST_P(ConditionalDisaggParamsTest, DisaggTokenComparison)
{
GTEST_SKIP() << " need " << processNum << " processes but got " << commSize << " mpi processes, skip test.";
}
if (mDeviceCount < deviceNum)
{
GTEST_SKIP() << " need " << deviceNum << " devices but got " << mDeviceCount << " devices, skip test.";
}
bool isContext = commRank == 0;
bool isGeneration = commRank == 1;
@ -1082,38 +1120,47 @@ TEST_P(ConditionalDisaggParamsTest, DisaggTokenComparison)
}
INSTANTIATE_TEST_SUITE_P(GptDisaggSymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(2), testing::Values(std::vector<std::string>{"gpt", "gpt"}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}}), testing::Values(std::vector<int>{1, 0}),
testing::Values(0)),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(GptDisaggSymmetricExecutorTest2, DisaggParamsTest,
testing::Combine(testing::Values(2), testing::Values(std::vector<std::string>{"gpt", "gpt"}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}}), testing::Values(std::vector<int>{1, 0}),
testing::Values(1)),
testing::Combine( //
testing::Values(2), // processNum
testing::Values(std::vector<std::string>{"gpt", "gpt"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0}, {1}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0}, {1}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0, 1) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(GptDisaggSymmetricExecutorMixedTest, DisaggParamsTest,
testing::Combine(testing::Values(2), testing::Values(std::vector<std::string>{"gpt", "gpt"}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}}), testing::Values(std::vector<int>{2, 2}),
testing::Values(1)),
testing::Combine( //
testing::Values(2), // processNum
testing::Values(std::vector<std::string>{"gpt", "gpt"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0}, {1}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0}, {1}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kMIXED, InstanceRole::kMIXED}), // instanceRoles
testing::Values(1) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(GptSingleDeviceDisaggSymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(2), testing::Values(std::vector<std::string>{"gpt", "gpt"}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}}),
testing::Values(std::vector<std::vector<int>>{{0}, {0}}), testing::Values(std::vector<int>{1, 0}),
testing::Values(0)),
testing::Combine( //
testing::Values(2), // processNum
testing::Values(std::vector<std::string>{"gpt", "gpt"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0}, {1}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0}, {0}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(GptSingleDeviceDisaggSymmetricExecutorMixedTest, DisaggParamsTest,
testing::Combine(testing::Values(2), testing::Values(std::vector<std::string>{"gpt", "gpt"}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}}),
testing::Values(std::vector<std::vector<int>>{{0}, {0}}), testing::Values(std::vector<int>{2, 2}),
testing::Values(1)),
testing::Combine( //
testing::Values(2), // processNum
testing::Values(std::vector<std::string>{"gpt", "gpt"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0}, {1}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0}, {0}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kMIXED, InstanceRole::kMIXED}), // instanceRoles
testing::Values(1) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(GptConditionalDisaggSymmetricExecutorTest, ConditionalDisaggParamsTest,
@ -1123,169 +1170,257 @@ INSTANTIATE_TEST_SUITE_P(LlamaConditionalDisaggSymmetricExecutorTest, Conditiona
testing::Combine(testing::Values("llama_tp1_pp1_cp1")), generateTestNameCondDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaTP2DisaggSymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(4),
testing::Values(std::vector<std::string>{"llama_tp2_pp1_cp1", "llama_tp2_pp1_cp1"}),
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3}}),
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3}}), testing::Values(std::vector<int>{1, 0}),
testing::Values(0)),
testing::Combine( //
testing::Values(4), // processNum
testing::Values(std::vector<std::string>{"llama_tp2_pp1_cp1", "llama_tp2_pp1_cp1"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaPP2DisaggSymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(4),
testing::Values(std::vector<std::string>{"llama_tp1_pp2_cp1", "llama_tp1_pp2_cp1"}),
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3}}),
testing::Values(std::vector<std::vector<int>>{{1, 0}, {3, 2}}), testing::Values(std::vector<int>{1, 0}),
testing::Values(0)),
testing::Combine( //
testing::Values(4), // processNum
testing::Values(std::vector<std::string>{"llama_tp1_pp2_cp1", "llama_tp1_pp2_cp1"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{1, 0}, {3, 2}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaTP2DisaggSymmetricExecutorMixedTest, DisaggParamsTest,
testing::Combine(testing::Values(2), testing::Values(std::vector<std::string>{"llama_tp2_pp1_cp1"}),
testing::Values(std::vector<std::vector<int>>{{0, 1}}), testing::Values(std::vector<std::vector<int>>{{0, 1}}),
testing::Values(std::vector<int>{2}), testing::Values(0)),
testing::Combine( //
testing::Values(2), // processNum
testing::Values(std::vector<std::string>{"llama_tp2_pp1_cp1"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0, 1}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0, 1}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kMIXED}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaPP2DisaggSymmetricExecutorMixedTest, DisaggParamsTest,
testing::Combine(testing::Values(2), testing::Values(std::vector<std::string>{"llama_tp1_pp2_cp1"}),
testing::Values(std::vector<std::vector<int>>{{0, 1}}), testing::Values(std::vector<std::vector<int>>{{0, 1}}),
testing::Values(std::vector<int>{2}), testing::Values(0)),
testing::Combine( //
testing::Values(2), // processNum
testing::Values(std::vector<std::string>{"llama_tp1_pp2_cp1"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0, 1}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0, 1}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kMIXED}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaTP2PP2DisaggSymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(8),
testing::Values(std::vector<std::string>{"llama_tp2_pp2_cp1", "llama_tp2_pp2_cp1"}),
testing::Values(std::vector<std::vector<int>>{{0, 1, 2, 3}, {4, 5, 6, 7}}),
testing::Values(std::vector<std::vector<int>>{{2, 3, 0, 1}, {2, 3, 0, 1}}),
testing::Values(std::vector<int>{1, 0}), testing::Values(0)),
testing::Combine( //
testing::Values(8), // processNum
testing::Values(std::vector<std::string>{"llama_tp2_pp2_cp1", "llama_tp2_pp2_cp1"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0, 1, 2, 3}, {4, 5, 6, 7}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{2, 3, 0, 1}, {2, 3, 0, 1}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaConPP2GenTP2DisaggAsymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(4),
testing::Values(std::vector<std::string>{"llama_tp1_pp2_cp1", "llama_tp2_pp1_cp1"}),
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3}}), // (1,0) (2,3)
testing::Values(std::vector<std::vector<int>>{{1, 0}, {2, 3}}), testing::Values(std::vector<int>{1, 0}),
testing::Values(0)),
testing::Combine( //
testing::Values(4), // processNum
testing::Values(std::vector<std::string>{"llama_tp1_pp2_cp1", "llama_tp2_pp1_cp1"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3}}), // (1,0) (2,3) // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{1, 0}, {2, 3}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaConTP2GenPP2DisaggAsymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(4),
testing::Values(std::vector<std::string>{"llama_tp2_pp1_cp1", "llama_tp1_pp2_cp1"}),
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3}}), // (0,1), (3,2)
testing::Values(std::vector<std::vector<int>>{{0, 1}, {3, 2}}), testing::Values(std::vector<int>{1, 0}),
testing::Values(0)),
testing::Combine( //
testing::Values(4), // processNum
testing::Values(std::vector<std::string>{"llama_tp2_pp1_cp1", "llama_tp1_pp2_cp1"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3}}), // (0,1), (3,2)// participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0, 1}, {3, 2}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaConTP2PP2GenPP2DisaggAsymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(6),
testing::Values(std::vector<std::string>{"llama_tp2_pp2_cp1", "llama_tp1_pp2_cp1"}),
testing::Values(std::vector<std::vector<int>>{{0, 1, 2, 3}, {4, 5}}), // (2,3,0,1) , (5,4)
testing::Values(std::vector<std::vector<int>>{{2, 3, 0, 1}, {1, 0}}), testing::Values(std::vector<int>{1, 0}),
testing::Values(0)),
testing::Combine( //
testing::Values(6), // processNum
testing::Values(std::vector<std::string>{"llama_tp2_pp2_cp1", "llama_tp1_pp2_cp1"}), // modelNames
testing::Values(
std::vector<std::vector<int>>{{0, 1, 2, 3}, {4, 5}}), // (2,3,0,1) , (5,4)// participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{2, 3, 0, 1}, {1, 0}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaConTP2PP2GenTP2DisaggAsymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(6),
testing::Values(std::vector<std::string>{"llama_tp2_pp2_cp1", "llama_tp2_pp1_cp1"}),
testing::Values(std::vector<std::vector<int>>{{0, 1, 2, 3}, {4, 5}}), // (2,3,0,1), (4,5)
testing::Values(std::vector<std::vector<int>>{{2, 3, 0, 1}, {0, 1}}), testing::Values(std::vector<int>{1, 0}),
testing::Values(0)),
testing::Combine( //
testing::Values(6), // processNum
testing::Values(std::vector<std::string>{"llama_tp2_pp2_cp1", "llama_tp2_pp1_cp1"}), // modelNames
testing::Values(
std::vector<std::vector<int>>{{0, 1, 2, 3}, {4, 5}}), // (2,3,0,1), (4,5)// participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{2, 3, 0, 1}, {0, 1}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaConTP2PP1GenTP2PP2DisaggAsymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(6),
testing::Values(std::vector<std::string>{"llama_tp2_pp1_cp1", "llama_tp2_pp2_cp1"}),
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3, 4, 5}}), // (0,1) , (4,5,2,3)%4
testing::Values(std::vector<std::vector<int>>{{0, 1}, {0, 1, 2, 3}}), testing::Values(std::vector<int>{1, 0}),
testing::Values(0)),
testing::Combine( //
testing::Values(6), // processNum
testing::Values(std::vector<std::string>{"llama_tp2_pp1_cp1", "llama_tp2_pp2_cp1"}), // modelNames
testing::Values(
std::vector<std::vector<int>>{{0, 1}, {2, 3, 4, 5}}), // (0,1) , (4,5,2,3)%4// participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0, 1}, {0, 1, 2, 3}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaConTP2GenPP4DisaggAsymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(6),
testing::Values(std::vector<std::string>{"llama_tp2_pp1_cp1", "llama_tp1_pp4_cp1"}),
testing::Values(std::vector<std::vector<int>>{{4, 5}, {0, 1, 2, 3}}), // (4,5) ,(3,2,1,0)
testing::Values(std::vector<std::vector<int>>{{0, 1}, {3, 2, 1, 0}}), testing::Values(std::vector<int>{1, 0}),
testing::Values(0)),
testing::Combine( //
testing::Values(6), // processNum
testing::Values(std::vector<std::string>{"llama_tp2_pp1_cp1", "llama_tp1_pp4_cp1"}), // modelNames
testing::Values(
std::vector<std::vector<int>>{{4, 5}, {0, 1, 2, 3}}), // (4,5) ,(3,2,1,0)// participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0, 1}, {3, 2, 1, 0}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaCon4TP1Gen1TP4DisaggAsymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(8),
testing::Values(std::vector<std::string>{
"llama_tp1_pp1_cp1", "llama_tp1_pp1_cp1", "llama_tp1_pp1_cp1", "llama_tp1_pp1_cp1", "llama_tp4_pp1_cp1"}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {2}, {3}, {4, 5, 6, 7}}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {2}, {3}, {0, 1, 2, 3}}),
testing::Values(std::vector<int>{1, 1, 1, 1, 0}), testing::Values(4)),
testing::Combine( //
testing::Values(8), // processNum
testing::Values(std::vector<std::string>{"llama_tp1_pp1_cp1", "llama_tp1_pp1_cp1", "llama_tp1_pp1_cp1",
"llama_tp1_pp1_cp1", "llama_tp4_pp1_cp1"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {2}, {3}, {4, 5, 6, 7}}), // participantIdsEachInstance
testing::Values(
std::vector<std::vector<int>>{{0}, {1}, {2}, {3}, {0, 1, 2, 3}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kCONTEXT,
InstanceRole::kCONTEXT, InstanceRole::kCONTEXT, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(4) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP1Gen2TP2AndPP2DisaggAsymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(6),
testing::Combine( //
testing::Values(6), // processNum
testing::Values(std::vector<std::string>{
"llama_tp1_pp1_cp1", "llama_tp1_pp1_cp1", "llama_tp2_pp1_cp1", "llama_tp1_pp2_cp1"}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {2, 3}, {4, 5}}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {2, 3}, {1, 0}}),
testing::Values(std::vector<int>{1, 1, 0, 0}), testing::Values(0)),
"llama_tp1_pp1_cp1", "llama_tp1_pp1_cp1", "llama_tp2_pp1_cp1", "llama_tp1_pp2_cp1"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {2, 3}, {4, 5}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {2, 3}, {1, 0}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kCONTEXT,
InstanceRole::kGENERATION, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP1Gen2PP2DisaggAsymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(6),
testing::Combine( //
testing::Values(6), // processNum
testing::Values(std::vector<std::string>{
"llama_tp1_pp1_cp1", "llama_tp1_pp1_cp1", "llama_tp1_pp2_cp1", "llama_tp1_pp2_cp1"}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {2, 3}, {4, 5}}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {3, 2}, {1, 0}}),
testing::Values(std::vector<int>{1, 1, 0, 0}), testing::Values(0)),
"llama_tp1_pp1_cp1", "llama_tp1_pp1_cp1", "llama_tp1_pp2_cp1", "llama_tp1_pp2_cp1"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {2, 3}, {4, 5}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {3, 2}, {1, 0}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kCONTEXT,
InstanceRole::kGENERATION, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaCon4TP1Gen1TP2PP2DisaggAsymmetricExecutorTest, DisaggParamsTest,
testing::Combine(testing::Values(8),
testing::Values(std::vector<std::string>{
"llama_tp1_pp1_cp1", "llama_tp1_pp1_cp1", "llama_tp1_pp1_cp1", "llama_tp1_pp1_cp1", "llama_tp2_pp2_cp1"}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {2}, {3}, {4, 5, 6, 7}}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {2}, {3}, {2, 3, 0, 1}}),
testing::Values(std::vector<int>{1, 1, 1, 1, 0}), testing::Values(4)),
testing::Combine( //
testing::Values(8), // processNum
testing::Values(std::vector<std::string>{"llama_tp1_pp1_cp1", "llama_tp1_pp1_cp1", "llama_tp1_pp1_cp1",
"llama_tp1_pp1_cp1", "llama_tp2_pp2_cp1"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {2}, {3}, {4, 5, 6, 7}}), // participantIdsEachInstance
testing::Values(
std::vector<std::vector<int>>{{0}, {1}, {2}, {3}, {2, 3, 0, 1}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kCONTEXT,
InstanceRole::kCONTEXT, InstanceRole::kCONTEXT, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(4) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP1Gen2TP2DisaaggOrchestrator, DisaggOrchestratorParamsTest,
testing::Combine(testing::Values(7),
testing::Values(std::vector<std::string>{"llama_tp1_pp1", "llama_tp1_pp1", "llama_tp2_pp1", "llama_tp2_pp1"}),
testing::Values(std::vector<std::vector<int>>{{1}, {2}, {3, 4}, {5, 6}}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {2, 3}, {0, 1}}),
testing::Values(std::vector<int>{1, 1, 0, 0}), testing::Values(0)),
testing::Combine( //
testing::Values(7), // processNum
testing::Values(
std::vector<std::string>{"llama_tp1_pp1", "llama_tp1_pp1", "llama_tp2_pp1", "llama_tp2_pp1"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{1}, {2}, {3, 4}, {5, 6}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {2, 3}, {0, 1}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kCONTEXT,
InstanceRole::kGENERATION, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
// for disaggOrchestrator 1->0, 2->1, 3->2, 4->3, 5->0, 6->1
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP2Gen2TP1DisaaggOrchestrator, DisaggOrchestratorParamsTest,
testing::Combine(testing::Values(7),
testing::Values(std::vector<std::string>{"llama_tp2_pp1", "llama_tp2_pp1", "llama_tp1_pp1", "llama_tp1_pp1"}),
testing::Values(std::vector<std::vector<int>>{{1, 2}, {3, 4}, {5}, {6}}),
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3}, {0}, {1}}),
testing::Values(std::vector<int>{1, 1, 0, 0}), testing::Values(0)),
testing::Combine( //
testing::Values(7), // processNum
testing::Values(
std::vector<std::string>{"llama_tp2_pp1", "llama_tp2_pp1", "llama_tp1_pp1", "llama_tp1_pp1"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{1, 2}, {3, 4}, {5}, {6}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3}, {0}, {1}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kCONTEXT,
InstanceRole::kGENERATION, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP1Gen2PP2DisaaggOrchestrator, DisaggOrchestratorParamsTest,
testing::Combine(testing::Values(7),
testing::Values(std::vector<std::string>{"llama_tp1_pp1", "llama_tp1_pp1", "llama_tp1_pp2", "llama_tp1_pp2"}),
testing::Values(std::vector<std::vector<int>>{{1}, {2}, {3, 4}, {5, 6}}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {3, 2}, {1, 0}}),
testing::Values(std::vector<int>{1, 1, 0, 0}), testing::Values(0)),
testing::Combine( //
testing::Values(7), // processNum
testing::Values(
std::vector<std::string>{"llama_tp1_pp1", "llama_tp1_pp1", "llama_tp1_pp2", "llama_tp1_pp2"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{1}, {2}, {3, 4}, {5, 6}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {3, 2}, {1, 0}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kCONTEXT,
InstanceRole::kGENERATION, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP1Gen1TP2PP2DisaaggOrchestrator, DisaggOrchestratorParamsTest,
testing::Combine(testing::Values(7),
testing::Values(std::vector<std::string>{"llama_tp1_pp1", "llama_tp1_pp1", "llama_tp2_pp2"}),
testing::Values(std::vector<std::vector<int>>{{1}, {2}, {3, 4, 5, 6}}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {0, 1, 2, 3}}),
testing::Values(std::vector<int>{1, 1, 0}), testing::Values(0)),
testing::Combine( //
testing::Values(7), // processNum
testing::Values(std::vector<std::string>{"llama_tp1_pp1", "llama_tp1_pp1", "llama_tp2_pp2"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{1}, {2}, {3, 4, 5, 6}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {0, 1, 2, 3}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{
InstanceRole::kCONTEXT, InstanceRole::kCONTEXT, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP2Gen2TP1DisaaggSpawnOrchestrator, DisaggOrchestratorParamsTest,
testing::Combine(testing::Values(1),
testing::Values(std::vector<std::string>{"llama_tp2_pp1", "llama_tp2_pp1", "llama_tp1_pp1", "llama_tp1_pp1"}),
testing::Values(std::vector<std::vector<int>>{{1, 2}, {3, 4}, {5}, {6}}),
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3}, {0}, {1}}),
testing::Values(std::vector<int>{1, 1, 0, 0}), testing::Values(0)),
testing::Combine( //
testing::Values(1), // processNum
testing::Values(
std::vector<std::string>{"llama_tp2_pp1", "llama_tp2_pp1", "llama_tp1_pp1", "llama_tp1_pp1"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{1, 2}, {3, 4}, {5}, {6}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0, 1}, {2, 3}, {0}, {1}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kCONTEXT,
InstanceRole::kGENERATION, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);
INSTANTIATE_TEST_SUITE_P(LlamaCon2TP1Gen2PP2DisaaggSpawnOrchestrator, DisaggOrchestratorParamsTest,
testing::Combine(testing::Values(1),
testing::Values(std::vector<std::string>{"llama_tp1_pp1", "llama_tp1_pp1", "llama_tp1_pp2", "llama_tp1_pp2"}),
testing::Values(std::vector<std::vector<int>>{{1}, {2}, {3, 4}, {5, 6}}),
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {3, 2}, {1, 0}}),
testing::Values(std::vector<int>{1, 1, 0, 0}), testing::Values(0)),
testing::Combine( //
testing::Values(1), // processNum
testing::Values(
std::vector<std::string>{"llama_tp1_pp1", "llama_tp1_pp1", "llama_tp1_pp2", "llama_tp1_pp2"}), // modelNames
testing::Values(std::vector<std::vector<int>>{{1}, {2}, {3, 4}, {5, 6}}), // participantIdsEachInstance
testing::Values(std::vector<std::vector<int>>{{0}, {1}, {3, 2}, {1, 0}}), // participantDeviceIdsEachInstance
testing::Values(std::vector<InstanceRole>{InstanceRole::kCONTEXT, InstanceRole::kCONTEXT,
InstanceRole::kGENERATION, InstanceRole::kGENERATION}), // instanceRoles
testing::Values(0) // controllerRank
),
generateTestNameDisaggParams);

View File

@ -32,12 +32,19 @@ public:
protected:
void SetUp() override
{
mDeviceCount = tensorrt_llm::common::getDeviceCount();
if (mDeviceCount == 0)
{
GTEST_SKIP() << "No GPUs found";
}
mLogger = std::make_shared<tensorrt_llm::runtime::TllmLogger>();
initTrtLlmPlugins(mLogger.get());
}
void TearDown() override {}
int mDeviceCount{};
std::shared_ptr<nvinfer1::ILogger> mLogger{};
SizeType32 mMaxWaitMs = 300000;
SizeType32 mTrigWarnMs = 10000;