/* * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/iBuffer.h" #include #include #include #include #include #ifndef _WIN32 #include #endif // We rely on SizeType32 being int32_t in some places with weak type checking, // i.e. we're passing void ptr to some function. To prevent mysterious errors // in the future, we trigger a compilation error here if SizeType32 isn't int32_t. static_assert(std::is_same::value); namespace tensorrt_llm::mpi { MPI_Datatype getMpiDtype(MpiType dtype) { #if ENABLE_MULTI_DEVICE static std::unordered_map const dtype_map{ {MpiType::kBYTE, MPI_BYTE}, {MpiType::kHALF, MPI_UINT16_T}, {MpiType::kFLOAT, MPI_FLOAT}, {MpiType::kDOUBLE, MPI_DOUBLE}, {MpiType::kBOOL, MPI_C_BOOL}, {MpiType::kINT8, MPI_INT8_T}, {MpiType::kUINT8, MPI_UINT8_T}, {MpiType::kINT32, MPI_INT32_T}, {MpiType::kUINT32, MPI_UINT32_T}, {MpiType::kINT64, MPI_INT64_T}, {MpiType::kUINT64, MPI_UINT64_T}, {MpiType::kFP8, MPI_UINT8_T}, {MpiType::kBF16, MPI_UINT16_T}, {MpiType::kCHAR, MPI_CHAR}, }; return dtype_map.at(dtype); #else TLLM_THROW("Multi device support is disabled."); #endif } MPI_Op getMpiOp(MpiOp op) { #if ENABLE_MULTI_DEVICE static std::unordered_map const op_map{ {MpiOp::NULLOP, MPI_OP_NULL}, {MpiOp::MAX, MPI_MAX}, {MpiOp::MIN, MPI_MIN}, {MpiOp::SUM, MPI_SUM}, {MpiOp::PROD, MPI_PROD}, {MpiOp::LAND, MPI_LAND}, {MpiOp::BAND, MPI_BAND}, {MpiOp::LOR, MPI_LOR}, {MpiOp::BOR, MPI_BOR}, {MpiOp::LXOR, MPI_LXOR}, {MpiOp::BXOR, MPI_BXOR}, {MpiOp::MINLOC, MPI_MINLOC}, {MpiOp::MAXLOC, MPI_MAXLOC}, {MpiOp::REPLACE, MPI_REPLACE}, }; return op_map.at(op); #else TLLM_THROW("Multi device support is disabled."); #endif // ENABLE_MULTI_DEVICE } namespace { bool mpiInitialized = false; std::recursive_mutex mpiMutex; MpiComm initLocalSession() { #if ENABLE_MULTI_DEVICE MPI_Comm localComm = nullptr; MPI_Comm_split_type(COMM_SESSION, OMPI_COMM_TYPE_HOST, COMM_SESSION.getRank(), MPI_INFO_NULL, &localComm); MpiComm localSession{localComm, false}; #else MpiComm localSession{COMM_SESSION, false}; #endif // ENABLE_MULTI_DEVICE return localSession; } } // namespace std::vector getWorldRanks(MpiComm const& comm) { #if ENABLE_MULTI_DEVICE MPI_Group group = nullptr; MPI_Group worldGroup = nullptr; MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); MPICHECK(MPI_Comm_group(comm, &group)); int groupSize = 0; MPICHECK(MPI_Group_size(group, &groupSize)); std::vector ranks(groupSize); std::vector worldRanks(groupSize); std::iota(ranks.begin(), ranks.end(), 0); MPICHECK(MPI_Group_translate_ranks(group, groupSize, ranks.data(), worldGroup, worldRanks.data())); MPICHECK(MPI_Group_free(&group)); MPICHECK(MPI_Group_free(&worldGroup)); #else std::vector worldRanks{0}; #endif return worldRanks; } int getNumNodes() { #if ENABLE_MULTI_DEVICE TLLM_LOG_WARNING("Number of nodes was not provided, using MPI to determine number of nodes"); // Create a communicator for processes with the same hostname MPI_Comm node_comm; MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &node_comm); // Get rank in node_comm int node_rank; MPI_Comm_rank(node_comm, &node_rank); // Count only rank 0 processes int local_count = (node_rank == 0) ? 1 : 0; int num_nodes = 0; MPI_Allreduce(&local_count, &num_nodes, 1, MPI_INT, MPI_SUM, MPI_COMM_WORLD); MPI_Comm_free(&node_comm); return num_nodes; #else return 1; #endif } void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent) { // double-checked locking if (mpiInitialized) { return; } std::lock_guard lk(mpiMutex); if (mpiInitialized) { return; } #if ENABLE_MULTI_DEVICE int initialized = 0; TLLM_MPI_CHECK(MPI_Initialized(&initialized)); if (!initialized) { TLLM_LOG_INFO("Initializing MPI with thread mode %d", threadMode); int providedMode = 0; auto requiredMode = static_cast(threadMode); MPICHECK(MPI_Init_thread(nullptr, nullptr, requiredMode, &providedMode)); TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed"); std::atexit([]() { MPI_Finalize(); }); /* * We only catch SIGABRT and SIGSEGV because most, of not all errors in the worker will cause one of these 2 * signals. Signals like SIGINT and SIGTERM should be issued to the parent and should terminate MPI workers * correctly. */ for (int sig : {SIGABRT, SIGSEGV}) { __sighandler_t previousHandler = nullptr; if (forwardAbortToParent) { previousHandler = std::signal(sig, [](int signal) { #ifndef _WIN32 pid_t parentProcessId = getppid(); kill(parentProcessId, SIGKILL); #endif MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); }); } else { previousHandler = std::signal(sig, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); }); } TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed"); } // ensure local MPI communicator is initialized MpiComm::localSession(); TLLM_LOG_INFO("Initialized MPI"); } #endif // ENABLE_MULTI_DEVICE mpiInitialized = true; } void MpiComm::barrier() const { couldUseMPI(); #if ENABLE_MULTI_DEVICE MPICHECK(MPI_Barrier(mComm)); #else TLLM_THROW("Multi device support is disabled."); #endif // ENABLE_MULTI_DEVICE } #if ENABLE_MULTI_DEVICE template >>> size_t invokeChunked(TMpiFunc func, TBase* buffer, size_t size, MPI_Datatype dtype, TArgs... args) { constexpr auto maxP1 = static_cast(std::numeric_limits::max()) + 1; if (TLLM_LIKELY(size < maxP1)) { MPICHECK(func(buffer, size, dtype, args...)); return 1; } constexpr size_t alignment = 256; int elementSize = 1; MPICHECK(MPI_Type_size(dtype, &elementSize)); elementSize = std::min(elementSize, alignment); // We cap at max alignment-bytes chunks that can be sent at once. auto const step = maxP1 - (alignment / elementSize); using TCast = std::conditional_t, uint8_t const, uint8_t>; size_t count = 0; while (size != 0) { auto currentStep = static_cast(std::min(size, step)); MPICHECK(func(buffer, currentStep, dtype, args...)); size -= currentStep; size_t diff = static_cast(currentStep) * elementSize; buffer = static_cast(buffer) + diff; ++count; } return count; } #endif // ENABLE_MULTI_DEVICE std::unique_ptr MpiComm::bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const { couldUseMPI(); std::unique_ptr r = std::make_unique(); #if ENABLE_MULTI_DEVICE invokeChunked(MPI_Ibcast, buffer, size, getMpiDtype(dtype), root, mComm, &r->mRequest); #else TLLM_THROW("Multi device support is disabled."); #endif // ENABLE_MULTI_DEVICE return r; } std::unique_ptr MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const { couldUseMPI(); return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root); } void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const { couldUseMPI(); #if ENABLE_MULTI_DEVICE invokeChunked(MPI_Bcast, buffer, size, getMpiDtype(dtype), root, mComm); #else TLLM_THROW("Multi device support is disabled."); #endif // ENABLE_MULTI_DEVICE } void MpiComm::bcast(runtime::IBuffer& buf, int root) const { couldUseMPI(); bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root); } std::unique_ptr MpiComm::sendAsync( void const* buffer, size_t size, MpiType dtype, int dest, MpiTag tag) const { couldUseMPI(); TLLM_LOG_DEBUG("start MPI_Isend with dest %d, tag %d, size %d", dest, static_cast(tag), size); std::unique_ptr r = std::make_unique(); #if ENABLE_MULTI_DEVICE invokeChunked(MPI_Isend, buffer, size, getMpiDtype(dtype), dest, static_cast(tag), mComm, &r->mRequest); #else TLLM_THROW("Multi device support is disabled."); #endif TLLM_LOG_DEBUG("end MPI_Isend with dest %d, tag %d, size %d", dest, static_cast(tag), size); return r; } std::unique_ptr MpiComm::sendAsync(runtime::IBuffer const& buf, int dest, MpiTag tag) const { couldUseMPI(); return sendAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag); } void MpiComm::sendRawTag(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const { couldUseMPI(); TLLM_LOG_DEBUG("start MPI_Send with dest %d, tag %d, size %d", dest, tag, size); #if ENABLE_MULTI_DEVICE invokeChunked(MPI_Send, buffer, size, getMpiDtype(dtype), dest, tag, mComm); #else TLLM_THROW("Multi device support is disabled."); #endif // ENABLE_MULTI_DEVICE TLLM_LOG_DEBUG("end MPI_Send with dest %d, tag %d, size %d", dest, tag, size); } void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, MpiTag tag) const { couldUseMPI(); sendRawTag(buffer, size, dtype, dest, static_cast(tag)); } void MpiComm::send(runtime::IBuffer const& buf, int dest, MpiTag tag) const { couldUseMPI(); send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag); } MPI_Status MpiComm::recvRawTag(void* buffer, size_t size, MpiType dtype, int source, int tag) const { couldUseMPI(); TLLM_LOG_DEBUG("start MPI_Recv with source %d, tag %d, size %d", source, tag, size); MPI_Status status{}; #if ENABLE_MULTI_DEVICE invokeChunked(MPI_Recv, buffer, size, getMpiDtype(dtype), source, tag, mComm, &status); #else TLLM_THROW("Multi device support is disabled."); #endif // ENABLE_MULTI_DEVICE TLLM_LOG_DEBUG("end MPI_Recv with source %d, tag %d, size %d", source, tag, size); return status; } MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, MpiTag tag) const { couldUseMPI(); return recvRawTag(buffer, size, dtype, source, static_cast(tag)); } MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, MpiTag tag) const { couldUseMPI(); return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag); } MpiComm MpiComm::split(int color, int key) const { MPI_Comm splitComm = nullptr; #if ENABLE_MULTI_DEVICE MPICHECK(MPI_Comm_split(mComm, color, key, &splitComm)); #else TLLM_THROW("Multi device support is disabled."); #endif // ENABLE_MULTI_DEVICE return MpiComm{splitComm, true}; } MpiComm const& MpiComm::setRawSessionByFortran(int64_t fortranHandle) { #if ENABLE_MULTI_DEVICE auto comm = MpiComm{MPI_Comm_f2c(fortranHandle), false}; #else TLLM_THROW("Multi device support is disabled."); auto comm = MpiComm(nullptr, false); #endif // ENABLE_MULTI_DEVICE return MpiComm::setSession(std::move(comm)); } void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const { couldUseMPI(); #if ENABLE_MULTI_DEVICE MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm)); #else TLLM_THROW("Multi device support is disabled."); #endif // ENABLE_MULTI_DEVICE } void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const { couldUseMPI(); #if ENABLE_MULTI_DEVICE MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm)); #else TLLM_THROW("Multi device support is disabled."); #endif // ENABLE_MULTI_DEVICE } void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf, std::vector const& recvcounts, std::vector const& displs, MpiType recvtype) const { couldUseMPI(); #if ENABLE_MULTI_DEVICE MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(), getMpiDtype(recvtype), mComm)); #else TLLM_THROW("Multi device support is disabled."); #endif // ENABLE_MULTI_DEVICE } void MpiComm::mprobeRawTag(int source, int tag, MPI_Message* msg, MPI_Status* status) const { #if ENABLE_MULTI_DEVICE MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status)); #else TLLM_THROW("Multi device support is disabled."); #endif // ENABLE_MULTI_DEVICE } void MpiComm::mprobe(int source, MpiTag tag, MPI_Message* msg, MPI_Status* status) const { mprobeRawTag(source, static_cast(tag), msg, status); } bool MpiComm::improbe(int source, MpiTag tag, MPI_Message* msg, MPI_Status* status) const { #if ENABLE_MULTI_DEVICE int flag{0}; MPICHECK(MPI_Improbe(source, static_cast(tag), mComm, &flag, msg, status)); return flag != 0; #else TLLM_THROW("Multi device support is disabled."); return false; #endif } bool MpiComm::iprobe(int source, MpiTag tag, MPI_Status* status) const { #if ENABLE_MULTI_DEVICE int flag{0}; MPICHECK(MPI_Iprobe(source, static_cast(tag), mComm, &flag, status)); return flag != 0; #else TLLM_THROW("Multi device support is disabled."); return false; #endif } void MpiComm::recvPoll(int source, MpiTag tag, int periodMs) const { couldUseMPI(); MPI_Status status; while (!iprobe(source, tag, &status)) { std::this_thread::sleep_for(std::chrono::milliseconds(periodMs)); } } int MpiComm::getRank() const { int rank = 0; #if ENABLE_MULTI_DEVICE MPICHECK(MPI_Comm_rank(mComm, &rank)); #endif return rank; } int MpiComm::getSize() const { int world_size = 1; #if ENABLE_MULTI_DEVICE MPICHECK(MPI_Comm_size(mComm, &world_size)); #endif return world_size; } MpiComm const& MpiComm::world() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); static MpiComm commWorld{MPI_COMM_WORLD, false}; initialize(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return commWorld; } MpiComm& MpiComm::mutableSession() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); static MpiComm commSession{MPI_COMM_WORLD, false}; initialize(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return commSession; } MpiComm& MpiComm::mutableLocalSession() { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); static MpiComm localSession = initLocalSession(); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); return localSession; } void MpiComm::refreshLocalSession() { #if ENABLE_MULTI_DEVICE static std::mutex mutex; std::unique_lock lock(mutex); auto initSessionRanks = getWorldRanks(MpiComm::session()); auto localSessionRanks = getWorldRanks(MpiComm::localSession()); // Add to intersectionRanks in order of initSessionRanks std::vector intersectionRanks; std::unordered_set localSessionRanksSet(localSessionRanks.begin(), localSessionRanks.end()); for (auto rank : initSessionRanks) { if (localSessionRanksSet.find(rank) != localSessionRanksSet.end()) { intersectionRanks.push_back(rank); } } MPI_Group worldGroup = nullptr; MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); MPI_Group localGroup = nullptr; MPICHECK(MPI_Group_incl(worldGroup, intersectionRanks.size(), intersectionRanks.data(), &localGroup)); MPI_Comm localComm = nullptr; MPICHECK(MPI_Comm_create_group(MPI_COMM_WORLD, localGroup, intersectionRanks.front(), &localComm)); MpiComm::mutableLocalSession().mFreeComm = true; MpiComm::mutableLocalSession() = MpiComm{localComm, false}; TLLM_LOG_INFO("Refreshed the MPI local session"); #endif // ENABLE_MULTI_DEVICE } MpiComm::MpiComm(MPI_Comm g, bool freeComm) : mComm{g} , mFreeComm{freeComm} { TLLM_CHECK(mComm != MPI_COMM_NULL); } MpiComm::~MpiComm() noexcept { #if ENABLE_MULTI_DEVICE if (mFreeComm && mComm) { if (MPI_Comm_free(&mComm) != MPI_SUCCESS) { TLLM_LOG_ERROR("MPI_Comm_free failed"); } } #endif // ENABLE_MULTI_DEVICE } MpiComm::MpiComm(MpiComm&& comm) noexcept : mComm{comm.mComm} , mFreeComm{comm.mFreeComm} { comm.mFreeComm = false; } MpiComm& MpiComm::operator=(MpiComm&& comm) noexcept { this->~MpiComm(); mComm = comm.mComm; mFreeComm = comm.mFreeComm; comm.mFreeComm = false; return *this; } MpiWaitThread::MpiWaitThread(std::string name, std::function funcWait, std::function funcSetup) : mName{name.c_str()} , mFuncWait{funcWait} , mFuncSetup{funcSetup} { TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); mThread = std::make_unique(&MpiWaitThread::sideThread, this); TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); } MpiWaitThread::~MpiWaitThread() { TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); waitStop(); mShouldExit.store(true); notifyStart(); mThread->join(); mThread.reset(nullptr); TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); } void MpiWaitThread::sideThread() { if (mFuncSetup) { mFuncSetup(); } while (!mShouldExit.load()) { notifyStop(); waitStart(); mFuncWait(); } } void MpiWaitThread::waitStart() { TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); std::unique_lock lock(mMutex); mCondVar.wait(lock, [this] { return mRunning; }); TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); } void MpiWaitThread::waitStop() { TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); std::unique_lock lock(mMutex); mCondVar.wait(lock, [this] { return !mRunning; }); TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); } void MpiWaitThread::notifyStart() { TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); std::lock_guard lock(mMutex); mRunning = true; mCondVar.notify_one(); TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); } void MpiWaitThread::notifyStop() { TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); std::lock_guard lock(mMutex); mRunning = false; mCondVar.notify_one(); TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); } } // namespace tensorrt_llm::mpi