TensorRT-LLMs/tests/unittest/bindings/test_hostfunc.py
Enwei Zhu 2ce785f39a
[https://nvbugs/5643631][fix] Fix hostfunc seg fault (#10028)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
2025-12-20 07:58:43 -05:00

33 lines
723 B
Python

import torch
from tensorrt_llm._torch.hostfunc import HOSTFUNC_USER_DATA_HANDLES, hostfunc
def test_hostfunc():
@hostfunc
def increase(x: torch.Tensor):
x.add_(1)
x = torch.zeros(10, dtype=torch.int32)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
for _ in range(5):
increase(x)
torch.cuda.synchronize()
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=stream):
increase(x)
increase(x)
torch.cuda.synchronize()
with torch.cuda.stream(stream):
for _ in range(10):
g.replay()
torch.cuda.synchronize()
assert (x == 25).all().item()
assert len(HOSTFUNC_USER_DATA_HANDLES) == 2