mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
18 lines
507 B
Python
18 lines
507 B
Python
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.jit
|
|
def lookup_kernel(X, Y, Z, vocab_size, hidden_size, token_num):
|
|
pid = tl.program_id(axis=0)
|
|
while pid < token_num * hidden_size:
|
|
row_idx = pid // hidden_size
|
|
col_idx = pid % hidden_size
|
|
word_idx = tl.load(X + row_idx)
|
|
embedding = tl.load(
|
|
Y + word_idx * hidden_size + col_idx,
|
|
mask=word_idx < vocab_size,
|
|
)
|
|
tl.store(Z + pid, embedding)
|
|
pid += tl.num_programs(0)
|