TensorRT-LLMs/tensorrt_llm/_torch/hostfunc.py
Enwei Zhu 5ff3a65b23
[TRTLLM-7028][feat] Enable guided decoding with speculative decoding (part 2: one-model engine) (#6948)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
2025-09-03 15:16:11 -07:00

44 lines
1.1 KiB
Python

import atexit
import torch
from ..bindings.internal import runtime as bindings
HOSTFUNC_USER_DATA_HANDLES = set()
def launch_hostfunc(hostfunc, *args, **kwargs):
stream = torch.cuda.current_stream()
is_capturing = torch.cuda.is_current_stream_capturing()
handle = bindings.launch_hostfunc(stream.cuda_stream, not is_capturing,
hostfunc, *args, **kwargs)
if is_capturing:
HOSTFUNC_USER_DATA_HANDLES.add(handle)
else:
assert handle is None
return handle
def hostfunc(hostfunc):
def wrapper(*args, **kwargs):
return launch_hostfunc(hostfunc, *args, **kwargs)
return wrapper
def free_hostfunc_user_data(handle: int):
if handle not in HOSTFUNC_USER_DATA_HANDLES:
raise ValueError(f"Hostfunc user data handle {handle} not found.")
bindings.free_hostfunc_user_data(handle)
HOSTFUNC_USER_DATA_HANDLES.remove(handle)
def free_all_hostfunc_user_data():
for handle in HOSTFUNC_USER_DATA_HANDLES:
bindings.free_hostfunc_user_data(handle)
HOSTFUNC_USER_DATA_HANDLES.clear()
atexit.register(free_all_hostfunc_user_data)