[TRTLLM-9467][fix] Fix PP+CP combination with helix parallelism (#10312)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
Balaram Buddharaju 2026-01-01 10:42:53 -08:00 committed by GitHub
parent 5845951538
commit 4a1b742aa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 532 additions and 26 deletions

View File

@ -1506,7 +1506,7 @@ class AutoTuner:
"""Broadcast tactics from root rank to all other ranks."""
cache_data = self.profiling_cache.get_specific_custom_op(custom_op)
root = 0
cache_data = self._dist.tp_broadcast(obj=cache_data, root=root)
cache_data = self._dist.tp_cp_broadcast(obj=cache_data, root=root)
self.profiling_cache.merge_cache_data(cache_data)

View File

@ -116,6 +116,26 @@ class Distributed(ABC):
def allgather(self, obj, root=0):
pass
@abstractmethod
def tp_broadcast(self, obj, root=0, **kwargs):
pass
@abstractmethod
def cp_broadcast(self, obj, root=0, **kwargs):
pass
def tp_cp_broadcast(self, obj, root=0, **kwargs):
"""Broadcast object across both TP and CP groups.
This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
First broadcasts within the TP group, then within the CP group.
"""
if self.tp_size > 1:
obj = self.tp_broadcast(obj, root=root, **kwargs)
if self.cp_size > 1:
obj = self.cp_broadcast(obj, root=root, **kwargs)
return obj
def safe_broadcast(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
"""
@ -407,6 +427,14 @@ class MPIDist(Distributed):
def cp_allgather(self, obj):
return self.cp_comm.allgather(obj)
def cp_broadcast(self,
obj,
root=0,
chunk_size: int = 4 * 1024 * 1024,
**kwargs):
comm = self.cp_comm
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
def tp_allgather(self, obj):
return self.tp_comm.allgather(obj)
@ -414,7 +442,11 @@ class MPIDist(Distributed):
comm = self.tp_comm
return safe_gather(comm, obj, root=root, chunk_size=chunk_size)
def tp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
def tp_broadcast(self,
obj,
root=0,
chunk_size: int = 4 * 1024 * 1024,
**kwargs):
comm = self.tp_comm
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
@ -699,7 +731,7 @@ class TorchDist(Distributed):
return output_list
@log_op
def tp_broadcast(self, obj, root=0):
def tp_broadcast(self, obj, root=0, **kwargs):
if isinstance(obj, torch.Tensor):
dist.broadcast(obj, src=root, group=self.mapping.tp_group_pg)
return obj
@ -712,6 +744,20 @@ class TorchDist(Distributed):
device=torch.device("cpu"))
return ret[0]
@log_op
def cp_broadcast(self, obj, root=0, **kwargs):
if isinstance(obj, torch.Tensor):
dist.broadcast(obj, src=root, group=self.mapping.cp_group_pg)
return obj
else:
ret = [obj]
torch.distributed.broadcast_object_list(
ret,
src=root,
group=self.mapping.cp_group_pg,
device=torch.device("cpu"))
return ret[0]
@log_op
def pp_allgather(self, obj):
if isinstance(obj, torch.Tensor):

View File

@ -587,9 +587,10 @@ class ExecutorRequestQueue:
if not self.dist.has_pp:
return self.dist.broadcast(payloads, root=0)
# Broadcast within first tp group before send/recv chain to other tp groups
if self.dist.tp_size > 1 and self.dist.is_first_pp_rank:
payloads = self.dist.tp_broadcast(payloads, root=0)
# Broadcast within first PP stage before send/recv chain to other PP stages.
# This needs to cover both TP and CP ranks within the first PP stage.
if self.dist.is_first_pp_rank:
payloads = self.dist.tp_cp_broadcast(payloads, root=0)
# Tag for communication
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts

View File

@ -871,7 +871,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device(4)
@pytest.mark.skip_less_device(8)
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 2, 2), (2, 1, 2)],
ids=["pp1tp2cp2", "pp2tp1cp2"])
@pytest.mark.parametrize("cuda_graph_config", [
None,
{
@ -888,8 +890,10 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
"cudagraph:with_padding"
])
@pytest.mark.parametrize("comms_medium", ["fifo", "nccl"])
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
gen_pp, gen_tp, gen_cp):
use_nccl_for_alltoall = comms_medium == "nccl"
gen_ep = gen_tp * gen_cp
kv_cache_config = {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False,
@ -898,20 +902,22 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
}
ctx_server_config = {
"pipeline_parallel_size": 1,
"tensor_parallel_size": 2,
"tensor_parallel_size": 4,
"context_parallel_size": 1,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"enable_chunked_prefill": False,
"cuda_graph_config": None,
"cache_transceiver_config": {
"backend": "UCX"
"backend": "UCX",
"max_tokens_in_buffer": 8192,
},
}
gen_server_config = {
"tensor_parallel_size": 1,
"pipeline_parallel_size": 1,
"context_parallel_size": 2,
"tensor_parallel_size": gen_tp,
"pipeline_parallel_size": gen_pp,
"context_parallel_size": gen_cp,
"moe_expert_parallel_size": gen_ep,
"cp_config": {
"cp_type": "HELIX",
"tokens_per_block": 32,
@ -922,7 +928,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
"enable_chunked_prefill": False,
"cuda_graph_config": cuda_graph_config,
"cache_transceiver_config": {
"backend": "UCX"
"backend": "UCX",
"max_tokens_in_buffer": 8192,
},
}
disaggregated_server_config = {

View File

@ -540,14 +540,14 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False]
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency]

View File

@ -66,6 +66,7 @@ l0_dgx_b200:
backend: pytorch
orchestrator: mpi
tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60)
@ -92,6 +93,7 @@ l0_dgx_b200:
backend: pytorch
orchestrator: mpi
tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (60)

View File

@ -72,8 +72,6 @@ l0_gb200_multi_gpus:
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90)
- condition:
ranges:
@ -89,10 +87,6 @@ l0_gb200_multi_gpus:
stage: post_merge
backend: pytorch
tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]

View File

@ -523,6 +523,8 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_sof
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] SKIP (https://nvbugs/5774869)
triton_server/test_triton.py::test_llava_onevision[llava_onevision] SKIP (https://nvbugs/5775205)
triton_server/test_triton.py::test_gpt_ib_lad[gpt-ib-lad] SKIP (https://nvbugs/5775223)
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=False] SKIP (https://nvbugs/5777044)
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4[torch_compile=True] SKIP (https://nvbugs/5777044)
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=True] SKIP (https://nvbugs/5740377)
unittest/_torch/modules/test_fused_moe.py::test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu[MoEWeightLoadingMode.FUSED_GATE_UP_PROJ-DefaultMoeRoutingMethod-1] SKIP (https://nvbugs/5775256)
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py::test_ep_shard[3-2] SKIP (https://nvbugs/5777041)

View File

@ -0,0 +1,265 @@
"""
Tests for cp_broadcast functionality in both MPIDist and TorchDist.
This module tests the context parallelism broadcast operation which is used
when CP (context parallelism) is enabled (e.g., in Helix parallelism).
For MPIDist tests, run with mpirun:
mpirun -n 2 python -m pytest tests/unittest/_torch/distributed/test_cp_broadcast.py -v
For TorchDist tests, see test_ops.py which uses Ray for distributed testing.
"""
import numpy as np
import pytest
from tensorrt_llm._torch.distributed import MPIDist
from tensorrt_llm.mapping import Mapping
def get_mpi_info():
"""Get MPI rank and world size, returns (0, 1) if MPI is not available."""
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
return comm.Get_rank(), comm.Get_size()
except ImportError:
return 0, 1
def skip_if_not_mpi():
"""Skip test if not running under MPI with sufficient ranks."""
rank, world_size = get_mpi_info()
if world_size < 2:
pytest.skip("Test requires at least 2 MPI ranks (run with mpirun -n 2)")
class TestMPIDistCpBroadcast:
"""Tests for MPIDist.cp_broadcast functionality."""
@pytest.fixture(autouse=True)
def setup(self):
"""Set up MPI environment and mapping for each test."""
skip_if_not_mpi()
self.rank, self.world_size = get_mpi_info()
# Set up mapping with CP enabled (cp_size = world_size, tp_size = 1)
self.mapping = Mapping(
world_size=self.world_size,
rank=self.rank,
tp_size=1,
cp_size=self.world_size,
pp_size=1,
)
self.dist = MPIDist(mapping=self.mapping)
def test_broadcast_numpy_array(self):
"""Test broadcasting a numpy array via cp_broadcast."""
root = 0
shape = (64, 128)
if self.mapping.cp_rank == root:
# Root rank creates the data to broadcast
data = np.random.randn(*shape).astype(np.float32)
else:
# Non-root ranks have empty/zero data
data = np.zeros(shape, dtype=np.float32)
# Store original data from root for verification
from mpi4py import MPI
expected = np.zeros(shape, dtype=np.float32)
MPI.COMM_WORLD.Bcast(data if self.mapping.cp_rank == root else expected, root=root)
if self.mapping.cp_rank == root:
expected = data.copy()
# Perform cp_broadcast
result = self.dist.cp_broadcast(data, root=root)
# Verify all ranks have the same data
np.testing.assert_array_almost_equal(result, expected)
def test_broadcast_python_dict(self):
"""Test broadcasting a Python dictionary via cp_broadcast."""
root = 0
if self.mapping.cp_rank == root:
obj = {
"model_name": "llama",
"batch_size": 32,
"tokens": [1, 2, 3, 4, 5],
"config": {"hidden_size": 4096, "num_layers": 32},
}
else:
obj = None
result = self.dist.cp_broadcast(obj, root=root)
# Verify all ranks received the correct object
assert result["model_name"] == "llama"
assert result["batch_size"] == 32
assert result["tokens"] == [1, 2, 3, 4, 5]
assert result["config"]["hidden_size"] == 4096
assert result["config"]["num_layers"] == 32
def test_broadcast_python_list(self):
"""Test broadcasting a Python list via cp_broadcast."""
root = 0
if self.mapping.cp_rank == root:
obj = ["request1", "request2", {"id": 123, "data": [1, 2, 3]}]
else:
obj = None
result = self.dist.cp_broadcast(obj, root=root)
assert result == ["request1", "request2", {"id": 123, "data": [1, 2, 3]}]
def test_broadcast_from_non_zero_root(self):
"""Test broadcasting from a non-zero root rank."""
if self.world_size < 2:
pytest.skip("Need at least 2 ranks to test non-zero root")
root = 1 # Broadcast from rank 1
if self.mapping.cp_rank == root:
obj = {"source": "rank1", "value": 42}
else:
obj = None
result = self.dist.cp_broadcast(obj, root=root)
assert result["source"] == "rank1"
assert result["value"] == 42
def test_broadcast_large_object(self):
"""Test broadcasting a large object that may require chunking."""
root = 0
# Create a large list to test chunking behavior
large_size = 100000
if self.mapping.cp_rank == root:
obj = list(range(large_size))
else:
obj = None
result = self.dist.cp_broadcast(obj, root=root)
assert len(result) == large_size
assert result[0] == 0
assert result[-1] == large_size - 1
def test_broadcast_string(self):
"""Test broadcasting a simple string via cp_broadcast."""
root = 0
if self.mapping.cp_rank == root:
obj = "Hello from root rank!"
else:
obj = None
result = self.dist.cp_broadcast(obj, root=root)
assert result == "Hello from root rank!"
# Additional integration-style test that can be run standalone
def test_mpi_cp_broadcast_integration():
"""
Integration test for MPIDist cp_broadcast.
"""
rank, world_size = get_mpi_info()
if world_size < 2:
pytest.skip("Test requires at least 2 MPI ranks")
# Create mapping with CP enabled
mapping = Mapping(
world_size=world_size,
rank=rank,
tp_size=1,
cp_size=world_size,
pp_size=1,
)
dist = MPIDist(mapping=mapping)
# Test 1: Broadcast dict
if mapping.cp_rank == 0:
payload = {"requests": [{"id": i} for i in range(10)]}
else:
payload = None
result = dist.cp_broadcast(payload, root=0)
assert len(result["requests"]) == 10
assert result["requests"][0]["id"] == 0
# Test 2: Broadcast numpy array
shape = (32, 64)
if mapping.cp_rank == 0:
arr = np.ones(shape, dtype=np.float32) * (rank + 1)
else:
arr = np.zeros(shape, dtype=np.float32)
result = dist.cp_broadcast(arr, root=0)
expected_val = 1.0 # From rank 0
np.testing.assert_array_almost_equal(result, np.ones(shape) * expected_val)
if __name__ == "__main__":
# Allow running directly with mpirun
pytest.main([__file__, "-v"])
class TestMPIDistTpCpBroadcast:
"""Tests for MPIDist.tp_cp_broadcast functionality."""
@pytest.fixture(autouse=True)
def setup(self):
"""Set up MPI environment and mapping for each test."""
skip_if_not_mpi()
self.rank, self.world_size = get_mpi_info()
# Set up mapping with both TP and CP enabled
# For 2 ranks: tp_size=1, cp_size=2 (tp_cp_broadcast will only do cp_broadcast)
self.mapping = Mapping(
world_size=self.world_size,
rank=self.rank,
tp_size=1,
cp_size=self.world_size,
pp_size=1,
)
self.dist = MPIDist(mapping=self.mapping)
def test_tp_cp_broadcast_python_dict(self):
"""Test broadcasting a Python dictionary via tp_cp_broadcast."""
root = 0
# Only rank 0 in both TP and CP groups should have the object
if self.mapping.tp_rank == root and self.mapping.cp_rank == root:
obj = {
"model_name": "llama",
"batch_size": 32,
"tokens": [1, 2, 3, 4, 5],
}
else:
obj = None
result = self.dist.tp_cp_broadcast(obj, root=root)
# Verify all ranks received the correct object
assert result["model_name"] == "llama"
assert result["batch_size"] == 32
assert result["tokens"] == [1, 2, 3, 4, 5]
def test_tp_cp_broadcast_python_list(self):
"""Test broadcasting a Python list via tp_cp_broadcast."""
root = 0
if self.mapping.tp_rank == root and self.mapping.cp_rank == root:
obj = ["request1", "request2", {"id": 123, "data": [1, 2, 3]}]
else:
obj = None
result = self.dist.tp_cp_broadcast(obj, root=root)
assert result == ["request1", "request2", {"id": 123, "data": [1, 2, 3]}]

View File

@ -258,3 +258,192 @@ def test_allreduce_pg_op(setup_ray_cluster, seq_len, hidden_size):
])
for r in results:
assert r is True
@ray.remote(num_gpus=1)
class CpBroadcastTest:
"""Test worker for cp_broadcast operations with context parallelism."""
def __init__(self, rank, world_size, tp_size, cp_size):
self.rank = rank
self.world_size = world_size
self.tp_size = tp_size
self.cp_size = cp_size
self.master_address = os.environ["MASTER_ADDR"]
assert len(ray.get_gpu_ids()) == 1
self.gpu = int(ray.get_gpu_ids()[0])
from tensorrt_llm.executor.ray_gpu_worker import RayWorkerWrapper
local_gpu = RayWorkerWrapper.physical_to_local_id(self.gpu)
torch.cuda.set_device(local_gpu)
def _create_tcp_store(self,
port: Optional[int] = None
) -> torch.distributed.TCPStore:
actual_port = port if port is not None else 0
return torch.distributed.TCPStore(host_name=self.master_address,
port=actual_port,
world_size=self.world_size,
is_master=(self.rank == 0),
wait_for_workers=False)
def setup_tcp_store(self):
if self.rank != 0:
raise RuntimeError("Only the master worker can setup TCP store")
self.store = self._create_tcp_store()
return self.store.port
def setup_distributed_env(self, port: int):
if self.rank != 0:
self.store = self._create_tcp_store(port)
torch.distributed.init_process_group(backend="cuda:nccl,cpu:gloo",
store=self.store,
world_size=self.world_size,
rank=self.rank)
self.mapping = Mapping(world_size=self.world_size,
gpus_per_node=self.world_size,
tp_size=self.tp_size,
cp_size=self.cp_size,
rank=self.rank)
self.dist = TorchDist(self.mapping)
def run_tensor_broadcast(self, root_tensor: torch.Tensor, root: int = 0):
"""Test broadcasting a tensor via cp_broadcast."""
cp_rank = self.mapping.cp_rank
if cp_rank == root:
# Root rank has the tensor to broadcast.
tensor = root_tensor.cuda()
else:
# Non-root ranks start with zeros.
tensor = torch.zeros_like(root_tensor).cuda()
result = self.dist.cp_broadcast(tensor, root=root)
# After broadcast, all CP ranks should have the same tensor.
expected = root_tensor.cuda()
return torch.allclose(result, expected)
def run_object_broadcast(self, root_obj, root: int = 0):
"""Test broadcasting a non-tensor object via cp_broadcast."""
cp_rank = self.mapping.cp_rank
if cp_rank == root:
obj = root_obj
else:
obj = None
result = self.dist.cp_broadcast(obj, root=root)
# After broadcast, all CP ranks should have the same object.
return result == root_obj
def run_tp_cp_broadcast(self, root_obj, root: int = 0):
"""Test broadcasting an object via tp_cp_broadcast."""
# For tp_cp_broadcast, only rank 0 in both TP and CP should have the object.
tp_rank = self.mapping.tp_rank
cp_rank = self.mapping.cp_rank
if tp_rank == root and cp_rank == root:
obj = root_obj
else:
obj = None
result = self.dist.tp_cp_broadcast(obj, root=root)
# After broadcast, all TP and CP ranks should have the same object.
return result == root_obj
@pytest.mark.gpu2
@pytest.mark.parametrize("hidden_size", [128, 512], ids=lambda x: f"hidden:{x}")
@pytest.mark.parametrize("seq_len", [16, 32], ids=lambda x: f"seqlen:{x}")
def test_cp_broadcast_tensor(setup_ray_cluster, seq_len, hidden_size):
"""Test TorchDist.cp_broadcast with tensor data."""
torch.manual_seed(42)
dtype = torch.bfloat16
world_size = 2
tp_size = 1
cp_size = 2 # Enable context parallelism.
# Create tensor to broadcast from root.
root_tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
runtime_env = ray.runtime_env.RuntimeEnv()
runtime_env["env_vars"] = os.environ.copy()
runtime_env["env_vars"].update({
"TLLM_DISABLE_MPI": "1",
"MASTER_ADDR": "127.0.0.1",
})
remote_tests = []
for rank in range(world_size):
remote_tests.append(
CpBroadcastTest.options(runtime_env=runtime_env).remote(
rank, world_size, tp_size, cp_size))
ray.get([test.__ray_ready__.remote() for test in remote_tests])
port = ray.get(remote_tests[0].setup_tcp_store.remote())
ray.get([test.setup_distributed_env.remote(port) for test in remote_tests])
# Test broadcasting from root=0.
results = ray.get([
test.run_tensor_broadcast.remote(root_tensor, root=0)
for test in remote_tests
])
for r in results:
assert r is True, "Tensor broadcast from root=0 failed"
@pytest.mark.gpu2
@pytest.mark.parametrize("test_object", [
{
"key1": "value1",
"key2": [1, 2, 3]
},
["item1", "item2", {
"nested": True
}],
"simple_string",
],
ids=["dict", "list", "string"])
@pytest.mark.parametrize("broadcast_method", [
"run_object_broadcast",
"run_tp_cp_broadcast",
],
ids=["cp_broadcast", "tp_cp_broadcast"])
def test_cp_tp_broadcast_object(setup_ray_cluster, test_object,
broadcast_method):
"""Test TorchDist.cp_broadcast and tp_cp_broadcast with non-tensor objects.
This tests both cp_broadcast (for context parallelism only) and tp_cp_broadcast
(for combined TP+CP broadcast used in helix parallelism).
"""
world_size = 2
tp_size = 1
cp_size = 2 # Enable context parallelism.
runtime_env = ray.runtime_env.RuntimeEnv()
runtime_env["env_vars"] = os.environ.copy()
runtime_env["env_vars"].update({
"TLLM_DISABLE_MPI": "1",
"MASTER_ADDR": "127.0.0.1",
})
remote_tests = []
for rank in range(world_size):
remote_tests.append(
CpBroadcastTest.options(runtime_env=runtime_env).remote(
rank, world_size, tp_size, cp_size))
ray.get([test.__ray_ready__.remote() for test in remote_tests])
port = ray.get(remote_tests[0].setup_tcp_store.remote())
ray.get([test.setup_distributed_env.remote(port) for test in remote_tests])
# Test broadcasting object from root=0 using the specified method.
results = ray.get([
getattr(test, broadcast_method).remote(test_object, root=0)
for test in remote_tests
])
for r in results:
assert r is True, f"{broadcast_method} from root=0 failed for {type(test_object)}"