mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5791900][fix] Fix HelixCpMnnvlMemory init with PP
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
parent
c5d5af9e7f
commit
6f826eb867
@ -370,15 +370,32 @@ class HelixCpMnnvlMemory(MnnvlMemory):
|
||||
if cls.comm is not None:
|
||||
return cls.comm
|
||||
comm = mpi_comm().Split(
|
||||
mapping.pp_rank * mapping.tp_size * mapping.moe_tp_size
|
||||
+ mapping.tp_rank * mapping.moe_tp_size
|
||||
+ mapping.moe_tp_rank,
|
||||
mapping.pp_rank * mapping.tp_size + mapping.tp_rank,
|
||||
mapping.cp_rank,
|
||||
)
|
||||
cls.comm = comm
|
||||
return comm
|
||||
|
||||
|
||||
def init_helix_cp_comm(mapping: Mapping) -> None:
|
||||
"""Pre-initialize the Helix CP communicator.
|
||||
|
||||
This function MUST be called during model initialization when all ranks
|
||||
are synchronized (before any PP pipeline divergence). The MPI Split operation
|
||||
is collective and requires all ranks in the communicator to participate.
|
||||
|
||||
In PP (pipeline parallel) mode, different PP stages execute different parts
|
||||
of the model at different times. If the communicator is initialized lazily
|
||||
during the first forward pass, ranks in different PP stages may not reach
|
||||
the Split operation at the same time, causing a deadlock.
|
||||
|
||||
Args:
|
||||
mapping: The mapping object containing parallelism configuration.
|
||||
"""
|
||||
if mapping.has_cp_helix() and not mapping.cp_config.get("use_nccl_for_alltoall", True):
|
||||
HelixCpMnnvlMemory.get_comm(mapping)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoEAlltoallInfo:
|
||||
local_gather_indices: torch.Tensor
|
||||
|
||||
@ -16,6 +16,7 @@ try:
|
||||
except Exception:
|
||||
MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
|
||||
|
||||
from tensorrt_llm._mnnvl_utils import init_helix_cp_comm
|
||||
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
|
||||
mpi_disabled, mpi_isend, mpi_isend_object,
|
||||
mpi_recv, mpi_recv_object, mpi_send,
|
||||
@ -888,6 +889,7 @@ def init_pp_comm(mapping):
|
||||
_pp_comm = PPCommTorch(mapping)
|
||||
else:
|
||||
_pp_comm = PPCommNCCL(mapping)
|
||||
init_helix_cp_comm(mapping)
|
||||
|
||||
|
||||
@TorchDist.log_op
|
||||
|
||||
@ -871,6 +871,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 1, 4), (1, 2, 2),
|
||||
(2, 1, 2)],
|
||||
|
||||
@ -71,7 +71,7 @@ l0_dgx_b200:
|
||||
backend: pytorch
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
|
||||
@ -101,7 +101,7 @@ l0_dgx_b200:
|
||||
backend: pytorch
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user