diff --git a/tensorrt_llm/_torch/distributed/communicator.py b/tensorrt_llm/_torch/distributed/communicator.py index bbbdd961ce..dec1dc1c4b 100644 --- a/tensorrt_llm/_torch/distributed/communicator.py +++ b/tensorrt_llm/_torch/distributed/communicator.py @@ -20,7 +20,8 @@ from tensorrt_llm._mnnvl_utils import init_helix_cp_comm from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm, mpi_disabled, mpi_isend, mpi_isend_object, mpi_recv, mpi_recv_object, mpi_send, - mpi_send_object, torch_pybind11_abi) + mpi_send_object, mpi_world_size, + torch_pybind11_abi) from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE from tensorrt_llm.bindings.internal.process_group import init_pg from tensorrt_llm.logger import logger @@ -456,6 +457,19 @@ class MPIDist(Distributed): self._tp_comm = None self._pp_comm = None + def _validate_world_size(self): + """Validate world size before creating sub-communicators to prevent segfaults.""" + + if ENABLE_MULTI_DEVICE: + actual_world_size = mpi_world_size() + max_rank_needed = self.mapping.world_size + + if max_rank_needed > actual_world_size: + raise RuntimeError( + f"Mapping requires world_size={max_rank_needed} " + f"(tp_size={self.mapping.tp_size} * pp_size={self.mapping.pp_size} * cp_size={self.mapping.cp_size}), " + f"but MPI world size is only {actual_world_size}. ") + def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024): comm = mpi_comm() return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size) @@ -493,6 +507,7 @@ class MPIDist(Distributed): @property def tp_comm(self): if self._tp_comm is None: + self._validate_world_size() mapping = self.mapping new_group = mpi_comm().group.Incl(mapping.tp_group) self._tp_comm = mpi_comm().Create_group(new_group) @@ -501,6 +516,7 @@ class MPIDist(Distributed): @property def pp_comm(self): if self._pp_comm is None: + self._validate_world_size() mapping = self.mapping new_group = mpi_comm().group.Incl(mapping.pp_group) self._pp_comm = mpi_comm().Create_group(new_group) @@ -509,6 +525,7 @@ class MPIDist(Distributed): @property def cp_comm(self): if self._cp_comm is None: + self._validate_world_size() new_group = mpi_comm().group.Incl(self.mapping.cp_group) self._cp_comm = mpi_comm().Create_group(new_group) return self._cp_comm