TensorRT-LLMs/tests/unittest/bindings/test_hostfunc.py
Emma Qiao a74ce266d3
[None][infra] Waive failed tests for release branch 11/07 (#9026)
Signed-off-by: qqiao <qqiao@nvidia.com>
2025-11-09 18:18:49 +08:00

34 lines
759 B
Python

import pytest
import torch
from tensorrt_llm._torch.hostfunc import HOSTFUNC_USER_DATA_HANDLES, hostfunc
@pytest.mark.skip(reason="https://nvbugs/5643631")
def test_hostfunc():
@hostfunc
def increase(x: torch.Tensor):
x.add_(1)
x = torch.zeros(10, dtype=torch.int32)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
for _ in range(5):
increase(x)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=stream):
increase(x)
increase(x)
torch.cuda.synchronize()
with torch.cuda.stream(stream):
for _ in range(10):
g.replay()
torch.cuda.synchronize()
assert (x == 25).all().item()
assert len(HOSTFUNC_USER_DATA_HANDLES) == 2