diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp index 7a698b4eb6..df8120b3ef 100644 --- a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp @@ -347,9 +347,9 @@ void initBindings(nb::module_& m) nb::call_guard()); nb::class_(m, "McastGPUBuffer") - .def(nb::init(), nb::arg("buf_size"), - nb::arg("group_size"), nb::arg("group_rank"), nb::arg("split_color"), nb::arg("device_idx"), - nb::arg("mn_nvlink"), nb::call_guard()) + .def(nb::init(), nb::arg("buf_size"), + nb::arg("group_size"), nb::arg("group_rank"), nb::arg("device_idx"), nb::arg("mn_nvlink"), + nb::arg("mpi_comm_fortran_handle"), nb::call_guard()) .def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer, nb::call_guard()) .def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer, diff --git a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp index ee4303d31b..8f9d0bc025 100644 --- a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp @@ -443,9 +443,9 @@ void initBindings(pybind11::module_& m) py::call_guard()); py::class_(m, "McastGPUBuffer") - .def(py::init(), py::arg("buf_size"), - py::arg("group_size"), py::arg("group_rank"), py::arg("split_color"), py::arg("device_idx"), - py::arg("mn_nvlink"), py::call_guard()) + .def(py::init(), py::arg("buf_size"), + py::arg("group_size"), py::arg("group_rank"), py::arg("device_idx"), py::arg("mn_nvlink"), + py::arg("mpi_comm_fortran_handle"), py::call_guard()) .def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer, py::call_guard()) .def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer, diff --git a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp index 9be590c7fc..cb1303fc84 100644 --- a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp +++ b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp @@ -38,7 +38,7 @@ inline size_t roundUp(size_t val, size_t gran) } // namespace McastDeviceMemory::McastDeviceMemory( - size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink) + size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink, int64_t mpiCommFortranHandle) : mIsMNNvlink(mnNvlink) , mDeviceIdx(deviceIdx) , mGroupSize(groupSize) @@ -48,7 +48,11 @@ McastDeviceMemory::McastDeviceMemory( , mAllocationSize(0) , mMcPtr(0) , mMcHandle(0) - , mGroupComm(tensorrt_llm::mpi::MpiComm::session().split(splitColor, mGroupRank)) +#if ENABLE_MULTI_DEVICE + , mGroupComm(MPI_Comm_f2c(mpiCommFortranHandle), false) +#else + , mGroupComm(nullptr, false) +#endif { TLLM_CUDA_CHECK(cudaSetDevice(mDeviceIdx)); @@ -66,9 +70,9 @@ McastDeviceMemory::McastDeviceMemory( int const world_rank{tensorrt_llm::mpi::MpiComm::session().getRank()}; TLLM_LOG_DEBUG( - "[McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, GroupSplitColor: %u, isMultiNode: %d, " + "[McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, isMultiNode: %d, " "device_idx: %d, Signal pad offset: %zu", - world_rank, mGroupRank, mGroupSize, splitColor, mIsMNNvlink, mDeviceIdx, mSignalPadOffset); + world_rank, mGroupRank, mGroupSize, mIsMNNvlink, mDeviceIdx, mSignalPadOffset); if (mIsMNNvlink) { diff --git a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h index d9428b4126..537bf49788 100644 --- a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h +++ b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h @@ -43,8 +43,8 @@ public: McastDeviceMemory(McastDeviceMemory const&) = delete; McastDeviceMemory& operator=(McastDeviceMemory const&) = delete; - McastDeviceMemory( - size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink); + McastDeviceMemory(size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink, + int64_t mpiCommFortranHandle); // We don't register the pointer in these two functions since we don't expect any python-level code would call // to obtain the raw pointers. diff --git a/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h b/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h index 4c011a790b..160ab63f05 100644 --- a/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h +++ b/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h @@ -34,12 +34,12 @@ public: //! \param bufSize The total size of the buffer in bytes. //! \param groupSize The number of ranks in the communication group. //! \param groupRank The rank of the local process within the group. - //! \param splitColor The color of the split for topology split. - //! \param device The CUDA device for buffer allocation. + //! \param deviceIdx The CUDA device for buffer allocation. //! \param mnNvlink Flag indicating if multi-node NVLink is used. - McastGPUBuffer( - size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, uint32_t deviceIdx, bool mnNvlink) - : mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, deviceIdx, mnNvlink) + //! \param mpiCommFortranHandle The Fortran handle for the MPI communicator (from Python mpi4py). + McastGPUBuffer(size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t deviceIdx, bool mnNvlink, + int64_t mpiCommFortranHandle) + : mMcastDeviceMemory(bufSize, groupSize, groupRank, deviceIdx, mnNvlink, mpiCommFortranHandle) , mBufSize(bufSize) , mLocalDevice(at::Device(at::DeviceType::CUDA, deviceIdx)) { diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 713d728566..711968a92e 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -65,14 +65,12 @@ def get_or_scale_allreduce_mnnvl_workspace( """ NUM_LAMPORT_BUFFERS = 3 - if not hasattr(_thread_local, - f'allreduce_mnnvl_workspaces_{mapping.pp_rank}'): - setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}', - {}) + + # Use MNNVLAllReduce class to share across threads + allreduce_mnnvl_workspaces = MNNVLAllReduce.allreduce_mnnvl_workspaces + # A safe method to get the element size of the dtype elem_size = torch.tensor([], dtype=dtype).element_size() - allreduce_mnnvl_workspaces = getattr( - _thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}') force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1" use_fabric_handle = force_mn or mapping.is_multi_node() @@ -101,19 +99,19 @@ def get_or_scale_allreduce_mnnvl_workspace( # Increase the buffer size in 8 MiB granularity to avoid frequently scaling the buffer buffer_size_bytes = math.ceil(req_buffer_size_bytes / (8 * 1024 * 1024)) * (8 * 1024 * 1024) - if mapping.tp_rank == 0: - logger.debug( - f"[MNNVL] Requested {req_buffer_size_bytes} bytes, is larger than the current workspace size. Scaling workspace for pp_rank {mapping.pp_rank}, tp_size {mapping.tp_size} from {allreduce_mnnvl_workspaces[mapping]['buffer_size_bytes']} to {buffer_size_bytes} bytes" - ) + logger.debug( + f"[MNNVL] Requested {req_buffer_size_bytes} bytes, is larger than the current workspace size. Scaling workspace for pp_rank {mapping.pp_rank}, tp_size {mapping.tp_size} from {allreduce_mnnvl_workspaces[mapping]['buffer_size_bytes']} to {buffer_size_bytes} bytes" + ) # Each workspace contains NUM_LAMPORT_BUFFERS buffers. workspace_size_bytes = NUM_LAMPORT_BUFFERS * buffer_size_bytes + # Pass the pre-split MPI communicator's Fortran handle to avoid redundant splitting in C++ mcast_buf_handle = McastGPUBuffer( workspace_size_bytes, mapping.tp_size, mapping.tp_rank, - mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.local_rank, use_fabric_handle, # whether to use fabric handle or POSIX FD ipc + comm.py2f(), # Fortran handle for the MPI communicator ) # We use per FP32 element in the buffer for lamport sync @@ -517,6 +515,7 @@ class MNNVLAllReduce(nn.Module): This class handles the MNNVL-specific allreduce operations, which can be more efficient for certain operations when using NVLink for multi-node communication. """ + allreduce_mnnvl_workspaces: Dict[int, Dict] = {} def __init__(self, mapping: Mapping, dtype: torch.dtype): super().__init__() @@ -530,7 +529,7 @@ class MNNVLAllReduce(nn.Module): ) # Initialize the workspace - _ = get_or_scale_allreduce_mnnvl_workspace(self.mapping, self.dtype) + get_or_scale_allreduce_mnnvl_workspace(self.mapping, self.dtype) @staticmethod def get_supported_dtypes():