[https://nvbugs/5782112][fix] Fix hanging issue for MNNVL Allreduce under PP (#10633)

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2026-01-16 13:03:36 +08:00 committed by GitHub
parent e2c3373749
commit f001c4946d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 32 additions and 29 deletions

View File

@ -347,9 +347,9 @@ void initBindings(nb::module_& m)
nb::call_guard<nb::gil_scoped_release>());
nb::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
.def(nb::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), nb::arg("buf_size"),
nb::arg("group_size"), nb::arg("group_rank"), nb::arg("split_color"), nb::arg("device_idx"),
nb::arg("mn_nvlink"), nb::call_guard<nb::gil_scoped_release>())
.def(nb::init<size_t, uint32_t, uint32_t, uint32_t, bool, int64_t>(), nb::arg("buf_size"),
nb::arg("group_size"), nb::arg("group_rank"), nb::arg("device_idx"), nb::arg("mn_nvlink"),
nb::arg("mpi_comm_fortran_handle"), nb::call_guard<nb::gil_scoped_release>())
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer,
nb::call_guard<nb::gil_scoped_release>())
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer,

View File

@ -443,9 +443,9 @@ void initBindings(pybind11::module_& m)
py::call_guard<py::gil_scoped_release>());
py::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), py::arg("buf_size"),
py::arg("group_size"), py::arg("group_rank"), py::arg("split_color"), py::arg("device_idx"),
py::arg("mn_nvlink"), py::call_guard<py::gil_scoped_release>())
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, bool, int64_t>(), py::arg("buf_size"),
py::arg("group_size"), py::arg("group_rank"), py::arg("device_idx"), py::arg("mn_nvlink"),
py::arg("mpi_comm_fortran_handle"), py::call_guard<py::gil_scoped_release>())
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer,
py::call_guard<py::gil_scoped_release>())
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer,

View File

