mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5791900][fix] Fix HelixCpMnnvlMemory init with PP
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
parent
c5d5af9e7f
commit
6f826eb867
@ -370,15 +370,32 @@ class HelixCpMnnvlMemory(MnnvlMemory):
|
|||||||
if cls.comm is not None:
|
if cls.comm is not None:
|
||||||
return cls.comm
|
return cls.comm
|
||||||
comm = mpi_comm().Split(
|
comm = mpi_comm().Split(
|
||||||
mapping.pp_rank * mapping.tp_size * mapping.moe_tp_size
|
mapping.pp_rank * mapping.tp_size + mapping.tp_rank,
|
||||||
+ mapping.tp_rank * mapping.moe_tp_size
|
|
||||||
+ mapping.moe_tp_rank,
|
|
||||||
mapping.cp_rank,
|
mapping.cp_rank,
|
||||||
)
|
)
|
||||||
cls.comm = comm
|
cls.comm = comm
|
||||||
return comm
|
return comm
|
||||||
|
|
||||||
|
|
||||||
|
def init_helix_cp_comm(mapping: Mapping) -> None:
|
||||||
|
"""Pre-initialize the Helix CP communicator.
|
||||||
|
|
||||||
|
This function MUST be called during model initialization when all ranks
|
||||||
|
are synchronized (before any PP pipeline divergence). The MPI Split operation
|
||||||
|
is collective and requires all ranks in the communicator to participate.
|
||||||
|
|
||||||
|
In PP (pipeline parallel) mode, different PP stages execute different parts
|
||||||
|
of the model at different times. If the communicator is initialized lazily
|
||||||
|
during the first forward pass, ranks in different PP stages may not reach
|
||||||
|
the Split operation at the same time, causing a deadlock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mapping: The mapping object containing parallelism configuration.
|
||||||
|
"""
|
||||||
|
if mapping.has_cp_helix() and not mapping.cp_config.get("use_nccl_for_alltoall", True):
|
||||||
|
HelixCpMnnvlMemory.get_comm(mapping)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MoEAlltoallInfo:
|
class MoEAlltoallInfo:
|
||||||
local_gather_indices: torch.Tensor
|
local_gather_indices: torch.Tensor
|
||||||
|
|||||||
@ -16,6 +16,7 @@ try:
|
|||||||
except Exception:
|
except Exception:
|
||||||
MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
|
MPI = None # deferred; functions will error if used when ENABLE_MULTI_DEVICE is True
|
||||||
|
|
||||||
|
from tensorrt_llm._mnnvl_utils import init_helix_cp_comm
|
||||||
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
|
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_comm,
|
||||||
mpi_disabled, mpi_isend, mpi_isend_object,
|
mpi_disabled, mpi_isend, mpi_isend_object,
|
||||||
mpi_recv, mpi_recv_object, mpi_send,
|
mpi_recv, mpi_recv_object, mpi_send,
|
||||||
@ -888,6 +889,7 @@ def init_pp_comm(mapping):
|
|||||||
_pp_comm = PPCommTorch(mapping)
|
_pp_comm = PPCommTorch(mapping)
|
||||||
else:
|
else:
|
||||||
_pp_comm = PPCommNCCL(mapping)
|
_pp_comm = PPCommNCCL(mapping)
|
||||||
|
init_helix_cp_comm(mapping)
|
||||||
|
|
||||||
|
|
||||||
@TorchDist.log_op
|
@TorchDist.log_op
|
||||||
|
|||||||
@ -871,6 +871,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
|||||||
task = GSM8K(self.MODEL_NAME)
|
task = GSM8K(self.MODEL_NAME)
|
||||||
task.evaluate(llm)
|
task.evaluate(llm)
|
||||||
|
|
||||||
|
@skip_pre_blackwell
|
||||||
@pytest.mark.skip_less_device(8)
|
@pytest.mark.skip_less_device(8)
|
||||||
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 1, 4), (1, 2, 2),
|
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 1, 4), (1, 2, 2),
|
||||||
(2, 1, 2)],
|
(2, 1, 2)],
|
||||||
|
|||||||
@ -71,7 +71,7 @@ l0_dgx_b200:
|
|||||||
backend: pytorch
|
backend: pytorch
|
||||||
orchestrator: mpi
|
orchestrator: mpi
|
||||||
tests:
|
tests:
|
||||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
||||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
|
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
|
||||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
|
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
|
||||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
|
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
|
||||||
@ -101,7 +101,7 @@ l0_dgx_b200:
|
|||||||
backend: pytorch
|
backend: pytorch
|
||||||
orchestrator: mpi
|
orchestrator: mpi
|
||||||
tests:
|
tests:
|
||||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
||||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
|
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
|
||||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
|
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
|
||||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
|
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user