TensorRT-LLMs/tests/unittest/python_plugin/test_plugin_wrapper.py
Kaiyu Xie 3aa6b11d13
Update TensorRT-LLM (#2936)
* Update TensorRT-LLM

---------

Co-authored-by: changcui <cuichang147@gmail.com>
2025-03-18 21:25:19 +08:00

172 lines
5.0 KiB
Python

import os
import sys
from typing import Sequence
import pytest
import torch
import tensorrt_llm
from tensorrt_llm import PluginBase, Tensor
from tensorrt_llm._utils import (TensorWrapper, torch_dtype_to_str,
torch_dtype_to_trt)
from tensorrt_llm.python_plugin import SymTensor, trtllm_plugin
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
from plugin_wrapper_utils import DummyPlugin
from utils.util import create_session, run_session
sys.path.append(
os.path.join(os.path.dirname(__file__), "..", "..", "..", "examples",
"python_plugin"))
from plugin_lib import LookUpPlugin
@pytest.fixture(scope="function", autouse=True)
def use_cuda_as_default_device():
old_level = tensorrt_llm.logger.level
old_device = torch.get_default_device()
tensorrt_llm.logger.set_level("verbose")
torch.set_default_device("cuda")
try:
yield
finally:
torch.set_default_device(old_device)
tensorrt_llm.logger.set_level(old_level)
@pytest.mark.parametrize("dtype",
[torch.float32, torch.bfloat16, torch.float16],
ids=torch_dtype_to_str)
@pytest.mark.parametrize("to_torch", [False, True])
@pytest.mark.parametrize("fp32_output", [False, True])
def test_triton_plugin(dtype, to_torch, fp32_output):
# meta data
batch_size = 10
vocab_size = 1000
n_embed = 1024
# test data
## input index
index_shape = (batch_size, )
index_data = torch.randint(0, vocab_size, index_shape, dtype=torch.int32)
weight_data = torch.rand(vocab_size, n_embed, dtype=dtype)
embedding = torch.nn.Embedding.from_pretrained(weight_data)
trt_plugin = LookUpPlugin(to_torch, fp32_output)
builder = tensorrt_llm.Builder()
builder.strongly_typed = True
network = builder.create_network()
with tensorrt_llm.net_guard(network):
x = Tensor(
name="x",
shape=index_shape,
dtype=tensorrt_llm.str_dtype_to_trt("int32"),
)
y = Tensor(name="y",
shape=(vocab_size, n_embed),
dtype=torch_dtype_to_trt(dtype))
def lookup(x, y):
lookup_plugin = LookUpPlugin(to_torch, fp32_output)
return lookup_plugin(x, y)
output = lookup(x, y)
output.mark_output(
"output",
torch_dtype_to_str(dtype if not fp32_output else torch.float32))
session = create_session(builder,
network,
precision=torch_dtype_to_str(dtype))
input_dict = {"x": index_data, "y": weight_data}
trt_out = run_session(session, input_dict)["output"]
trt_plugin_out = trt_plugin(index_data, weight_data)
torch_out = embedding(index_data)
if fp32_output:
torch_out = torch_out.to(torch.float32)
torch.testing.assert_close(trt_out, torch_out)
torch.testing.assert_close(trt_plugin_out, torch_out)
def test_redefinition():
class Plugin(PluginBase):
def __init__(self):
super().__init__()
def shape_dtype_inference(self,
inputs: Sequence[SymTensor]) -> SymTensor:
return inputs[0]
def forward(self, inputs: Sequence[TensorWrapper],
outputs: Sequence[TensorWrapper]):
pass
trtllm_plugin("Plugin")(Plugin)
with pytest.raises(AssertionError):
trtllm_plugin("Plugin")(Plugin)
with pytest.raises(AssertionError):
@trtllm_plugin("Plugin")
class PluginRedefine(PluginBase):
def __init__(self):
super().__init__()
def shape_dtype_inference(
self,
inputs: Sequence[SymTensor]) -> tuple[SymTensor, SymTensor]:
return inputs[0], inputs[0]
def forward(self, inputs: Sequence[TensorWrapper],
outputs: Sequence[TensorWrapper]):
pass
def test_no_register():
class NoRegisterNoOutputNumPlugin(PluginBase):
def __init__(self):
super().__init__()
def shape_dtype_inference(self,
inputs: Sequence[SymTensor]) -> SymTensor:
return inputs[0]
def forward(self, inputs: Sequence[TensorWrapper],
outputs: Sequence[TensorWrapper]):
pass
with pytest.raises(AssertionError):
NoRegisterNoOutputNumPlugin()
with pytest.raises(AssertionError):
@trtllm_plugin("UtilsPlugin")
class NoOutputNumPlugin(PluginBase):
def __init__(self):
super().__init__()
def shape_dtype_inference(self, inputs: Sequence[SymTensor]):
return inputs[0]
def forward(self, inputs: Sequence[TensorWrapper],
outputs: Sequence[TensorWrapper]):
pass
def test_single_creator():
a = DummyPlugin()
b = DummyPlugin()
assert a._plugin_creator is b._plugin_creator