[https://nvbugs/5791900][fix] Fix HelixCpMnnvlMemory init with PP

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
Balaram Buddharaju 2026-01-08 05:25:00 +00:00
parent c5d5af9e7f
commit 6f826eb867
4 changed files with 25 additions and 5 deletions

View File

@ -370,15 +370,32 @@ class HelixCpMnnvlMemory(MnnvlMemory):
if cls.comm is not None:
return cls.comm
comm = mpi_comm().Split(
mapping.pp_rank * mapping.tp_size * mapping.moe_tp_size
+ mapping.tp_rank * mapping.moe_tp_size
+ mapping.moe_tp_rank,
mapping.pp_rank * mapping.tp_size + mapping.tp_rank,
mapping.cp_rank,
)
cls.comm = comm
return comm
def init_helix_cp_comm(mapping: Mapping) -> None:
"""Pre-initialize the Helix CP communicator.
This function MUST be called during model initialization when all ranks
are synchronized (before any PP pipeline divergence). The MPI Split operation
is collective and requires all ranks in the communicator to participate.
In PP (pipeline parallel) mode, different PP stages execute different parts
of the model at different times. If the communicator is initialized lazily
during the first forward pass, ranks in different PP stages may not reach
the Split operation at the same time, causing a deadlock.
Args:
mapping: The mapping object containing parallelism configuration.
"""
if mapping.has_cp_helix() and not mapping.cp_config.get("use_nccl_for_alltoall", True):
HelixCpMnnvlMemory.get_comm(mapping)
@dataclass
class MoEAlltoallInfo:
local_gather_indices: torch.Tensor

View File

@ -16,6 +16,7 @@ try:
except Exception:
MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
from tensorrt_llm._mnnvl_utils import init_helix_cp_comm
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
mpi_disabled, mpi_isend, mpi_isend_object,
mpi_recv, mpi_recv_object, mpi_send,
@ -888,6 +889,7 @@ def init_pp_comm(mapping):
_pp_comm = PPCommTorch(mapping)
else:
_pp_comm = PPCommNCCL(mapping)
init_helix_cp_comm(mapping)
@TorchDist.log_op

View File

@ -871,6 +871,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@skip_pre_blackwell
@pytest.mark.skip_less_device(8)
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 1, 4), (1, 2, 2),
(2, 1, 2)],

View File

@ -71,7 +71,7 @@ l0_dgx_b200:
backend: pytorch
orchestrator: mpi
tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
@ -101,7 +101,7 @@ l0_dgx_b200:
backend: pytorch
orchestrator: mpi
tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)