[https://nvbugs/5489015][fix] Support communicator split in MNNVL allreduce and fix the binding issues. (#7387)

Signed-off-by: Shiyu Li <shili@nvidia.com>
This commit is contained in:
Shiyu Li 2025-09-16 16:43:20 -07:00 committed by GitHub
parent a91453de34
commit 8bdbb48264
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 52 additions and 40 deletions

View File

@ -361,7 +361,9 @@ void initBindings(nb::module_& m)
nb::call_guard<nb::gil_scoped_release>());
nb::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
.def(nb::init<size_t, uint32_t, uint32_t, at::Device, bool>(), nb::call_guard<nb::gil_scoped_release>())
.def(nb::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), nb::arg("buf_size"),
nb::arg("group_size"), nb::arg("group_rank"), nb::arg("split_color"), nb::arg("device_idx"),
nb::arg("mn_nvlink"), nb::call_guard<nb::gil_scoped_release>())
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer,
nb::call_guard<nb::gil_scoped_release>())
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer,

View File

@ -455,7 +455,9 @@ void initBindings(pybind11::module_& m)
py::call_guard<py::gil_scoped_release>());
py::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
.def(py::init<size_t, uint32_t, uint32_t, at::Device, bool>(), py::call_guard<py::gil_scoped_release>())
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), py::arg("buf_size"),
py::arg("group_size"), py::arg("group_rank"), py::arg("split_color"), py::arg("device_idx"),
py::arg("mn_nvlink"), py::call_guard<py::gil_scoped_release>())
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer,
py::call_guard<py::gil_scoped_release>())
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer,

View File

