TensorRT-LLMs/tests/python_plugin/plugin_wrapper_utils.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

42 lines
1.1 KiB
Python

import os
import sys
from typing import Sequence
import triton
import triton.language as tl
from tensorrt_llm import PluginBase
from tensorrt_llm._utils import TensorWrapper
from tensorrt_llm.python_plugin import SymTensor, trtllm_plugin
sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir))
@trtllm_plugin("DummyPlugin")
class DummyPlugin(PluginBase):
def __init__(self):
super().__init__()
def shape_dtype_inference(self, inputs: Sequence[SymTensor]) -> SymTensor:
return inputs[0]
def forward(self, inputs: Sequence[TensorWrapper],
outputs: Sequence[TensorWrapper]):
pass
@triton.jit
def lookup_kernel(X, Y, Z, vocab_size, hidden_size, token_num):
pid = tl.program_id(axis=0)
while pid < token_num * hidden_size:
row_idx = pid // hidden_size
col_idx = pid % hidden_size
word_idx = tl.load(X + row_idx)
embedding = tl.load(
Y + word_idx * hidden_size + col_idx,
mask=word_idx < vocab_size,
)
tl.store(Z + pid, embedding)
pid += tl.num_programs(0)