test: Add fixture to skip tests based on MPI world size (#5028)

Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
This commit is contained in:
Yi Zhang 2025-06-16 11:25:01 +08:00 committed by GitHub
parent 2848e012ae
commit 9b616db13b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 5 deletions

View File

@ -365,7 +365,7 @@ class TestLlama3_2_1B(LlmapiAccuracyTestHarness):
class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.3-70B-Instruct"
@pytest.mark.skip_less_device(8)
@pytest.mark.skip_less_mpi_world_size(8)
def test_auto_dtype_tp8(self):
model_path = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct"
with LLM(model_path, tensor_parallel_size=8) as llm:
@ -412,7 +412,7 @@ class TestLlama4MaverickInstruct(LlmapiAccuracyTestHarness):
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Maverick-17B-128E-Instruct"
@skip_pre_blackwell
@pytest.mark.skip_less_device(8)
@pytest.mark.skip_less_mpi_world_size(8)
@parametrize_with_ids("cuda_graph", [False, True])
@pytest.mark.parametrize("tp_size,pp_size,ep_size", [(8, 1, 1), (8, 1, 4),
(8, 1, 8)],
@ -434,7 +434,7 @@ class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"
@skip_pre_hopper
@pytest.mark.skip_less_device(8)
@pytest.mark.skip_less_mpi_world_size(8)
@parametrize_with_ids("cuda_graph", [False, True])
@pytest.mark.parametrize("tp_size,pp_size,ep_size", [(8, 1, 1), (8, 1, 4),
(8, 1, 8)],
@ -993,7 +993,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
MODEL_NAME = "deepseek-ai/DeepSeek-R1"
MODEL_PATH = f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1"
@pytest.mark.skip_less_device(8)
@pytest.mark.skip_less_mpi_world_size(8)
@skip_pre_blackwell
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend",
@ -1048,7 +1048,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
task.evaluate(llm,
extra_evaluator_kwargs=dict(apply_chat_template=True))
@pytest.mark.skip_less_device(8)
@pytest.mark.skip_less_mpi_world_size(8)
@skip_pre_hopper
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size",

View File

@ -37,6 +37,7 @@ import yaml
from _pytest.mark import ParameterSet
from tensorrt_llm.bindings import ipc_nvls_supported
from tensorrt_llm.llmapi.mpi_session import get_mpi_world_size
from .perf.gpu_clock_lock import GPUClockLock
from .perf.session_data_writer import SessionDataWriter
@ -1811,6 +1812,19 @@ def skip_by_device_count(request):
f'Device count {device_count} is less than {expected_count}')
@pytest.fixture(autouse=True)
def skip_by_mpi_world_size(request):
"fixture for skip less device count"
if request.node.get_closest_marker('skip_less_mpi_world_size'):
mpi_world_size = get_mpi_world_size()
expected_count = request.node.get_closest_marker(
'skip_less_mpi_world_size').args[0]
if expected_count > int(mpi_world_size):
pytest.skip(
f'MPI world size {mpi_world_size} is less than {expected_count}'
)
@pytest.fixture(autouse=True)
def skip_by_device_memory(request):
"fixture for skip less device memory"

View File

@ -7,6 +7,7 @@ addopts = --ignore-glob="*perf/test_perf.py" --ignore-glob="*test_list_validati
norecursedirs = ./triton/perf
markers =
skip_less_device: skip when less device detected than the declared
skip_less_mpi_world_size: skip when less mpi world size detected than the declared
skip_less_device_memory: skip when less device memory detected than the requested
skip_less_host_memory: skip when less host memory detected than the requested
support_fp8: skip when fp8 is not supported on the device