# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import warnings from subprocess import CalledProcessError from defs.conftest import tests_path def merge_report(base_file, extra_file, output_file, is_retry=False): import xml.etree.ElementTree as ElementTree base = ElementTree.parse(base_file) try: extra = ElementTree.parse(extra_file) except FileNotFoundError: return base_suite = base.getroot().find('testsuite') extra_suite = extra.getroot().find('testsuite') def merge_attr(name, type_=int): base_suite.attrib[name] = str( type_(base_suite.attrib[name]) + type_(extra_suite.attrib[name])) merge_attr("time", type_=float) if is_retry: base_suite.attrib['failures'] = extra_suite.attrib['failures'] # pytest may generate testcase node without classname or name attribute when worker crashed catastrophically. # Simply ignore these nodes since they are not meaningful. extra_suite_nodes = [ element for element in extra_suite if 'name' in element.attrib ] case_names = {(element.attrib['classname'], element.attrib['name']) for element in extra_suite_nodes} base_suite[:] = [ element for element in base_suite if 'name' in element.attrib if (element.attrib['classname'], element.attrib['name']) not in case_names ] + extra_suite_nodes else: merge_attr("errors") merge_attr("failures") merge_attr("skipped") merge_attr("tests") base_suite[:] = list(base_suite) + list(extra_suite) os.remove(extra_file) base.write(output_file, encoding="UTF-8", xml_declaration=True) def test_unittests_v2(llm_root, llm_venv, case: str, output_dir, request): import pandas as pd import pynvml pynvml.nvmlInit() test_root = tests_path() dry_run = False my_test_prefix = request.config.getoption("--test-prefix") if my_test_prefix: test_prefix = f"{my_test_prefix}/unittest" else: test_prefix = "unittest" waives_file = request.config.getoption("--waives-file") run_ray = request.config.getoption("--run-ray") num_workers = 1 # This dataframe is not manually edited. Infra team will regularly generate this dataframe based on test execution results. # If you need to override this policy, please use postprocess code as below. agg_unit_mem_df = pd.read_csv( f'{test_root}/integration/defs/agg_unit_mem_df.csv') gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0) gpu_name = pynvml.nvmlDeviceGetName(gpu_handle) if isinstance(gpu_name, bytes): gpu_name = gpu_name.decode() agg_unit_mem_df = agg_unit_mem_df[agg_unit_mem_df['gpu'] == gpu_name] print(agg_unit_mem_df) parallel_dict = {} for _, row in agg_unit_mem_df.iterrows(): key = (row['gpu'], row['unittest_case_name']) parallel_dict[key] = row['parallel_factor'] print(parallel_dict) cur_key = (gpu_name, case) if cur_key in parallel_dict: num_workers = parallel_dict[cur_key] num_workers = min(num_workers, 8) else: warnings.warn( f'Cannot find parallel config entry for unittest {case} on "{gpu_name}". Fallback to serial test. Please add config entry to agg_unit_mem_df.csv.' ) num_workers = max(1, num_workers) if parallel_override := os.environ.get("LLM_TEST_PARALLEL_OVERRIDE", None): num_workers = int(parallel_override) print('Parallel workers: ', num_workers) ignore_opt = f"--ignore={test_root}/integration" import shlex arg_list = shlex.split(case) case_fn = case.replace('/', '-') if len(case_fn) > 80: case_fn = case_fn[:80] output_xml = os.path.join(output_dir, f'results-sub-unittests-{case_fn}.xml') command = [ '-m', 'pytest', ignore_opt, "-vv", "--tb=short", "-rF", "--timeout=2400", "--timeout-method=thread" ] if test_prefix: command += [f"--test-prefix={test_prefix}"] if dry_run: command += ['--collect-only'] if waives_file: waives_file = os.path.abspath(waives_file) command += [f"--waives-file={waives_file}"] if run_ray: command += ["--run-ray"] command += arg_list print(f"Running unit test:\"python {' '.join(command)}\"") def run_command(cmd, num_workers=1): try: pythonpath = os.environ.get("PYTHONPATH", "") env = {'PYTHONPATH': f"{llm_root}/tests/unittest:{pythonpath}"} if num_workers > 1: env['TORCHINDUCTOR_COMPILE_THREADS'] = '1' llm_venv.run_cmd( cmd, cwd=test_root, env=env, ) except CalledProcessError as e: print(f"\n{'='*60}") print(f"UNITTEST FAILED with exit code: {e.returncode}") print(f"Command: {' '.join(cmd)}") if hasattr(e, 'stdout') and e.stdout: print( f"STDOUT:\n{e.stdout.decode() if isinstance(e.stdout, bytes) else e.stdout}" ) if hasattr(e, 'stderr') and e.stderr: print( f"STDERR:\n{e.stderr.decode() if isinstance(e.stderr, bytes) else e.stderr}" ) print(f"{'='*60}\n") return False return True if num_workers == 1: # Do not bother with pytest-xdist at all if we don't need parallel execution command += ["-p", "no:xdist", f"--junitxml={output_xml}"] passed = run_command(command) else: # Avoid .xml extension to prevent CI from reading failures from it parallel_output_xml = os.path.join( output_dir, f'parallel-sub-results-unittests-{case_fn}.xml.intermediate') parallel_command = command + [ "-n", f"{num_workers}", f"--junitxml={parallel_output_xml}" ] passed = run_command(parallel_command, num_workers) assert os.path.exists( parallel_output_xml ), "no report generated, fatal failure happened in unittests (parallel phase)" if dry_run or passed: os.rename(parallel_output_xml, output_xml) else: # Avoid .xml extension to prevent CI from reading failures from it retry_output_xml = os.path.join( output_dir, f'retry-sub-results-unittests-{case_fn}.xml.intermediate') # Run failed case sequentially. retry_command = command + [ "-p", "no:xdist", '--lf', f"--junitxml={retry_output_xml}" ] passed = run_command(retry_command) if os.path.exists(retry_output_xml): merge_report(parallel_output_xml, retry_output_xml, output_xml, True) else: os.rename(parallel_output_xml, output_xml) assert False, "no report generated, fatal failure happened in unittests (retry phase)" assert passed, "failure reported in unittests"