TensorRT-LLMs/examples/python_plugin/plugin_lib/lookup_kernel.py
Kaiyu Xie f14d1d433c
Update TensorRT-LLM (#2389)
* Update TensorRT-LLM

---------

Co-authored-by: Alessio Netti <netti.alessio@gmail.com>
2024-10-29 22:24:38 +08:00

18 lines
507 B
Python

import triton
import triton.language as tl
@triton.jit
def lookup_kernel(X, Y, Z, vocab_size, hidden_size, token_num):
pid = tl.program_id(axis=0)
while pid < token_num * hidden_size:
row_idx = pid // hidden_size
col_idx = pid % hidden_size
word_idx = tl.load(X + row_idx)
embedding = tl.load(
Y + word_idx * hidden_size + col_idx,
mask=word_idx < vocab_size,
)
tl.store(Z + pid, embedding)
pid += tl.num_programs(0)