mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com> Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
185 lines
6.5 KiB
Python
185 lines
6.5 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import warnings
|
|
from subprocess import CalledProcessError
|
|
|
|
from defs.conftest import tests_path
|
|
|
|
|
|
def merge_report(base_file, extra_file, output_file, is_retry=False):
|
|
import xml.etree.ElementTree as ElementTree
|
|
|
|
base = ElementTree.parse(base_file)
|
|
try:
|
|
extra = ElementTree.parse(extra_file)
|
|
except FileNotFoundError:
|
|
return
|
|
|
|
base_suite = base.getroot().find('testsuite')
|
|
extra_suite = extra.getroot().find('testsuite')
|
|
|
|
def merge_attr(name, type_=int):
|
|
base_suite.attrib[name] = str(
|
|
type_(base_suite.attrib[name]) + type_(extra_suite.attrib[name]))
|
|
|
|
merge_attr("time", type_=float)
|
|
|
|
if is_retry:
|
|
base_suite.attrib['failures'] = extra_suite.attrib['failures']
|
|
# pytest may generate testcase node without classname or name attribute when worker crashed catastrophically.
|
|
# Simply ignore these nodes since they are not meaningful.
|
|
extra_suite_nodes = [
|
|
element for element in extra_suite if 'name' in element.attrib
|
|
]
|
|
case_names = {(element.attrib['classname'], element.attrib['name'])
|
|
for element in extra_suite_nodes}
|
|
base_suite[:] = [
|
|
element for element in base_suite if 'name' in element.attrib
|
|
if (element.attrib['classname'],
|
|
element.attrib['name']) not in case_names
|
|
] + extra_suite_nodes
|
|
else:
|
|
merge_attr("errors")
|
|
merge_attr("failures")
|
|
merge_attr("skipped")
|
|
merge_attr("tests")
|
|
base_suite[:] = list(base_suite) + list(extra_suite)
|
|
|
|
os.remove(extra_file)
|
|
base.write(output_file, encoding="UTF-8", xml_declaration=True)
|
|
|
|
|
|
def test_unittests_v2(llm_root, llm_venv, case: str, output_dir, request):
|
|
import pandas as pd
|
|
import pynvml
|
|
pynvml.nvmlInit()
|
|
|
|
test_root = tests_path()
|
|
dry_run = False
|
|
|
|
my_test_prefix = request.config.getoption("--test-prefix")
|
|
if my_test_prefix:
|
|
test_prefix = f"{my_test_prefix}/unittest"
|
|
else:
|
|
test_prefix = "unittest"
|
|
|
|
num_workers = 1
|
|
|
|
# This dataframe is not manually edited. Infra team will regularly generate this dataframe based on test execution results.
|
|
# If you need to override this policy, please use postprocess code as below.
|
|
agg_unit_mem_df = pd.read_csv(
|
|
f'{test_root}/integration/defs/agg_unit_mem_df.csv')
|
|
gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
|
gpu_name = pynvml.nvmlDeviceGetName(gpu_handle)
|
|
if isinstance(gpu_name, bytes):
|
|
gpu_name = gpu_name.decode()
|
|
agg_unit_mem_df = agg_unit_mem_df[agg_unit_mem_df['gpu'] == gpu_name]
|
|
print(agg_unit_mem_df)
|
|
|
|
parallel_dict = {}
|
|
for _, row in agg_unit_mem_df.iterrows():
|
|
key = (row['gpu'], row['unittest_case_name'])
|
|
parallel_dict[key] = row['parallel_factor']
|
|
|
|
print(parallel_dict)
|
|
|
|
cur_key = (gpu_name, case)
|
|
if cur_key in parallel_dict:
|
|
num_workers = parallel_dict[cur_key]
|
|
num_workers = min(num_workers, 8)
|
|
else:
|
|
warnings.warn(
|
|
f'Cannot find parallel config entry for unittest {case} on "{gpu_name}". Fallback to serial test. Please add config entry to agg_unit_mem_df.csv.'
|
|
)
|
|
|
|
num_workers = max(1, num_workers)
|
|
|
|
if parallel_override := os.environ.get("LLM_TEST_PARALLEL_OVERRIDE", None):
|
|
num_workers = int(parallel_override)
|
|
|
|
print('Parallel workers: ', num_workers)
|
|
|
|
ignore_opt = f"--ignore={test_root}/integration"
|
|
|
|
import shlex
|
|
arg_list = shlex.split(case)
|
|
case_fn = case.replace('/', '-')
|
|
if len(case_fn) > 80:
|
|
case_fn = case_fn[:80]
|
|
output_xml = os.path.join(output_dir,
|
|
f'results-sub-unittests-{case_fn}.xml')
|
|
|
|
command = [
|
|
'-m', 'pytest', ignore_opt, "-v", "--timeout=2400",
|
|
"--timeout-method=thread"
|
|
]
|
|
if test_prefix:
|
|
command += [f"--test-prefix={test_prefix}"]
|
|
|
|
if dry_run:
|
|
command += ['--collect-only']
|
|
|
|
command += arg_list
|
|
|
|
print(f"Running unit test:'{command}'")
|
|
|
|
def run_command(cmd):
|
|
try:
|
|
llm_venv.run_cmd(cmd, cwd=test_root)
|
|
except CalledProcessError:
|
|
return False
|
|
return True
|
|
|
|
if num_workers == 1:
|
|
# Do not bother with pytest-xdist at all if we don't need parallel execution
|
|
command += ["-p", "no:xdist", f"--junitxml={output_xml}"]
|
|
passed = run_command(command)
|
|
else:
|
|
# Avoid .xml extension to prevent CI from reading failures from it
|
|
parallel_output_xml = os.path.join(
|
|
output_dir,
|
|
f'parallel-sub-results-unittests-{case_fn}.xml.intermediate')
|
|
parallel_command = command + [
|
|
"-n", f"{num_workers}", f"--junitxml={parallel_output_xml}"
|
|
]
|
|
passed = run_command(parallel_command)
|
|
|
|
assert os.path.exists(
|
|
parallel_output_xml
|
|
), "no report generated, fatal failure happened in unittests (parallel phase)"
|
|
|
|
if dry_run or passed:
|
|
os.rename(parallel_output_xml, output_xml)
|
|
else:
|
|
# Avoid .xml extension to prevent CI from reading failures from it
|
|
retry_output_xml = os.path.join(
|
|
output_dir,
|
|
f'retry-sub-results-unittests-{case_fn}.xml.intermediate')
|
|
# Run failed case sequentially.
|
|
retry_command = command + [
|
|
"-p", "no:xdist", '--lf', f"--junitxml={retry_output_xml}"
|
|
]
|
|
passed = run_command(retry_command)
|
|
|
|
if os.path.exists(retry_output_xml):
|
|
merge_report(parallel_output_xml, retry_output_xml, output_xml,
|
|
True)
|
|
else:
|
|
os.rename(parallel_output_xml, output_xml)
|
|
assert False, "no report generated, fatal failure happened in unittests (retry phase)"
|
|
|
|
assert passed, "failure reported in unittests"
|