[fix]: Fix main test skip issue (#5503)

Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
Co-authored-by: Yanchao Lu <yanchaol@nvidia.com>
This commit is contained in:
Yi Zhang 2025-07-01 09:39:49 +08:00 committed by GitHub
parent 6ee94c7ac8
commit 7cf1209a19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 9 deletions

View File

@ -1240,7 +1240,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
True,
32,
"CUTLASS",
marks=pytest.mark.skip_less_device(8)),
marks=pytest.mark.skip_less_mpi_world_size(8)),
pytest.param(8,
1,
4,
@ -1251,7 +1251,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
True,
32,
"TRTLLM",
marks=pytest.mark.skip_less_device(8)),
marks=pytest.mark.skip_less_mpi_world_size(8)),
pytest.param(8,
1,
8,
@ -1262,7 +1262,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
True,
32,
"CUTLASS",
marks=pytest.mark.skip_less_device(8)),
marks=pytest.mark.skip_less_mpi_world_size(8)),
pytest.param(8,
1,
1,
@ -1273,7 +1273,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
True,
32,
"CUTLASS",
marks=pytest.mark.skip_less_device(8)),
marks=pytest.mark.skip_less_mpi_world_size(8)),
pytest.param(4,
1,
1,
@ -1284,7 +1284,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
True,
16,
"CUTLASS",
marks=pytest.mark.skip_less_device(4)),
marks=pytest.mark.skip_less_mpi_world_size(4)),
],
ids=[
"latency", "latency_trtllmgen", "throughput", "throughput_tp8",

View File

@ -1825,15 +1825,21 @@ def skip_by_device_count(request):
@pytest.fixture(autouse=True)
def skip_by_mpi_world_size(request):
"fixture for skip less device count"
"fixture for skip less mpi world size"
if request.node.get_closest_marker('skip_less_mpi_world_size'):
mpi_world_size = get_mpi_world_size()
device_count = get_device_count()
if mpi_world_size == 1:
# For mpi_world_size == 1 case, we only need to check device count since we can spawn mpi workers in the test itself
total_count = device_count
else:
# Otherwise, we follow the mpi world size setting
total_count = mpi_world_size
expected_count = request.node.get_closest_marker(
'skip_less_mpi_world_size').args[0]
if expected_count > int(mpi_world_size):
if expected_count > int(total_count):
pytest.skip(
f'MPI world size {mpi_world_size} is less than {expected_count}'
)
f'Total world size {total_count} is less than {expected_count}')
@pytest.fixture(autouse=True)