/* * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/runtime/utils/multiDeviceUtils.h" #include #include #ifdef ENABLE_FP8 #include #endif #ifdef ENABLE_BF16 #include #endif #include #include #include #include #include #include #include #if ENABLE_MULTI_DEVICE #include #else // Dummy defines to avoid #if in wider places. typedef void* MPI_Datatype; typedef void* MPI_Comm; typedef void* MPI_Request; typedef void* MPI_Message; typedef void* MPI_Op; typedef struct MPI_Status { int dummy; } MPI_Status; #define MPI_THREAD_SINGLE 0 #define MPI_THREAD_FUNNELED 1 #define MPI_THREAD_SERIALIZED 2 #define MPI_THREAD_MULTIPLE 3 #define MPI_COMM_WORLD ((MPI_Comm) 0x44000000) #define MPI_COMM_NULL ((MPI_Comm) 0x04000000) #endif // ENABLE_MULTI_DEVICE #include #include #define MPICHECK(cmd) TLLM_MPI_CHECK(cmd) namespace tensorrt_llm::runtime { class IBuffer; } // A wrapper module of the MPI library. namespace tensorrt_llm::mpi { // A wrapper of MPI data type. MpiType::{data_type} enum class MpiType { kBYTE, kHALF, kFLOAT, kDOUBLE, kBOOL, kINT8, kUINT8, kINT32, kUINT32, kINT64, kUINT64, kFP8, kBF16, kCHAR, }; //! \brief For converting a C++ data type to a TensorRT data type. template struct MpiTypeConverter { }; template <> struct MpiTypeConverter { static constexpr auto value = MpiType::kBYTE; }; template <> struct MpiTypeConverter { static constexpr auto value = MpiType::kHALF; }; template <> struct MpiTypeConverter { static constexpr auto value = MpiType::kFLOAT; }; template <> struct MpiTypeConverter { static constexpr auto value = MpiType::kDOUBLE; }; template <> struct MpiTypeConverter { static constexpr auto value = MpiType::kBOOL; }; template <> struct MpiTypeConverter { static constexpr auto value = MpiType::kINT8; }; template <> struct MpiTypeConverter { static constexpr auto value = MpiType::kUINT8; }; template <> struct MpiTypeConverter { static constexpr auto value = MpiType::kINT32; }; template <> struct MpiTypeConverter { static constexpr auto value = MpiType::kUINT32; }; template <> struct MpiTypeConverter { static constexpr auto value = MpiType::kINT64; }; template <> struct MpiTypeConverter { static constexpr auto value = MpiType::kUINT64; }; template <> struct MpiTypeConverter { static constexpr auto value = MpiType::kCHAR; }; #ifdef ENABLE_FP8 template <> struct MpiTypeConverter<__nv_fp8_e4m3> { static constexpr auto value = MpiType::kFP8; }; #endif #ifdef ENABLE_BF16 template <> struct MpiTypeConverter<__nv_bfloat16> { static constexpr auto value = MpiType::kBF16; }; #endif // A wrapper of MPI_Op type. enum class MpiOp { NULLOP, MAX, MIN, SUM, PROD, LAND, BAND, LOR, BOR, LXOR, BXOR, MINLOC, MAXLOC, REPLACE, }; // A wrapper of the level of MPI thread support enum class MpiThreadSupport : int { THREAD_SINGLE = MPI_THREAD_SINGLE, THREAD_FUNNELED = MPI_THREAD_FUNNELED, THREAD_SERIALIZED = MPI_THREAD_SERIALIZED, THREAD_MULTIPLE = MPI_THREAD_MULTIPLE, }; class MpiRequest { public: MpiRequest() = default; ~MpiRequest() = default; void wait() { #if ENABLE_MULTI_DEVICE // TODO: Don't ignore return status MPI_Wait(&mRequest, MPI_STATUS_IGNORE); #else TLLM_THROW("Multi device support is disabled."); #endif } MPI_Request mRequest{}; }; MPI_Datatype getMpiDtype(MpiType dtype); class MpiComm { public: explicit MpiComm(MPI_Comm g, bool freeComm); ~MpiComm() noexcept; // no copy MpiComm(MpiComm const&) = delete; MpiComm& operator=(MpiComm const&) = delete; // move MpiComm(MpiComm&&) noexcept; MpiComm& operator=(MpiComm&&) noexcept; [[nodiscard]] int getRank() const; [[nodiscard]] int getSize() const; operator MPI_Comm() const // NOLINT(*-explicit-constructor) { return mComm; } //! \brief Returns the MPI world communicator. static MpiComm const& world(); //! \brief Corresponds to `world()` by default, but can be overridden per process. static MpiComm const& session() { return mutableSession(); } //! \brief Returns the MPI local communicator. static MpiComm const& localSession() { return mutableLocalSession(); } static MpiComm const& setSession(MpiComm comm) { auto& session = mutableSession(); session = std::move(comm); refreshLocalSession(); return session; } [[nodiscard]] MpiComm split(int color, int key) const; std::shared_ptr bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const; std::shared_ptr bcastAsync(runtime::IBuffer& buf, int root) const; void bcast(void* buffer, size_t size, MpiType dtype, int root) const; void bcast(runtime::IBuffer& buf, int root) const; template void bcastValue(T& value, int root) const { if constexpr (std::is_fundamental_v>) { bcast(&value, 1, MpiTypeConverter>::value, root); } else { bcast(&value, sizeof(T), MpiType::kBYTE, root); } } template void bcast(std::vector& vec, int root) const { auto const rank = getRank(); auto vecSize = (rank == root) ? static_cast(vec.size()) : int64_t(0); bcast(&vecSize, 1, MpiType::kINT64, root); vec.resize(vecSize); if (vec.empty()) { return; } size_t bcastSize = vec.size() * sizeof(T); if constexpr (std::is_fundamental_v>) { bcastSize = vec.size(); } // To prevent overflowing int32_t limit size_t const maxChunkSize = std::numeric_limits::max(); for (size_t pos = 0; pos < bcastSize; pos += maxChunkSize) { auto chunkSize = std::min(bcastSize - pos, maxChunkSize); auto intChunkSize = static_cast(chunkSize); if constexpr (std::is_fundamental_v>) { bcast(vec.data() + pos, intChunkSize, MpiTypeConverter>::value, root); } else { bcast(reinterpret_cast(vec.data()) + pos, intChunkSize, MpiType::kBYTE, root); } } } std::shared_ptr sendAsync(void const* buffer, std::size_t size, MpiType dtype, int dest, int tag) const; std::shared_ptr sendAsync(runtime::IBuffer const& buf, int dest, int tag) const; void send(void const* buffer, std::size_t size, MpiType dtype, int dest, int tag) const; void send(runtime::IBuffer const& buf, int dest, int tag) const; template void sendValue(T const& value, int dest, int tag) const { if constexpr (std::is_fundamental_v>) { send(&value, 1, MpiTypeConverter>::value, dest, tag); } else { send(&value, sizeof(T), MpiType::kBYTE, dest, tag); } } MPI_Status recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const; MPI_Status recv(runtime::IBuffer& buf, int source, int tag) const; template MPI_Status recvValue(T& value, int source, int tag) const { #if ENABLE_MULTI_DEVICE if constexpr (std::is_fundamental_v>) { return recv(&value, 1, MpiTypeConverter>::value, source, tag); } else { return recv(&value, sizeof(T), MpiType::kBYTE, source, tag); } #else TLLM_THROW("Multi device support is disabled."); #endif } void allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const; void allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const; void allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf, std::vector const& recvcounts, std::vector const& displs, MpiType recvtype) const; void barrier() const; void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const; bool improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const; //! \brief Returns if a message with the specified source and tag is available bool iprobe(int source, int tag, MPI_Status* status) const; //! \brief Poll every periodMs until a message is available void recvPoll(int source, int tag, int periodMs) const; bool operator==(MpiComm const& rhs) const { return mComm == rhs.mComm; } bool operator!=(MpiComm const& rhs) const { return !(rhs == *this); } private: //! \brief Corresponds to `world()` by default, but can be overridden per process. static MpiComm& mutableSession(); //! \brief Returns the MPI local communicator. static MpiComm& mutableLocalSession(); static void refreshLocalSession(); MPI_Comm mComm; bool mFreeComm; }; std::vector getWorldRanks(MpiComm const& comm); void initialize(MpiThreadSupport threadMode = MpiThreadSupport::THREAD_MULTIPLE, bool forwardAbortToParent = false); class MpiWaitThread { public: explicit MpiWaitThread(std::string name, std::function funcWait, std::function funcSetup = nullptr); ~MpiWaitThread(); void waitStop(); void notifyStart(); private: void sideThread(); void waitStart(); void notifyStop(); std::string mName; std::function mFuncWait; std::function mFuncSetup; std::unique_ptr mThread; std::mutex mMutex; std::condition_variable mCondVar; bool mRunning{true}; std::atomic mShouldExit{false}; }; } // namespace tensorrt_llm::mpi #define COMM_SESSION tensorrt_llm::mpi::MpiComm::session() #define LOCAL_COMM_SESSION tensorrt_llm::mpi::MpiComm::localSession()