[https://nvbugs/5777041][fix] fix AutoDeploy ep sharding test (#10460)

Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
Lucas Liebenwein 2026-01-14 21:53:56 -05:00 committed by GitHub
parent 94c7b69048
commit 15b43e8a14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 7 deletions

View File

@ -322,8 +322,6 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[throughput_laten
triton_server/test_triton.py::test_llava_onevision[llava_onevision] SKIP (https://nvbugs/5775205)
triton_server/test_triton.py::test_gpt_ib_lad[gpt-ib-lad] SKIP (https://nvbugs/5775223)
unittest/_torch/modules/test_fused_moe.py::test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu[MoEWeightLoadingMode.FUSED_GATE_UP_PROJ-DefaultMoeRoutingMethod-1] SKIP (https://nvbugs/5775256)
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py::test_ep_shard[3-2] SKIP (https://nvbugs/5777041)
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py::test_ep_shard[8-2] SKIP (https://nvbugs/5777041)
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8] SKIP (https://nvbugs/5769890)
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/5769890)
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-v3-8b-hf] SKIP (https://nvbugs/5769890,https://nvbugs/5748683)

View File

@ -31,10 +31,6 @@ def _run_ep_shard_job(num_experts: int, rank: int, world_size: int) -> None:
).to(device=device, dtype=torch.bfloat16)
x = model.get_input(device=device, dtype=torch.bfloat16)
if world_size > num_experts:
print(f"world_size {world_size} > num_experts {num_experts}, skipping test")
return
def _get_expected_num_params(rank: int, world_size: int, num_p_og: int) -> int:
if world_size <= 1:
return num_p_og
@ -141,9 +137,11 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) ->
run_sharding_pattern_detection_test(detected_transformations, expected_transformations)
@pytest.mark.parametrize("device_count", get_device_counts())
@pytest.mark.parametrize("device_count", get_device_counts([2, 8]))
@pytest.mark.parametrize("num_experts", [3, 8])
def test_ep_shard(device_count: int, num_experts: int):
if device_count > num_experts:
pytest.skip(f"world_size {device_count} > num_experts {num_experts}")
dist_common.spawn_multiprocess_job(
job=partial(_run_ep_shard_job, num_experts),
size=device_count,