mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9467][fix] Fix PP+CP combination with helix parallelism (#10312)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
parent
5845951538
commit
4a1b742aa0
@ -1506,7 +1506,7 @@ class AutoTuner:
|
||||
"""Broadcast tactics from root rank to all other ranks."""
|
||||
cache_data = self.profiling_cache.get_specific_custom_op(custom_op)
|
||||
root = 0
|
||||
cache_data = self._dist.tp_broadcast(obj=cache_data, root=root)
|
||||
cache_data = self._dist.tp_cp_broadcast(obj=cache_data, root=root)
|
||||
|
||||
self.profiling_cache.merge_cache_data(cache_data)
|
||||
|
||||
|
||||
@ -116,6 +116,26 @@ class Distributed(ABC):
|
||||
def allgather(self, obj, root=0):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tp_broadcast(self, obj, root=0, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cp_broadcast(self, obj, root=0, **kwargs):
|
||||
pass
|
||||
|
||||
def tp_cp_broadcast(self, obj, root=0, **kwargs):
|
||||
"""Broadcast object across both TP and CP groups.
|
||||
|
||||
This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
|
||||
First broadcasts within the TP group, then within the CP group.
|
||||
"""
|
||||
if self.tp_size > 1:
|
||||
obj = self.tp_broadcast(obj, root=root, **kwargs)
|
||||
if self.cp_size > 1:
|
||||
obj = self.cp_broadcast(obj, root=root, **kwargs)
|
||||
return obj
|
||||
|
||||
|
||||
def safe_broadcast(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
|
||||
"""
|
||||
@ -407,6 +427,14 @@ class MPIDist(Distributed):
|
||||
def cp_allgather(self, obj):
|
||||
return self.cp_comm.allgather(obj)
|
||||
|
||||
def cp_broadcast(self,
|
||||
obj,
|
||||
root=0,
|
||||
chunk_size: int = 4 * 1024 * 1024,
|
||||
**kwargs):
|
||||
comm = self.cp_comm
|
||||
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
|
||||
|
||||
def tp_allgather(self, obj):
|
||||
return self.tp_comm.allgather(obj)
|
||||
|
||||
@ -414,7 +442,11 @@ class MPIDist(Distributed):
|
||||
comm = self.tp_comm
|
||||
return safe_gather(comm, obj, root=root, chunk_size=chunk_size)
|
||||
|
||||
def tp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
|
||||
def tp_broadcast(self,
|
||||
obj,
|
||||
root=0,
|
||||
chunk_size: int = 4 * 1024 * 1024,
|
||||
**kwargs):
|
||||
comm = self.tp_comm
|
||||
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
|
||||
|
||||
@ -699,7 +731,7 @@ class TorchDist(Distributed):
|
||||
return output_list
|
||||
|
||||
@log_op
|
||||
def tp_broadcast(self, obj, root=0):
|
||||
def tp_broadcast(self, obj, root=0, **kwargs):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
dist.broadcast(obj, src=root, group=self.mapping.tp_group_pg)
|
||||
return obj
|
||||
@ -712,6 +744,20 @@ class TorchDist(Distributed):
|
||||
device=torch.device("cpu"))
|
||||
return ret[0]
|
||||
|
||||
@log_op
|
||||
def cp_broadcast(self, obj, root=0, **kwargs):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
dist.broadcast(obj, src=root, group=self.mapping.cp_group_pg)
|
||||
return obj
|
||||
else:
|
||||
ret = [obj]
|
||||
torch.distributed.broadcast_object_list(
|
||||
ret,
|
||||
src=root,
|
||||
group=self.mapping.cp_group_pg,
|
||||
device=torch.device("cpu"))
|
||||
return ret[0]
|
||||
|
||||
@log_op
|
||||
def pp_allgather(self, obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
|
||||
@ -587,9 +587,10 @@ class ExecutorRequestQueue:
|
||||
if not self.dist.has_pp:
|
||||
return self.dist.broadcast(payloads, root=0)
|
||||
|
||||
# Broadcast within first tp group before send/recv chain to other tp groups
|
||||
if self.dist.tp_size > 1 and self.dist.is_first_pp_rank:
|
||||
payloads = self.dist.tp_broadcast(payloads, root=0)
|
||||
# Broadcast within first PP stage before send/recv chain to other PP stages.
|
||||
# This needs to cover both TP and CP ranks within the first PP stage.
|
||||
if self.dist.is_first_pp_rank:
|
||||
payloads = self.dist.tp_cp_broadcast(payloads, root=0)
|
||||
|
||||
# Tag for communication
|
||||
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts
|
||||
|
||||
@ -871,7 +871,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 2, 2), (2, 1, 2)],
|
||||
ids=["pp1tp2cp2", "pp2tp1cp2"])
|
||||
@pytest.mark.parametrize("cuda_graph_config", [
|
||||
None,
|
||||
{
|
||||
@ -888,8 +890,10 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
"cudagraph:with_padding"
|
||||
])
|
||||
@pytest.mark.parametrize("comms_medium", ["fifo", "nccl"])
|
||||
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
|
||||
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
|
||||
gen_pp, gen_tp, gen_cp):
|
||||
use_nccl_for_alltoall = comms_medium == "nccl"
|
||||
gen_ep = gen_tp * gen_cp
|
||||
kv_cache_config = {
|
||||
"free_gpu_memory_fraction": 0.5,
|
||||
"enable_block_reuse": False,
|
||||
@ -898,20 +902,22 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
}
|
||||
ctx_server_config = {
|
||||
"pipeline_parallel_size": 1,
|
||||
"tensor_parallel_size": 2,
|
||||
"tensor_parallel_size": 4,
|
||||
"context_parallel_size": 1,
|
||||
"disable_overlap_scheduler": True,
|
||||
"kv_cache_config": kv_cache_config,
|
||||
"enable_chunked_prefill": False,
|
||||
"cuda_graph_config": None,
|
||||
"cache_transceiver_config": {
|
||||
"backend": "UCX"
|
||||
"backend": "UCX",
|
||||
"max_tokens_in_buffer": 8192,
|
||||
},
|
||||
}
|
||||
gen_server_config = {
|
||||
"tensor_parallel_size": 1,
|
||||
"pipeline_parallel_size": 1,
|
||||
"context_parallel_size": 2,
|
||||
"tensor_parallel_size": gen_tp,
|
||||
"pipeline_parallel_size": gen_pp,
|
||||
"context_parallel_size": gen_cp,
|
||||
"moe_expert_parallel_size": gen_ep,
|
||||
"cp_config": {
|
||||
"cp_type": "HELIX",
|
||||
"tokens_per_block": 32,
|
||||
@ -922,7 +928,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
"enable_chunked_prefill": False,
|
||||
"cuda_graph_config": cuda_graph_config,
|
||||
"cache_transceiver_config": {
|
||||
"backend": "UCX"
|
||||
"backend": "UCX",
|
||||
"max_tokens_in_buffer": 8192,
|
||||
},
|
||||
}
|
||||
disaggregated_server_config = {
|
||||
|
||||
@ -540,14 +540,14 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
|
||||
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
|
||||
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
|
||||
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False]
|
||||
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency]
|
||||
|
||||
@ -66,6 +66,7 @@ l0_dgx_b200:
|
||||
backend: pytorch
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60)
|
||||
@ -92,6 +93,7 @@ l0_dgx_b200:
|
||||
backend: pytorch
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (60)
|
||||
|
||||
@ -72,8 +72,6 @@ l0_gb200_multi_gpus:
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
|
||||
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90)
|
||||
- condition:
|
||||
ranges:
|
||||
@ -89,10 +87,6 @@ l0_gb200_multi_gpus:
|
||||
stage: post_merge
|
||||
backend: pytorch
|
||||
tests:
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
|
||||
|
||||
@ -523,6 +523,8 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_sof
|
||||
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] SKIP (https://nvbugs/5774869)
|
||||
triton_server/test_triton.py::test_llava_onevision[llava_onevision] SKIP (https://nvbugs/5775205)
|
||||
triton_server/test_triton.py::test_gpt_ib_lad[gpt-ib-lad] SKIP (https://nvbugs/5775223)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=False] SKIP (https://nvbugs/5777044)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=True] SKIP (https://nvbugs/5777044)
|
||||
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=True] SKIP (https://nvbugs/5740377)
|
||||
unittest/_torch/modules/test_fused_moe.py::test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu[MoEWeightLoadingMode.FUSED_GATE_UP_PROJ-DefaultMoeRoutingMethod-1] SKIP (https://nvbugs/5775256)
|
||||
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py::test_ep_shard[3-2] SKIP (https://nvbugs/5777041)
|
||||
|
||||
265
tests/unittest/_torch/distributed/test_cp_broadcast.py
Normal file
265
tests/unittest/_torch/distributed/test_cp_broadcast.py
Normal file
@ -0,0 +1,265 @@
|
||||
"""
|
||||
Tests for cp_broadcast functionality in both MPIDist and TorchDist.
|
||||
|
||||
This module tests the context parallelism broadcast operation which is used
|
||||
when CP (context parallelism) is enabled (e.g., in Helix parallelism).
|
||||
|
||||
For MPIDist tests, run with mpirun:
|
||||
mpirun -n 2 python -m pytest tests/unittest/_torch/distributed/test_cp_broadcast.py -v
|
||||
|
||||
For TorchDist tests, see test_ops.py which uses Ray for distributed testing.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm._torch.distributed import MPIDist
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
|
||||
def get_mpi_info():
|
||||
"""Get MPI rank and world size, returns (0, 1) if MPI is not available."""
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
return comm.Get_rank(), comm.Get_size()
|
||||
except ImportError:
|
||||
return 0, 1
|
||||
|
||||
|
||||
def skip_if_not_mpi():
|
||||
"""Skip test if not running under MPI with sufficient ranks."""
|
||||
rank, world_size = get_mpi_info()
|
||||
if world_size < 2:
|
||||
pytest.skip("Test requires at least 2 MPI ranks (run with mpirun -n 2)")
|
||||
|
||||
|
||||
class TestMPIDistCpBroadcast:
|
||||
"""Tests for MPIDist.cp_broadcast functionality."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
"""Set up MPI environment and mapping for each test."""
|
||||
skip_if_not_mpi()
|
||||
self.rank, self.world_size = get_mpi_info()
|
||||
|
||||
# Set up mapping with CP enabled (cp_size = world_size, tp_size = 1)
|
||||
self.mapping = Mapping(
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
tp_size=1,
|
||||
cp_size=self.world_size,
|
||||
pp_size=1,
|
||||
)
|
||||
self.dist = MPIDist(mapping=self.mapping)
|
||||
|
||||
def test_broadcast_numpy_array(self):
|
||||
"""Test broadcasting a numpy array via cp_broadcast."""
|
||||
root = 0
|
||||
shape = (64, 128)
|
||||
|
||||
if self.mapping.cp_rank == root:
|
||||
# Root rank creates the data to broadcast
|
||||
data = np.random.randn(*shape).astype(np.float32)
|
||||
else:
|
||||
# Non-root ranks have empty/zero data
|
||||
data = np.zeros(shape, dtype=np.float32)
|
||||
|
||||
# Store original data from root for verification
|
||||
from mpi4py import MPI
|
||||
|
||||
expected = np.zeros(shape, dtype=np.float32)
|
||||
MPI.COMM_WORLD.Bcast(data if self.mapping.cp_rank == root else expected, root=root)
|
||||
if self.mapping.cp_rank == root:
|
||||
expected = data.copy()
|
||||
|
||||
# Perform cp_broadcast
|
||||
result = self.dist.cp_broadcast(data, root=root)
|
||||
|
||||
# Verify all ranks have the same data
|
||||
np.testing.assert_array_almost_equal(result, expected)
|
||||
|
||||
def test_broadcast_python_dict(self):
|
||||
"""Test broadcasting a Python dictionary via cp_broadcast."""
|
||||
root = 0
|
||||
|
||||
if self.mapping.cp_rank == root:
|
||||
obj = {
|
||||
"model_name": "llama",
|
||||
"batch_size": 32,
|
||||
"tokens": [1, 2, 3, 4, 5],
|
||||
"config": {"hidden_size": 4096, "num_layers": 32},
|
||||
}
|
||||
else:
|
||||
obj = None
|
||||
|
||||
result = self.dist.cp_broadcast(obj, root=root)
|
||||
|
||||
# Verify all ranks received the correct object
|
||||
assert result["model_name"] == "llama"
|
||||
assert result["batch_size"] == 32
|
||||
assert result["tokens"] == [1, 2, 3, 4, 5]
|
||||
assert result["config"]["hidden_size"] == 4096
|
||||
assert result["config"]["num_layers"] == 32
|
||||
|
||||
def test_broadcast_python_list(self):
|
||||
"""Test broadcasting a Python list via cp_broadcast."""
|
||||
root = 0
|
||||
|
||||
if self.mapping.cp_rank == root:
|
||||
obj = ["request1", "request2", {"id": 123, "data": [1, 2, 3]}]
|
||||
else:
|
||||
obj = None
|
||||
|
||||
result = self.dist.cp_broadcast(obj, root=root)
|
||||
|
||||
assert result == ["request1", "request2", {"id": 123, "data": [1, 2, 3]}]
|
||||
|
||||
def test_broadcast_from_non_zero_root(self):
|
||||
"""Test broadcasting from a non-zero root rank."""
|
||||
if self.world_size < 2:
|
||||
pytest.skip("Need at least 2 ranks to test non-zero root")
|
||||
|
||||
root = 1 # Broadcast from rank 1
|
||||
|
||||
if self.mapping.cp_rank == root:
|
||||
obj = {"source": "rank1", "value": 42}
|
||||
else:
|
||||
obj = None
|
||||
|
||||
result = self.dist.cp_broadcast(obj, root=root)
|
||||
|
||||
assert result["source"] == "rank1"
|
||||
assert result["value"] == 42
|
||||
|
||||
def test_broadcast_large_object(self):
|
||||
"""Test broadcasting a large object that may require chunking."""
|
||||
root = 0
|
||||
# Create a large list to test chunking behavior
|
||||
large_size = 100000
|
||||
|
||||
if self.mapping.cp_rank == root:
|
||||
obj = list(range(large_size))
|
||||
else:
|
||||
obj = None
|
||||
|
||||
result = self.dist.cp_broadcast(obj, root=root)
|
||||
|
||||
assert len(result) == large_size
|
||||
assert result[0] == 0
|
||||
assert result[-1] == large_size - 1
|
||||
|
||||
def test_broadcast_string(self):
|
||||
"""Test broadcasting a simple string via cp_broadcast."""
|
||||
root = 0
|
||||
|
||||
if self.mapping.cp_rank == root:
|
||||
obj = "Hello from root rank!"
|
||||
else:
|
||||
obj = None
|
||||
|
||||
result = self.dist.cp_broadcast(obj, root=root)
|
||||
|
||||
assert result == "Hello from root rank!"
|
||||
|
||||
|
||||
# Additional integration-style test that can be run standalone
|
||||
def test_mpi_cp_broadcast_integration():
|
||||
"""
|
||||
Integration test for MPIDist cp_broadcast.
|
||||
"""
|
||||
rank, world_size = get_mpi_info()
|
||||
if world_size < 2:
|
||||
pytest.skip("Test requires at least 2 MPI ranks")
|
||||
|
||||
# Create mapping with CP enabled
|
||||
mapping = Mapping(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
tp_size=1,
|
||||
cp_size=world_size,
|
||||
pp_size=1,
|
||||
)
|
||||
dist = MPIDist(mapping=mapping)
|
||||
|
||||
# Test 1: Broadcast dict
|
||||
if mapping.cp_rank == 0:
|
||||
payload = {"requests": [{"id": i} for i in range(10)]}
|
||||
else:
|
||||
payload = None
|
||||
|
||||
result = dist.cp_broadcast(payload, root=0)
|
||||
assert len(result["requests"]) == 10
|
||||
assert result["requests"][0]["id"] == 0
|
||||
|
||||
# Test 2: Broadcast numpy array
|
||||
shape = (32, 64)
|
||||
if mapping.cp_rank == 0:
|
||||
arr = np.ones(shape, dtype=np.float32) * (rank + 1)
|
||||
else:
|
||||
arr = np.zeros(shape, dtype=np.float32)
|
||||
|
||||
result = dist.cp_broadcast(arr, root=0)
|
||||
expected_val = 1.0 # From rank 0
|
||||
np.testing.assert_array_almost_equal(result, np.ones(shape) * expected_val)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Allow running directly with mpirun
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
|
||||
class TestMPIDistTpCpBroadcast:
|
||||
"""Tests for MPIDist.tp_cp_broadcast functionality."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
"""Set up MPI environment and mapping for each test."""
|
||||
skip_if_not_mpi()
|
||||
self.rank, self.world_size = get_mpi_info()
|
||||
|
||||
# Set up mapping with both TP and CP enabled
|
||||
# For 2 ranks: tp_size=1, cp_size=2 (tp_cp_broadcast will only do cp_broadcast)
|
||||
self.mapping = Mapping(
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
tp_size=1,
|
||||
cp_size=self.world_size,
|
||||
pp_size=1,
|
||||
)
|
||||
self.dist = MPIDist(mapping=self.mapping)
|
||||
|
||||
def test_tp_cp_broadcast_python_dict(self):
|
||||
"""Test broadcasting a Python dictionary via tp_cp_broadcast."""
|
||||
root = 0
|
||||
|
||||
# Only rank 0 in both TP and CP groups should have the object
|
||||
if self.mapping.tp_rank == root and self.mapping.cp_rank == root:
|
||||
obj = {
|
||||
"model_name": "llama",
|
||||
"batch_size": 32,
|
||||
"tokens": [1, 2, 3, 4, 5],
|
||||
}
|
||||
else:
|
||||
obj = None
|
||||
|
||||
result = self.dist.tp_cp_broadcast(obj, root=root)
|
||||
|
||||
# Verify all ranks received the correct object
|
||||
assert result["model_name"] == "llama"
|
||||
assert result["batch_size"] == 32
|
||||
assert result["tokens"] == [1, 2, 3, 4, 5]
|
||||
|
||||
def test_tp_cp_broadcast_python_list(self):
|
||||
"""Test broadcasting a Python list via tp_cp_broadcast."""
|
||||
root = 0
|
||||
|
||||
if self.mapping.tp_rank == root and self.mapping.cp_rank == root:
|
||||
obj = ["request1", "request2", {"id": 123, "data": [1, 2, 3]}]
|
||||
else:
|
||||
obj = None
|
||||
|
||||
result = self.dist.tp_cp_broadcast(obj, root=root)
|
||||
|
||||
assert result == ["request1", "request2", {"id": 123, "data": [1, 2, 3]}]
|
||||
@ -258,3 +258,192 @@ def test_allreduce_pg_op(setup_ray_cluster, seq_len, hidden_size):
|
||||
])
|
||||
for r in results:
|
||||
assert r is True
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class CpBroadcastTest:
|
||||
"""Test worker for cp_broadcast operations with context parallelism."""
|
||||
|
||||
def __init__(self, rank, world_size, tp_size, cp_size):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.tp_size = tp_size
|
||||
self.cp_size = cp_size
|
||||
self.master_address = os.environ["MASTER_ADDR"]
|
||||
|
||||
assert len(ray.get_gpu_ids()) == 1
|
||||
self.gpu = int(ray.get_gpu_ids()[0])
|
||||
from tensorrt_llm.executor.ray_gpu_worker import RayWorkerWrapper
|
||||
local_gpu = RayWorkerWrapper.physical_to_local_id(self.gpu)
|
||||
torch.cuda.set_device(local_gpu)
|
||||
|
||||
def _create_tcp_store(self,
|
||||
port: Optional[int] = None
|
||||
) -> torch.distributed.TCPStore:
|
||||
actual_port = port if port is not None else 0
|
||||
return torch.distributed.TCPStore(host_name=self.master_address,
|
||||
port=actual_port,
|
||||
world_size=self.world_size,
|
||||
is_master=(self.rank == 0),
|
||||
wait_for_workers=False)
|
||||
|
||||
def setup_tcp_store(self):
|
||||
if self.rank != 0:
|
||||
raise RuntimeError("Only the master worker can setup TCP store")
|
||||
self.store = self._create_tcp_store()
|
||||
return self.store.port
|
||||
|
||||
def setup_distributed_env(self, port: int):
|
||||
if self.rank != 0:
|
||||
self.store = self._create_tcp_store(port)
|
||||
|
||||
torch.distributed.init_process_group(backend="cuda:nccl,cpu:gloo",
|
||||
store=self.store,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank)
|
||||
self.mapping = Mapping(world_size=self.world_size,
|
||||
gpus_per_node=self.world_size,
|
||||
tp_size=self.tp_size,
|
||||
cp_size=self.cp_size,
|
||||
rank=self.rank)
|
||||
self.dist = TorchDist(self.mapping)
|
||||
|
||||
def run_tensor_broadcast(self, root_tensor: torch.Tensor, root: int = 0):
|
||||
"""Test broadcasting a tensor via cp_broadcast."""
|
||||
cp_rank = self.mapping.cp_rank
|
||||
if cp_rank == root:
|
||||
# Root rank has the tensor to broadcast.
|
||||
tensor = root_tensor.cuda()
|
||||
else:
|
||||
# Non-root ranks start with zeros.
|
||||
tensor = torch.zeros_like(root_tensor).cuda()
|
||||
|
||||
result = self.dist.cp_broadcast(tensor, root=root)
|
||||
|
||||
# After broadcast, all CP ranks should have the same tensor.
|
||||
expected = root_tensor.cuda()
|
||||
return torch.allclose(result, expected)
|
||||
|
||||
def run_object_broadcast(self, root_obj, root: int = 0):
|
||||
"""Test broadcasting a non-tensor object via cp_broadcast."""
|
||||
cp_rank = self.mapping.cp_rank
|
||||
if cp_rank == root:
|
||||
obj = root_obj
|
||||
else:
|
||||
obj = None
|
||||
|
||||
result = self.dist.cp_broadcast(obj, root=root)
|
||||
|
||||
# After broadcast, all CP ranks should have the same object.
|
||||
return result == root_obj
|
||||
|
||||
def run_tp_cp_broadcast(self, root_obj, root: int = 0):
|
||||
"""Test broadcasting an object via tp_cp_broadcast."""
|
||||
# For tp_cp_broadcast, only rank 0 in both TP and CP should have the object.
|
||||
tp_rank = self.mapping.tp_rank
|
||||
cp_rank = self.mapping.cp_rank
|
||||
if tp_rank == root and cp_rank == root:
|
||||
obj = root_obj
|
||||
else:
|
||||
obj = None
|
||||
|
||||
result = self.dist.tp_cp_broadcast(obj, root=root)
|
||||
|
||||
# After broadcast, all TP and CP ranks should have the same object.
|
||||
return result == root_obj
|
||||
|
||||
|
||||
@pytest.mark.gpu2
|
||||
@pytest.mark.parametrize("hidden_size", [128, 512], ids=lambda x: f"hidden:{x}")
|
||||
@pytest.mark.parametrize("seq_len", [16, 32], ids=lambda x: f"seqlen:{x}")
|
||||
def test_cp_broadcast_tensor(setup_ray_cluster, seq_len, hidden_size):
|
||||
"""Test TorchDist.cp_broadcast with tensor data."""
|
||||
torch.manual_seed(42)
|
||||
dtype = torch.bfloat16
|
||||
world_size = 2
|
||||
tp_size = 1
|
||||
cp_size = 2 # Enable context parallelism.
|
||||
|
||||
# Create tensor to broadcast from root.
|
||||
root_tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
|
||||
|
||||
runtime_env = ray.runtime_env.RuntimeEnv()
|
||||
runtime_env["env_vars"] = os.environ.copy()
|
||||
runtime_env["env_vars"].update({
|
||||
"TLLM_DISABLE_MPI": "1",
|
||||
"MASTER_ADDR": "127.0.0.1",
|
||||
})
|
||||
|
||||
remote_tests = []
|
||||
for rank in range(world_size):
|
||||
remote_tests.append(
|
||||
CpBroadcastTest.options(runtime_env=runtime_env).remote(
|
||||
rank, world_size, tp_size, cp_size))
|
||||
|
||||
ray.get([test.__ray_ready__.remote() for test in remote_tests])
|
||||
|
||||
port = ray.get(remote_tests[0].setup_tcp_store.remote())
|
||||
ray.get([test.setup_distributed_env.remote(port) for test in remote_tests])
|
||||
|
||||
# Test broadcasting from root=0.
|
||||
results = ray.get([
|
||||
test.run_tensor_broadcast.remote(root_tensor, root=0)
|
||||
for test in remote_tests
|
||||
])
|
||||
for r in results:
|
||||
assert r is True, "Tensor broadcast from root=0 failed"
|
||||
|
||||
|
||||
@pytest.mark.gpu2
|
||||
@pytest.mark.parametrize("test_object", [
|
||||
{
|
||||
"key1": "value1",
|
||||
"key2": [1, 2, 3]
|
||||
},
|
||||
["item1", "item2", {
|
||||
"nested": True
|
||||
}],
|
||||
"simple_string",
|
||||
],
|
||||
ids=["dict", "list", "string"])
|
||||
@pytest.mark.parametrize("broadcast_method", [
|
||||
"run_object_broadcast",
|
||||
"run_tp_cp_broadcast",
|
||||
],
|
||||
ids=["cp_broadcast", "tp_cp_broadcast"])
|
||||
def test_cp_tp_broadcast_object(setup_ray_cluster, test_object,
|
||||
broadcast_method):
|
||||
"""Test TorchDist.cp_broadcast and tp_cp_broadcast with non-tensor objects.
|
||||
|
||||
This tests both cp_broadcast (for context parallelism only) and tp_cp_broadcast
|
||||
(for combined TP+CP broadcast used in helix parallelism).
|
||||
"""
|
||||
world_size = 2
|
||||
tp_size = 1
|
||||
cp_size = 2 # Enable context parallelism.
|
||||
|
||||
runtime_env = ray.runtime_env.RuntimeEnv()
|
||||
runtime_env["env_vars"] = os.environ.copy()
|
||||
runtime_env["env_vars"].update({
|
||||
"TLLM_DISABLE_MPI": "1",
|
||||
"MASTER_ADDR": "127.0.0.1",
|
||||
})
|
||||
|
||||
remote_tests = []
|
||||
for rank in range(world_size):
|
||||
remote_tests.append(
|
||||
CpBroadcastTest.options(runtime_env=runtime_env).remote(
|
||||
rank, world_size, tp_size, cp_size))
|
||||
|
||||
ray.get([test.__ray_ready__.remote() for test in remote_tests])
|
||||
|
||||
port = ray.get(remote_tests[0].setup_tcp_store.remote())
|
||||
ray.get([test.setup_distributed_env.remote(port) for test in remote_tests])
|
||||
|
||||
# Test broadcasting object from root=0 using the specified method.
|
||||
results = ray.get([
|
||||
getattr(test, broadcast_method).remote(test_object, root=0)
|
||||
for test in remote_tests
|
||||
])
|
||||
for r in results:
|
||||
assert r is True, f"{broadcast_method} from root=0 failed for {type(test_object)}"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user