TensorRT-LLMs/tensorrt_llm/tools/plugin_gen/templates/functional.py.tpl
2023-09-28 09:00:05 -07:00

71 lines
2.3 KiB
Smarty

{% if add_header %}
# //header begin//
import ctypes
from collections import OrderedDict
from pathlib import Path
from typing import List
import numpy as np
import tensorrt as trt
from tensorrt_llm._common import default_trtnet
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.functional import Tensor, _create_tensor
from tensorrt_llm.module import Module
TRT_LLM_PLUGIN_NAMESPACE = 'tensorrt_llm'
def _load_triton_plugin_lib():
triton_plugin_dir = Path(__file__).parent.absolute()
plugin_lib = "[[ plugin_lib_path ]]"
handle = ctypes.CDLL(plugin_lib, mode=ctypes.RTLD_GLOBAL)
if handle is None:
raise ImportError('TensorRT-LLM Triton Plugin is unavailable')
handle.initLibNvInferPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
handle.initLibNvInferPlugins.restype = ctypes.c_bool
assert handle.initLibNvInferPlugins(
None, TRT_LLM_PLUGIN_NAMESPACE.encode('utf-8'))
_load_triton_plugin_lib()
# //header end//
{% endif %}
def [[ kernel_name ]]([[ arg_list ]]):
'''
Inputs:
{% for arg in metadata.get_params() -%}
- [[arg.name]]: [[arg.dtype.dtype.to('np')]]
{% endfor %}
{% for arg in metadata.get_inputs() -%}
- [[arg.name]]: {% if arg.is_tensor %}tensor<{%endif%}[[arg.dtype.dtype.to('np')]]>
{% endfor %}
Outputs:
{% for arg in metadata.get_outputs() -%}
- [[arg.name]]: {% if arg.is_tensor %}tensor<{%endif%}[[arg.dtype.dtype.to('np')]]>
{% endfor -%}
'''
plg_creator = trt.get_plugin_registry().get_plugin_creator(
'[[ plugin_name ]]', '[[ kernel_version ]]', TRT_LLM_PLUGIN_NAMESPACE)
assert plg_creator is not None
pfc = trt.PluginFieldCollection([
{% for arg in params -%}
{# input is a dict[name, np_type, trt_type ] #}
trt.PluginField("[[arg.name]]", np.array([ [[ arg.name ]] ], np.[[ arg.dtype.dtype.to('np') ]]),
trt.PluginFieldType.[[ arg.dtype.dtype.to('trt_plugin_py') ]]),
{% endfor %}
])
plugin = plg_creator.create_plugin("[[ plugin_name ]]", pfc)
plug_inputs = [ [[ input_list ]] ]
layer = default_trtnet().add_plugin_v2(plug_inputs, plugin)
return [
{% for id in range(num_outputs) %}
_create_tensor(layer.get_output([[ id ]]), layer),
{% endfor %}
]