mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Async pp send for PPCommTorch. (#9976)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
af899d2fe7
commit
7588029763
@ -33,6 +33,7 @@ NcclCommunicatorOp::NcclCommunicatorOp(int64_t worldSize, int64_t rank)
|
||||
|
||||
void NcclCommunicatorOp::send(th::Tensor tensor, int64_t toRank) const
|
||||
{
|
||||
tensor.record_stream(at::cuda::getCurrentCUDAStream());
|
||||
auto ptr = static_cast<std::uint8_t*>(tensor.data_ptr());
|
||||
size_t const size = tensor.numel() * th::elementSize(th::typeMetaToScalarType(tensor.dtype()));
|
||||
tensorrt_llm::runtime::CudaStream cudaStream{at::cuda::getCurrentCUDAStream().stream(), mRank, false};
|
||||
@ -41,6 +42,7 @@ void NcclCommunicatorOp::send(th::Tensor tensor, int64_t toRank) const
|
||||
|
||||
void NcclCommunicatorOp::recv(th::Tensor& tensor, int64_t fromRank) const
|
||||
{
|
||||
tensor.record_stream(at::cuda::getCurrentCUDAStream());
|
||||
auto ptr = static_cast<std::uint8_t*>(tensor.data_ptr());
|
||||
size_t const size = tensor.numel() * th::elementSize(th::typeMetaToScalarType(tensor.dtype()));
|
||||
tensorrt_llm::runtime::CudaStream cudaStream{at::cuda::getCurrentCUDAStream().stream(), mRank, false};
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import get_process_group_ranks
|
||||
from torch.distributed import ProcessGroup, get_process_group_ranks
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -48,27 +48,27 @@ class DeviceMeshTopologyImpl(_MappingBaseForTypeCheck):
|
||||
# Access Torch ProcessGroup
|
||||
@property
|
||||
@require_device_mesh
|
||||
def tp_group_pg(self):
|
||||
def tp_group_pg(self) -> ProcessGroup:
|
||||
return self._get_mesh_dim_by_name('tp').get_group()
|
||||
|
||||
@property
|
||||
@require_device_mesh
|
||||
def pp_group_pg(self):
|
||||
def pp_group_pg(self) -> ProcessGroup:
|
||||
return self._get_mesh_dim_by_name('pp').get_group()
|
||||
|
||||
@property
|
||||
@require_device_mesh
|
||||
def cp_group_pg(self):
|
||||
def cp_group_pg(self) -> ProcessGroup:
|
||||
return self._get_mesh_dim_by_name('cp').get_group()
|
||||
|
||||
@property
|
||||
@require_device_mesh
|
||||
def moe_tp_group_pg(self):
|
||||
def moe_tp_group_pg(self) -> ProcessGroup:
|
||||
return self._get_mesh_dim_by_name('moe_tp').get_group()
|
||||
|
||||
@property
|
||||
@require_device_mesh
|
||||
def moe_ep_group_pg(self):
|
||||
def moe_ep_group_pg(self) -> ProcessGroup:
|
||||
return self._get_mesh_dim_by_name('moe_ep').get_group()
|
||||
|
||||
# Access rank
|
||||
|
||||
@ -16,7 +16,6 @@ try:
|
||||
except Exception:
|
||||
MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
|
||||
|
||||
from tensorrt_llm._torch.hostfunc import hostfunc
|
||||
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
|
||||
mpi_disabled, mpi_isend, mpi_isend_object,
|
||||
mpi_recv, mpi_recv_object, mpi_send,
|
||||
@ -783,26 +782,16 @@ class TorchDist(Distributed):
|
||||
return ret[0]
|
||||
|
||||
|
||||
class PPCommBase:
|
||||
class PPCommNCCL:
|
||||
|
||||
def __init__(self, global_mapping: Mapping):
|
||||
self.mapping = global_mapping
|
||||
self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(
|
||||
self.mapping.world_size,
|
||||
self.mapping.rank,
|
||||
)
|
||||
self.tensor_ready_event = torch.cuda.Event()
|
||||
self.send_stream = torch.cuda.Stream()
|
||||
self.tensor_cache = {}
|
||||
|
||||
def _cache_tensor(self, tensor: torch.Tensor):
|
||||
cache_id = id(tensor)
|
||||
self.tensor_cache[cache_id] = tensor
|
||||
|
||||
@hostfunc
|
||||
def _release_tensor(self, tensor: torch.Tensor):
|
||||
cache_id = id(tensor)
|
||||
del self.tensor_cache[cache_id]
|
||||
|
||||
@abstractmethod
|
||||
def direct_send(self, tensor: torch.Tensor, dest: int):
|
||||
raise NotImplementedError("direct_send is not implemented")
|
||||
|
||||
def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
|
||||
if dest is None:
|
||||
@ -811,30 +800,13 @@ class PPCommBase:
|
||||
# NCCL send kernel in send_stream cannot be captured,
|
||||
# so we send in the current stream instead in CUDA graph cases.
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
self.direct_send(tensor, dest)
|
||||
self.nccl_comm.send(tensor, dest)
|
||||
return
|
||||
|
||||
self.tensor_ready_event.record()
|
||||
with torch.cuda.stream(self.send_stream):
|
||||
self.tensor_ready_event.wait()
|
||||
# tensor may be released before NCCL send finished,
|
||||
# so we cache it first and release it after send finished.
|
||||
self._cache_tensor(tensor)
|
||||
self.direct_send(tensor, dest)
|
||||
self._release_tensor(tensor)
|
||||
|
||||
|
||||
class PPCommNCCL(PPCommBase):
|
||||
|
||||
def __init__(self, global_mapping: Mapping):
|
||||
super().__init__(global_mapping)
|
||||
self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(
|
||||
self.mapping.world_size,
|
||||
self.mapping.rank,
|
||||
)
|
||||
|
||||
def direct_send(self, tensor: torch.Tensor, dest: int):
|
||||
self.nccl_comm.send(tensor, dest)
|
||||
self.nccl_comm.send(tensor, dest)
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: Optional[int] = None):
|
||||
if src is None:
|
||||
@ -842,10 +814,10 @@ class PPCommNCCL(PPCommBase):
|
||||
self.nccl_comm.recv(tensor, src)
|
||||
|
||||
|
||||
class PPCommTorch(PPCommBase):
|
||||
class PPCommTorch:
|
||||
|
||||
def __init__(self, global_mapping: Mapping):
|
||||
super().__init__(global_mapping)
|
||||
self.mapping = global_mapping
|
||||
self.pg = self.mapping.pp_group_pg
|
||||
self.pg_group = self.mapping.pp_group
|
||||
|
||||
@ -853,21 +825,22 @@ class PPCommTorch(PPCommBase):
|
||||
assert global_rank in self.pg_group
|
||||
return self.pg_group.index(global_rank)
|
||||
|
||||
def direct_send(self, tensor: torch.Tensor, dest: int):
|
||||
self.pg.send([tensor], self._global_to_local_rank(dest), tag=0).wait()
|
||||
|
||||
# TODO: support async pp send for PPCommTorch
|
||||
def send(self, tensor: torch.Tensor, dest: Optional[int] = None):
|
||||
if dest is None:
|
||||
dest = self.mapping.next_pp_rank()
|
||||
|
||||
self.pg.send([tensor], self._global_to_local_rank(dest), tag=0).wait()
|
||||
work = self.pg.send([tensor], self._global_to_local_rank(dest), tag=0)
|
||||
# Send operation cannot be captured without blocking wait,
|
||||
# so we block the current stream in CUDA graph cases.
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
work.block_current_stream()
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: Optional[int] = None):
|
||||
if src is None:
|
||||
src = self.mapping.prev_pp_rank()
|
||||
|
||||
self.pg.recv([tensor], self._global_to_local_rank(src), tag=0).wait()
|
||||
work = self.pg.recv([tensor], self._global_to_local_rank(src), tag=0)
|
||||
work.block_current_stream()
|
||||
|
||||
|
||||
_pp_comm = None
|
||||
|
||||
@ -16,6 +16,7 @@ from enum import IntEnum
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from tensorrt_llm._torch.device_mesh import DeviceMeshTopologyImpl
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
@ -518,23 +519,23 @@ class Mapping(MappingBase):
|
||||
|
||||
# DeviceMesh specific methods
|
||||
@property
|
||||
def tp_group_pg(self):
|
||||
def tp_group_pg(self) -> ProcessGroup:
|
||||
raise NotImplementedError("tp_group_pg is not implemented.")
|
||||
|
||||
@property
|
||||
def pp_group_pg(self):
|
||||
def pp_group_pg(self) -> ProcessGroup:
|
||||
raise NotImplementedError("pp_group_pg is not implemented.")
|
||||
|
||||
@property
|
||||
def cp_group_pg(self):
|
||||
def cp_group_pg(self) -> ProcessGroup:
|
||||
raise NotImplementedError("cp_group_pg is not implemented.")
|
||||
|
||||
@property
|
||||
def moe_tp_group_pg(self):
|
||||
def moe_tp_group_pg(self) -> ProcessGroup:
|
||||
raise NotImplementedError("moe_tp_group_pg is not implemented.")
|
||||
|
||||
@property
|
||||
def moe_ep_group_pg(self):
|
||||
def moe_ep_group_pg(self) -> ProcessGroup:
|
||||
raise NotImplementedError("moe_ep_group_pg is not implemented.")
|
||||
|
||||
def build_mesh(self):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user