mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
50 lines
1.7 KiB
Python
50 lines
1.7 KiB
Python
import os
|
|
|
|
from tensorrt_llm.tools.plugin_gen.core import *
|
|
|
|
openai_triton_example_root = os.path.join(
|
|
os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "..",
|
|
"examples", "openai_triton", "manual_plugin")
|
|
|
|
|
|
def get_fmha_kernel_meta_data():
|
|
return KernelMetaData(
|
|
kernel_name='fused_attention_kernel',
|
|
ios=[
|
|
# outputs
|
|
OutputArg('Out', Type('tensor[fp16]'), hints=['16', '16']),
|
|
OutputArg('L', Type('tensor[fp32]'), hints=['16', '16']),
|
|
OutputArg('M', Type('tensor[fp16]'), hints=['16', '16']),
|
|
# inputs
|
|
InputArg('Q', Type('tensor[fp16]'), hints=['16', '16']),
|
|
InputArg('K', Type('tensor[fp16]'), hints=['16', '16']),
|
|
InputArg('V', Type('tensor[fp16]'), hints=['16', '16']),
|
|
ParamArg('sm_scale', Type('fp32')),
|
|
DimSizeArg('batch_size'),
|
|
ParamArg('num_heads', Type('i32')),
|
|
DimSizeArg('seq_len', hints=['', '16']),
|
|
# constexprs
|
|
Constexpr(128),
|
|
Constexpr(64),
|
|
Constexpr(128),
|
|
],
|
|
shape_infer_rules=[
|
|
# The following rules helps to deduce the shapes of the output tensors
|
|
"Q[*] -> Out[*]",
|
|
"Q[m,n,k,*] -> L[m,n,k]",
|
|
"Q[m,n,k,*] -> M[m,n,k]",
|
|
|
|
# The following rules helps to deduce both DimSizeArgs: batch_size and seq_len
|
|
"Q[m,n,k,*] : m -> batch_size",
|
|
"Q[m,n,k,*] : k -> seq_len",
|
|
],
|
|
version=0,
|
|
kernel_file=f'{openai_triton_example_root}/fmha_triton.py',
|
|
num_warps=1,
|
|
grid_dims=("(seq_len + 127) / 128", "batch_size * num_heads", "1"))
|
|
|
|
|
|
KERNELS = [
|
|
get_fmha_kernel_meta_data(),
|
|
]
|