mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
test: Add fixture to skip tests based on MPI world size (#5028)
Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
This commit is contained in:
parent
2848e012ae
commit
9b616db13b
@ -365,7 +365,7 @@ class TestLlama3_2_1B(LlmapiAccuracyTestHarness):
|
||||
class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-3.3-70B-Instruct"
|
||||
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.skip_less_mpi_world_size(8)
|
||||
def test_auto_dtype_tp8(self):
|
||||
model_path = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct"
|
||||
with LLM(model_path, tensor_parallel_size=8) as llm:
|
||||
@ -412,7 +412,7 @@ class TestLlama4MaverickInstruct(LlmapiAccuracyTestHarness):
|
||||
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Maverick-17B-128E-Instruct"
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.skip_less_mpi_world_size(8)
|
||||
@parametrize_with_ids("cuda_graph", [False, True])
|
||||
@pytest.mark.parametrize("tp_size,pp_size,ep_size", [(8, 1, 1), (8, 1, 4),
|
||||
(8, 1, 8)],
|
||||
@ -434,7 +434,7 @@ class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
|
||||
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.skip_less_mpi_world_size(8)
|
||||
@parametrize_with_ids("cuda_graph", [False, True])
|
||||
@pytest.mark.parametrize("tp_size,pp_size,ep_size", [(8, 1, 1), (8, 1, 4),
|
||||
(8, 1, 8)],
|
||||
@ -993,7 +993,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "deepseek-ai/DeepSeek-R1"
|
||||
MODEL_PATH = f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1"
|
||||
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.skip_less_mpi_world_size(8)
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size,moe_backend",
|
||||
@ -1048,7 +1048,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm,
|
||||
extra_evaluator_kwargs=dict(apply_chat_template=True))
|
||||
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.skip_less_mpi_world_size(8)
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,max_batch_size",
|
||||
|
||||
@ -37,6 +37,7 @@ import yaml
|
||||
from _pytest.mark import ParameterSet
|
||||
|
||||
from tensorrt_llm.bindings import ipc_nvls_supported
|
||||
from tensorrt_llm.llmapi.mpi_session import get_mpi_world_size
|
||||
|
||||
from .perf.gpu_clock_lock import GPUClockLock
|
||||
from .perf.session_data_writer import SessionDataWriter
|
||||
@ -1811,6 +1812,19 @@ def skip_by_device_count(request):
|
||||
f'Device count {device_count} is less than {expected_count}')
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_by_mpi_world_size(request):
|
||||
"fixture for skip less device count"
|
||||
if request.node.get_closest_marker('skip_less_mpi_world_size'):
|
||||
mpi_world_size = get_mpi_world_size()
|
||||
expected_count = request.node.get_closest_marker(
|
||||
'skip_less_mpi_world_size').args[0]
|
||||
if expected_count > int(mpi_world_size):
|
||||
pytest.skip(
|
||||
f'MPI world size {mpi_world_size} is less than {expected_count}'
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_by_device_memory(request):
|
||||
"fixture for skip less device memory"
|
||||
|
||||
@ -7,6 +7,7 @@ addopts = --ignore-glob="*perf/test_perf.py" --ignore-glob="*test_list_validati
|
||||
norecursedirs = ./triton/perf
|
||||
markers =
|
||||
skip_less_device: skip when less device detected than the declared
|
||||
skip_less_mpi_world_size: skip when less mpi world size detected than the declared
|
||||
skip_less_device_memory: skip when less device memory detected than the requested
|
||||
skip_less_host_memory: skip when less host memory detected than the requested
|
||||
support_fp8: skip when fp8 is not supported on the device
|
||||
|
||||
Loading…
Reference in New Issue
Block a user