[https://nvbugs/5791900][fix] Fix HelixCpMnnvlMemory init with PP

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
Balaram Buddharaju 2026-01-08 05:25:00 +00:00
parent c5d5af9e7f
commit 6f826eb867
4 changed files with 25 additions and 5 deletions

View File

@ -370,15 +370,32 @@ class HelixCpMnnvlMemory(MnnvlMemory):
if cls.comm is not None: if cls.comm is not None:
return cls.comm return cls.comm
comm = mpi_comm().Split( comm = mpi_comm().Split(
mapping.pp_rank * mapping.tp_size * mapping.moe_tp_size mapping.pp_rank * mapping.tp_size + mapping.tp_rank,
+ mapping.tp_rank * mapping.moe_tp_size
+ mapping.moe_tp_rank,
mapping.cp_rank, mapping.cp_rank,
) )
cls.comm = comm cls.comm = comm
return comm return comm
def init_helix_cp_comm(mapping: Mapping) -> None:
"""Pre-initialize the Helix CP communicator.
This function MUST be called during model initialization when all ranks
are synchronized (before any PP pipeline divergence). The MPI Split operation
is collective and requires all ranks in the communicator to participate.
In PP (pipeline parallel) mode, different PP stages execute different parts
of the model at different times. If the communicator is initialized lazily
during the first forward pass, ranks in different PP stages may not reach
the Split operation at the same time, causing a deadlock.
Args:
mapping: The mapping object containing parallelism configuration.
"""
if mapping.has_cp_helix() and not mapping.cp_config.get("use_nccl_for_alltoall", True):
HelixCpMnnvlMemory.get_comm(mapping)
@dataclass @dataclass
class MoEAlltoallInfo: class MoEAlltoallInfo:
local_gather_indices: torch.Tensor local_gather_indices: torch.Tensor

View File

@ -16,6 +16,7 @@ try:
except Exception: except Exception:
MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
from tensorrt_llm._mnnvl_utils import init_helix_cp_comm
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm, from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
mpi_disabled, mpi_isend, mpi_isend_object, mpi_disabled, mpi_isend, mpi_isend_object,
mpi_recv, mpi_recv_object, mpi_send, mpi_recv, mpi_recv_object, mpi_send,
@ -888,6 +889,7 @@ def init_pp_comm(mapping):
_pp_comm = PPCommTorch(mapping) _pp_comm = PPCommTorch(mapping)
else: else:
_pp_comm = PPCommNCCL(mapping) _pp_comm = PPCommNCCL(mapping)
init_helix_cp_comm(mapping)
@TorchDist.log_op @TorchDist.log_op

View File

@ -871,6 +871,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
task = GSM8K(self.MODEL_NAME) task = GSM8K(self.MODEL_NAME)
task.evaluate(llm) task.evaluate(llm)
@skip_pre_blackwell
@pytest.mark.skip_less_device(8) @pytest.mark.skip_less_device(8)
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 1, 4), (1, 2, 2), @pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 1, 4), (1, 2, 2),
(2, 1, 2)], (2, 1, 2)],

View File

@ -71,7 +71,7 @@ l0_dgx_b200:
backend: pytorch backend: pytorch
orchestrator: mpi orchestrator: mpi
tests: tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
@ -101,7 +101,7 @@ l0_dgx_b200:
backend: pytorch backend: pytorch
orchestrator: mpi orchestrator: mpi
tests: tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)