TensorRT-LLMs/tests/integration/defs/test_list_validation.py
Emma Qiao c945e92fdb
[Infra]Remove some old keyword (#4552)
Signed-off-by: qqiao <qqiao@nvidia.com>
2025-05-31 13:50:45 +08:00

140 lines
5.3 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
import yaml
from .perf.test_perf import PerfTestConfig
from .test_list_parser import get_test_name_corrections, parse_test_list
# A function to convert the mako test list or test-db to a list of test names.
def parse_test_list_or_db(test_list, trt_config):
# For mako test list.
if test_list.endswith("txt"):
# Mako options that may not always be defined, but are required in test lists.
# Defined first so that values from `trt_config` can overwrite values.
mako_opts = {
"level": 4,
"idx": 0,
"priority": 0,
"host_mem_available_mib":
1e8, # Set to very large to cover all tests.
"gpu_memory": 1e6, # Set to very large to cover all tests.
"system_gpu_count": 128, # Set to very large to cover all tests.
}
mako_opts.update(trt_config["mako_opts"])
test_names, _ = parse_test_list(test_list,
print_mako=False,
no_mako=False,
mako_opts=mako_opts,
test_prefix=trt_config["test_prefix"])
# For yaml-based test db.
elif test_list.endswith("yml"):
with open(test_list) as f:
test_db_data = yaml.load(f, Loader=yaml.Loader)
test_names = []
context_name = os.path.basename(test_list).replace(".yml", "")
for condition in test_db_data[context_name]:
test_names.extend(condition["tests"])
else:
raise ValueError(f"Unexpected test list name: {test_list}")
# Perf tests are generated based on test lists dynamically, so separate them out from normal tests.
non_perf_test_names = [
x for x in test_names if "perf/test_perf.py::test_perf" not in x
]
perf_test_names = [
x for x in test_names if "perf/test_perf.py::test_perf" in x
]
return non_perf_test_names, perf_test_names
# Validate perf test names, which are generated dynamically based on test lists.
def validate_perf_tests(perf_test_names) -> bool:
passed = True
for test_name in perf_test_names:
config = PerfTestConfig()
try:
# Get only the "[...]" part in the test name.
test_param_labels = test_name.split("[")[-1][:-1]
# Check if perf test config can be successfully loaded.
config.load_from_str(test_param_labels)
except Exception as e:
print(f"Perf test name {test_name} is invalid! Error: {e}")
passed = False
return passed
def test_list_validation(test_root, all_pytest_items, trt_config,
is_trt_environment):
# Don't run test list validation in TRT environment because TRT uses
# YAML-based test-db for test lists.
if is_trt_environment:
print(
"Skipped TRT-LLM test list validation because the pipeline is running in TRT environment."
)
return
# Glob all the test list files.
test_list_path = os.path.join(test_root, "test_lists", "*", "*.txt")
all_test_lists = glob.glob(test_list_path)
assert len(all_test_lists
) > 0, f"Cannot find any test lists with path {test_list_path}!"
# Glob all the test db files.
test_db_path = os.path.join(test_root, "test_lists", "*", "*.yml")
all_test_dbs = glob.glob(test_db_path)
assert len(all_test_dbs
) > 0, f"Cannot find any test lists with path {test_db_path}!"
# Go through test lists to get test name corrections.
passed = True
for test_list in (all_test_lists + all_test_dbs):
print(f"Validating test list: {test_list} ...")
non_perf_test_names, perf_test_names = parse_test_list_or_db(
test_list, trt_config)
if not validate_perf_tests(perf_test_names):
passed = False
corrections = get_test_name_corrections(non_perf_test_names,
all_pytest_items,
trt_config["test_prefix"])
if corrections:
err_msg = "{} errors found in test list: {}".format(
len(corrections), test_list)
print(err_msg)
print("Invalid tests:")
for name, correct in corrections.items():
if correct is not None:
print("\tSUGGESTED CORRECTION: {} -> {}".format(
name, correct))
else:
print("\tCORRECTION UNKNOWN: {}".format(name))
passed = False
assert passed, "Some test lists contain invalid test names!"