@ -38,7 +38,7 @@ inline size_t roundUp(size_t val, size_t gran)
} // namespace
McastDeviceMemory::McastDeviceMemory(
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink)
size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink, int64_t mpiCommFortranHandle)
: mIsMNNvlink(mnNvlink)
, mDeviceIdx(deviceIdx)
, mGroupSize(groupSize)
@ -48,7 +48,11 @@ McastDeviceMemory::McastDeviceMemory(
, mAllocationSize(0)
, mMcPtr(0)
, mMcHandle(0)
, mGroupComm(tensorrt_llm::mpi::MpiComm::session().split(splitColor, mGroupRank))
#if ENABLE_MULTI_DEVICE
, mGroupComm(MPI_Comm_f2c(mpiCommFortranHandle), false)
#else
, mGroupComm(nullptr, false)
#endif
{
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceIdx));
@ -66,9 +70,9 @@ McastDeviceMemory::McastDeviceMemory(
int const world_rank{tensorrt_llm::mpi::MpiComm::session().getRank()};
TLLM_LOG_DEBUG(
"[McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, GroupSplitColor: %u, isMultiNode: %d, "
"[McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, isMultiNode: %d, "
"device_idx: %d, Signal pad offset: %zu",
world_rank, mGroupRank, mGroupSize, splitColor, mIsMNNvlink, mDeviceIdx, mSignalPadOffset);
world_rank, mGroupRank, mGroupSize, mIsMNNvlink, mDeviceIdx, mSignalPadOffset);
if (mIsMNNvlink)
{

View File

@ -43,8 +43,8 @@ public:
McastDeviceMemory(McastDeviceMemory const&) = delete;
McastDeviceMemory& operator=(McastDeviceMemory const&) = delete;
McastDeviceMemory(
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink);
McastDeviceMemory(size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink,
int64_t mpiCommFortranHandle);
// We don't register the pointer in these two functions since we don't expect any python-level code would call
// to obtain the raw pointers.

View File

@ -34,12 +34,12 @@ public:
//! \param bufSize The total size of the buffer in bytes.
//! \param groupSize The number of ranks in the communication group.
//! \param groupRank The rank of the local process within the group.
//! \param splitColor The color of the split for topology split.
//! \param device The CUDA device for buffer allocation.
//! \param deviceIdx The CUDA device for buffer allocation.
//! \param mnNvlink Flag indicating if multi-node NVLink is used.
McastGPUBuffer(
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, uint32_t deviceIdx, bool mnNvlink)
: mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, deviceIdx, mnNvlink)
//! \param mpiCommFortranHandle The Fortran handle for the MPI communicator (from Python mpi4py).
McastGPUBuffer(size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t deviceIdx, bool mnNvlink,
int64_t mpiCommFortranHandle)
: mMcastDeviceMemory(bufSize, groupSize, groupRank, deviceIdx, mnNvlink, mpiCommFortranHandle)
, mBufSize(bufSize)
, mLocalDevice(at::Device(at::DeviceType::CUDA, deviceIdx))
{

View File

@ -65,14 +65,12 @@ def get_or_scale_allreduce_mnnvl_workspace(
"""
NUM_LAMPORT_BUFFERS = 3
if not hasattr(_thread_local,
f'allreduce_mnnvl_workspaces_{mapping.pp_rank}'):
setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}',
{})
# Use MNNVLAllReduce class to share across threads
allreduce_mnnvl_workspaces = MNNVLAllReduce.allreduce_mnnvl_workspaces
# A safe method to get the element size of the dtype
elem_size = torch.tensor([], dtype=dtype).element_size()
allreduce_mnnvl_workspaces = getattr(
_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}')
force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1"
use_fabric_handle = force_mn or mapping.is_multi_node()
@ -101,19 +99,19 @@ def get_or_scale_allreduce_mnnvl_workspace(
# Increase the buffer size in 8 MiB granularity to avoid frequently scaling the buffer
buffer_size_bytes = math.ceil(req_buffer_size_bytes /
(8 * 1024 * 1024)) * (8 * 1024 * 1024)
if mapping.tp_rank == 0:
logger.debug(
f"[MNNVL] Requested {req_buffer_size_bytes} bytes, is larger than the current workspace size. Scaling workspace for pp_rank {mapping.pp_rank}, tp_size {mapping.tp_size} from {allreduce_mnnvl_workspaces[mapping]['buffer_size_bytes']} to {buffer_size_bytes} bytes"
)
logger.debug(
f"[MNNVL] Requested {req_buffer_size_bytes} bytes, is larger than the current workspace size. Scaling workspace for pp_rank {mapping.pp_rank}, tp_size {mapping.tp_size} from {allreduce_mnnvl_workspaces[mapping]['buffer_size_bytes']} to {buffer_size_bytes} bytes"
)
# Each workspace contains NUM_LAMPORT_BUFFERS buffers.
workspace_size_bytes = NUM_LAMPORT_BUFFERS * buffer_size_bytes
# Pass the pre-split MPI communicator's Fortran handle to avoid redundant splitting in C++
mcast_buf_handle = McastGPUBuffer(
workspace_size_bytes,
mapping.tp_size,
mapping.tp_rank,
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
mapping.local_rank,
use_fabric_handle, # whether to use fabric handle or POSIX FD ipc
comm.py2f(), # Fortran handle for the MPI communicator
)
# We use per FP32 element in the buffer for lamport sync
@ -517,6 +515,7 @@ class MNNVLAllReduce(nn.Module):
This class handles the MNNVL-specific allreduce operations, which can be more efficient
for certain operations when using NVLink for multi-node communication.
"""
allreduce_mnnvl_workspaces: Dict[int, Dict] = {}
def __init__(self, mapping: Mapping, dtype: torch.dtype):
super().__init__()
@ -530,7 +529,7 @@ class MNNVLAllReduce(nn.Module):
)
# Initialize the workspace
_ = get_or_scale_allreduce_mnnvl_workspace(self.mapping, self.dtype)
get_or_scale_allreduce_mnnvl_workspace(self.mapping, self.dtype)
@staticmethod
def get_supported_dtypes():