diff --git a/tensorrt_llm/_mnnvl_utils.py b/tensorrt_llm/_mnnvl_utils.py index 5d168447f9..2436f30c82 100644 --- a/tensorrt_llm/_mnnvl_utils.py +++ b/tensorrt_llm/_mnnvl_utils.py @@ -370,15 +370,32 @@ class HelixCpMnnvlMemory(MnnvlMemory): if cls.comm is not None: return cls.comm comm = mpi_comm().Split( - mapping.pp_rank * mapping.tp_size * mapping.moe_tp_size - + mapping.tp_rank * mapping.moe_tp_size - + mapping.moe_tp_rank, + mapping.pp_rank * mapping.tp_size + mapping.tp_rank, mapping.cp_rank, ) cls.comm = comm return comm +def init_helix_cp_comm(mapping: Mapping) -> None: + """Pre-initialize the Helix CP communicator. + + This function MUST be called during model initialization when all ranks + are synchronized (before any PP pipeline divergence). The MPI Split operation + is collective and requires all ranks in the communicator to participate. + + In PP (pipeline parallel) mode, different PP stages execute different parts + of the model at different times. If the communicator is initialized lazily + during the first forward pass, ranks in different PP stages may not reach + the Split operation at the same time, causing a deadlock. + + Args: + mapping: The mapping object containing parallelism configuration. + """ + if mapping.has_cp_helix() and not mapping.cp_config.get("use_nccl_for_alltoall", True): + HelixCpMnnvlMemory.get_comm(mapping) + + @dataclass class MoEAlltoallInfo: local_gather_indices: torch.Tensor diff --git a/tensorrt_llm/_torch/distributed/communicator.py b/tensorrt_llm/_torch/distributed/communicator.py index 09bbc234ee..20401b5a24 100644 --- a/tensorrt_llm/_torch/distributed/communicator.py +++ b/tensorrt_llm/_torch/distributed/communicator.py @@ -16,6 +16,7 @@ try: except Exception: MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True +from tensorrt_llm._mnnvl_utils import init_helix_cp_comm from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm, mpi_disabled, mpi_isend, mpi_isend_object, mpi_recv, mpi_recv_object, mpi_send, @@ -888,6 +889,7 @@ def init_pp_comm(mapping): _pp_comm = PPCommTorch(mapping) else: _pp_comm = PPCommNCCL(mapping) + init_helix_cp_comm(mapping) @TorchDist.log_op diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 2ba2ee1bfe..14bb0cb811 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -871,6 +871,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @skip_pre_blackwell @pytest.mark.skip_less_device(8) @pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 1, 4), (1, 2, 2), (2, 1, 2)], diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index b7a31e57d9..a914a00c53 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -71,7 +71,7 @@ l0_dgx_b200: backend: pytorch orchestrator: mpi tests: - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60) @@ -101,7 +101,7 @@ l0_dgx_b200: backend: pytorch orchestrator: mpi tests: - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)