[None][feat] Async pp send for PPCommTorch. (#9976)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Yuxian Qiu 2025-12-15 14:03:46 +08:00 committed by GitHub
parent af899d2fe7
commit 7588029763
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 30 additions and 54 deletions

View File

@ -33,6 +33,7 @@ NcclCommunicatorOp::NcclCommunicatorOp(int64_t worldSize, int64_t rank)
void NcclCommunicatorOp::send(th::Tensor tensor, int64_t toRank) const
{
tensor.record_stream(at::cuda::getCurrentCUDAStream());
auto ptr = static_cast<std::uint8_t*>(tensor.data_ptr());
size_t const size = tensor.numel() * th::elementSize(th::typeMetaToScalarType(tensor.dtype()));
tensorrt_llm::runtime::CudaStream cudaStream{at::cuda::getCurrentCUDAStream().stream(), mRank, false};
@ -41,6 +42,7 @@ void NcclCommunicatorOp::send(th::Tensor tensor, int64_t toRank) const
void NcclCommunicatorOp::recv(th::Tensor& tensor, int64_t fromRank) const
{
tensor.record_stream(at::cuda::getCurrentCUDAStream());
auto ptr = static_cast<std::uint8_t*>(tensor.data_ptr());
size_t const size = tensor.numel() * th::elementSize(th::typeMetaToScalarType(tensor.dtype()));
tensorrt_llm::runtime::CudaStream cudaStream{at::cuda::getCurrentCUDAStream().stream(), mRank, false};

View File

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, List
import torch
import torch.distributed as dist
from torch.distributed import get_process_group_ranks
from torch.distributed import ProcessGroup, get_process_group_ranks
from torch.distributed.device_mesh import init_device_mesh
from tensorrt_llm.logger import logger
@ -48,27 +48,27 @@ class DeviceMeshTopologyImpl(_MappingBaseForTypeCheck):
# Access Torch ProcessGroup
@property
@require_device_mesh
def tp_group_pg(self):
def tp_group_pg(self) -> ProcessGroup:
return self._get_mesh_dim_by_name('tp').get_group()
@property
@require_device_mesh
def pp_group_pg(self):
def pp_group_pg(self) -> ProcessGroup:
return self._get_mesh_dim_by_name('pp').get_group()
@property
@require_device_mesh
def cp_group_pg(self):
def cp_group_pg(self) -> ProcessGroup:
return self._get_mesh_dim_by_name('cp').get_group()
@property
@require_device_mesh
def moe_tp_group_pg(self):
def moe_tp_group_pg(self) -> ProcessGroup:
return self._get_mesh_dim_by_name('moe_tp').get_group()
@property
@require_device_mesh
def moe_ep_group_pg(self):
def moe_ep_group_pg(self) -> ProcessGroup:
return self._get_mesh_dim_by_name('moe_ep').get_group()
# Access rank

View File

@ -16,7 +16,6 @@ try:
except Exception:
MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
from tensorrt_llm._torch.hostfunc import hostfunc
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
mpi_disabled, mpi_isend, mpi_isend_object,
mpi_recv, mpi_recv_object, mpi_send,
@ -783,26 +782,16 @@ class TorchDist(Distributed):
return ret[0]
class PPCommBase:
class PPCommNCCL:
def __init__(self, global_mapping: Mapping):
self.mapping = global_mapping
self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(
self.mapping.world_size,
self.mapping.rank,
)
self.tensor_ready_event = torch.cuda.Event()
self.send_stream = torch.cuda.Stream()
self.tensor_cache = {}
def _cache_tensor(self, tensor: torch.Tensor):
cache_id = id(tensor)
self.tensor_cache[cache_id] = tensor
@hostfunc
def _release_tensor(self, tensor: torch.Tensor):
cache_id = id(tensor)
del self.tensor_cache[cache_id]
@abstractmethod
def direct_send(self, tensor: torch.Tensor, dest: int):
raise NotImplementedError("direct_send is not implemented")
def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
if dest is None:
@ -811,30 +800,13 @@ class PPCommBase:
# NCCL send kernel in send_stream cannot be captured,
# so we send in the current stream instead in CUDA graph cases.
if torch.cuda.is_current_stream_capturing():
self.direct_send(tensor, dest)
self.nccl_comm.send(tensor, dest)
return
self.tensor_ready_event.record()
with torch.cuda.stream(self.send_stream):
self.tensor_ready_event.wait()
# tensor may be released before NCCL send finished,
# so we cache it first and release it after send finished.
self._cache_tensor(tensor)
self.direct_send(tensor, dest)
self._release_tensor(tensor)
class PPCommNCCL(PPCommBase):
def __init__(self, global_mapping: Mapping):
super().__init__(global_mapping)
self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(
self.mapping.world_size,
self.mapping.rank,
)
def direct_send(self, tensor: torch.Tensor, dest: int):
self.nccl_comm.send(tensor, dest)
self.nccl_comm.send(tensor, dest)
def recv(self, tensor: torch.Tensor, src: Optional[int] = None):
if src is None:
@ -842,10 +814,10 @@ class PPCommNCCL(PPCommBase):
self.nccl_comm.recv(tensor, src)
class PPCommTorch(PPCommBase):
class PPCommTorch:
def __init__(self, global_mapping: Mapping):
super().__init__(global_mapping)
self.mapping = global_mapping
self.pg = self.mapping.pp_group_pg
self.pg_group = self.mapping.pp_group
@ -853,21 +825,22 @@ class PPCommTorch(PPCommBase):
assert global_rank in self.pg_group
return self.pg_group.index(global_rank)
def direct_send(self, tensor: torch.Tensor, dest: int):
self.pg.send([tensor], self._global_to_local_rank(dest), tag=0).wait()
# TODO: support async pp send for PPCommTorch
def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
if dest is None:
dest = self.mapping.next_pp_rank()
self.pg.send([tensor], self._global_to_local_rank(dest), tag=0).wait()
work = self.pg.send([tensor], self._global_to_local_rank(dest), tag=0)
# Send operation cannot be captured without blocking wait,
# so we block the current stream in CUDA graph cases.
if torch.cuda.is_current_stream_capturing():
work.block_current_stream()
def recv(self, tensor: torch.Tensor, src: Optional[int] = None):
if src is None:
src = self.mapping.prev_pp_rank()
self.pg.recv([tensor], self._global_to_local_rank(src), tag=0).wait()
work = self.pg.recv([tensor], self._global_to_local_rank(src), tag=0)
work.block_current_stream()
_pp_comm = None

View File

@ -16,6 +16,7 @@ from enum import IntEnum
from typing import List
import torch
from torch.distributed import ProcessGroup
from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl
from tensorrt_llm._utils import mpi_disabled
@ -518,23 +519,23 @@ class Mapping(MappingBase):
# DeviceMesh specific methods
@property
def tp_group_pg(self):
def tp_group_pg(self) -> ProcessGroup:
raise NotImplementedError("tp_group_pg is not implemented.")
@property
def pp_group_pg(self):
def pp_group_pg(self) -> ProcessGroup:
raise NotImplementedError("pp_group_pg is not implemented.")
@property
def cp_group_pg(self):
def cp_group_pg(self) -> ProcessGroup:
raise NotImplementedError("cp_group_pg is not implemented.")
@property
def moe_tp_group_pg(self):
def moe_tp_group_pg(self) -> ProcessGroup:
raise NotImplementedError("moe_tp_group_pg is not implemented.")
@property
def moe_ep_group_pg(self):
def moe_ep_group_pg(self) -> ProcessGroup:
raise NotImplementedError("moe_ep_group_pg is not implemented.")
def build_mesh(self):