TensorRT-LLMs/tensorrt_llm/python_plugin.py
Dan Blanaru 16d2467ea8 Update TensorRT-LLM (#2755)
* Update TensorRT-LLM

---------

Co-authored-by: Denis Kayshev <topenkoff@gmail.com>
Co-authored-by: akhoroshev <arthoroshev@gmail.com>
Co-authored-by: Patrick Reiter Horn <patrick.horn@gmail.com>

Update
2025-02-11 03:01:00 +00:00

609 lines
22 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import pickle # nosec B403
import typing
from copy import deepcopy
from dataclasses import dataclass
from typing import Sequence, Type, Union
import numpy as np
import tensorrt as trt
import torch
from ._common import default_trtnet
from ._utils import (TensorWrapper, np_dtype_to_trt, str_dtype_to_trt,
torch_dtype_to_trt, trt_dtype_to_torch)
from .functional import Tensor, _create_tensor
from .plugin.plugin import TRT_LLM_PLUGIN_NAMESPACE
_plugin_registered = dict()
@dataclass(slots=True, frozen=True)
class PluginInfo:
trt_plugin_version: int
plugin_namespace: str
plugin_name: str
plugin_version: str
plugin_num_outputs: int
def __hash__(self):
return hash(
(self.plugin_name, self.plugin_namespace, self.plugin_version))
def __eq__(self, obj):
if not isinstance(obj, PluginInfo):
return False
return (self.plugin_name == obj.plugin_name
and self.plugin_namespace == obj.plugin_namespace
and self.plugin_version == obj.plugin_version)
def make_expr(
exprBuilder: Union[trt.IExprBuilder, Type[None]],
dim: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]],
) -> Union[trt.IDimensionExpr, Type[None]]:
"""
Parameters:
exprBuilder: Union[trt.IExprBuilder, Type[None]]
The trt.exprBuilder object. Using it to check whether dim has the same exprBuilder or to create trt.IDimensionExpr if necessary.
dim : Union["DimensionExpr", int, Type[None]]
The input dim
Returns:
A trt.IDimensionExpr object
"""
if isinstance(dim, DimensionExpr):
assert exprBuilder == dim.exprBuilder
return dim.expr
elif isinstance(dim, int):
return exprBuilder.constant(dim)
elif dim is None:
return None
elif isinstance(dim, trt.IDimensionExpr):
return dim
else:
raise Exception
def expr_operation(
a: Union[trt.IDimensionExpr, Type[None]],
b: Union[trt.IDimensionExpr, Type[None]],
operation: trt.DimensionOperation,
exprBuilder: trt.IExprBuilder,
):
"""
The function to do expr operation with None support
"""
if exprBuilder == None or a is None or b is None:
expr = None
else:
expr = exprBuilder.operation(operation, a, b)
return DimensionExpr(expr, exprBuilder)
class DimensionExpr:
'''
The class to wrap `trt.IDimensionExpr` to support more pythonic methods.
'''
def __init__(
self,
expr: Union[trt.IDimensionExpr, int, Type[None]],
exprBuilder: Union[trt.IExprBuilder, Type[None]],
):
self.exprBuilder = exprBuilder
self.expr = expr
@property
def expr(self):
return self._expr
@expr.setter
def expr(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
self._expr = make_expr(self.exprBuilder, expr)
def __add__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.SUM,
self.exprBuilder)
def __radd__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
return self.__add__(expr)
def __mul__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.PROD,
self.exprBuilder)
def __rmul__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
return self.__mul__(expr)
def __sub__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.SUB,
self.exprBuilder)
def __rsub__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(expr, self.expr, trt.DimensionOperation.SUB,
self.exprBuilder)
def __eq__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.EQUAL,
self.exprBuilder)
def __lt__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.LESS,
self.exprBuilder)
def __floordiv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.FLOOR_DIV,
self.exprBuilder)
def __rfloordiv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr,
int, Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(expr, self.expr, trt.DimensionOperation.FLOOR_DIV,
self.exprBuilder)
def __truediv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.CEIL_DIV,
self.exprBuilder)
def __rtruediv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(expr, self.expr, trt.DimensionOperation.CEIL_DIV,
self.exprBuilder)
def max(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.MAX,
self.exprBuilder)
def min(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.MIN,
self.exprBuilder)
class ShapeExpr:
'''
The class to Wrap `trt.DimsExprs` to support more pythonic methods.
'''
def __init__(
self,
dims: Union[Sequence[trt.IDimensionExpr], Sequence[int],
Sequence[type[None]]],
exprBuilder: Union[trt.IExprBuilder, type[None]],
):
self.exprBuilder = exprBuilder
self.dims = dims
@property
def dims(self):
return self._dims
@dims.setter
def dims(
self,
dims: Sequence[Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]],
):
if dims is not None:
self._dims = [
DimensionExpr(make_expr(self.exprBuilder, i), self.exprBuilder)
for i in dims
]
else:
self._dims = None
def __getitem__(self, index: int):
if self._dims is not None:
return self._dims[index]
else:
return DimensionExpr(None, self.exprBuilder)
def __setitem__(
self,
index: int,
value: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]],
):
if self._dims is None:
return
assert index < len(self._dims)
value = DimensionExpr(make_expr(self.exprBuilder, value),
self.exprBuilder)
self._dims[index] = value
def __len__(self):
if self._dims is None:
return 0
else:
return len(self._dims)
def to_trt(self) -> trt.DimsExprs:
return trt.DimsExprs([i.expr for i in self.dims])
class SymTensor:
'''
The class to represent symbolic tensors. Only contains dtype and shape information for users to write their own shape/dtype inference function.
'''
def __init__(
self,
dtype: Union[torch.dtype, np.dtype, str, trt.DataType, Type[None]],
shape: Union[ShapeExpr, Sequence[int]],
):
self.dtype = dtype
self.shape = shape
@property
def shape(self) -> Union[ShapeExpr, Sequence[int]]:
return self._shape
@shape.setter
def shape(self, shape: Union[ShapeExpr, Sequence[int]]):
assert isinstance(shape, (ShapeExpr, list, tuple))
if isinstance(shape, (list, tuple)):
for i in shape:
assert isinstance(i, int)
self._shape = shape
@property
def dtype(self) -> Union[trt.DataType, Type[None]]:
return self._dtype
@dtype.setter
def dtype(self, dtype: Union[torch.dtype, str, np.dtype, trt.DataType,
Type[None]]):
if isinstance(dtype, torch.dtype):
self._dtype = torch_dtype_to_trt(dtype)
elif isinstance(dtype, str):
self._dtype = str_dtype_to_trt(dtype)
elif isinstance(dtype, np.dtype):
self._dtype = np_dtype_to_trt(dtype)
elif isinstance(dtype, trt.DataType):
self._dtype = dtype
elif dtype is None:
self._dtype = None
else:
raise TypeError(f"Unsupported dtype: {dtype}")
def _convert_return_value_to_list(ret):
if not isinstance(ret, (list, tuple)):
return [ret]
assert isinstance(ret, (list, tuple))
return ret
class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
trt.IPluginV3OneRuntime):
'''
The base class of TRT-LLM plugin.
All TRT-LLM plugin should inherit this class and at least rewrite `forward` and `shape_dtype_inference` function. `forward` defines the plugin's compute flow while `shape_dtype_inference` defines how would the output tensor's shape and dtype be inferenced from the input tensor.
'''
_plugin_creator = None
_no_serialize_attr = {"_current_stream", "_workspace"}
def __init__(self):
cls = type(self)
# Runtime check for plugin decorator
assert (
cls._plugin_creator is not None
), "Please make sure the plugin is registered through `@trtllm_plugin`"
assert cls != PluginBase
trt.IPluginV3.__init__(self)
trt.IPluginV3OneCore.__init__(self)
trt.IPluginV3OneBuild.__init__(self)
trt.IPluginV3OneRuntime.__init__(self)
self.plugin_phase = trt.TensorRTPhase.BUILD
self.num_outputs = self._num_outputs
self.plugin_namespace = self._plugin_namespace
self.plugin_name = self._plugin_name
self.plugin_version = self._plugin_version
self.current_stream = -1
self.workspace = 0 # nullptr
@property
def current_stream(self):
if self._current_stream == -1:
return torch.cuda.current_stream().cuda_stream
else:
return self._current_stream
@current_stream.setter
def current_stream(self, stream: int):
assert isinstance(stream, int)
self._current_stream = stream
@property
def workspace(self) -> int:
buffer = self._workspace
return buffer if isinstance(buffer, int) else buffer.data_ptr()
@workspace.setter
def workspace(self, workspace: Union[int, torch.Tensor]):
assert isinstance(workspace, (int, torch.Tensor))
self._workspace = workspace
def clone(self):
cls = type(self)
cloned_plugin = cls.__new__(cls)
super(cls, cloned_plugin).__init__()
cloned_plugin.__dict__.update(self._get_dict_to_serialize())
return cloned_plugin
def get_capability_interface(self, type):
return self
def configure_plugin(self, input_desc, output_desc):
pass
def get_output_data_types(self, input_types):
ret = self.shape_dtype_inference(
[SymTensor(i, ShapeExpr(None, None)) for i in input_types])
ret = _convert_return_value_to_list(ret)
assert len(ret) == self.num_outputs
for i in ret:
assert isinstance(i, SymTensor)
return [i.dtype for i in ret]
def get_output_shapes(self, inputs, shape_inputs, exprBuilder):
assert len(
shape_inputs) == 0, "Currently we do not support shape inputs"
ret = self.shape_dtype_inference(
[SymTensor(None, ShapeExpr(i, exprBuilder)) for i in inputs])
ret = _convert_return_value_to_list(ret)
assert len(ret) == self.num_outputs
for i in ret:
assert isinstance(i, SymTensor)
return [i.shape.to_trt() for i in ret]
def supports_format_combination(self, pos, in_out, num_inputs):
"""
By default, TRT-LLM plugin supports all dtype and linear format. It is the users responsibility to check the dtype the plugin supported in `forward` function.
"""
assert pos < len(in_out)
desc = in_out[pos].desc
if desc.format != trt.TensorFormat.LINEAR:
return False
return True
def attach_to_context(self, context):
return self.clone()
def get_fields_to_serialize(self):
buffer = pickle.dumps(self._get_dict_to_serialize())
return trt.PluginFieldCollection([
trt.PluginField("__plugin_pickle_obj__", buffer,
trt.PluginFieldType.UNKNOWN)
])
def enqueue(self, input_desc, output_desc, inputs, outputs, workspace,
stream):
torch_stream = torch.cuda.ExternalStream(stream_ptr=stream)
self.workspace = workspace
self.current_stream = stream
with torch.cuda.stream(torch_stream):
self.forward(
tuple(
TensorWrapper.from_trt_desc(input_desc[i], inputs[i])
for i in range(len(input_desc))),
tuple(
TensorWrapper.from_trt_desc(output_desc[i], outputs[i])
for i in range(len(output_desc))),
)
self.current_stream = -1
def __call__(self, *args: Union[Sequence[TensorWrapper],
Sequence[torch.Tensor]]):
is_trtllm = True
for i in args:
is_trtllm &= isinstance(i, Tensor)
if not is_trtllm:
for i in args:
assert isinstance(
i, torch.Tensor
), "Plugin inputs must be `tensorrt_llm.Tensor`s or `torch.Tensor`s"
sym_tensors = self.shape_dtype_inference(
[SymTensor(i.dtype, [j for j in i.shape]) for i in args])
sym_tensors = _convert_return_value_to_list(sym_tensors)
ret = [
torch.empty(sym_tensor.shape,
dtype=trt_dtype_to_torch(sym_tensor.dtype))
for sym_tensor in sym_tensors
]
self.current_stream = torch.cuda.current_stream().cuda_stream
self.workspace = torch.empty(self.workspace)
self.forward(args, ret)
else:
args = [i.trt_tensor for i in args]
layer_plugin = default_trtnet().add_plugin_v3(args, [], self)
ret = [
_create_tensor(layer_plugin.get_output(i), layer_plugin)
for i in range(self.num_outputs)
]
if len(ret) == 1:
return ret[0]
return ret
def on_shape_change(self, input_desc, output_desc):
pass
def get_valid_tactics(self):
return []
def set_tactic(self, index):
if index != 0:
raise RuntimeError(
"By default TRT should not set tactics since PluginBase do not provide custom tactic."
)
def forward(self, inputs: Sequence[TensorWrapper],
outputs: Sequence[TensorWrapper]):
'''
Expect users to rewrite this function to define the compute flow. There are a few special attributes for users to get access to some resources.
`self.workspace`: The workspace address of TRT managed workspace.
`self.current_stream`: The CUDA stream this plugin is expected to execute on. By default `PluginBase` set the torch.cuda.current_stream() to this stream. This attribute is for the toolkit that doesn't work with torch's stream.
'''
raise NotImplementedError
def shape_dtype_inference(self, inputs: Sequence[SymTensor]):
'''
Expect users to rewrite this function to define the shape dtype inference for output tensors.
'''
raise NotImplementedError
def _get_dict_to_serialize(self):
ret = {}
for k, v in self.__dict__.items():
if k not in self._no_serialize_attr:
ret[k] = deepcopy(v) if self.deepcopy_clone else v
return ret
class PluginCreatorBase(trt.IPluginCreatorV3One):
def __init__(self):
super().__init__()
def create_plugin(self, name, fc, phase):
if len(fc) == 1 and fc[0].name == "__plugin_pickle_obj__":
data = fc[0].data
plugin_dict = pickle.loads(data) # nosec B301
plugin = self.plugin_cls.__new__(self.plugin_cls)
super(self.plugin_cls, plugin).__init__()
plugin.__dict__.update(plugin_dict)
else:
raise RuntimeError("Expect to be called by TRT")
plugin.plugin_phase = phase
return plugin
def trtllm_plugin(
plugin_name: str,
*,
plugin_version: str = "1",
plugin_namespace: str = TRT_LLM_PLUGIN_NAMESPACE,
plugin_num_outputs: Union[int, Type[None]] = None,
deepcopy_clone: bool = True,
no_serialize_attr: Sequence[str] = set(),
):
def plugin_registration(plugin_cls):
assert issubclass(plugin_cls, PluginBase)
assert hasattr(
plugin_cls,
"__dict__"), "Plugin wrapper uses `__dict__` to track plugin states"
nonlocal plugin_num_outputs
annotation = inspect.signature(
plugin_cls.shape_dtype_inference).return_annotation
origin_annotation = typing.get_origin(annotation)
if origin_annotation is tuple or annotation is SymTensor:
if origin_annotation is tuple:
element_types = typing.get_args(annotation)
for ty in element_types:
assert (
ty == SymTensor
), f"Plugin {plugin_name}'s `shape_dtype_inference` return annotation must be SymTensor or a tuple of SymTensor"
infered_num_outputs = len(element_types)
else:
infered_num_outputs = 1
if plugin_num_outputs is not None:
assert (
plugin_num_outputs == infered_num_outputs
), f"Plugin {plugin_name}'s `_num_outputs` and return annotation mismatch, {plugin_cls._num_outputs} != {infered_num_outputs}"
plugin_num_outputs = infered_num_outputs
else:
assert (
plugin_num_outputs is not None
), f"Must specify `num_outputs` or valid `shape_dtype_inference` return annotation for {plugin_name}. The valid types are SymTensor or a tuple of SymTensor, got {annotation}."
plugin_info = PluginInfo(3, plugin_namespace, plugin_name,
plugin_version, plugin_num_outputs)
assert (
plugin_info not in _plugin_registered
), f"Redefine plugin with info: {plugin_info} which is previously defined as {_plugin_registered[plugin_info]}"
_plugin_registered[plugin_info] = plugin_info
plugin_cls._plugin_name = plugin_name
plugin_cls._plugin_version = plugin_version
plugin_cls._plugin_namespace = plugin_namespace
plugin_cls._num_outputs = plugin_num_outputs
plugin_cls.deepcopy_clone = deepcopy_clone
plugin_cls._no_serialize_attr.update(no_serialize_attr)
plugin_registry = trt.get_plugin_registry()
plugin_creator = PluginCreatorBase()
plugin_creator.name = plugin_cls._plugin_name
plugin_creator.plugin_namespace = plugin_cls._plugin_namespace
plugin_creator.plugin_version = plugin_cls._plugin_version
plugin_creator.field_names = trt.PluginFieldCollection([])
plugin_creator.plugin_cls = plugin_cls
plugin_cls._plugin_creator = plugin_creator
ret = plugin_registry.register_creator(plugin_creator,
plugin_cls._plugin_namespace)
assert ret, f"Plugin: {plugin_cls} register failed, please check the error log."
return plugin_cls
return plugin_registration