mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5489015][fix] Support communicator split in MNNVL allreduce and fix the binding issues. (#7387)
Signed-off-by: Shiyu Li <shili@nvidia.com>
This commit is contained in:
parent
a91453de34
commit
8bdbb48264
@ -361,7 +361,9 @@ void initBindings(nb::module_& m)
|
||||
nb::call_guard<nb::gil_scoped_release>());
|
||||
|
||||
nb::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
|
||||
.def(nb::init<size_t, uint32_t, uint32_t, at::Device, bool>(), nb::call_guard<nb::gil_scoped_release>())
|
||||
.def(nb::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), nb::arg("buf_size"),
|
||||
nb::arg("group_size"), nb::arg("group_rank"), nb::arg("split_color"), nb::arg("device_idx"),
|
||||
nb::arg("mn_nvlink"), nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer,
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer,
|
||||
|
||||
@ -455,7 +455,9 @@ void initBindings(pybind11::module_& m)
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
py::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
|
||||
.def(py::init<size_t, uint32_t, uint32_t, at::Device, bool>(), py::call_guard<py::gil_scoped_release>())
|
||||
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), py::arg("buf_size"),
|
||||
py::arg("group_size"), py::arg("group_rank"), py::arg("split_color"), py::arg("device_idx"),
|
||||
py::arg("mn_nvlink"), py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer,
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
#include "tensorrt_llm/common/cudaDriverWrapper.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cuda_runtime_api.h>
|
||||
@ -38,7 +38,7 @@ inline size_t roundUp(size_t val, size_t gran)
|
||||
} // namespace
|
||||
|
||||
McastDeviceMemory::McastDeviceMemory(
|
||||
size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink)
|
||||
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink)
|
||||
: mIsMNNvlink(mnNvlink)
|
||||
, mDeviceIdx(deviceIdx)
|
||||
, mGroupSize(groupSize)
|
||||
@ -48,6 +48,7 @@ McastDeviceMemory::McastDeviceMemory(
|
||||
, mAllocationSize(0)
|
||||
, mMcPtr(0)
|
||||
, mMcHandle(0)
|
||||
, mGroupComm(tensorrt_llm::mpi::MpiComm::session().split(splitColor, mGroupRank))
|
||||
{
|
||||
|
||||
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceIdx));
|
||||
@ -62,9 +63,12 @@ McastDeviceMemory::McastDeviceMemory(
|
||||
// From pytorch implementation for alignment
|
||||
constexpr size_t kSignalPadAlignment = 16UL;
|
||||
mSignalPadOffset = roundUp(mBufSize, kSignalPadAlignment);
|
||||
int const world_rank{tensorrt_llm::mpi::MpiComm::session().getRank()};
|
||||
|
||||
TLLM_LOG_DEBUG(
|
||||
"[McastDeviceMemory] Rank: %u, Group size: %u, isMultiNode: %d, device_idx: %d, Signal pad offset: %zu",
|
||||
mGroupRank, mGroupSize, mIsMNNvlink, mDeviceIdx, mSignalPadOffset);
|
||||
"[McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, GroupSplitColor: %u, isMultiNode: %d, "
|
||||
"device_idx: %d, Signal pad offset: %zu",
|
||||
world_rank, mGroupRank, mGroupSize, splitColor, mIsMNNvlink, mDeviceIdx, mSignalPadOffset);
|
||||
|
||||
if (mIsMNNvlink)
|
||||
{
|
||||
@ -127,9 +131,6 @@ McastDeviceMemory::~McastDeviceMemory()
|
||||
|
||||
void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
|
||||
{
|
||||
|
||||
auto const& mpi_comm = tensorrt_llm::mpi::MpiComm::session();
|
||||
|
||||
CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC;
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.requestedHandleTypes = handle_type;
|
||||
@ -156,7 +157,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
|
||||
// All gather
|
||||
cudaMallocHost(&exphndl, mGroupSize * sizeof(CUmemFabricHandle));
|
||||
memcpy(exphndl + mGroupRank * sizeof(CUmemFabricHandle), &myhndl, sizeof(CUmemFabricHandle));
|
||||
mpi_comm.allgather(
|
||||
mGroupComm.allgather(
|
||||
exphndl + mGroupRank * sizeof(CUmemFabricHandle), exphndl, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR);
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
@ -175,7 +176,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
|
||||
TLLM_CU_CHECK(cuMemExportToShareableHandle((void*) fabric_handle, mMcHandle, CU_MEM_HANDLE_TYPE_FABRIC, 0));
|
||||
}
|
||||
// Broadcast
|
||||
mpi_comm.bcast(fabric_handle, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR, 0);
|
||||
mGroupComm.bcast(fabric_handle, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR, 0);
|
||||
cudaDeviceSynchronize();
|
||||
if (mGroupRank != 0)
|
||||
{
|
||||
@ -210,12 +211,9 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
|
||||
|
||||
void McastDeviceMemory::allocNvlsMcastMem(size_t bufSize)
|
||||
{
|
||||
// Create a std::set to include all ranks in range (0, group_size)
|
||||
std::set<int> ranks;
|
||||
for (uint32_t i = 0; i < mGroupSize; ++i)
|
||||
{
|
||||
ranks.insert(i);
|
||||
}
|
||||
// Get the world ranks for ranks in this group
|
||||
auto ranks_ = tensorrt_llm::mpi::getWorldRanks(mGroupComm);
|
||||
std::set<int> ranks(ranks_.begin(), ranks_.end());
|
||||
// Reuse existing implementation
|
||||
mNvlsHandle = tensorrt_llm::runtime::ipcNvlsAllocate(bufSize, ranks);
|
||||
mMcHandle = mNvlsHandle->mc_handle;
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
#include "tensorrt_llm/common/mcastDevMemUtils.h"
|
||||
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cuda.h>
|
||||
@ -42,7 +43,8 @@ public:
|
||||
McastDeviceMemory(McastDeviceMemory const&) = delete;
|
||||
McastDeviceMemory& operator=(McastDeviceMemory const&) = delete;
|
||||
|
||||
McastDeviceMemory(size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink);
|
||||
McastDeviceMemory(
|
||||
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink);
|
||||
|
||||
// We don't register the pointer in these two functions since we don't expect any python-level code would call
|
||||
// to obtain the raw pointers.
|
||||
@ -98,6 +100,8 @@ private:
|
||||
CUmemGenericAllocationHandle mMcHandle;
|
||||
std::vector<CUmemGenericAllocationHandle> mUcHandles;
|
||||
|
||||
tensorrt_llm::mpi::MpiComm mGroupComm; //!< The MPI communicator for the group
|
||||
|
||||
// Host array of pointers
|
||||
std::vector<CUdeviceptr> mUcPtrs;
|
||||
std::vector<CUdeviceptr> mSignalPads;
|
||||
|
||||
@ -34,12 +34,14 @@ public:
|
||||
//! \param bufSize The total size of the buffer in bytes.
|
||||
//! \param groupSize The number of ranks in the communication group.
|
||||
//! \param groupRank The rank of the local process within the group.
|
||||
//! \param splitColor The color of the split for topology split.
|
||||
//! \param device The CUDA device for buffer allocation.
|
||||
//! \param mnNvlink Flag indicating if multi-node NVLink is used.
|
||||
McastGPUBuffer(size_t bufSize, uint32_t groupSize, uint32_t groupRank, at::Device device, bool mnNvlink)
|
||||
: mMcastDeviceMemory(bufSize, groupSize, groupRank, device.index(), mnNvlink)
|
||||
McastGPUBuffer(
|
||||
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, uint32_t deviceIdx, bool mnNvlink)
|
||||
: mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, deviceIdx, mnNvlink)
|
||||
, mBufSize(bufSize)
|
||||
, mLocalDevice(device)
|
||||
, mLocalDevice(at::Device(at::DeviceType::CUDA, deviceIdx))
|
||||
{
|
||||
}
|
||||
|
||||
@ -49,7 +51,7 @@ public:
|
||||
//! \param dtype The data type of the tensor elements.
|
||||
//! \param storageOffset The offset in elements from the start of the buffer.
|
||||
//! \return An ATen tensor wrapping the unicast buffer section.
|
||||
at::Tensor getUCBuffer(uint32_t rank, c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
|
||||
at::Tensor getUCBuffer(uint32_t rank, std::vector<long int> sizes, torch::ScalarType dtype, int64_t storageOffset)
|
||||
{
|
||||
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
|
||||
size_t const elementSize = c10::elementSize(dtype);
|
||||
@ -59,7 +61,10 @@ public:
|
||||
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getUnicastPtr(rank)) + storageOffset * elementSize;
|
||||
|
||||
auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
|
||||
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
|
||||
return at::for_blob(dataPtr, c10::IntArrayRef(sizes))
|
||||
.options(options)
|
||||
.target_device(mLocalDevice)
|
||||
.make_tensor();
|
||||
}
|
||||
|
||||
//! \brief Returns a PyTorch tensor view of the multicast buffer portion.
|
||||
@ -67,7 +72,7 @@ public:
|
||||
//! \param dtype The data type of the tensor elements.
|
||||
//! \param storageOffset The offset in elements from the start of the buffer.
|
||||
//! \return An ATen tensor wrapping the multicast buffer section.
|
||||
at::Tensor getMCBuffer(c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
|
||||
at::Tensor getMCBuffer(std::vector<long int> sizes, torch::ScalarType dtype, int64_t storageOffset)
|
||||
{
|
||||
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
|
||||
size_t const elementSize = c10::elementSize(dtype);
|
||||
@ -77,7 +82,10 @@ public:
|
||||
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getMulticastPtr()) + storageOffset * elementSize;
|
||||
|
||||
auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
|
||||
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
|
||||
return at::for_blob(dataPtr, c10::IntArrayRef(sizes))
|
||||
.options(options)
|
||||
.target_device(mLocalDevice)
|
||||
.make_tensor();
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
@ -8,7 +7,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._utils import mpi_barrier
|
||||
from tensorrt_llm._utils import mpi_comm
|
||||
from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer
|
||||
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
|
||||
AllReduceStrategy, MoEAllReduceParams)
|
||||
@ -17,7 +16,6 @@ from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
|
||||
|
||||
_thread_local = threading.local()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
|
||||
@ -55,11 +53,15 @@ def allocate_low_presicion_allreduce_workspace(mapping: Mapping) -> None:
|
||||
def get_allreduce_mnnvl_workspace(
|
||||
mapping: Mapping, dtype: torch.dtype
|
||||
) -> Tuple[McastGPUBuffer, torch.Tensor, torch.Tensor, int]:
|
||||
|
||||
if not hasattr(_thread_local,
|
||||
f'allreduce_mnnvl_workspaces_{mapping.pp_rank}'):
|
||||
setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}',
|
||||
{})
|
||||
|
||||
# Support topology split
|
||||
comm = mpi_comm().Split(
|
||||
int(mapping.pp_rank * mapping.cp_size + mapping.cp_rank),
|
||||
mapping.tp_rank)
|
||||
force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1"
|
||||
|
||||
allreduce_mnnvl_workspaces = getattr(
|
||||
@ -77,7 +79,9 @@ def get_allreduce_mnnvl_workspace(
|
||||
buffer_size_in_bytes,
|
||||
mapping.tp_size,
|
||||
mapping.tp_rank,
|
||||
torch.device("cuda", mapping.local_rank),
|
||||
# Split the communicator according to the topology
|
||||
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
|
||||
mapping.local_rank,
|
||||
True, # mnNvlink
|
||||
)
|
||||
|
||||
@ -87,7 +91,7 @@ def get_allreduce_mnnvl_workspace(
|
||||
buffer.fill_(-0.0)
|
||||
# CPU barrier since we assume this should not be called in cuda graph
|
||||
torch.cuda.synchronize()
|
||||
mpi_barrier()
|
||||
comm.Barrier()
|
||||
|
||||
# This is a buffer to maintain the state of this allreduce Op
|
||||
# Should have the same lifetime with self._buffer
|
||||
@ -458,12 +462,7 @@ class AllReduce(nn.Module):
|
||||
# Initialize MNNVL AllReduce if needed
|
||||
if self.strategy in (AllReduceStrategy.AUTO,
|
||||
AllReduceStrategy.MNNVL):
|
||||
if self.mapping.tp_size != self.mapping.world_size:
|
||||
logger.debug(
|
||||
f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} "
|
||||
f"!= world_size:{self.mapping.world_size}")
|
||||
self.mnnvl_allreduce = None
|
||||
elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
|
||||
if MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
|
||||
try:
|
||||
self.mnnvl_allreduce = MNNVLAllReduce(
|
||||
self.mapping, dtype) if dtype else None
|
||||
|
||||
@ -771,12 +771,11 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
self.mapping.tp_size,
|
||||
)
|
||||
|
||||
if tp > self.mapping.gpus_per_node and not self.allreduce.is_mnnvl(
|
||||
):
|
||||
if tp > self.mapping.gpus_per_node:
|
||||
mlp_tp_size = math.gcd(
|
||||
tp,
|
||||
self.mapping.gpus_per_node,
|
||||
) # Avoid costly inter-node TP when MNNVL is not supported
|
||||
) # Avoid costly inter-node TP
|
||||
else:
|
||||
mlp_tp_size = tp
|
||||
return mlp_tp_size
|
||||
|
||||
Loading…
Reference in New Issue
Block a user