mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
140 lines
5.3 KiB
Python
140 lines
5.3 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import glob
|
|
import os
|
|
|
|
import yaml
|
|
|
|
from .perf.test_perf import PerfTestConfig
|
|
from .test_list_parser import get_test_name_corrections, parse_test_list
|
|
|
|
|
|
# A function to convert the mako test list or test-db to a list of test names.
|
|
def parse_test_list_or_db(test_list, trt_config):
|
|
|
|
# For mako test list.
|
|
if test_list.endswith("txt"):
|
|
# Mako options that may not always be defined, but are required in test lists.
|
|
# Defined first so that values from `trt_config` can overwrite values.
|
|
mako_opts = {
|
|
"level": 4,
|
|
"idx": 0,
|
|
"priority": 0,
|
|
"host_mem_available_mib":
|
|
1e8, # Set to very large to cover all tests.
|
|
"gpu_memory": 1e6, # Set to very large to cover all tests.
|
|
"system_gpu_count": 128, # Set to very large to cover all tests.
|
|
}
|
|
mako_opts.update(trt_config["mako_opts"])
|
|
|
|
test_names, _ = parse_test_list(test_list,
|
|
print_mako=False,
|
|
no_mako=False,
|
|
mako_opts=mako_opts,
|
|
test_prefix=trt_config["test_prefix"])
|
|
|
|
# For yaml-based test db.
|
|
elif test_list.endswith("yml"):
|
|
with open(test_list) as f:
|
|
test_db_data = yaml.load(f, Loader=yaml.Loader)
|
|
|
|
test_names = []
|
|
context_name = os.path.basename(test_list).replace(".yml", "")
|
|
for condition in test_db_data[context_name]:
|
|
test_names.extend(condition["tests"])
|
|
|
|
else:
|
|
raise ValueError(f"Unexpected test list name: {test_list}")
|
|
|
|
# Perf tests are generated based on test lists dynamically, so separate them out from normal tests.
|
|
non_perf_test_names = [
|
|
x for x in test_names if "perf/test_perf.py::test_perf" not in x
|
|
]
|
|
perf_test_names = [
|
|
x for x in test_names if "perf/test_perf.py::test_perf" in x
|
|
]
|
|
|
|
return non_perf_test_names, perf_test_names
|
|
|
|
|
|
# Validate perf test names, which are generated dynamically based on test lists.
|
|
def validate_perf_tests(perf_test_names) -> bool:
|
|
passed = True
|
|
for test_name in perf_test_names:
|
|
config = PerfTestConfig()
|
|
try:
|
|
# Get only the "[...]" part in the test name.
|
|
test_param_labels = test_name.split("[")[-1][:-1]
|
|
# Check if perf test config can be successfully loaded.
|
|
config.load_from_str(test_param_labels)
|
|
except Exception as e:
|
|
print(f"Perf test name {test_name} is invalid! Error: {e}")
|
|
passed = False
|
|
|
|
return passed
|
|
|
|
|
|
def test_list_validation(test_root, all_pytest_items, trt_config,
|
|
is_trt_environment):
|
|
|
|
# Don't run test list validation in TRT environment because TRT uses
|
|
# YAML-based test-db for test lists.
|
|
if is_trt_environment:
|
|
print(
|
|
"Skipped TRT-LLM test list validation because the pipeline is running in TRT environment."
|
|
)
|
|
return
|
|
|
|
# Glob all the test list files.
|
|
test_list_path = os.path.join(test_root, "test_lists", "*", "*.txt")
|
|
all_test_lists = glob.glob(test_list_path)
|
|
assert len(all_test_lists
|
|
) > 0, f"Cannot find any test lists with path {test_list_path}!"
|
|
|
|
# Glob all the test db files.
|
|
test_db_path = os.path.join(test_root, "test_lists", "*", "*.yml")
|
|
all_test_dbs = glob.glob(test_db_path)
|
|
assert len(all_test_dbs
|
|
) > 0, f"Cannot find any test lists with path {test_db_path}!"
|
|
|
|
# Go through test lists to get test name corrections.
|
|
passed = True
|
|
for test_list in (all_test_lists + all_test_dbs):
|
|
print(f"Validating test list: {test_list} ...")
|
|
non_perf_test_names, perf_test_names = parse_test_list_or_db(
|
|
test_list, trt_config)
|
|
|
|
if not validate_perf_tests(perf_test_names):
|
|
passed = False
|
|
|
|
corrections = get_test_name_corrections(non_perf_test_names,
|
|
all_pytest_items,
|
|
trt_config["test_prefix"])
|
|
if corrections:
|
|
err_msg = "{} errors found in test list: {}".format(
|
|
len(corrections), test_list)
|
|
print(err_msg)
|
|
print("Invalid tests:")
|
|
for name, correct in corrections.items():
|
|
if correct is not None:
|
|
print("\tSUGGESTED CORRECTION: {} -> {}".format(
|
|
name, correct))
|
|
else:
|
|
print("\tCORRECTION UNKNOWN: {}".format(name))
|
|
passed = False
|
|
|
|
assert passed, "Some test lists contain invalid test names!"
|