@ -20,7 +20,7 @@
#include "tensorrt_llm/common/cudaDriverWrapper.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <cstddef>
#include <cstdint>
#include <cuda_runtime_api.h>
@ -38,7 +38,7 @@ inline size_t roundUp(size_t val, size_t gran)
} // namespace
McastDeviceMemory::McastDeviceMemory(
size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink)
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink)
: mIsMNNvlink(mnNvlink)
, mDeviceIdx(deviceIdx)
, mGroupSize(groupSize)
@ -48,6 +48,7 @@ McastDeviceMemory::McastDeviceMemory(
, mAllocationSize(0)
, mMcPtr(0)
, mMcHandle(0)
, mGroupComm(tensorrt_llm::mpi::MpiComm::session().split(splitColor, mGroupRank))
{
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceIdx));
@ -62,9 +63,12 @@ McastDeviceMemory::McastDeviceMemory(
// From pytorch implementation for alignment
constexpr size_t kSignalPadAlignment = 16UL;
mSignalPadOffset = roundUp(mBufSize, kSignalPadAlignment);
int const world_rank{tensorrt_llm::mpi::MpiComm::session().getRank()};
TLLM_LOG_DEBUG(
"[McastDeviceMemory] Rank: %u, Group size: %u, isMultiNode: %d, device_idx: %d, Signal pad offset: %zu",
mGroupRank, mGroupSize, mIsMNNvlink, mDeviceIdx, mSignalPadOffset);
"[McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, GroupSplitColor: %u, isMultiNode: %d, "
"device_idx: %d, Signal pad offset: %zu",
world_rank, mGroupRank, mGroupSize, splitColor, mIsMNNvlink, mDeviceIdx, mSignalPadOffset);
if (mIsMNNvlink)
{
@ -127,9 +131,6 @@ McastDeviceMemory::~McastDeviceMemory()
void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
{
auto const& mpi_comm = tensorrt_llm::mpi::MpiComm::session();
CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC;
CUmemAllocationProp prop = {};
prop.requestedHandleTypes = handle_type;
@ -156,7 +157,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
// All gather
cudaMallocHost(&exphndl, mGroupSize * sizeof(CUmemFabricHandle));
memcpy(exphndl + mGroupRank * sizeof(CUmemFabricHandle), &myhndl, sizeof(CUmemFabricHandle));
mpi_comm.allgather(
mGroupComm.allgather(
exphndl + mGroupRank * sizeof(CUmemFabricHandle), exphndl, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR);
cudaDeviceSynchronize();
@ -175,7 +176,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
TLLM_CU_CHECK(cuMemExportToShareableHandle((void*) fabric_handle, mMcHandle, CU_MEM_HANDLE_TYPE_FABRIC, 0));
}
// Broadcast
mpi_comm.bcast(fabric_handle, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR, 0);
mGroupComm.bcast(fabric_handle, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR, 0);
cudaDeviceSynchronize();
if (mGroupRank != 0)
{
@ -210,12 +211,9 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
void McastDeviceMemory::allocNvlsMcastMem(size_t bufSize)
{
// Create a std::set to include all ranks in range (0, group_size)
std::set<int> ranks;
for (uint32_t i = 0; i < mGroupSize; ++i)
{
ranks.insert(i);
}
// Get the world ranks for ranks in this group
auto ranks_ = tensorrt_llm::mpi::getWorldRanks(mGroupComm);
std::set<int> ranks(ranks_.begin(), ranks_.end());
// Reuse existing implementation
mNvlsHandle = tensorrt_llm::runtime::ipcNvlsAllocate(bufSize, ranks);
mMcHandle = mNvlsHandle->mc_handle;

View File

@ -17,6 +17,7 @@
#include "tensorrt_llm/common/mcastDevMemUtils.h"
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <cstddef>
#include <cstdint>
#include <cuda.h>
@ -42,7 +43,8 @@ public:
McastDeviceMemory(McastDeviceMemory const&) = delete;
McastDeviceMemory& operator=(McastDeviceMemory const&) = delete;
McastDeviceMemory(size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink);
McastDeviceMemory(
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink);
// We don't register the pointer in these two functions since we don't expect any python-level code would call
// to obtain the raw pointers.
@ -98,6 +100,8 @@ private:
CUmemGenericAllocationHandle mMcHandle;
std::vector<CUmemGenericAllocationHandle> mUcHandles;
tensorrt_llm::mpi::MpiComm mGroupComm; //!< The MPI communicator for the group
// Host array of pointers
std::vector<CUdeviceptr> mUcPtrs;
std::vector<CUdeviceptr> mSignalPads;

View File

@ -34,12 +34,14 @@ public:
//! \param bufSize The total size of the buffer in bytes.
//! \param groupSize The number of ranks in the communication group.
//! \param groupRank The rank of the local process within the group.
//! \param splitColor The color of the split for topology split.
//! \param device The CUDA device for buffer allocation.
//! \param mnNvlink Flag indicating if multi-node NVLink is used.
McastGPUBuffer(size_t bufSize, uint32_t groupSize, uint32_t groupRank, at::Device device, bool mnNvlink)
: mMcastDeviceMemory(bufSize, groupSize, groupRank, device.index(), mnNvlink)
McastGPUBuffer(
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, uint32_t deviceIdx, bool mnNvlink)
: mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, deviceIdx, mnNvlink)
, mBufSize(bufSize)
, mLocalDevice(device)
, mLocalDevice(at::Device(at::DeviceType::CUDA, deviceIdx))
{
}
@ -49,7 +51,7 @@ public:
//! \param dtype The data type of the tensor elements.
//! \param storageOffset The offset in elements from the start of the buffer.
//! \return An ATen tensor wrapping the unicast buffer section.
at::Tensor getUCBuffer(uint32_t rank, c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
at::Tensor getUCBuffer(uint32_t rank, std::vector<long int> sizes, torch::ScalarType dtype, int64_t storageOffset)
{
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
size_t const elementSize = c10::elementSize(dtype);
@ -59,7 +61,10 @@ public:
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getUnicastPtr(rank)) + storageOffset * elementSize;
auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
return at::for_blob(dataPtr, c10::IntArrayRef(sizes))
.options(options)
.target_device(mLocalDevice)
.make_tensor();
}
//! \brief Returns a PyTorch tensor view of the multicast buffer portion.
@ -67,7 +72,7 @@ public:
//! \param dtype The data type of the tensor elements.
//! \param storageOffset The offset in elements from the start of the buffer.
//! \return An ATen tensor wrapping the multicast buffer section.
at::Tensor getMCBuffer(c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
at::Tensor getMCBuffer(std::vector<long int> sizes, torch::ScalarType dtype, int64_t storageOffset)
{
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
size_t const elementSize = c10::elementSize(dtype);
@ -77,7 +82,10 @@ public:
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getMulticastPtr()) + storageOffset * elementSize;
auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
return at::for_blob(dataPtr, c10::IntArrayRef(sizes))
.options(options)
.target_device(mLocalDevice)
.make_tensor();
}
private:

View File

@ -1,4 +1,3 @@
import logging
import math
import os
import platform
@ -8,7 +7,7 @@ from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from tensorrt_llm._utils import mpi_barrier
from tensorrt_llm._utils import mpi_comm
from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
AllReduceStrategy, MoEAllReduceParams)
@ -17,7 +16,6 @@ from tensorrt_llm.mapping import Mapping
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
_thread_local = threading.local()
logger = logging.getLogger(__name__)
def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
@ -55,11 +53,15 @@ def allocate_low_presicion_allreduce_workspace(mapping: Mapping) -> None:
def get_allreduce_mnnvl_workspace(
mapping: Mapping, dtype: torch.dtype
) -> Tuple[McastGPUBuffer, torch.Tensor, torch.Tensor, int]:
if not hasattr(_thread_local,
f'allreduce_mnnvl_workspaces_{mapping.pp_rank}'):
setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}',
{})
# Support topology split
comm = mpi_comm().Split(
int(mapping.pp_rank * mapping.cp_size + mapping.cp_rank),
mapping.tp_rank)
force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1"
allreduce_mnnvl_workspaces = getattr(
@ -77,7 +79,9 @@ def get_allreduce_mnnvl_workspace(
buffer_size_in_bytes,
mapping.tp_size,
mapping.tp_rank,
torch.device("cuda", mapping.local_rank),
# Split the communicator according to the topology
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
mapping.local_rank,
True, # mnNvlink
)
@ -87,7 +91,7 @@ def get_allreduce_mnnvl_workspace(
buffer.fill_(-0.0)
# CPU barrier since we assume this should not be called in cuda graph
torch.cuda.synchronize()
mpi_barrier()
comm.Barrier()
# This is a buffer to maintain the state of this allreduce Op
# Should have the same lifetime with self._buffer
@ -458,12 +462,7 @@ class AllReduce(nn.Module):
# Initialize MNNVL AllReduce if needed
if self.strategy in (AllReduceStrategy.AUTO,
AllReduceStrategy.MNNVL):
if self.mapping.tp_size != self.mapping.world_size:
logger.debug(
f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} "
f"!= world_size:{self.mapping.world_size}")
self.mnnvl_allreduce = None
elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
if MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
try:
self.mnnvl_allreduce = MNNVLAllReduce(
self.mapping, dtype) if dtype else None

View File

@ -771,12 +771,11 @@ class DeepseekV3DecoderLayer(DecoderLayer):
self.mapping.tp_size,
)
if tp > self.mapping.gpus_per_node and not self.allreduce.is_mnnvl(
):
if tp > self.mapping.gpus_per_node:
mlp_tp_size = math.gcd(
tp,
self.mapping.gpus_per_node,
) # Avoid costly inter-node TP when MNNVL is not supported
) # Avoid costly inter-node TP
else:
mlp_tp_size = tp
return mlp_tp_size