mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
164 lines
4.8 KiB
Python
164 lines
4.8 KiB
Python
from typing import Sequence
|
|
|
|
import pytest
|
|
import torch
|
|
from plugin_wrapper_utils import DummyPlugin
|
|
from python_plugin.plugin_lib import LookUpPlugin
|
|
from utils.util import create_session, run_session
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm import PluginBase, Tensor
|
|
from tensorrt_llm._utils import (TensorWrapper, torch_dtype_to_str,
|
|
torch_dtype_to_trt)
|
|
from tensorrt_llm.python_plugin import SymTensor, trtllm_plugin
|
|
|
|
|
|
@pytest.fixture(scope="function", autouse=True)
|
|
def use_cuda_as_default_device():
|
|
old_level = tensorrt_llm.logger.level
|
|
old_device = torch.get_default_device()
|
|
tensorrt_llm.logger.set_level("verbose")
|
|
torch.set_default_device("cuda")
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
torch.set_default_device(old_device)
|
|
tensorrt_llm.logger.set_level(old_level)
|
|
|
|
|
|
@pytest.mark.parametrize("dtype",
|
|
[torch.float32, torch.bfloat16, torch.float16],
|
|
ids=torch_dtype_to_str)
|
|
@pytest.mark.parametrize("to_torch", [False, True])
|
|
@pytest.mark.parametrize("fp32_output", [False, True])
|
|
def test_triton_plugin(dtype, to_torch, fp32_output):
|
|
# meta data
|
|
batch_size = 10
|
|
vocab_size = 1000
|
|
n_embed = 1024
|
|
|
|
# test data
|
|
## input index
|
|
index_shape = (batch_size, )
|
|
index_data = torch.randint(0, vocab_size, index_shape, dtype=torch.int32)
|
|
weight_data = torch.rand(vocab_size, n_embed, dtype=dtype)
|
|
embedding = torch.nn.Embedding.from_pretrained(weight_data)
|
|
trt_plugin = LookUpPlugin(to_torch, fp32_output)
|
|
|
|
builder = tensorrt_llm.Builder()
|
|
builder.strongly_typed = True
|
|
network = builder.create_network()
|
|
with tensorrt_llm.net_guard(network):
|
|
x = Tensor(
|
|
name="x",
|
|
shape=index_shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt("int32"),
|
|
)
|
|
y = Tensor(name="y",
|
|
shape=(vocab_size, n_embed),
|
|
dtype=torch_dtype_to_trt(dtype))
|
|
|
|
def lookup(x, y):
|
|
lookup_plugin = LookUpPlugin(to_torch, fp32_output)
|
|
return lookup_plugin(x, y)
|
|
|
|
output = lookup(x, y)
|
|
|
|
output.mark_output(
|
|
"output",
|
|
torch_dtype_to_str(dtype if not fp32_output else torch.float32))
|
|
|
|
session = create_session(builder,
|
|
network,
|
|
precision=torch_dtype_to_str(dtype))
|
|
|
|
input_dict = {"x": index_data, "y": weight_data}
|
|
trt_out = run_session(session, input_dict)["output"]
|
|
trt_plugin_out = trt_plugin(index_data, weight_data)
|
|
torch_out = embedding(index_data)
|
|
|
|
if fp32_output:
|
|
torch_out = torch_out.to(torch.float32)
|
|
|
|
torch.testing.assert_close(trt_out, torch_out)
|
|
torch.testing.assert_close(trt_plugin_out, torch_out)
|
|
|
|
|
|
def test_redefinition():
|
|
|
|
class Plugin(PluginBase):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def shape_dtype_inference(self,
|
|
inputs: Sequence[SymTensor]) -> SymTensor:
|
|
return inputs[0]
|
|
|
|
def forward(self, inputs: Sequence[TensorWrapper],
|
|
outputs: Sequence[TensorWrapper]):
|
|
pass
|
|
|
|
trtllm_plugin("Plugin")(Plugin)
|
|
|
|
with pytest.raises(AssertionError):
|
|
trtllm_plugin("Plugin")(Plugin)
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
@trtllm_plugin("Plugin")
|
|
class PluginRedefine(PluginBase):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def shape_dtype_inference(
|
|
self,
|
|
inputs: Sequence[SymTensor]) -> tuple[SymTensor, SymTensor]:
|
|
return inputs[0], inputs[0]
|
|
|
|
def forward(self, inputs: Sequence[TensorWrapper],
|
|
outputs: Sequence[TensorWrapper]):
|
|
pass
|
|
|
|
|
|
def test_no_register():
|
|
|
|
class NoRegisterNoOutputNumPlugin(PluginBase):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def shape_dtype_inference(self,
|
|
inputs: Sequence[SymTensor]) -> SymTensor:
|
|
return inputs[0]
|
|
|
|
def forward(self, inputs: Sequence[TensorWrapper],
|
|
outputs: Sequence[TensorWrapper]):
|
|
pass
|
|
|
|
with pytest.raises(AssertionError):
|
|
NoRegisterNoOutputNumPlugin()
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
@trtllm_plugin("UtilsPlugin")
|
|
class NoOutputNumPlugin(PluginBase):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def shape_dtype_inference(self, inputs: Sequence[SymTensor]):
|
|
return inputs[0]
|
|
|
|
def forward(self, inputs: Sequence[TensorWrapper],
|
|
outputs: Sequence[TensorWrapper]):
|
|
pass
|
|
|
|
|
|
def test_single_creator():
|
|
a = DummyPlugin()
|
|
b = DummyPlugin()
|
|
assert a._plugin_creator is b._plugin_creator
|