mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-10 13:03:34 +08:00
115 lines
3.6 KiB
Python
115 lines
3.6 KiB
Python
"""Pytest configuration for disagg tests.
|
|
|
|
Only collects tests in this directory when --disagg parameter is provided.
|
|
Can share options like --disagg-test-list defined in this conftest.py.
|
|
"""
|
|
|
|
import pytest
|
|
from utils.logger import logger
|
|
|
|
|
|
def pytest_addoption(parser):
|
|
"""Add disagg-specific command line options."""
|
|
parser.addoption(
|
|
"--disagg",
|
|
action="store_true",
|
|
default=False,
|
|
help="Enable disaggregated tests collection. Example: pytest --disagg",
|
|
)
|
|
parser.addoption(
|
|
"--disagg-test-list",
|
|
action="store",
|
|
default=None,
|
|
help="Path to a file containing test IDs (one per line) to run. "
|
|
"Example: pytest --disagg --disagg-test-list=testlist/testlist_gb200.txt",
|
|
)
|
|
|
|
|
|
def pytest_collect_directory(path, parent):
|
|
"""Only collect tests in this directory when --disagg parameter is provided.
|
|
|
|
This hook executes earliest in the collection phase to avoid loading unnecessary test files.
|
|
|
|
Args:
|
|
path: Current directory path
|
|
parent: Parent collector
|
|
|
|
Returns:
|
|
True: Skip collection of this directory
|
|
None: Proceed with normal collection
|
|
"""
|
|
disagg_enabled = parent.config.getoption("--disagg", default=False)
|
|
|
|
if not disagg_enabled:
|
|
# No --disagg parameter, skip collection
|
|
return True
|
|
|
|
# With --disagg parameter, proceed with normal collection
|
|
# Can subsequently use --disagg-test-list and other options from main conftest.py for filtering
|
|
return None
|
|
|
|
|
|
def pytest_collection_modifyitems(config, items):
|
|
"""Filter tests based on --disagg-test-list option.
|
|
|
|
Args:
|
|
config: pytest config object
|
|
items: list of collected test items
|
|
"""
|
|
test_list_file = config.getoption("--disagg-test-list")
|
|
|
|
if not test_list_file:
|
|
# No filtering needed if --disagg-test-list is not provided
|
|
return
|
|
|
|
# Read test IDs from file
|
|
try:
|
|
with open(test_list_file, "r", encoding="utf-8") as f:
|
|
# Read non-empty lines and strip whitespace
|
|
wanted_tests = set(
|
|
line.strip() for line in f if line.strip() and not line.strip().startswith("#")
|
|
)
|
|
except FileNotFoundError:
|
|
pytest.exit(f"Error: Test list file not found: {test_list_file}")
|
|
return
|
|
except Exception as e:
|
|
pytest.exit(f"Error reading test list file {test_list_file}: {e}")
|
|
return
|
|
|
|
if not wanted_tests:
|
|
pytest.exit(
|
|
f"Error: Test list file {test_list_file} is empty or contains no valid test IDs"
|
|
)
|
|
return
|
|
|
|
# Filter items based on test list
|
|
selected = []
|
|
deselected = []
|
|
|
|
for item in items:
|
|
# item.nodeid is the full test identifier like:
|
|
# "test_disagg_simple.py::TestDisaggBenchmark::test_benchmark[deepseek-r1-fp4:1k1k:...]"
|
|
if item.nodeid in wanted_tests:
|
|
selected.append(item)
|
|
else:
|
|
deselected.append(item)
|
|
|
|
# Apply the filtering
|
|
if deselected:
|
|
config.hook.pytest_deselected(items=deselected)
|
|
items[:] = selected
|
|
|
|
# Log summary
|
|
logger.info(f"\n{'=' * 70}")
|
|
logger.success("Test List Filter Active")
|
|
logger.info(f"File: {test_list_file}")
|
|
logger.info(f"Requested: {len(wanted_tests)} test(s)")
|
|
logger.info(f"Selected: {len(selected)} test(s)")
|
|
logger.info(f"Deselected: {len(deselected)} test(s)")
|
|
|
|
if len(selected) == 0:
|
|
logger.warning("No tests matched the test list!")
|
|
logger.warning(f"Please check that the test IDs in {test_list_file} are correct.")
|
|
|
|
logger.info(f"{'=' * 70}\n")
|