mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
44 lines
1.1 KiB
Python
44 lines
1.1 KiB
Python
import atexit
|
|
|
|
import torch
|
|
|
|
from ..bindings.internal import runtime as bindings
|
|
|
|
HOSTFUNC_USER_DATA_HANDLES = set()
|
|
|
|
|
|
def launch_hostfunc(hostfunc, *args, **kwargs):
|
|
stream = torch.cuda.current_stream()
|
|
is_capturing = torch.cuda.is_current_stream_capturing()
|
|
handle = bindings.launch_hostfunc(stream.cuda_stream, not is_capturing,
|
|
hostfunc, *args, **kwargs)
|
|
if is_capturing:
|
|
HOSTFUNC_USER_DATA_HANDLES.add(handle)
|
|
else:
|
|
assert handle is None
|
|
return handle
|
|
|
|
|
|
def hostfunc(hostfunc):
|
|
|
|
def wrapper(*args, **kwargs):
|
|
return launch_hostfunc(hostfunc, *args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def free_hostfunc_user_data(handle: int):
|
|
if handle not in HOSTFUNC_USER_DATA_HANDLES:
|
|
raise ValueError(f"Hostfunc user data handle {handle} not found.")
|
|
bindings.free_hostfunc_user_data(handle)
|
|
HOSTFUNC_USER_DATA_HANDLES.remove(handle)
|
|
|
|
|
|
def free_all_hostfunc_user_data():
|
|
for handle in HOSTFUNC_USER_DATA_HANDLES:
|
|
bindings.free_hostfunc_user_data(handle)
|
|
HOSTFUNC_USER_DATA_HANDLES.clear()
|
|
|
|
|
|
atexit.register(free_all_hostfunc_user_data)
|