TensorRT-LLMs/examples/openai_triton/plugin_autogen/kernel_config.py
Kaiyu Xie 250d9c293d
Update TensorRT-LLM Release branch (#1445)
* Update TensorRT-LLM

---------

Co-authored-by: Bhuvanesh Sridharan <bhuvan.sridharan@gmail.com>
Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
Co-authored-by: Eddie-Wang1120 <wangjinheng1120@163.com>
Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
2024-04-12 17:59:19 +08:00

56 lines
1.9 KiB
Python

import os
import torch
from tensorrt_llm.tools.plugin_gen.core import *
openai_triton_example_root = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "manual_plugin")
def get_fmha_kernel_meta_data():
block_size = 128
num_stages = 2 if torch.cuda.get_device_capability() >= (8, 0) else 1
return KernelMetaData(
kernel_name='fused_attention_kernel',
ios=[
# outputs
OutputArg('Out', Type('tensor[fp16]'), hints=['16', '16']),
OutputArg('L', Type('tensor[fp32]'), hints=['16', '16']),
OutputArg('M', Type('tensor[fp32]'), hints=['16', '16']),
# inputs
InputArg('Q', Type('tensor[fp16]'), hints=['16', '16']),
InputArg('K', Type('tensor[fp16]'), hints=['16', '16']),
InputArg('V', Type('tensor[fp16]'), hints=['16', '16']),
ParamArg('sm_scale', Type('fp32')),
DimSizeArg('batch_size'),
ParamArg('num_heads', Type('i32')),
DimSizeArg('seq_len', hints=['', '16']),
# constexprs
Constexpr(block_size),
Constexpr(64),
Constexpr(block_size),
],
shape_infer_rules=[
# The following rules helps to deduce the shapes of the output tensors
"Q[*] -> Out[*]",
"Q[m,n,k,*] -> L[m,n,k]",
"Q[m,n,k,*] -> M[m,n,k]",
# The following rules helps to deduce both DimSizeArgs: batch_size and seq_len
"Q[m,n,k,*] : m -> batch_size",
"Q[m,n,k,*] : k -> seq_len",
],
version=0,
kernel_file=f'{openai_triton_example_root}/fmha_triton.py',
num_warps=4,
num_stages=num_stages,
grid_dims=(f"(seq_len + {block_size-1}) / {block_size}",
"batch_size * num_heads", "1"))
KERNELS = [
get_fmha_kernel_meta_data(),
]