From 7cf1209a19164ca50606a163c468f933fb35d58c Mon Sep 17 00:00:00 2001 From: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> Date: Tue, 1 Jul 2025 09:39:49 +0800 Subject: [PATCH] [fix]: Fix main test skip issue (#5503) Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> Co-authored-by: Yanchao Lu --- .../defs/accuracy/test_llm_api_pytorch.py | 10 +++++----- tests/integration/defs/conftest.py | 14 ++++++++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 41c0014ec9..c33aabcddc 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1240,7 +1240,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness): True, 32, "CUTLASS", - marks=pytest.mark.skip_less_device(8)), + marks=pytest.mark.skip_less_mpi_world_size(8)), pytest.param(8, 1, 4, @@ -1251,7 +1251,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness): True, 32, "TRTLLM", - marks=pytest.mark.skip_less_device(8)), + marks=pytest.mark.skip_less_mpi_world_size(8)), pytest.param(8, 1, 8, @@ -1262,7 +1262,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness): True, 32, "CUTLASS", - marks=pytest.mark.skip_less_device(8)), + marks=pytest.mark.skip_less_mpi_world_size(8)), pytest.param(8, 1, 1, @@ -1273,7 +1273,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness): True, 32, "CUTLASS", - marks=pytest.mark.skip_less_device(8)), + marks=pytest.mark.skip_less_mpi_world_size(8)), pytest.param(4, 1, 1, @@ -1284,7 +1284,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness): True, 16, "CUTLASS", - marks=pytest.mark.skip_less_device(4)), + marks=pytest.mark.skip_less_mpi_world_size(4)), ], ids=[ "latency", "latency_trtllmgen", "throughput", "throughput_tp8", diff --git a/tests/integration/defs/conftest.py b/tests/integration/defs/conftest.py index a5d069f37b..a1300f5798 100644 --- a/tests/integration/defs/conftest.py +++ b/tests/integration/defs/conftest.py @@ -1825,15 +1825,21 @@ def skip_by_device_count(request): @pytest.fixture(autouse=True) def skip_by_mpi_world_size(request): - "fixture for skip less device count" + "fixture for skip less mpi world size" if request.node.get_closest_marker('skip_less_mpi_world_size'): mpi_world_size = get_mpi_world_size() + device_count = get_device_count() + if mpi_world_size == 1: + # For mpi_world_size == 1 case, we only need to check device count since we can spawn mpi workers in the test itself + total_count = device_count + else: + # Otherwise, we follow the mpi world size setting + total_count = mpi_world_size expected_count = request.node.get_closest_marker( 'skip_less_mpi_world_size').args[0] - if expected_count > int(mpi_world_size): + if expected_count > int(total_count): pytest.skip( - f'MPI world size {mpi_world_size} is less than {expected_count}' - ) + f'Total world size {total_count} is less than {expected_count}') @pytest.fixture(autouse=True)