mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-18 08:45:05 +08:00
[None][fix] Better error message for mismatched MPI world size (#11294)
Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
This commit is contained in:
parent
cc4511997a
commit
2450188808
@ -20,7 +20,8 @@ from tensorrt_llm._mnnvl_utils import init_helix_cp_comm
|
||||
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
|
||||
mpi_disabled, mpi_isend, mpi_isend_object,
|
||||
mpi_recv, mpi_recv_object, mpi_send,
|
||||
mpi_send_object, torch_pybind11_abi)
|
||||
mpi_send_object, mpi_world_size,
|
||||
torch_pybind11_abi)
|
||||
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
|
||||
from tensorrt_llm.bindings.internal.process_group import init_pg
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -456,6 +457,19 @@ class MPIDist(Distributed):
|
||||
self._tp_comm = None
|
||||
self._pp_comm = None
|
||||
|
||||
def _validate_world_size(self):
|
||||
"""Validate world size before creating sub-communicators to prevent segfaults."""
|
||||
|
||||
if ENABLE_MULTI_DEVICE:
|
||||
actual_world_size = mpi_world_size()
|
||||
max_rank_needed = self.mapping.world_size
|
||||
|
||||
if max_rank_needed > actual_world_size:
|
||||
raise RuntimeError(
|
||||
f"Mapping requires world_size={max_rank_needed} "
|
||||
f"(tp_size={self.mapping.tp_size} * pp_size={self.mapping.pp_size} * cp_size={self.mapping.cp_size}), "
|
||||
f"but MPI world size is only {actual_world_size}. ")
|
||||
|
||||
def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
|
||||
comm = mpi_comm()
|
||||
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
|
||||
@ -493,6 +507,7 @@ class MPIDist(Distributed):
|
||||
@property
|
||||
def tp_comm(self):
|
||||
if self._tp_comm is None:
|
||||
self._validate_world_size()
|
||||
mapping = self.mapping
|
||||
new_group = mpi_comm().group.Incl(mapping.tp_group)
|
||||
self._tp_comm = mpi_comm().Create_group(new_group)
|
||||
@ -501,6 +516,7 @@ class MPIDist(Distributed):
|
||||
@property
|
||||
def pp_comm(self):
|
||||
if self._pp_comm is None:
|
||||
self._validate_world_size()
|
||||
mapping = self.mapping
|
||||
new_group = mpi_comm().group.Incl(mapping.pp_group)
|
||||
self._pp_comm = mpi_comm().Create_group(new_group)
|
||||
@ -509,6 +525,7 @@ class MPIDist(Distributed):
|
||||
@property
|
||||
def cp_comm(self):
|
||||
if self._cp_comm is None:
|
||||
self._validate_world_size()
|
||||
new_group = mpi_comm().group.Incl(self.mapping.cp_group)
|
||||
self._cp_comm = mpi_comm().Create_group(new_group)
|
||||
return self._cp_comm
|
||||
|
||||
Loading…
Reference in New Issue
Block a user