mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
723 lines
27 KiB
Python
723 lines
27 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Logic for parsing test lists."""
|
|
|
|
import array
|
|
import json
|
|
import os
|
|
import re
|
|
import time
|
|
import traceback
|
|
from collections import defaultdict
|
|
from textwrap import dedent
|
|
from typing import Any, Dict, List, Optional, Sequence
|
|
|
|
import pytest
|
|
from mako.template import Template
|
|
|
|
from .trt_test_alternative import print_info, print_warning
|
|
|
|
# from misc.reorder_venv_tests import reorder_tests
|
|
|
|
kVALID_TEST_LIST_MARKERS = ["XFAIL", "SKIP", "UNSTABLE", "TIMEOUT"]
|
|
record_invalid_tests = True
|
|
|
|
kSTRIP_PARENS_PAT = re.compile(r'\((.*?)\)')
|
|
|
|
|
|
def strip_parens(s):
|
|
"""Strips the outer parens from the given string or returns None if not
|
|
found. Does not check for balanced parentheses."""
|
|
m = kSTRIP_PARENS_PAT.match(s)
|
|
return m.group(1) if m else None
|
|
|
|
|
|
def preprocess_test_list_lines(test_list, lines, mako_opts={}):
|
|
"""Apply mako (https://www.makotemplates.org/) preprocessing if requested to the test list before parsing it.
|
|
|
|
Args:
|
|
test_list: Name of the test list, used for debugging
|
|
lines: The lines of the raw test list file
|
|
mako_opts: JSON object containing key-value pairs to forward to Mako
|
|
|
|
Returns:
|
|
The processed lines if successful. Raises a RuntimeError on failure.
|
|
"""
|
|
mako_tmpl_text = "".join(lines)
|
|
|
|
try:
|
|
#assert isinstance(mako_opts, dict)
|
|
|
|
def applyMarkerIf(marker, cond, reason):
|
|
if cond:
|
|
return "{} ({})".format(marker, reason)
|
|
else:
|
|
return ""
|
|
|
|
def applyMarkerIfPrefix(marker, prefix, reason):
|
|
return applyMarkerIf(marker, prefix == mako_opts["test_prefix"],
|
|
reason)
|
|
|
|
mako_opts["skipIf"] = lambda c, r: applyMarkerIf("SKIP", c, r)
|
|
mako_opts["xfailIf"] = lambda c, r: applyMarkerIf("XFAIL", c, r)
|
|
mako_opts["unstableIf"] = lambda c, r: applyMarkerIf("UNSTABLE", c, r)
|
|
|
|
mako_opts["skipIfPrefix"] = lambda pfx, r: applyMarkerIfPrefix(
|
|
"SKIP", pfx, r)
|
|
mako_opts["xfailIfPrefix"] = lambda pfx, r: applyMarkerIfPrefix(
|
|
"XFAIL", pfx, r)
|
|
mako_opts["unstableIfPrefix"] = lambda pfx, r: applyMarkerIfPrefix(
|
|
"UNSTABLE", pfx, r)
|
|
|
|
template = Template(mako_tmpl_text)
|
|
new_text = template.render(**mako_opts)
|
|
if isinstance(new_text, bytes):
|
|
new_text = new_text.decode()
|
|
lines = new_text.splitlines()
|
|
|
|
# Strip extra whitespace characters from test names before returning them
|
|
lines = [line.strip() for line in lines]
|
|
except Exception:
|
|
raise RuntimeError("Mako preprocessing of file {} failed: {}".format(
|
|
test_list, traceback.format_exc()))
|
|
|
|
return lines
|
|
|
|
|
|
def parse_test_list_lines(test_list, lines, test_prefix, convert_unittest=True):
|
|
"""Parses the lines of a test list. Test names returned contain all values within square brackets. Does not process
|
|
each test id value.
|
|
|
|
Args:
|
|
test_list: Name of the test list, used for debugging
|
|
lines: The lines of the test list file
|
|
test_prefix: The value of the --test-prefix option, or None if this option isn't set.
|
|
Returns:
|
|
A tuple (test_names, test_name_to_marker_dict).
|
|
test_names: List of test names parsed from the test list file,
|
|
ordered by their appearance in the list.
|
|
test_name_to_marker_dict: Dictionary mapping test names to a tuple (parsed test marker, reason string).
|
|
"""
|
|
|
|
def parse_test_name(s):
|
|
if s.startswith("full:"):
|
|
s = s.lstrip("full:")
|
|
if test_prefix:
|
|
# Check for SM version pattern (e.g., sm90, sm89, sm100)
|
|
sm_match = re.match(r'sm(\d+)/', s)
|
|
if sm_match:
|
|
sm_version = int(sm_match.group(1))
|
|
# Get current SM version
|
|
try:
|
|
from .conftest import get_sm_version
|
|
current_sm = get_sm_version()
|
|
# If SM versions match, replace with test_prefix
|
|
if sm_version == current_sm:
|
|
s = s.replace(f'sm{sm_version}',
|
|
test_prefix.split('-')[0])
|
|
except Exception:
|
|
# If can't get SM version, skip SM-based filtering
|
|
pass
|
|
elif test_prefix.split('-')[0] in s:
|
|
s = s.replace(test_prefix.split('-')[0], test_prefix)
|
|
return s
|
|
elif test_prefix:
|
|
return "/".join([test_prefix, s])
|
|
else:
|
|
return s
|
|
|
|
def parse_test_line(enumerated_line):
|
|
lineno, line = enumerated_line
|
|
lineno += 1
|
|
# Strip comments and whitespace
|
|
line = line.partition("#")[0].strip()
|
|
if len(line) == 0:
|
|
return (None, None, None)
|
|
|
|
# test_name [MARKER] [REASON]
|
|
test_name = line
|
|
marker = None
|
|
reason = None
|
|
timeout = None
|
|
for tmp_marker in kVALID_TEST_LIST_MARKERS:
|
|
if f" {tmp_marker}" in line:
|
|
test_name, marker, reason_raw = line.partition(f" {tmp_marker}")
|
|
test_name = test_name.strip()
|
|
marker = marker.strip()
|
|
if marker == "TIMEOUT":
|
|
# Extract timeout value from parentheses
|
|
timeout = strip_parens(reason_raw.strip())
|
|
print_info(f"Timeout setting for {test_name}: {timeout}")
|
|
if not timeout or not timeout.isdigit():
|
|
raise ValueError(
|
|
f'{test_list}:{lineno}: Invalid syntax for TIMEOUT value: "{reason_raw}". '
|
|
"Expected a numeric value in parentheses.")
|
|
timeout = int(timeout) * 60
|
|
elif len(reason_raw) > 0:
|
|
reason = strip_parens(reason_raw.strip())
|
|
if not reason:
|
|
raise ValueError(
|
|
('{}:{}: Invalid syntax for reason: "{}". '
|
|
"Did you forget to add parentheses?").format(
|
|
test_list, lineno, reason_raw))
|
|
break
|
|
|
|
if convert_unittest:
|
|
# extract full:XXX/ prefix
|
|
full_prefix = ""
|
|
match = re.match(r'(full:.*?/)(.+)', test_name)
|
|
if match:
|
|
full_prefix = match.group(1)
|
|
test_name = match.group(2)
|
|
|
|
# convert unittest to actual test name
|
|
if test_name.startswith("unittest/"):
|
|
test_name = f"test_unittests.py::test_unittests_v2[{test_name}]"
|
|
|
|
# combine back
|
|
test_name = full_prefix + test_name
|
|
|
|
test_name = parse_test_name(test_name)
|
|
|
|
return (test_name, marker, reason, timeout)
|
|
|
|
parsed_test_list = map(parse_test_line, enumerate(lines))
|
|
parsed_test_list = list(filter(lambda x: x[0] is not None,
|
|
parsed_test_list))
|
|
test_names = [x[0] for x in parsed_test_list]
|
|
test_name_to_marker_dict = {
|
|
x[0]: (x[1], x[2], x[3])
|
|
for x in parsed_test_list
|
|
}
|
|
|
|
return (test_names, test_name_to_marker_dict)
|
|
|
|
|
|
def parse_test_list(test_list, test_prefix):
|
|
with open(test_list, "r") as f:
|
|
lines = f.readlines()
|
|
|
|
lines = preprocess_test_list_lines(test_list, lines)
|
|
return parse_test_list_lines(test_list, lines, test_prefix)
|
|
|
|
|
|
def split_test_name_into_components(test_name):
|
|
"""
|
|
Splits a fully-qualified test name with file name into components.
|
|
|
|
Args:
|
|
test_name (str): A test name (with or without parameters).
|
|
|
|
Returns:
|
|
Tuple[str, str, Tuple[str]]: A tuple containing:
|
|
- The name of the file containing the test (if found)
|
|
- The name of the test function
|
|
- The parameters of the test as a single string
|
|
"""
|
|
params = ()
|
|
test_file, sep, test_basename = test_name.partition("::")
|
|
|
|
if not sep:
|
|
test_basename = test_file
|
|
test_file = ""
|
|
|
|
test_id_params, _, params = test_basename.partition("[")
|
|
|
|
if params:
|
|
params = params.rstrip("]")
|
|
|
|
return test_file, test_id_params, params
|
|
|
|
|
|
def join_test_name_components(test_file: str, test_function: str,
|
|
test_params: str):
|
|
"""Performs the inverse of split_test_name_into_components()."""
|
|
name = ""
|
|
if test_file:
|
|
name += f"{test_file}::"
|
|
if test_function:
|
|
name += test_function
|
|
if test_params:
|
|
name += f"[{test_params}]"
|
|
return name
|
|
|
|
|
|
# Global cache for storing previously computed edit distances
|
|
_edit_distance_cache: "dict[tuple[str,str], int]" = {}
|
|
|
|
|
|
def edit_distance(s0: str, s1: str):
|
|
"""Compute the Levenshtein edit distance (https://en.wikipedia.org/wiki/Levenshtein_distance) between two strings."""
|
|
|
|
if s0 == s1:
|
|
return 0
|
|
|
|
# Ensure that s0 <= s1, since edit_distance(s0, s1) == edit_distance(s1, s0)
|
|
if s1 < s0:
|
|
s0, s1 = s1, s0
|
|
|
|
m, n = len(s0), len(s1)
|
|
|
|
if not s1:
|
|
return n
|
|
if not s0:
|
|
return m
|
|
|
|
if (s0, s1) not in _edit_distance_cache:
|
|
prev_ed = array.array("l", range(n + 1))
|
|
cur_ed = array.array("l", [0] * (n + 1))
|
|
|
|
for i in range(m):
|
|
cur_ed[0] = i + 1
|
|
for j in range(n):
|
|
del_cost = prev_ed[j + 1] + 1
|
|
ins_cost = cur_ed[j] + 1
|
|
sub_cost = prev_ed[j] if s0[i] == s1[j] else prev_ed[j] + 1
|
|
cur_ed[j + 1] = min(del_cost, ins_cost, sub_cost)
|
|
cur_ed, prev_ed = prev_ed, cur_ed
|
|
|
|
_edit_distance_cache[(s0, s1)] = prev_ed[n]
|
|
|
|
return _edit_distance_cache[(s0, s1)]
|
|
|
|
|
|
def strip_prefix(test_prefix, test_name):
|
|
"""Strips the test prefix (as provided by the --test-prefix option) from the given test name, if applicable."""
|
|
|
|
if test_prefix is None:
|
|
return test_name
|
|
|
|
if test_name.startswith(test_prefix):
|
|
_, _, test_name = test_name.partition(test_prefix)
|
|
|
|
# Strip away any leading slashes left over after removing the test prefix
|
|
test_name = test_name.lstrip("/")
|
|
|
|
return test_name
|
|
|
|
|
|
def get_test_name_corrections(
|
|
test_names: Sequence[str],
|
|
items: Sequence[Any],
|
|
test_prefix: Optional[str] = None) -> Dict[str, Optional[str]]:
|
|
"""Provided for backwards compatibility with tests which rely on this interface. Use get_test_name_corrections_v2 instead."""
|
|
|
|
all_valid_tests = set(
|
|
strip_prefix(test_prefix, test.nodeid) for test in items)
|
|
no_pfx_test_names = set(
|
|
strip_prefix(test_prefix, test_name) for test_name in test_names)
|
|
|
|
corr = get_test_name_corrections_v2(no_pfx_test_names, all_valid_tests,
|
|
TestCorrectionMode.EXACT_MATCH)
|
|
|
|
ret = {}
|
|
for test_name, corrections in corr.items():
|
|
if corrections:
|
|
ret[test_name] = corrections[0]
|
|
else:
|
|
ret[test_name] = None
|
|
|
|
return ret
|
|
|
|
|
|
class TestCorrectionMode:
|
|
"""Enum describing the different test correction modes supported by get_test_name_corrections_v2."""
|
|
|
|
EXACT_MATCH = 0 # Filters specify exact matches
|
|
SUBSTRING = 1 # Filters specify substring matches
|
|
REGEX = 2 # Filters specify regex matches
|
|
|
|
|
|
def get_test_name_corrections_v2(
|
|
test_filters: "set[str]", valid_test_names: "set[str]",
|
|
mode: TestCorrectionMode) -> "dict[str, list[str]]":
|
|
"""
|
|
Given a set of user-provided test filter names and set of valid test names, suggests
|
|
corrections for any invalid tests.
|
|
|
|
Expects test_filters and valid_test_names to be fully-qualified, e.g.
|
|
`test_foo.py::test_foo[bar-baz]`.
|
|
|
|
Args:
|
|
test_filters (set[str]): Test name filters that may or may not be valid.
|
|
valid_test_names (set[str]): Known valid test names.
|
|
mode (TestCorrectionMode): How the test name filters are used to filter tests. Used to determine which filters are invalid.
|
|
|
|
Returns:
|
|
dict[str, list[str]]: A mapping of invalid test names to a list of candidate corrections.
|
|
"""
|
|
|
|
print_info("Checking for invalid test name filters and corrections")
|
|
corrections_start_time = time.time()
|
|
|
|
invalid_filters = set()
|
|
|
|
compiled_regex_map = {}
|
|
|
|
if mode == TestCorrectionMode.REGEX:
|
|
# Pre-compile the regular expressions to save time.
|
|
compiled_regex_map = {r: re.compile(r) for r in test_filters}
|
|
|
|
if mode == TestCorrectionMode.EXACT_MATCH:
|
|
invalid_filters = test_filters - valid_test_names
|
|
elif mode in (TestCorrectionMode.SUBSTRING, TestCorrectionMode.REGEX):
|
|
if mode == TestCorrectionMode.SUBSTRING:
|
|
matches_func = lambda f, t: f in t
|
|
else:
|
|
matches_func = lambda f, t: compiled_regex_map[f].search(
|
|
t) is not None
|
|
invalid_filters = set(test_filters)
|
|
for v in valid_test_names:
|
|
found_filters = set()
|
|
for f in invalid_filters:
|
|
if matches_func(f, v):
|
|
found_filters.add(f)
|
|
invalid_filters -= found_filters
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
# Bail out of corrections if there are too many invalid filters, as it becomes prohibitively slow to check every invalid filter.
|
|
MAX_INVALID_FILTER_THRESHOLD = 50
|
|
|
|
# Maximum number of candidates to check across all filters. If corrections
|
|
# are computed, the number of candidates allowed per filter is then in the
|
|
# range [MAX_NB_CANDIDATES / MAX_INVALID_FILTER_THRESHOLD, MAX_NB_CANDIDATES].
|
|
#
|
|
# Too many candidates can significantly impact the algorithm's runtime, so keep this number constrained.
|
|
MAX_NB_CANDIDATES = 10000
|
|
|
|
nb_invalid_filters = len(invalid_filters)
|
|
nb_candidates_per_filter = (MAX_NB_CANDIDATES //
|
|
nb_invalid_filters if nb_invalid_filters else 0)
|
|
|
|
def build_valid_test_buckets() -> "dict[str,dict[str,set[str]]]":
|
|
"""Bucket valid tests by file and test name to save time when looking for matches."""
|
|
valid_test_buckets = defaultdict(lambda: defaultdict(set))
|
|
for valid_test in valid_test_names:
|
|
(
|
|
valid_test_file,
|
|
valid_test_name,
|
|
valid_test_params,
|
|
) = split_test_name_into_components(valid_test)
|
|
valid_test_buckets[valid_test_file][valid_test_name].add(
|
|
valid_test_params)
|
|
return valid_test_buckets # pyright: ignore
|
|
|
|
def suggest_correction(valid_test_buckets, test):
|
|
"""Attempt to find corrections for a given invalid test name.
|
|
Works well for exact matches and probably substrings, but likely won't work well for regex patterns.
|
|
"""
|
|
test_file, test_name, test_params = split_test_name_into_components(
|
|
test)
|
|
|
|
# Maximum number of corrections to suggest per invalid filter
|
|
MAX_NB_CORRECTIONS = 3
|
|
|
|
# Maximum edit distance for a test name for which we consider this a possible match
|
|
MAX_TEST_NAME_EDIT_DISTANCE = 6
|
|
|
|
# Maximum edit distance for a test parameter for which we consider this a possible match
|
|
# Be more generous here than for test names, since it's easier to get the parameter naming wrong.
|
|
MAX_TEST_PARAMETER_EDIT_DISTANCE = 12
|
|
|
|
candidates: "list[tuple[tuple[str, int], str]]" = []
|
|
|
|
# Only consider candidates in the same file as the invalid test filter as it is prohibitively expensive
|
|
# to check every test in every file.
|
|
tests_in_same_file = valid_test_buckets[test_file]
|
|
# First add tests with identical file and test name (if any) to candidates list.
|
|
# Arbitrarily prune the list if there are too many candidates.
|
|
params_with_same_name = tests_in_same_file[test_name]
|
|
candidates.extend(
|
|
((test_name, 0), p)
|
|
for p in sorted(params_with_same_name)[:nb_candidates_per_filter])
|
|
if len(candidates) < min(nb_candidates_per_filter, MAX_NB_CORRECTIONS):
|
|
# Next, include tests with identical file, but differing test name.
|
|
# Prioritize test names with smaller edit distance.
|
|
# We can skip this step if we already have at least MAX_NB_CORRECTIONS
|
|
# candidates as the only candidates selected will have identical test names (see PHP
|
|
# discussion below).
|
|
ctn = sorted(
|
|
((c, edit_distance(test_name, c))
|
|
for c in tests_in_same_file if c != test_name),
|
|
key=lambda x: x[1],
|
|
)
|
|
for cname, cedit_distance in ctn:
|
|
if cedit_distance > MAX_TEST_NAME_EDIT_DISTANCE:
|
|
continue
|
|
cparams_bucket = tests_in_same_file[cname]
|
|
for cparam in cparams_bucket:
|
|
if len(candidates) >= nb_candidates_per_filter:
|
|
break
|
|
candidates.append(((cname, cedit_distance), cparam))
|
|
|
|
# Prune candidates based on the pigeonhole principle. Let N = MAX_NB_CORRECTIONS.
|
|
# If we already have N or more candidates in the list (sorted by test
|
|
# name edit distance) then any candidate with edit distance worse than
|
|
# the Nth candidate will not be selected and hence can be pruned from
|
|
# the list.
|
|
#
|
|
# For example, given the below list of scores (N=5):
|
|
#
|
|
# [ 2 3 4 4 4 4 4 4 5 5 6 6 ]
|
|
#
|
|
# The 5th candidate has score 4, so we need to consider all candidates with score
|
|
# 4 or below. No candidate with score 5 or 6 will be selected (based on PHP)
|
|
# and can be ignored.
|
|
if len(candidates) >= MAX_NB_CORRECTIONS:
|
|
threshold_candidate = candidates[MAX_NB_CORRECTIONS - 1]
|
|
threshold_name_edit_distance = threshold_candidate[0][1]
|
|
candidates = [
|
|
c for c in candidates if c[0][1] <= threshold_name_edit_distance
|
|
]
|
|
|
|
# For the remaining candidates, compute parameter edit distance and the final match score.
|
|
possible_matches = []
|
|
for ((cname, cedit_distance)), cparam in candidates:
|
|
param_edit_distance = edit_distance(cparam, test_params)
|
|
if param_edit_distance > MAX_TEST_PARAMETER_EDIT_DISTANCE:
|
|
continue
|
|
match_score = (cedit_distance, param_edit_distance)
|
|
ctest_name = join_test_name_components(test_file, cname, cparam)
|
|
possible_matches.append((ctest_name, match_score))
|
|
|
|
return list(m[0] for m in sorted(
|
|
possible_matches, key=lambda m: m[1]))[:MAX_NB_CORRECTIONS]
|
|
|
|
print_info(
|
|
f"Computing corrections for {nb_invalid_filters} invalid filters")
|
|
|
|
if nb_invalid_filters > MAX_INVALID_FILTER_THRESHOLD:
|
|
print_info(
|
|
f"Bailing out of corrections as there are more than {MAX_INVALID_FILTER_THRESHOLD} invalid filters."
|
|
)
|
|
# Just return the invalid filters, with no suggested corrections.
|
|
ret = {f: [] for f in invalid_filters}
|
|
elif nb_invalid_filters > 0:
|
|
valid_test_buckets = build_valid_test_buckets()
|
|
ret = {
|
|
f: suggest_correction(valid_test_buckets, f)
|
|
for f in invalid_filters
|
|
}
|
|
else:
|
|
ret = {}
|
|
|
|
corrections_dt = time.time() - corrections_start_time
|
|
print_info(
|
|
f"Finished checking for corrections in {corrections_dt:.3f} seconds.")
|
|
return ret
|
|
|
|
|
|
def apply_test_list_corrections(test_list,
|
|
corrections,
|
|
items,
|
|
test_prefix=None):
|
|
"""
|
|
Attempt to correct invalid test names in a test list.
|
|
|
|
Args:
|
|
test_list (str): The path to a test list.
|
|
corrections (Dict[str, str]): A mapping of invalid test names to valid tests, as returned by
|
|
get_test_name_corrections().
|
|
test_prefix (Optional[str]): The value of the --test-prefix option, or None if this option isn't set.
|
|
"""
|
|
print_info("Applying corrections to: {}".format(test_list))
|
|
with open(test_list, "r") as f:
|
|
contents = f.read()
|
|
|
|
for invalid_test, correction_list in corrections.items():
|
|
if correction_list:
|
|
correction = correction_list[0]
|
|
if test_prefix:
|
|
# Strip the test prefix from the correction and invalid test
|
|
correction = correction[len(f"{test_prefix}/"):]
|
|
invalid_test = invalid_test[len(f"{test_prefix}/"):]
|
|
print_info(f"Correcting {invalid_test} to {correction}")
|
|
contents = contents.replace(invalid_test, correction)
|
|
else:
|
|
print_info(
|
|
"Could not automatically correct: {}".format(invalid_test))
|
|
|
|
# We don't want to correct the test list automatically in L0 tests
|
|
#with open(test_list, "w") as f:
|
|
# f.write(contents)
|
|
|
|
# Clear the items list to prevent pytest from listing collected tests
|
|
items.clear()
|
|
|
|
raise pytest.UsageError(
|
|
"Exiting early since --apply-test-list-correction was specified.")
|
|
|
|
|
|
def generate_correction_error_message(corrections: Dict[str, List[str]],
|
|
prefix: Optional[str]) -> Optional[str]:
|
|
"""Returns a string error message reporting any corrections, or None if there are no corrections to report."""
|
|
if not corrections:
|
|
return None
|
|
ret = dedent("""
|
|
!!!!! INVALID TEST NAME FILTERS !!!!!
|
|
|
|
Some filter strings do not correspond to any known test, and will be ignored.
|
|
Please correct the test filters to use valid test names.
|
|
|
|
Hint: If you encounter this in automation, this is likely due to a bad test name in test-db.
|
|
Please correct the test name in test db configuration yaml file.
|
|
|
|
Below are the invalid filter strings, as well as any suggested corrections that identified.
|
|
|
|
""")
|
|
|
|
for filter in sorted(corrections):
|
|
corr_list = corrections[filter]
|
|
|
|
ret += f"- {strip_prefix(prefix, filter)}\n"
|
|
for l in corr_list:
|
|
ret += f" - correction: {strip_prefix(prefix, l)}\n"
|
|
|
|
ret += "\n"
|
|
|
|
return ret
|
|
|
|
|
|
def handle_corrections(corrections, test_prefix):
|
|
corr_err_msg = generate_correction_error_message(corrections,
|
|
prefix=test_prefix)
|
|
if corr_err_msg is None:
|
|
return
|
|
|
|
for l in corr_err_msg.splitlines():
|
|
print(l)
|
|
|
|
|
|
def record_invalid_tests(output_file, corrections):
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
|
with open(output_file, "a") as f:
|
|
invalid_tests = {"invalid": list(corrections.keys())}
|
|
json.dump(invalid_tests, f)
|
|
f.write("\n")
|
|
|
|
|
|
def parse_and_validate_test_list(
|
|
test_list,
|
|
config,
|
|
items,
|
|
check_for_corrections,
|
|
):
|
|
test_prefix = config.getoption("--test-prefix")
|
|
test_names, test_name_to_marker_dict = parse_test_list(
|
|
test_list, test_prefix)
|
|
|
|
if check_for_corrections:
|
|
corrections = get_test_name_corrections_v2(
|
|
set(test_names),
|
|
set(it.nodeid for it in items),
|
|
TestCorrectionMode.EXACT_MATCH,
|
|
)
|
|
|
|
apply_test_list_correction = config.getoption(
|
|
"--apply-test-list-correction")
|
|
if apply_test_list_correction and corrections:
|
|
apply_test_list_corrections(test_list, corrections, items,
|
|
test_prefix)
|
|
|
|
output_dir = config.getoption("--output-dir")
|
|
if record_invalid_tests and corrections:
|
|
record_invalid_tests(
|
|
os.path.join(output_dir, "invalid_tests.json"),
|
|
corrections,
|
|
)
|
|
|
|
handle_corrections(corrections, test_prefix)
|
|
|
|
return test_names, test_name_to_marker_dict
|
|
|
|
|
|
def modify_by_test_list(test_list, items, config):
|
|
"""Filter out tests based on the test names specified by the given test_list. Also
|
|
ensure the test order matches the order specified by the test_list, and add any
|
|
custom markers specified by the test_list."""
|
|
all_test_names = []
|
|
full_test_name_to_marker_dict = {}
|
|
|
|
test_names, test_name_to_marker_dict = parse_and_validate_test_list(
|
|
test_list,
|
|
config,
|
|
items,
|
|
check_for_corrections=True,
|
|
)
|
|
all_test_names.extend(test_names)
|
|
full_test_name_to_marker_dict.update(test_name_to_marker_dict)
|
|
|
|
found_items = {}
|
|
deselected = []
|
|
|
|
# Figure out which items have names specified by the filter
|
|
for item in items:
|
|
if item.nodeid in full_test_name_to_marker_dict:
|
|
found_items[item.nodeid] = item
|
|
else:
|
|
deselected.append(item)
|
|
|
|
# Construct a list of tests based on the ordering given in the file
|
|
selected = []
|
|
for name in all_test_names:
|
|
if name in found_items:
|
|
item = found_items[name]
|
|
selected.append(item)
|
|
# Also update the item based on the marker specified in the file
|
|
marker, reason, timeout = full_test_name_to_marker_dict[name]
|
|
if marker:
|
|
if marker == "TIMEOUT" and timeout:
|
|
item.add_marker(pytest.mark.timeout(timeout))
|
|
else:
|
|
mark_func = getattr(pytest.mark, marker.lower())
|
|
mark = mark_func(reason=reason)
|
|
item.add_marker(mark)
|
|
|
|
if deselected:
|
|
config.hook.pytest_deselected(items=deselected)
|
|
items[:] = selected
|
|
|
|
|
|
def apply_waives(waives_file, items, config):
|
|
"""Apply waives based on the waive state specified by the given waives_file."""
|
|
|
|
# Corrections don't make sense for the waives file as it specifies global negative
|
|
# filters that may or may not be applicable to the current platform (i.e., the test names
|
|
# being waived may not be generated on the current platform).
|
|
ret = parse_and_validate_test_list(
|
|
waives_file,
|
|
config,
|
|
items,
|
|
check_for_corrections=False,
|
|
)
|
|
if not ret:
|
|
return
|
|
_, test_name_to_marker_dict = ret
|
|
|
|
# For each item in the list, apply waives if a waive entry exists
|
|
for item in items:
|
|
if item.nodeid in test_name_to_marker_dict:
|
|
marker, reason, _ = test_name_to_marker_dict[item.nodeid]
|
|
if marker:
|
|
mark_func = getattr(pytest.mark, marker.lower())
|
|
mark = mark_func(reason=reason)
|
|
item.add_marker(mark)
|
|
|
|
|
|
def uniquify_test_items(items):
|
|
nodeid_set = set()
|
|
duplication_set = set()
|
|
items_unique = []
|
|
for item in items:
|
|
if item.nodeid not in nodeid_set:
|
|
items_unique.append(item)
|
|
nodeid_set.add(item.nodeid)
|
|
else:
|
|
duplication_set.add(item.nodeid)
|
|
|
|
if duplication_set:
|
|
print_warning("Test item duplication: " +
|
|
",".join(list(duplication_set)))
|
|
|
|
items[:] = items_unique
|