mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[fix]: Fix main test skip issue (#5503)
Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> Co-authored-by: Yanchao Lu <yanchaol@nvidia.com>
This commit is contained in:
parent
6ee94c7ac8
commit
7cf1209a19
@ -1240,7 +1240,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
True,
|
||||
32,
|
||||
"CUTLASS",
|
||||
marks=pytest.mark.skip_less_device(8)),
|
||||
marks=pytest.mark.skip_less_mpi_world_size(8)),
|
||||
pytest.param(8,
|
||||
1,
|
||||
4,
|
||||
@ -1251,7 +1251,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
True,
|
||||
32,
|
||||
"TRTLLM",
|
||||
marks=pytest.mark.skip_less_device(8)),
|
||||
marks=pytest.mark.skip_less_mpi_world_size(8)),
|
||||
pytest.param(8,
|
||||
1,
|
||||
8,
|
||||
@ -1262,7 +1262,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
True,
|
||||
32,
|
||||
"CUTLASS",
|
||||
marks=pytest.mark.skip_less_device(8)),
|
||||
marks=pytest.mark.skip_less_mpi_world_size(8)),
|
||||
pytest.param(8,
|
||||
1,
|
||||
1,
|
||||
@ -1273,7 +1273,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
True,
|
||||
32,
|
||||
"CUTLASS",
|
||||
marks=pytest.mark.skip_less_device(8)),
|
||||
marks=pytest.mark.skip_less_mpi_world_size(8)),
|
||||
pytest.param(4,
|
||||
1,
|
||||
1,
|
||||
@ -1284,7 +1284,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
True,
|
||||
16,
|
||||
"CUTLASS",
|
||||
marks=pytest.mark.skip_less_device(4)),
|
||||
marks=pytest.mark.skip_less_mpi_world_size(4)),
|
||||
],
|
||||
ids=[
|
||||
"latency", "latency_trtllmgen", "throughput", "throughput_tp8",
|
||||
|
||||
@ -1825,15 +1825,21 @@ def skip_by_device_count(request):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_by_mpi_world_size(request):
|
||||
"fixture for skip less device count"
|
||||
"fixture for skip less mpi world size"
|
||||
if request.node.get_closest_marker('skip_less_mpi_world_size'):
|
||||
mpi_world_size = get_mpi_world_size()
|
||||
device_count = get_device_count()
|
||||
if mpi_world_size == 1:
|
||||
# For mpi_world_size == 1 case, we only need to check device count since we can spawn mpi workers in the test itself
|
||||
total_count = device_count
|
||||
else:
|
||||
# Otherwise, we follow the mpi world size setting
|
||||
total_count = mpi_world_size
|
||||
expected_count = request.node.get_closest_marker(
|
||||
'skip_less_mpi_world_size').args[0]
|
||||
if expected_count > int(mpi_world_size):
|
||||
if expected_count > int(total_count):
|
||||
pytest.skip(
|
||||
f'MPI world size {mpi_world_size} is less than {expected_count}'
|
||||
)
|
||||
f'Total world size {total_count} is less than {expected_count}')
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user