TensorRT-LLMs/cpp/tests/common/mpiUtilsTest.cpp
Kaiyu Xie 728cc0044b
Update TensorRT-LLM (#1233)
* Update TensorRT-LLM

---------

Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-03-05 18:32:53 +08:00

232 lines
6.4 KiB
C++

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#if ENABLE_MULTI_DEVICE
#include "tensorrt_llm/plugins/common/plugin.h"
#include <nccl.h>
#endif // ENABLE_MULTI_DEVICE
#include <algorithm>
namespace mpi = tensorrt_llm::mpi;
namespace tr = tensorrt_llm::runtime;
TEST(MPIUtils, RankAndSize)
{
auto& comm = mpi::MpiComm::world();
auto const rank = comm.getRank();
EXPECT_LE(0, rank);
auto const size = comm.getSize();
EXPECT_LE(rank, size);
}
template <typename T>
void testBroadcast()
{
auto& comm = mpi::MpiComm::world();
auto const rank = comm.getRank();
auto constexpr expectedValue = static_cast<T>(42);
auto constexpr root = 0;
auto value = rank == root ? expectedValue : T{};
comm.bcastValue(value, root);
EXPECT_EQ(value, expectedValue);
}
TEST(MPIUtils, Broadcast)
{
testBroadcast<std::byte>();
testBroadcast<float>();
testBroadcast<double>();
testBroadcast<bool>();
testBroadcast<std::int8_t>();
testBroadcast<std::uint8_t>();
testBroadcast<std::int32_t>();
testBroadcast<std::uint32_t>();
testBroadcast<std::int64_t>();
testBroadcast<std::uint64_t>();
}
#if ENABLE_MULTI_DEVICE
TEST(MPIUtils, BroadcastNcclId)
{
auto& comm = mpi::MpiComm::world();
auto const rank = comm.getRank();
auto constexpr root = 0;
ncclUniqueId id;
if (rank == root)
{
ncclGetUniqueId(&id);
}
else
{
std::memset(&id, 0, sizeof(id));
}
comm.bcastValue(id, root);
EXPECT_TRUE(std::any_of(
id.internal, id.internal + sizeof(id.internal) / sizeof(id.internal[0]), [](auto x) { return x != 0; }));
}
TEST(MPIUtils, GlobalSessionHandle)
{
EXPECT_EQ(tensorrt_llm::plugins::getCommSessionHandle(), &COMM_SESSION);
}
#endif // ENABLE_MULTI_DEVICE
template <typename T>
void testBroadcastBuffer()
{
using BufferType = T;
auto& comm = mpi::MpiComm::world();
auto const rank = comm.getRank();
auto constexpr root = 0;
auto constexpr expectedValue = static_cast<BufferType>(42);
auto const value = rank == root ? expectedValue : BufferType{};
auto constexpr bufferSize = 1024;
auto buffer = tr::BufferManager::cpu(bufferSize, tr::TRTDataType<BufferType>::value);
auto* data = tr::bufferCast<BufferType>(*buffer);
std::fill(data, data + bufferSize, value);
comm.bcast(*buffer, root);
EXPECT_TRUE(std::all_of(data, data + bufferSize, [&](auto x) { return x == expectedValue; }));
}
TEST(MPIUtils, BroadcastBuffer)
{
testBroadcastBuffer<float>();
testBroadcastBuffer<bool>();
testBroadcastBuffer<std::int8_t>();
testBroadcastBuffer<std::uint8_t>();
testBroadcastBuffer<std::int32_t>();
testBroadcastBuffer<std::uint32_t>();
testBroadcastBuffer<std::int64_t>();
testBroadcastBuffer<std::uint64_t>();
}
template <typename T>
void testSendRecv()
{
auto& comm = mpi::MpiComm::world();
auto const rank = comm.getRank();
auto constexpr expectedValue = static_cast<T>(42);
auto constexpr tag = 0;
if (rank == 0)
{
comm.send(expectedValue, 1, tag);
}
else if (rank == 1)
{
T value{};
comm.recv(value, 0, tag);
EXPECT_EQ(value, expectedValue);
}
}
TEST(MPIUtils, SendRecv)
{
auto& comm = mpi::MpiComm::world();
if (comm.getSize() < 2)
{
GTEST_SKIP() << "Test requires at least 2 processes";
}
testSendRecv<float>();
testSendRecv<bool>();
testSendRecv<std::int8_t>();
testSendRecv<std::uint8_t>();
testSendRecv<std::int32_t>();
testSendRecv<std::uint32_t>();
testSendRecv<std::int64_t>();
testSendRecv<std::uint64_t>();
}
template <typename T>
void testSendMRecv()
{
auto& comm = mpi::MpiComm::world();
auto const rank = comm.getRank();
auto constexpr expectedValue = static_cast<T>(42);
auto constexpr tag = 0;
if (rank == 0)
{
comm.send(expectedValue, 1, tag);
}
else if (rank == 1)
{
MPI_Message msg;
MPI_Status status;
comm.mprobe(0, tag, &msg, &status);
int count = 0;
MPICHECK(MPI_Get_count(&status, getMpiDtype(mpi::MpiTypeConverter<std::remove_cv_t<T>>::value), &count));
EXPECT_EQ(1, count);
T value{};
MPICHECK(
MPI_Mrecv(&value, count, getMpiDtype(mpi::MpiTypeConverter<std::remove_cv_t<T>>::value), &msg, &status));
EXPECT_EQ(value, expectedValue);
}
}
TEST(MPIUtils, SendMRecv)
{
auto& comm = mpi::MpiComm::world();
if (comm.getSize() < 2)
{
GTEST_SKIP() << "Test requires at least 2 processes";
}
testSendMRecv<float>();
testSendMRecv<bool>();
testSendMRecv<std::int8_t>();
testSendMRecv<std::uint8_t>();
testSendMRecv<std::int32_t>();
testSendMRecv<std::uint32_t>();
testSendMRecv<std::int64_t>();
testSendMRecv<std::uint64_t>();
}
TEST(MPIUtils, SessionCommunicator)
{
auto& world = mpi::MpiComm::world();
if (world.getSize() < 2)
{
GTEST_SKIP() << "Test requires at least 2 processes";
}
auto const rank = world.getRank();
auto const size = world.getSize();
auto const sessionSize = (size + 1) / 2;
auto const sessionColor = rank / sessionSize;
auto sessionRank = rank % sessionSize;
auto& session = mpi::MpiComm::session();
EXPECT_EQ(session, world);
session = world.split(sessionColor, sessionRank);
EXPECT_EQ(session, mpi::MpiComm::session());
EXPECT_NE(session, world);
EXPECT_EQ(session.getRank(), sessionRank);
EXPECT_LE(sessionSize - 1, session.getSize());
EXPECT_LE(session.getSize(), sessionSize);
session = session.split(sessionRank, 0);
EXPECT_EQ(session, mpi::MpiComm::session());
EXPECT_EQ(session.getRank(), 0);
EXPECT_EQ(session.getSize(), 1);
}