mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 02:02:01 +08:00
[https://nvbugs/5782112][fix] Fix hanging issue for MNNVL Allreduce under PP (#10633)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
parent
e2c3373749
commit
f001c4946d
@ -347,9 +347,9 @@ void initBindings(nb::module_& m)
|
||||
nb::call_guard<nb::gil_scoped_release>());
|
||||
|
||||
nb::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
|
||||
.def(nb::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), nb::arg("buf_size"),
|
||||
nb::arg("group_size"), nb::arg("group_rank"), nb::arg("split_color"), nb::arg("device_idx"),
|
||||
nb::arg("mn_nvlink"), nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(nb::init<size_t, uint32_t, uint32_t, uint32_t, bool, int64_t>(), nb::arg("buf_size"),
|
||||
nb::arg("group_size"), nb::arg("group_rank"), nb::arg("device_idx"), nb::arg("mn_nvlink"),
|
||||
nb::arg("mpi_comm_fortran_handle"), nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer,
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer,
|
||||
|
||||
@ -443,9 +443,9 @@ void initBindings(pybind11::module_& m)
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
py::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
|
||||
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), py::arg("buf_size"),
|
||||
py::arg("group_size"), py::arg("group_rank"), py::arg("split_color"), py::arg("device_idx"),
|
||||
py::arg("mn_nvlink"), py::call_guard<py::gil_scoped_release>())
|
||||
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, bool, int64_t>(), py::arg("buf_size"),
|
||||
py::arg("group_size"), py::arg("group_rank"), py::arg("device_idx"), py::arg("mn_nvlink"),
|
||||
py::arg("mpi_comm_fortran_handle"), py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer,
|
||||
|
||||
@ -38,7 +38,7 @@ inline size_t roundUp(size_t val, size_t gran)
|
||||
} // namespace
|
||||
|
||||
McastDeviceMemory::McastDeviceMemory(
|
||||
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink)
|
||||
size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink, int64_t mpiCommFortranHandle)
|
||||
: mIsMNNvlink(mnNvlink)
|
||||
, mDeviceIdx(deviceIdx)
|
||||
, mGroupSize(groupSize)
|
||||
@ -48,7 +48,11 @@ McastDeviceMemory::McastDeviceMemory(
|
||||
, mAllocationSize(0)
|
||||
, mMcPtr(0)
|
||||
, mMcHandle(0)
|
||||
, mGroupComm(tensorrt_llm::mpi::MpiComm::session().split(splitColor, mGroupRank))
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
, mGroupComm(MPI_Comm_f2c(mpiCommFortranHandle), false)
|
||||
#else
|
||||
, mGroupComm(nullptr, false)
|
||||
#endif
|
||||
{
|
||||
|
||||
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceIdx));
|
||||
@ -66,9 +70,9 @@ McastDeviceMemory::McastDeviceMemory(
|
||||
int const world_rank{tensorrt_llm::mpi::MpiComm::session().getRank()};
|
||||
|
||||
TLLM_LOG_DEBUG(
|
||||
"[McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, GroupSplitColor: %u, isMultiNode: %d, "
|
||||
"[McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, isMultiNode: %d, "
|
||||
"device_idx: %d, Signal pad offset: %zu",
|
||||
world_rank, mGroupRank, mGroupSize, splitColor, mIsMNNvlink, mDeviceIdx, mSignalPadOffset);
|
||||
world_rank, mGroupRank, mGroupSize, mIsMNNvlink, mDeviceIdx, mSignalPadOffset);
|
||||
|
||||
if (mIsMNNvlink)
|
||||
{
|
||||
|
||||
@ -43,8 +43,8 @@ public:
|
||||
McastDeviceMemory(McastDeviceMemory const&) = delete;
|
||||
McastDeviceMemory& operator=(McastDeviceMemory const&) = delete;
|
||||
|
||||
McastDeviceMemory(
|
||||
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink);
|
||||
McastDeviceMemory(size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink,
|
||||
int64_t mpiCommFortranHandle);
|
||||
|
||||
// We don't register the pointer in these two functions since we don't expect any python-level code would call
|
||||
// to obtain the raw pointers.
|
||||
|
||||
@ -34,12 +34,12 @@ public:
|
||||
//! \param bufSize The total size of the buffer in bytes.
|
||||
//! \param groupSize The number of ranks in the communication group.
|
||||
//! \param groupRank The rank of the local process within the group.
|
||||
//! \param splitColor The color of the split for topology split.
|
||||
//! \param device The CUDA device for buffer allocation.
|
||||
//! \param deviceIdx The CUDA device for buffer allocation.
|
||||
//! \param mnNvlink Flag indicating if multi-node NVLink is used.
|
||||
McastGPUBuffer(
|
||||
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, uint32_t deviceIdx, bool mnNvlink)
|
||||
: mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, deviceIdx, mnNvlink)
|
||||
//! \param mpiCommFortranHandle The Fortran handle for the MPI communicator (from Python mpi4py).
|
||||
McastGPUBuffer(size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t deviceIdx, bool mnNvlink,
|
||||
int64_t mpiCommFortranHandle)
|
||||
: mMcastDeviceMemory(bufSize, groupSize, groupRank, deviceIdx, mnNvlink, mpiCommFortranHandle)
|
||||
, mBufSize(bufSize)
|
||||
, mLocalDevice(at::Device(at::DeviceType::CUDA, deviceIdx))
|
||||
{
|
||||
|
||||
@ -65,14 +65,12 @@ def get_or_scale_allreduce_mnnvl_workspace(
|
||||
"""
|
||||
|
||||
NUM_LAMPORT_BUFFERS = 3
|
||||
if not hasattr(_thread_local,
|
||||
f'allreduce_mnnvl_workspaces_{mapping.pp_rank}'):
|
||||
setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}',
|
||||
{})
|
||||
|
||||
# Use MNNVLAllReduce class to share across threads
|
||||
allreduce_mnnvl_workspaces = MNNVLAllReduce.allreduce_mnnvl_workspaces
|
||||
|
||||
# A safe method to get the element size of the dtype
|
||||
elem_size = torch.tensor([], dtype=dtype).element_size()
|
||||
allreduce_mnnvl_workspaces = getattr(
|
||||
_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}')
|
||||
force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1"
|
||||
use_fabric_handle = force_mn or mapping.is_multi_node()
|
||||
|
||||
@ -101,19 +99,19 @@ def get_or_scale_allreduce_mnnvl_workspace(
|
||||
# Increase the buffer size in 8 MiB granularity to avoid frequently scaling the buffer
|
||||
buffer_size_bytes = math.ceil(req_buffer_size_bytes /
|
||||
(8 * 1024 * 1024)) * (8 * 1024 * 1024)
|
||||
if mapping.tp_rank == 0:
|
||||
logger.debug(
|
||||
f"[MNNVL] Requested {req_buffer_size_bytes} bytes, is larger than the current workspace size. Scaling workspace for pp_rank {mapping.pp_rank}, tp_size {mapping.tp_size} from {allreduce_mnnvl_workspaces[mapping]['buffer_size_bytes']} to {buffer_size_bytes} bytes"
|
||||
)
|
||||
logger.debug(
|
||||
f"[MNNVL] Requested {req_buffer_size_bytes} bytes, is larger than the current workspace size. Scaling workspace for pp_rank {mapping.pp_rank}, tp_size {mapping.tp_size} from {allreduce_mnnvl_workspaces[mapping]['buffer_size_bytes']} to {buffer_size_bytes} bytes"
|
||||
)
|
||||
# Each workspace contains NUM_LAMPORT_BUFFERS buffers.
|
||||
workspace_size_bytes = NUM_LAMPORT_BUFFERS * buffer_size_bytes
|
||||
# Pass the pre-split MPI communicator's Fortran handle to avoid redundant splitting in C++
|
||||
mcast_buf_handle = McastGPUBuffer(
|
||||
workspace_size_bytes,
|
||||
mapping.tp_size,
|
||||
mapping.tp_rank,
|
||||
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
|
||||
mapping.local_rank,
|
||||
use_fabric_handle, # whether to use fabric handle or POSIX FD ipc
|
||||
comm.py2f(), # Fortran handle for the MPI communicator
|
||||
)
|
||||
|
||||
# We use per FP32 element in the buffer for lamport sync
|
||||
@ -517,6 +515,7 @@ class MNNVLAllReduce(nn.Module):
|
||||
This class handles the MNNVL-specific allreduce operations, which can be more efficient
|
||||
for certain operations when using NVLink for multi-node communication.
|
||||
"""
|
||||
allreduce_mnnvl_workspaces: Dict[int, Dict] = {}
|
||||
|
||||
def __init__(self, mapping: Mapping, dtype: torch.dtype):
|
||||
super().__init__()
|
||||
@ -530,7 +529,7 @@ class MNNVLAllReduce(nn.Module):
|
||||
)
|
||||
|
||||
# Initialize the workspace
|
||||
_ = get_or_scale_allreduce_mnnvl_workspace(self.mapping, self.dtype)
|
||||
get_or_scale_allreduce_mnnvl_workspace(self.mapping, self.dtype)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_dtypes():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user