TensorRT-LLMs/tests/integration/defs/trt_test_alternative.py
HuiGao-NV f4059c6e2e
Add test case for kv memory estimation (#4158)
* Add test case for kv memory estimation
* Dump running log into file and parse kv cache memory size from file
* Set bigger peak memory size for mixed percision case and test_ptp_quickstart_advanced_eagle3 case
* Revert change to usage of fraction
* use context manager to guard temp files

Signed-off-by: Hui Gao <huig@nvidia.com>
2025-05-14 18:39:25 +08:00

278 lines
8.7 KiB
Python

# An alternative lib to trt_test to let TRT_LLM developer run test using pure pytest command
import contextlib
import logging
import os
import platform
import signal
import subprocess
import sys
import tempfile
import psutil
general_logger = logging.getLogger("general")
general_logger.setLevel(logging.CRITICAL)
exists = os.path.exists
is_windows = lambda: platform.system() == "Windows"
is_linux = lambda: platform.system() == "Linux"
is_wsl = lambda: False # FIXME: llm cases never on WSL?
makedirs = os.makedirs
wsl_to_win_path = lambda x: x # FIXME: a hack for llm not run on WSL
SessionDataWriter = None # TODO: hope never runs
@contextlib.contextmanager
def altered_env(**kwargs):
old = {}
for k, v in kwargs.items():
if k in os.environ:
old[k] = os.environ[k]
os.environ[k] = v
try:
yield
finally:
for k in kwargs:
if k not in old:
os.environ.pop(k)
else:
os.environ[k] = old[k]
# Our own version of subprocess functions that clean up the whole process tree upon failure.
# This ensures subsequent tests won't be affected by left over processes from previous testcase.
#
# On Linux we create a new session (start_new_session) when starting subprocess to trace
# descendants even when parent processes already exited.
# Subprocesses spawned by tests usually create their own process groups, so killpg() is not
# enough here. However, they usually don't create new session, so we use it to track.
#
# On Windows we create a job object and put the subprocess into it. Descendants created by
# a process in job will also in the job. Terminate the job object in turn terminates all process in
# the job.
if is_linux():
Popen = subprocess.Popen
def list_process_sid(sid: int):
current_uid = os.getuid()
pids = []
for proc in psutil.process_iter(['pid', 'uids']):
if current_uid in proc.info['uids']:
try:
if os.getsid(proc.pid) == sid:
pids.append(proc.pid)
except (ProcessLookupError, PermissionError):
pass
return pids
def cleanup_process_tree(p: subprocess.Popen, has_session=False):
target_pids = set()
if has_session:
# Session ID is the pid of the leader process
target_pids.update(list_process_sid(p.pid))
# Backup plan: using ppid to build subprocess tree
try:
target_pids.update(
sub.pid
for sub in psutil.Process(p.pid).children(recursive=True))
except psutil.Error:
pass
print("Found leftover pids:", target_pids)
for pid in target_pids:
try:
os.kill(pid, signal.SIGKILL)
except (ProcessLookupError, PermissionError):
pass
p.kill()
elif is_windows():
import pywintypes
import win32api
import win32job
class MyHandle:
def __init__(self, handle):
self.handle = handle
def __del__(self):
win32api.CloseHandle(self.handle)
def Popen(*popenargs, start_new_session, **kwargs):
job_handle = None
if start_new_session:
job_handle = win32job.CreateJobObject(None, "")
p = subprocess.Popen(*popenargs, **kwargs)
if start_new_session:
# It would be best to start with creationflags=0x04 (CREATE_SUSPENDED),
# add process to job, and resume the primary thread.
# However, subprocess.Popen simply discarded the thread handle and tid.
# Instead, simply hope we add the process early enough.
try:
win32job.AssignProcessToJobObject(job_handle, p._handle)
p.job_handle = MyHandle(job_handle)
except pywintypes.error:
p.job_handle = None
return p
def cleanup_process_tree(p: subprocess.Popen, has_session=False):
target_pids = []
try:
target_pids = [
sub.pid
for sub in psutil.Process(p.pid).children(recursive=True)
]
except psutil.Error:
pass
if has_session and p.job_handle is not None:
process_exit_code = 3600 # Some obvious special exit code
try:
win32job.TerminateJobObject(p.job_handle.handle,
process_exit_code)
except pywintypes.error:
pass
print("Found leftover pids:", target_pids)
for pid in target_pids:
try:
os.kill(pid, signal.SIGKILL)
except (ProcessLookupError, PermissionError):
pass
p.kill()
def call(*popenargs,
timeout=None,
start_new_session=True,
suppress_output_info=False,
**kwargs):
if not suppress_output_info:
print(f"Start subprocess with call({popenargs}, {kwargs})")
running_log = None
if "running_log" in kwargs:
if isinstance(kwargs["running_log"], tempfile._TemporaryFileWrapper):
running_log = kwargs["running_log"]
kwargs.pop("running_log", 'Not Found')
with Popen(*popenargs,
start_new_session=start_new_session,
stdout=running_log,
**kwargs) as p:
try:
retcode = p.wait(timeout=timeout)
if retcode and start_new_session:
cleanup_process_tree(p, True)
return retcode
except Exception as e:
if isinstance(e, subprocess.TimeoutExpired):
print("Process timed out.")
stdout, stderr = p.communicate()
if stdout:
print("STDOUT:", stdout.decode('utf-8', errors='replace'))
if stderr:
print("STDERR:", stderr.decode('utf-8', errors='replace'))
cleanup_process_tree(p, start_new_session)
raise
def check_call(*popenargs, **kwargs):
print(f"Start subprocess with check_call({popenargs}, {kwargs})")
retcode = call(*popenargs, suppress_output_info=True, **kwargs)
if retcode:
cmd = kwargs.get("args")
if cmd is None:
cmd = popenargs[0]
raise subprocess.CalledProcessError(retcode, cmd)
return 0
def check_output(*popenargs, timeout=None, start_new_session=True, **kwargs):
print(f"Start subprocess with check_output({popenargs}, {kwargs})")
with Popen(*popenargs,
stdout=subprocess.PIPE,
start_new_session=start_new_session,
**kwargs) as process:
try:
stdout, stderr = process.communicate(None, timeout=timeout)
except subprocess.TimeoutExpired as exc:
cleanup_process_tree(process, start_new_session)
if is_windows():
exc.stdout, exc.stderr = process.communicate()
else:
process.wait()
raise
except:
cleanup_process_tree(process, start_new_session)
raise
retcode = process.poll()
if retcode:
if start_new_session:
cleanup_process_tree(process, True)
raise subprocess.CalledProcessError(retcode,
process.args,
output=stdout,
stderr=stderr)
return stdout.decode()
def make_clean_dirs(path):
"""
Make directories for @path, clean content if it already exists.
"""
import shutil
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path)
def print_info(message: str) -> None:
"""
Prints an informational message.
"""
print(f"[INFO] {message}")
sys.stdout.flush()
general_logger.info(message)
def print_warning(message: str) -> None:
"""
Prints a warning message.
"""
print(f"[WARNING] {message}")
sys.stdout.flush()
general_logger.warning(message)
def print_error(message: str) -> None:
"""
Prints an error message.
"""
print(f"[ERROR] {message}")
sys.stdout.flush()
general_logger.error(message)
# custom test checker
def check_call_negative_test(*popenargs, **kwargs):
print(
f"Start subprocess with check_call_negative_test({popenargs}, {kwargs})"
)
retcode = call(*popenargs, suppress_output_info=True, **kwargs)
if retcode:
return 0
else:
cmd = kwargs.get("args")
if cmd is None:
cmd = popenargs[0]
print(
f"Subprocess expected to fail with check_call_negative_test({popenargs}, {kwargs}), but passed."
)
raise subprocess.CalledProcessError(1, cmd)