[None][fix] Better error message for mismatched MPI world size (#11294)

Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
This commit is contained in:
jthomson04 2026-02-16 15:37:49 -08:00 committed by GitHub
parent cc4511997a
commit 2450188808
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,7 +20,8 @@ from tensorrt_llm._mnnvl_utils import init_helix_cp_comm
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
mpi_disabled, mpi_isend, mpi_isend_object,
mpi_recv, mpi_recv_object, mpi_send,
mpi_send_object, torch_pybind11_abi)
mpi_send_object, mpi_world_size,
torch_pybind11_abi)
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.bindings.internal.process_group import init_pg
from tensorrt_llm.logger import logger
@ -456,6 +457,19 @@ class MPIDist(Distributed):
self._tp_comm = None
self._pp_comm = None
def _validate_world_size(self):
"""Validate world size before creating sub-communicators to prevent segfaults."""
if ENABLE_MULTI_DEVICE:
actual_world_size = mpi_world_size()
max_rank_needed = self.mapping.world_size
if max_rank_needed > actual_world_size:
raise RuntimeError(
f"Mapping requires world_size={max_rank_needed} "
f"(tp_size={self.mapping.tp_size} * pp_size={self.mapping.pp_size} * cp_size={self.mapping.cp_size}), "
f"but MPI world size is only {actual_world_size}. ")
def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
comm = mpi_comm()
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
@ -493,6 +507,7 @@ class MPIDist(Distributed):
@property
def tp_comm(self):
if self._tp_comm is None:
self._validate_world_size()
mapping = self.mapping
new_group = mpi_comm().group.Incl(mapping.tp_group)
self._tp_comm = mpi_comm().Create_group(new_group)
@ -501,6 +516,7 @@ class MPIDist(Distributed):
@property
def pp_comm(self):
if self._pp_comm is None:
self._validate_world_size()
mapping = self.mapping
new_group = mpi_comm().group.Incl(mapping.pp_group)
self._pp_comm = mpi_comm().Create_group(new_group)
@ -509,6 +525,7 @@ class MPIDist(Distributed):
@property
def cp_comm(self):
if self._cp_comm is None:
self._validate_world_size()
new_group = mpi_comm().group.Incl(self.mapping.cp_group)
self._cp_comm = mpi_comm().Create_group(new_group)
return self._cp_comm