mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> Co-authored-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
6716 lines
240 KiB
Python
6716 lines
240 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2020-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import subprocess
|
|
from collections import namedtuple
|
|
from enum import IntEnum
|
|
from itertools import product
|
|
|
|
sm2name = {
|
|
70: 'volta',
|
|
72: 'volta',
|
|
75: 'turing',
|
|
80: 'ampere',
|
|
86: 'ampere',
|
|
87: 'ampere',
|
|
89: 'ada',
|
|
90: 'hopper',
|
|
120: 'blackwell',
|
|
}
|
|
|
|
dtype2traits = {
|
|
'int8': 'imma_int8_int32_traits',
|
|
'fp16': 'hmma_fp16_traits',
|
|
'fp16_fp32': 'hmma_fp32_traits',
|
|
'bf16': 'hmma_bf16_traits',
|
|
'e4m3': 'qmma_e4m3_fp32_traits',
|
|
'e4m3_fp32': 'qmma_e4m3_fp32_traits',
|
|
'e4m3_fp16': 'qmma_e4m3_fp16_traits'
|
|
}
|
|
|
|
dtype2OutputType = {
|
|
'int8': 'int8_t',
|
|
'fp16': 'fp16_t',
|
|
'fp16_fp32': 'fp16_t',
|
|
'bf16': 'bf16_t',
|
|
'e4m3': 'e4m3_t',
|
|
'e4m3_fp32': 'e4m3_t',
|
|
'e4m3_fp16': 'e4m3_t',
|
|
}
|
|
|
|
dtype2bytes = {
|
|
'int8': 1,
|
|
'fp16': 2,
|
|
'fp16_fp32': 2,
|
|
'bf16': 2,
|
|
'e4m3': 1,
|
|
'e4m3_fp32': 1,
|
|
'e4m3_fp16': 1
|
|
}
|
|
|
|
# TODO merge with above?
|
|
hopper_dtype2traits = {
|
|
'int8': 'igmma_int8_int32_traits',
|
|
'fp16': 'hgmma_fp16_traits',
|
|
'fp16_fp32': 'hgmma_fp32_traits',
|
|
'bf16': 'hgmma_bf16_traits',
|
|
'e4m3': 'qgmma_e4m3_fp32_traits',
|
|
'e4m3_fp32': 'qgmma_e4m3_fp32_traits',
|
|
}
|
|
|
|
# The minimal instruction shapes per warp group.
|
|
# TODO should this not be known to the trait itself?
|
|
hopper_traits2shape = {
|
|
'Hopper_igmma_int8_int32_traits': (64, 8, 32),
|
|
'Hopper_hgmma_fp16_traits': (64, 8, 16),
|
|
'Hopper_hgmma_fp32_traits': (64, 8, 16),
|
|
'Hopper_hgmma_bf16_traits': (64, 8, 16),
|
|
'Hopper_qgmma_e4m3_fp32_traits': (64, 8, 32),
|
|
}
|
|
|
|
dtype2typename = {
|
|
'int8': 'DATA_TYPE_INT8',
|
|
'fp16': 'DATA_TYPE_FP16',
|
|
'fp16_fp32': 'DATA_TYPE_FP16',
|
|
'bf16': 'DATA_TYPE_BF16',
|
|
'e4m3': 'DATA_TYPE_E4M3',
|
|
'e4m3_fp16': 'DATA_TYPE_E4M3',
|
|
'e4m3_fp32': 'DATA_TYPE_E4M3',
|
|
}
|
|
|
|
pythonBoolean2cpp = {True: 'true', False: 'false'}
|
|
|
|
|
|
# same definition as fused_multihead_attention.h.
|
|
class AttentionMaskType(IntEnum):
|
|
PADDING = 0
|
|
CAUSAL = 1
|
|
SLIDING_OR_CHUNKED_CAUSAL = 2
|
|
CUSTOM_MASK = 3
|
|
|
|
|
|
class InputLayout(IntEnum):
|
|
PACKED_QKV = 0
|
|
CONTIGUOUS_Q_KV = 1
|
|
Q_PAGED_KV = 2
|
|
SEPARATE_Q_K_V = 3
|
|
|
|
|
|
spec_fields = (
|
|
'sm',
|
|
'dtype',
|
|
'seq_len',
|
|
'head_size',
|
|
'warps_m',
|
|
'warps_n',
|
|
'version',
|
|
'interleaved',
|
|
'ldgsts_q',
|
|
'ldgsts_k',
|
|
'ldgsts_v',
|
|
'share_smem_k_v',
|
|
'loop_step',
|
|
'has_noloop',
|
|
'noloop_step',
|
|
'unroll_threshold',
|
|
'has_scale_max',
|
|
'ctas_per_head',
|
|
'sm_mma',
|
|
'head_interleaved',
|
|
# new added fields (only used by flash attention implementation)
|
|
'flash_attention',
|
|
'kv_loop_step',
|
|
'flash_attention_bh_upper_threshold', # to deprecate; not actively used
|
|
'limit_qk_fragments',
|
|
'limit_v_fragments',
|
|
'tiled',
|
|
# fields for warp specialized kernel
|
|
'warp_specialization',
|
|
'q_tile_buffers',
|
|
'kv_tile_buffers',
|
|
'scheduling_mode',
|
|
# attention qkv input layout.
|
|
'input_layout',
|
|
# fused MHCA.
|
|
'cross_mha',
|
|
# other features
|
|
'alibi',
|
|
'enable_attn_logit_softcapping',
|
|
'return_softmax_stats',
|
|
'disabled_mask_types',
|
|
'head_size_v',
|
|
'sage_block_sizes',
|
|
'output_dtype',
|
|
'is_mtp',
|
|
'enable_skip_softmax',
|
|
)
|
|
kernel_spec = namedtuple('kernel_spec', spec_fields)
|
|
kernel_spec.__new__.__defaults__ = (
|
|
1, # ctas_per_head
|
|
1, # sm_mma
|
|
True, # head_interleaved
|
|
False, # flash_attention
|
|
64, # kv_loop_step
|
|
-1, # flash_attention_bh_upper_threshold
|
|
False, # limit_qk_fragments
|
|
False, # limit_v_fragments
|
|
0, # tiled
|
|
False, # warp_specialization
|
|
1, # q_tile_buffers
|
|
1, # kv_tile_buffers
|
|
0, # scheduling_mode
|
|
InputLayout.PACKED_QKV,
|
|
0, # cross_mha
|
|
True, # alibi
|
|
False, # enable_attn_logit_softcapping
|
|
False, # return_softmax_stats
|
|
None, # disabled_mask_types
|
|
0, # head size of V
|
|
None, # sage_block_sizes
|
|
None, # output_dtype, same as dtype by default.
|
|
False, # use MTP or not
|
|
False, # enable skip softmax
|
|
)
|
|
|
|
generate_cu_trtllm = os.environ.get('GENERATE_CU_TRTLLM',
|
|
'False').lower() == 'true'
|
|
|
|
ns_open = r"""
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace kernels
|
|
{
|
|
// clang-format off
|
|
""" if generate_cu_trtllm else ""
|
|
|
|
ns_close = r"""
|
|
// clang-format on
|
|
} // namespace kernels
|
|
} // namespace tensorrt_llm
|
|
""" if generate_cu_trtllm else ""
|
|
|
|
copyright = '''\
|
|
/***************************************************************************************************
|
|
* SPDX-FileCopyrightText: Copyright (c) 2011-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
**************************************************************************************************/
|
|
'''
|
|
|
|
makefile_template = '''\
|
|
|
|
# The combination of supported gencodes.
|
|
GENCODES = $(GENCODE_SM70)
|
|
GENCODES += $(GENCODE_SM72)
|
|
GENCODES += $(GENCODE_SM75)
|
|
GENCODES += $(GENCODE_SM80)
|
|
GENCODES += $(GENCODE_SM86)
|
|
GENCODES += $(GENCODE_SM87)
|
|
GENCODES += $(GENCODE_SM89)
|
|
GENCODES += $(GENCODE_SM90)
|
|
GENCODES += $(GENCODE_SM100)
|
|
GENCODES += $(GENCODE_SM120)
|
|
|
|
OBJECTS_MHA = obj/fused_multihead_attention.cpp.o
|
|
OBJECTS_MHCA = obj/fused_multihead_cross_attention.cpp.o
|
|
|
|
{objects}
|
|
|
|
{cubins}
|
|
|
|
SOFTMAX_SRC = $(wildcard src/softmax*.cu)
|
|
SOFTMAX_OBJ = $(patsubst src/softmax%.cu, obj/softmax%.cu.o, $(SOFTMAX_SRC))
|
|
OBJECTS_MHA += $(SOFTMAX_OBJ)
|
|
OBJECTS_MHA += obj/convert.cu.o
|
|
OBJECTS_MHCA += $(SOFTMAX_OBJ)
|
|
OBJECTS_MHCA += obj/convert.cu.o
|
|
'''
|
|
|
|
|
|
def get_makefile_code(specs_names):
|
|
objects = '\n'.join([
|
|
'OBJECTS_MHA += obj/{}.o'.format(fname)
|
|
for kspec, fname, lname, kname in specs_names
|
|
])
|
|
objects = objects + '\n' + '\n'.join([
|
|
'OBJECTS_MHCA += obj/{}.o'.format(fname)
|
|
for kspec, fname, lname, kname in specs_names
|
|
])
|
|
|
|
cubins = '\n'.join([
|
|
'CUBINS += cubin/{}.cubin'.format(fname)
|
|
for kspec, fname, lname, kname in specs_names
|
|
])
|
|
return makefile_template.format(objects=objects,
|
|
cubins=cubins,
|
|
copyright=copyright)
|
|
|
|
|
|
MAX_STGS_PER_LOOP = 4
|
|
|
|
kernel_template = '''\
|
|
{copyright}
|
|
|
|
//We can disable the FADD trick for archs with F2IP
|
|
#if {disable_fadd_trick} // disable_fadd_trick
|
|
#ifdef USE_I2F_EMULATION_TRICK
|
|
#undef USE_I2F_EMULATION_TRICK
|
|
#endif // USE_I2F_EMULATION_TRICK
|
|
|
|
#ifdef USE_F2I_EMULATION_TRICK
|
|
#undef USE_F2I_EMULATION_TRICK
|
|
#endif // USE_F2I_EMULATION_TRICK
|
|
#endif // disable_fadd_trick
|
|
|
|
#include <cuda.h>
|
|
#include <stdexcept>
|
|
|
|
#if CUDA_VERSION >= {min_cuda_version}
|
|
|
|
|
|
#if !{use_multi_cta} // !use_multi_cta
|
|
#include <fused_multihead_attention_kernel_{kernel_variant}.h>
|
|
#endif // !use_multi_cta
|
|
|
|
#if !{use_multi_cta} && {has_noloop} // !use_multi_cta && has_noloop
|
|
#include <fused_multihead_attention_kernel_1xN_noloop.h>
|
|
#endif // !use_multi_cta && has_noloop
|
|
|
|
#if {cross_mha} // cross_mha
|
|
#if {has_noloop} // has_noloop
|
|
#include <fused_multihead_cross_attention_kernel_1xN_noloop.h>
|
|
#endif // has_noloop
|
|
#include <fused_multihead_cross_attention_kernel_1xN.h>
|
|
#endif // cross_mha
|
|
|
|
#if {use_multi_cta} // use_multi_cta
|
|
#include <fused_multihead_attention_kernel_1xN_multi_cta.h>
|
|
#endif
|
|
|
|
using Attention_mask_type = fmha::Attention_mask_type;
|
|
using Launch_params = bert::Fused_multihead_attention_launch_params;
|
|
|
|
#if !{cross_mha} // !cross_mha
|
|
using Kernel_traits = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{seq_len},
|
|
{head_size},
|
|
{loop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags}>;
|
|
|
|
using Kernel_traits_causal = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{seq_len},
|
|
{head_size},
|
|
{loop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags},
|
|
/*causal mask*/ 3>;
|
|
#endif // Not cross attention
|
|
|
|
#if !{use_multi_cta} && !{cross_mha} // !use_multi_cta && !cross_mha
|
|
|
|
extern "C"
|
|
__global__
|
|
void {kernel_name}({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}<Kernel_traits>(params);
|
|
}}
|
|
|
|
extern "C"
|
|
__global__
|
|
void {causal_kernel_name}({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}<Kernel_traits_causal>(params);
|
|
}}
|
|
|
|
void {launcher_name}(
|
|
const {params_type} ¶ms,
|
|
const Launch_params &launch_params,
|
|
cudaStream_t stream){{
|
|
|
|
constexpr int smem_size = Kernel_traits::BYTES_PER_SMEM;
|
|
if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) {{
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({causal_kernel_name},
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
dim3 grid(params.h, params.b);
|
|
{causal_kernel_name}<<<grid, Kernel_traits_causal::THREADS, Kernel_traits_causal::BYTES_PER_SMEM, stream>>>(params);
|
|
}} else {{
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({kernel_name},
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
dim3 grid(params.h, params.b);
|
|
{kernel_name}<<<grid, Kernel_traits::THREADS, Kernel_traits::BYTES_PER_SMEM, stream>>>(params);
|
|
}}
|
|
}}
|
|
|
|
#endif // !use_multi_cta && !cross_mha
|
|
|
|
#if !{use_multi_cta} && {has_noloop} && !{cross_mha} // !use_multi_cta && has_noloop && !cross_mha
|
|
|
|
using Kernel_traits_nl = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{seq_len},
|
|
{head_size},
|
|
{noloop_step},
|
|
1,
|
|
{warps_m} * {warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags} | 0x200 /* no_loop flag */ >;
|
|
|
|
static_assert(Kernel_traits_nl::CTAS_PER_HEAD == 1, "");
|
|
|
|
using Kernel_traits_nl_causal = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{seq_len},
|
|
{head_size},
|
|
{noloop_step},
|
|
1,
|
|
{warps_m} * {warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags} | 0x200 /* no_loop flag */,
|
|
/*causal mask*/ 3>;
|
|
|
|
static_assert(Kernel_traits_nl_causal::CTAS_PER_HEAD == 1, "");
|
|
static_assert(Kernel_traits_nl_causal::MASK_VERSION == 3, "");
|
|
|
|
extern "C"
|
|
__global__
|
|
void {kernel_name}_nl({params_type} params){{
|
|
fused_multihead_attention::device_1xN_nl<Kernel_traits_nl>(params);
|
|
}}
|
|
|
|
extern "C"
|
|
__global__
|
|
void {causal_kernel_name}_nl({params_type} params){{
|
|
fused_multihead_attention::device_1xN_nl<Kernel_traits_nl_causal>(params);
|
|
}}
|
|
|
|
void {launcher_name}_nl(
|
|
const {params_type} ¶ms,
|
|
const Launch_params &launch_params,
|
|
cudaStream_t stream){{
|
|
|
|
|
|
constexpr int loop_iters = ({seq_len} + {noloop_step}-1) / {noloop_step};
|
|
static_assert(loop_iters * {noloop_step} >= {seq_len}, "");
|
|
dim3 grid(params.h, params.b, loop_iters);
|
|
if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) {{
|
|
constexpr int smem_size = Kernel_traits_nl_causal::BYTES_PER_SMEM;
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({causal_kernel_name}_nl,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{causal_kernel_name}_nl<<<grid, Kernel_traits_nl_causal::THREADS, Kernel_traits_nl_causal::BYTES_PER_SMEM, stream>>>(params);
|
|
}} else {{
|
|
constexpr int smem_size = Kernel_traits_nl::BYTES_PER_SMEM;
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({kernel_name}_nl,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{kernel_name}_nl<<<grid, Kernel_traits_nl::THREADS, Kernel_traits_nl::BYTES_PER_SMEM, stream>>>(params);
|
|
}}
|
|
}}
|
|
|
|
#endif // !use_multi_cta && has_noloop && !cross_mha
|
|
|
|
#if {cross_mha} // cross_mha
|
|
#if !{use_multi_cta} && {has_noloop} // !use_multi_cta && has_noloop
|
|
|
|
using Kernel_traits_nl = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{seq_len},
|
|
{head_size},
|
|
{noloop_step},
|
|
1,
|
|
{warps_m} * {warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags}>;
|
|
|
|
static_assert(Kernel_traits_nl::CTAS_PER_HEAD == 1, "");
|
|
|
|
extern "C"
|
|
__global__
|
|
void {kernel_name}_nl({params_type} params){{
|
|
fused_multihead_attention::device_mhca_1xN_nl<Kernel_traits_nl>(params);
|
|
}}
|
|
|
|
void {launcher_name}_nl(
|
|
const {params_type} ¶ms,
|
|
// const Launch_params &launch_params, // TODO
|
|
cudaStream_t stream){{
|
|
|
|
constexpr int smem_size = Kernel_traits_nl::BYTES_PER_SMEM;
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({kernel_name}_nl,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
const int loop_iters = (params.s_q + {noloop_step}-1) / {noloop_step};
|
|
// if (loop_iters * {noloop_step} != params.s_q) {{
|
|
// throw std::runtime_error("Incorrect seq len -- loop_iters * noloop_step != params.s_q");
|
|
// }}
|
|
assert(loop_iters * {noloop_step} >= params.s_q);
|
|
dim3 grid(params.h, params.b, loop_iters);
|
|
{kernel_name}_nl<<<grid, Kernel_traits_nl::THREADS, Kernel_traits_nl::BYTES_PER_SMEM, stream>>>(params);
|
|
}}
|
|
|
|
#endif // !use_multi_cta && has_noloop
|
|
|
|
#if !{use_multi_cta} // !use_multi_cta
|
|
|
|
using Kernel_traits = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{seq_len},
|
|
{head_size},
|
|
{loop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags}>;
|
|
|
|
extern "C"
|
|
__global__
|
|
void {kernel_name}({params_type} params){{
|
|
fused_multihead_attention::device_mhca_1xN<Kernel_traits>(params);
|
|
}}
|
|
|
|
void {launcher_name}(
|
|
const {params_type} ¶ms,
|
|
// const Launch_params &launch_params, // TODO
|
|
cudaStream_t stream){{
|
|
|
|
constexpr int smem_size = Kernel_traits::BYTES_PER_SMEM;
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({kernel_name},
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
dim3 grid(params.h, params.b);
|
|
{kernel_name}<<<grid, Kernel_traits::THREADS, Kernel_traits::BYTES_PER_SMEM, stream>>>(params);
|
|
}}
|
|
|
|
#endif // !use_multi_cta
|
|
|
|
#endif // cross_mha
|
|
|
|
#if {use_multi_cta} // use_multi_cta
|
|
|
|
// If that assert gets triggered - increase the value of MAX_STGS_PER_LOOP in "setup.py".
|
|
static_assert(Kernel_traits::Gmem_tile_o::STGS_PER_LOOP <= {MAX_STGS_PER_LOOP}, "");
|
|
|
|
extern "C"
|
|
__global__
|
|
void {kernel_name}({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_multi_cta<Kernel_traits>(params);
|
|
}}
|
|
|
|
extern "C"
|
|
__global__
|
|
void {causal_kernel_name}({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_multi_cta<Kernel_traits_causal>(params);
|
|
}}
|
|
|
|
void {launcher_name}(
|
|
const {params_type} ¶ms,
|
|
const Launch_params &launch_params,
|
|
cudaStream_t stream){{
|
|
|
|
assert(params.heads_per_wave != 0 && \"Heads per wave is not set, but multi cta is requested\");
|
|
|
|
// Clear the barriers and locks.
|
|
cudaMemsetAsync(params.counters, 0, 3*params.heads_per_wave*sizeof(int), stream);
|
|
|
|
// We may use more than 48kB of shared memory.
|
|
if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) {{
|
|
constexpr int smem_size = Kernel_traits_causal::BYTES_PER_SMEM;
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({causal_kernel_name},
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
// Launch one wave.
|
|
dim3 grid(Kernel_traits_causal::CTAS_PER_HEAD, params.heads_per_wave), block(Kernel_traits_causal::THREADS);
|
|
void *params_ = (void*) ¶ms;
|
|
FMHA_CHECK_CUDA(cudaLaunchCooperativeKernel((void*) &{causal_kernel_name}, grid, block, (void**) ¶ms_, smem_size, stream));
|
|
}} else {{
|
|
constexpr size_t smem_size = Kernel_traits::BYTES_PER_SMEM;
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({kernel_name},
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
// Launch one wave.
|
|
dim3 grid(Kernel_traits::CTAS_PER_HEAD, params.heads_per_wave), block(Kernel_traits::THREADS);
|
|
void *params_ = (void*) ¶ms;
|
|
FMHA_CHECK_CUDA(cudaLaunchCooperativeKernel((void*) &{kernel_name}, grid, block, (void**) ¶ms_, smem_size, stream));
|
|
}}
|
|
}}
|
|
|
|
#endif // use_multi_cta
|
|
|
|
void {launcher_name}_get_max_heads_per_wave(int *heads_per_wave) {{
|
|
#if {use_multi_cta} // use_multi_cta
|
|
// Determine the number of SMs and CTAs.
|
|
int dev;
|
|
cudaGetDevice(&dev);
|
|
cudaDeviceProp props;
|
|
FMHA_CHECK_CUDA(cudaGetDeviceProperties(&props, dev));
|
|
|
|
// The number of CTAs per SM.
|
|
constexpr size_t smem_size = Kernel_traits::BYTES_PER_SMEM;
|
|
int ctas_per_sm;
|
|
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm,
|
|
&{kernel_name},
|
|
Kernel_traits::THREADS,
|
|
smem_size));
|
|
|
|
// The number of heads per wave.
|
|
*heads_per_wave = props.multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_HEAD;
|
|
#else // use_multi_cta
|
|
*heads_per_wave = 0;
|
|
#endif // use_multi_cta
|
|
}}
|
|
|
|
#else // CUDA_VERSION >= {min_cuda_version}
|
|
|
|
void {launcher_name}(
|
|
const {params_type} ¶ms,
|
|
const Launch_params &launch_params,
|
|
cudaStream_t stream){{
|
|
assert(false && "Unsupported CUDA version");
|
|
}}
|
|
|
|
#if {has_noloop} // has_noloop
|
|
|
|
void {launcher_name}_nl(
|
|
const {params_type} ¶ms,
|
|
const Launch_params &launch_params,
|
|
cudaStream_t stream){{
|
|
assert(false && "Unsupported CUDA version");
|
|
}}
|
|
|
|
#endif // has_noloop
|
|
|
|
#endif // CUDA_VERSION >= {min_cuda_version}
|
|
'''
|
|
|
|
flash_attention_kernel_template = '''\
|
|
{copyright}
|
|
|
|
//We can disable the FADD trick for archs with F2IP
|
|
#if {disable_fadd_trick} // disable_fadd_trick
|
|
#ifdef USE_I2F_EMULATION_TRICK
|
|
#undef USE_I2F_EMULATION_TRICK
|
|
#endif // USE_I2F_EMULATION_TRICK
|
|
|
|
#ifdef USE_F2I_EMULATION_TRICK
|
|
#undef USE_F2I_EMULATION_TRICK
|
|
#endif // USE_F2I_EMULATION_TRICK
|
|
#endif // disable_fadd_trick
|
|
|
|
#include <cuda.h>
|
|
|
|
#if CUDA_VERSION >= {min_cuda_version}
|
|
|
|
#include <fused_multihead_flash_attention_kernel_noloop.h>
|
|
#include <fused_multihead_flash_attention_kernel_noloop_tiled.h>
|
|
#include <fused_multihead_flash_attention_kernel.h>
|
|
|
|
{include_str}
|
|
|
|
{local_ns_open}
|
|
{bert_launch_params}
|
|
{attn_mask_type_str}
|
|
|
|
#if 0 // has_noloop (unconditionally disabled since not maintained & not actively used)
|
|
using Kernel_traits = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{loop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags}>;
|
|
|
|
extern "C"
|
|
__global__
|
|
void {kernel_name}({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}<Kernel_traits>(params);
|
|
}}
|
|
|
|
void {launcher_name}(
|
|
const {params_type} ¶ms,
|
|
const Launch_params &launch_params,
|
|
cudaStream_t stream){{
|
|
|
|
constexpr int smem_size = Kernel_traits::BYTES_PER_SMEM;
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({kernel_name},
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
dim3 grid(params.h, params.b);
|
|
{kernel_name}<<<grid, Kernel_traits::THREADS, Kernel_traits::BYTES_PER_SMEM, stream>>>(params);
|
|
}}
|
|
|
|
#endif // has_noloop
|
|
|
|
#if {has_noloop} && !{tiled} // has_noloop && !tiled
|
|
using Kernel_traits_nl = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{noloop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags} | 0x200 /* no_loop flag */,
|
|
/*dense mask*/ 2,
|
|
/*bmm2_fp16_epilogue*/ true,
|
|
{output_dtype_},
|
|
{sage_block_size_q},
|
|
{sage_block_size_k},
|
|
{sage_block_size_v}>;
|
|
|
|
using Kernel_traits_nl_causal = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{noloop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags} | 0x200 /* no_loop flag */,
|
|
/*causal mask*/ 3,
|
|
/*bmm2_fp16_epilogue*/ true,
|
|
{output_dtype_}>;
|
|
|
|
using Kernel_traits_nl_sliding_or_chunked_causal = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{noloop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags} | 0x200 /* no_loop flag */,
|
|
/*sliding window causal mask*/ 4,
|
|
/*bmm2_fp16_epilogue*/ true,
|
|
{output_dtype_}>;
|
|
|
|
using Kernel_traits_nl_custom_mask = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{noloop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags} | 0x200 /* no_loop flag */,
|
|
/*custom mask*/ 5,
|
|
/*bmm2_fp16_epilogue*/ true,
|
|
{output_dtype_}>;
|
|
|
|
#if {padding_mask} // padding_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {kernel_name}_nl({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_nl<Kernel_traits_nl>(params);
|
|
}}
|
|
|
|
#endif // padding mask
|
|
|
|
#if {causal_mask} // causal_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {causal_kernel_name}_nl({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_nl<Kernel_traits_nl_causal>(params);
|
|
}}
|
|
|
|
#endif // causal mask
|
|
|
|
#if {sliding_or_chunked_causal_mask} // sliding_or_chunked_causal_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {sliding_or_chunked_causal_kernel_name}_nl({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_nl<Kernel_traits_nl_sliding_or_chunked_causal>(params);
|
|
}}
|
|
|
|
#endif // sliding_or_chunked_causal_mask
|
|
|
|
#if {custom_mask} // custom_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {custom_mask_kernel_name}_nl({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_nl<Kernel_traits_nl_custom_mask>(params);
|
|
}}
|
|
|
|
#endif // custom_mask
|
|
|
|
void {launcher_name}_nl(
|
|
{const_fused_multihead_attention_params_v2_str} ¶ms,
|
|
const Launch_params &launch_params,
|
|
cudaStream_t stream){{
|
|
|
|
// runtime q_loop_iters
|
|
int loop_iters = ( params.s + {noloop_step} - 1 ) / {noloop_step};
|
|
// dim3 grid(params.h, params.b, loop_iters);
|
|
dim3 grid(loop_iters, params.h, params.b); // better locality
|
|
constexpr int smem_size = Kernel_traits_nl::BYTES_PER_SMEM;
|
|
if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) {{
|
|
#if {causal_mask} // causal_mask
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({causal_kernel_name}_nl,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{causal_kernel_name}_nl<<<grid, Kernel_traits_nl::THREADS, Kernel_traits_nl::BYTES_PER_SMEM, stream>>>({params_str});
|
|
#endif // causal mask
|
|
}} else if( launch_params.attention_mask_type == Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL ) {{
|
|
#if {sliding_or_chunked_causal_mask} // sliding_or_chunked_causal_mask
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({sliding_or_chunked_causal_kernel_name}_nl,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{sliding_or_chunked_causal_kernel_name}_nl<<<grid, Kernel_traits_nl::THREADS, Kernel_traits_nl::BYTES_PER_SMEM, stream>>>({params_str});
|
|
#endif // sliding_or_chunked_causal_mask
|
|
}} else if( launch_params.attention_mask_type == Attention_mask_type::PADDING ) {{
|
|
#if {padding_mask} // padding_mask
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({kernel_name}_nl,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{kernel_name}_nl<<<grid, Kernel_traits_nl::THREADS, Kernel_traits_nl::BYTES_PER_SMEM, stream>>>({params_str});
|
|
#endif // padding_mask
|
|
}} else if( launch_params.attention_mask_type == Attention_mask_type::CUSTOM_MASK ) {{
|
|
#if {custom_mask} // custom_mask
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({custom_mask_kernel_name}_nl,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{custom_mask_kernel_name}_nl<<<grid, Kernel_traits_nl::THREADS, Kernel_traits_nl::BYTES_PER_SMEM, stream>>>({params_str});
|
|
#endif // custom mask
|
|
}}
|
|
}}
|
|
|
|
#endif // has_noloop && !tiled
|
|
|
|
#if {tiled} // tiled
|
|
|
|
using Kernel_traits_nl_tiled = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{noloop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags} | 0x200 /* no_loop flag */,
|
|
/*dense mask*/ 2,
|
|
/*bmm2_fp16_epilogue*/ true,
|
|
{output_dtype_},
|
|
{sage_block_size_q},
|
|
{sage_block_size_k},
|
|
{sage_block_size_v}>;
|
|
|
|
using Kernel_traits_nl_tiled_causal = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{noloop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags} | 0x200 /* no_loop flag */,
|
|
/*causal mask*/ 3,
|
|
/*bmm2_fp16_epilogue*/ true,
|
|
{output_dtype_}>;
|
|
|
|
using Kernel_traits_nl_tiled_sliding_or_chunked_causal = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{noloop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags} | 0x200 /* no_loop flag */,
|
|
/*sliding window causal mask*/ 4,
|
|
/*bmm2_fp16_epilogue*/ true,
|
|
{output_dtype_}>;
|
|
|
|
using Kernel_traits_nl_tiled_custom_mask = fmha::{kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{noloop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags} | 0x200 /* no_loop flag */,
|
|
/*custom mask*/ 5,
|
|
/*bmm2_fp16_epilogue*/ true,
|
|
{output_dtype_}>;
|
|
|
|
#if {padding_mask} // padding_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {kernel_name}_nl_tiled({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_nl_tiled<Kernel_traits_nl_tiled>(params);
|
|
}}
|
|
|
|
#endif // padding_mask
|
|
|
|
#if {causal_mask} // causal_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {causal_kernel_name}_nl_tiled({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_nl_tiled<Kernel_traits_nl_tiled_causal>(params);
|
|
}}
|
|
|
|
#endif // causal mask
|
|
|
|
#if {sliding_or_chunked_causal_mask} // sliding_or_chunked_causal_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {sliding_or_chunked_causal_kernel_name}_nl_tiled({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_nl_tiled<Kernel_traits_nl_tiled_sliding_or_chunked_causal>(params);
|
|
}}
|
|
|
|
#endif // sliding_or_chunked_causal_mask
|
|
|
|
#if {custom_mask} // custom_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {custom_mask_kernel_name}_nl_tiled({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_nl_tiled<Kernel_traits_nl_tiled_custom_mask>(params);
|
|
}}
|
|
|
|
#endif // custom mask
|
|
|
|
// Granular tiling
|
|
void {launcher_name}_nl_tiled(
|
|
{const_fused_multihead_attention_params_v2_str} ¶ms,
|
|
const Launch_params &launch_params,
|
|
cudaStream_t stream){{
|
|
// runtime q_loop_iters
|
|
using Cta_tile_o = typename Kernel_traits_nl_tiled::Cta_tile_o;
|
|
int ctas_per_o_row = (params.d + Cta_tile_o::N - 1) / Cta_tile_o::N;
|
|
int loop_iters = ( params.s + {noloop_step} - 1 ) / {noloop_step};
|
|
dim3 grid(loop_iters * ctas_per_o_row, params.h, params.b);
|
|
constexpr int smem_size = Kernel_traits_nl_tiled::BYTES_PER_SMEM;
|
|
if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) {{
|
|
#if {causal_mask} // causal_mask
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({causal_kernel_name}_nl_tiled,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{causal_kernel_name}_nl_tiled<<<grid, Kernel_traits_nl_tiled::THREADS, Kernel_traits_nl_tiled::BYTES_PER_SMEM, stream>>>({params_str});
|
|
#endif // causal mask
|
|
}} else if( launch_params.attention_mask_type == Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL ) {{
|
|
#if {sliding_or_chunked_causal_mask} // sliding_or_chunked_causal_mask
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({sliding_or_chunked_causal_kernel_name}_nl_tiled,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{sliding_or_chunked_causal_kernel_name}_nl_tiled<<<grid, Kernel_traits_nl_tiled::THREADS, Kernel_traits_nl_tiled::BYTES_PER_SMEM, stream>>>({params_str});
|
|
#endif // sliding_or_chunked_causal_mask
|
|
}} else if( launch_params.attention_mask_type == Attention_mask_type::PADDING ) {{
|
|
#if {padding_mask} // padding_mask
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({kernel_name}_nl_tiled,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{kernel_name}_nl_tiled<<<grid, Kernel_traits_nl_tiled::THREADS, Kernel_traits_nl_tiled::BYTES_PER_SMEM, stream>>>({params_str});
|
|
#endif // padding_mask
|
|
}} else if( launch_params.attention_mask_type == Attention_mask_type::CUSTOM_MASK ) {{
|
|
#if {custom_mask} // custom_mask
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({custom_mask_kernel_name}_nl_tiled,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{custom_mask_kernel_name}_nl_tiled<<<grid, Kernel_traits_nl_tiled::THREADS, Kernel_traits_nl_tiled::BYTES_PER_SMEM, stream>>>({params_str});
|
|
#endif // custom mask
|
|
}}
|
|
}}
|
|
|
|
#endif // tiled
|
|
|
|
#else // CUDA_VERSION >= {min_cuda_version}
|
|
|
|
void {launcher_name}(const {params_type} ¶ms, cudaStream_t stream){{
|
|
assert(false && "Unsupported CUDA version");
|
|
}}
|
|
|
|
void {launcher_name}_nl(const {params_type} ¶ms, cudaStream_t stream){{
|
|
assert(false && "Unsupported CUDA version");
|
|
}}
|
|
|
|
void {launcher_name}_nl_tiled(const {params_type} ¶ms, cudaStream_t stream){{
|
|
assert(false && "Unsupported CUDA version");
|
|
}}
|
|
|
|
#endif // CUDA_VERSION >= {min_cuda_version}
|
|
{local_ns_close}
|
|
'''
|
|
|
|
kernel_hopper_template = '''\
|
|
{copyright}
|
|
|
|
//We can disable the FADD trick for archs with F2IP
|
|
#if {disable_fadd_trick}
|
|
#ifdef USE_I2F_EMULATION_TRICK
|
|
#undef USE_I2F_EMULATION_TRICK
|
|
#endif
|
|
|
|
#ifdef USE_F2I_EMULATION_TRICK
|
|
#undef USE_F2I_EMULATION_TRICK
|
|
#endif
|
|
#endif
|
|
|
|
#include <cuda.h>
|
|
|
|
#if CUDA_VERSION >= {min_cuda_version}
|
|
|
|
#include <fused_multihead_attention_kernel_{kernel_variant}.h>
|
|
#if {has_noloop}
|
|
#include <fused_multihead_attention_kernel_{kernel_variant}_noloop.h>
|
|
#endif
|
|
|
|
#if {use_tma}
|
|
// only included if tma is used.
|
|
#include <fmha/hopper/tma_descriptor.h>
|
|
#endif //use_tma
|
|
|
|
{include_str}
|
|
{local_ns_open}
|
|
{bert_launch_params}
|
|
{attn_mask_type_str}
|
|
|
|
using Traits_p = fmha::{instruction_traits_p};
|
|
using Traits_o = fmha::{instruction_traits_o};
|
|
|
|
using Kernel_traits = {kernel_traits}<
|
|
Traits_p,
|
|
Traits_o,
|
|
{seq_len},
|
|
{head_size},
|
|
{loop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
2,
|
|
{kernel_flags}>;
|
|
|
|
using Kernel_traits_causal = {kernel_traits}<
|
|
Traits_p,
|
|
Traits_o,
|
|
{seq_len},
|
|
{head_size},
|
|
{loop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
3,
|
|
{kernel_flags}>;
|
|
|
|
using Kernel_traits_sliding_or_chunked_causal = {kernel_traits}<
|
|
Traits_p,
|
|
Traits_o,
|
|
{seq_len},
|
|
{head_size},
|
|
{loop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
4,
|
|
{kernel_flags}>;
|
|
|
|
#if {use_tma} // use_tma
|
|
|
|
#if {padding_mask} // padding_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {kernel_name}(const __grid_constant__ {params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_tma<Kernel_traits>(params);
|
|
}}
|
|
|
|
#endif // padding_mask
|
|
|
|
#if {causal_mask} // causal_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {causal_kernel_name}(const __grid_constant__ {params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_tma<Kernel_traits_causal>(params);
|
|
}}
|
|
|
|
#endif // causal mask
|
|
|
|
#if {sliding_or_chunked_causal_mask} // sliding_or_chunked_causal_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {sliding_or_chunked_causal_kernel_name}(const __grid_constant__ {params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_tma<Kernel_traits_sliding_or_chunked_causal>(params);
|
|
}}
|
|
|
|
#endif // sliding_or_chunked_causal_mask
|
|
|
|
#else
|
|
|
|
#if {padding_mask}
|
|
|
|
extern "C"
|
|
__global__
|
|
void {kernel_name}(const __grid_constant__ {params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}<Kernel_traits>(params);
|
|
}}
|
|
|
|
#endif // padding_mask
|
|
|
|
#if {causal_mask} // causal_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {causal_kernel_name}(const __grid_constant__ {params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}<Kernel_traits_causal>(params);
|
|
}}
|
|
|
|
#endif // causal mask
|
|
|
|
#if {sliding_or_chunked_causal_mask} // sliding_or_chunked_causal_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {sliding_or_chunked_causal_kernel_name}(const __grid_constant__ {params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}<Kernel_traits_sliding_or_chunked_causal>(params);
|
|
}}
|
|
#endif
|
|
|
|
#endif // sliding_or_chunked_causal_mask
|
|
|
|
void {launcher_name}({fused_multihead_attention_params_v2_str} ¶ms,
|
|
const Launch_params &launch_params, cudaStream_t stream){{
|
|
// setting TMA descriptors if needed.
|
|
// use_tma = {use_tma}
|
|
#if {use_tma}
|
|
// declare TMA desc for Q, K, V
|
|
typename fmha::Multiple_tma_descriptor<3> tma_desc_QKV;
|
|
|
|
// GMEM pointers, the offset between each batch is d*3*h*seqlen
|
|
// qkv pointer
|
|
char *qkv_ptr = reinterpret_cast<char*>(params.qkv_ptr);
|
|
|
|
// tensor size
|
|
uint32_t tensor_size_qkv[3];
|
|
tensor_size_qkv[2] = 1;
|
|
tensor_size_qkv[1] = params.is_s_padded ? params.s * params.b : launch_params.seqlens[params.b];
|
|
tensor_size_qkv[0] = (params.h + 2 * params.h_kv) * params.d;
|
|
|
|
// box size for Q
|
|
uint32_t box_size_q[3];
|
|
box_size_q[2] = 1;
|
|
box_size_q[1] = {loop_step}; // STEP size
|
|
box_size_q[0] = {head_size}; // head_size
|
|
|
|
// box size for k and v
|
|
uint32_t box_size_kv[3];
|
|
box_size_kv[2] = 1;
|
|
box_size_kv[1] = params.s; // S, should not be actual_s, OOB will be filled with zeros.
|
|
box_size_kv[0] = {head_size}; // head_size
|
|
|
|
// stride size
|
|
uint64_t tensor_stride_qkv[2];
|
|
tensor_stride_qkv[0] = tensor_size_qkv[0] * Traits_p::BITS_PER_ELEMENT_A / 8;
|
|
tensor_stride_qkv[1] = tensor_size_qkv[1] * tensor_stride_qkv[0];
|
|
|
|
// traversal stride
|
|
uint32_t traversal_stride_qkv[3] = {{1, 1, 1}};
|
|
|
|
// OOB fill zeros
|
|
uint32_t oob_fill = 0;
|
|
|
|
// FP32 to TF32 conversion disabled
|
|
uint32_t fp32_to_tf32 = 0;
|
|
|
|
//setup the descriptors
|
|
|
|
//setup the descriptor for Q
|
|
tma_desc_QKV.set_tma_desctriptor(reinterpret_cast<void*>(qkv_ptr),
|
|
fmha::cudaTmaDescFormat::F16_RN, // tma format (data type). For now hardcode to fp16
|
|
fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED,
|
|
fmha::cudaTmaDescSwizzle::SWIZZLE_128B,
|
|
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED,
|
|
tensor_size_qkv,
|
|
tensor_stride_qkv,
|
|
traversal_stride_qkv,
|
|
box_size_q,
|
|
oob_fill,
|
|
fp32_to_tf32,
|
|
¶ms.tma_desc_q);
|
|
|
|
// setup the descriptor for K
|
|
tma_desc_QKV.set_tma_desctriptor(reinterpret_cast<void*>(qkv_ptr),
|
|
fmha::cudaTmaDescFormat::F16_RN, // tma format (data type). For now hardcode to fp16
|
|
fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED,
|
|
fmha::cudaTmaDescSwizzle::SWIZZLE_128B,
|
|
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED,
|
|
tensor_size_qkv,
|
|
tensor_stride_qkv,
|
|
traversal_stride_qkv,
|
|
box_size_kv,
|
|
oob_fill,
|
|
fp32_to_tf32,
|
|
¶ms.tma_desc_k);
|
|
|
|
// setup the descriptor for V
|
|
tma_desc_QKV.set_tma_desctriptor(reinterpret_cast<void*>(qkv_ptr),
|
|
fmha::cudaTmaDescFormat::F16_RN, // tma format (data type). For now hardcode to fp16
|
|
fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED,
|
|
fmha::cudaTmaDescSwizzle::SWIZZLE_128B,
|
|
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED,
|
|
tensor_size_qkv,
|
|
tensor_stride_qkv,
|
|
traversal_stride_qkv,
|
|
box_size_kv,
|
|
oob_fill,
|
|
fp32_to_tf32,
|
|
¶ms.tma_desc_v);
|
|
|
|
|
|
#endif // use_tma
|
|
dim3 grid(params.h, params.b);
|
|
// Use the same smem_size for all traits.
|
|
constexpr int smem_size = Kernel_traits::BYTES_PER_SMEM;
|
|
if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) {{
|
|
#if {causal_mask} // causal_mask
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({causal_kernel_name},
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{causal_kernel_name}<<<grid, Kernel_traits::THREADS, Kernel_traits::BYTES_PER_SMEM, stream>>>({params_str});
|
|
#endif // causal mask
|
|
}} else if( launch_params.attention_mask_type == Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL ) {{
|
|
#if {sliding_or_chunked_causal_mask} // sliding_or_chunked_causal_mask
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({sliding_or_chunked_causal_kernel_name},
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{sliding_or_chunked_causal_kernel_name}<<<grid, Kernel_traits::THREADS, Kernel_traits::BYTES_PER_SMEM, stream>>>({params_str});
|
|
#endif // sliding_or_chunked_causal_mask
|
|
}} else {{
|
|
#if {padding_mask} // padding_mask
|
|
constexpr int smem_size = Kernel_traits::BYTES_PER_SMEM;
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({kernel_name},
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{kernel_name}<<<grid, Kernel_traits::THREADS, Kernel_traits::BYTES_PER_SMEM, stream>>>({params_str});
|
|
#endif // padding_mask
|
|
}}
|
|
}}
|
|
|
|
#if {has_noloop}
|
|
|
|
|
|
using Kernel_traits_nl = {kernel_traits}<
|
|
Traits_p,
|
|
Traits_o,
|
|
{seq_len},
|
|
{head_size},
|
|
{noloop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
2,
|
|
{kernel_flags}>;
|
|
|
|
using Kernel_traits_causal_nl = {kernel_traits}<
|
|
Traits_p,
|
|
Traits_o,
|
|
{seq_len},
|
|
{head_size},
|
|
{noloop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
3,
|
|
{kernel_flags}>;
|
|
|
|
using Kernel_traits_sliding_or_chunked_causal_nl = {kernel_traits}<
|
|
Traits_p,
|
|
Traits_o,
|
|
{seq_len},
|
|
{head_size},
|
|
{noloop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
4,
|
|
{kernel_flags}>;
|
|
|
|
#if {padding_mask} // padding_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {kernel_name}_nl({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_nl<Kernel_traits_nl>(params);
|
|
}}
|
|
|
|
#endif // padding_mask
|
|
|
|
#if {causal_mask} // causal_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {causal_kernel_name}_nl({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_nl<Kernel_traits_causal_nl>(params);
|
|
}}
|
|
|
|
#endif // causal mask
|
|
|
|
#if {sliding_or_chunked_causal_mask} // sliding_or_chunked_causal_mask
|
|
|
|
extern "C"
|
|
__global__
|
|
void {sliding_or_chunked_causal_kernel_name}_nl({params_type} params){{
|
|
fused_multihead_attention::device_{kernel_variant}_nl<Kernel_traits_sliding_or_chunked_causal_nl>(params);
|
|
}}
|
|
|
|
#endif // sliding_or_chunked_causal_mask
|
|
|
|
void {launcher_name}_nl({fused_multihead_attention_params_v2_str} ¶ms,
|
|
const Launch_params& launch_params, cudaStream_t stream){{
|
|
constexpr int loop_iters = {seq_len} / {noloop_step};
|
|
static_assert(loop_iters * {noloop_step} == {seq_len}, "");
|
|
dim3 grid(params.h, params.b, loop_iters);
|
|
|
|
// Use the same smem_size for all traits.
|
|
constexpr int smem_size = Kernel_traits::BYTES_PER_SMEM;
|
|
if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) {{
|
|
#if {causal_mask} // causal_mask
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({causal_kernel_name}_nl,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{causal_kernel_name}_nl<<<grid, Kernel_traits_nl::THREADS, Kernel_traits_nl::BYTES_PER_SMEM, stream>>>({params_str});
|
|
#endif // causal mask
|
|
}} else if( launch_params.attention_mask_type == Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL ) {{
|
|
#if {sliding_or_chunked_causal_mask} // sliding_or_chunked_causal_mask
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({sliding_or_chunked_causal_kernel_name}_nl,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{sliding_or_chunked_causal_kernel_name}_nl<<<grid, Kernel_traits_nl::THREADS, Kernel_traits_nl::BYTES_PER_SMEM, stream>>>({params_str});
|
|
#endif // sliding_or_chunked_causal_mask
|
|
}} else {{
|
|
#if {padding_mask} // padding_mask
|
|
if( smem_size >= 48*1024 ) {{
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({kernel_name}_nl,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem_size));
|
|
}}
|
|
{kernel_name}_nl<<<grid, Kernel_traits_nl::THREADS, Kernel_traits_nl::BYTES_PER_SMEM, stream>>>({params_str});
|
|
#endif // padding_mask
|
|
}}
|
|
}}
|
|
|
|
#endif
|
|
|
|
#else
|
|
|
|
void {launcher_name}(const {params_type} ¶ms, cudaStream_t stream){{
|
|
assert(false && "Unsupported CUDA version");
|
|
}}
|
|
|
|
#if {has_noloop}
|
|
|
|
void {launcher_name}_nl(const {params_type} ¶ms, cudaStream_t stream){{
|
|
assert(false && "Unsupported CUDA version");
|
|
}}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
{local_ns_close}
|
|
'''
|
|
|
|
kernel_hopper_warp_specialization_template = '''\
|
|
{copyright}
|
|
|
|
#include <fused_multihead_attention_utils.h>
|
|
#include <fmha/hopper/gmma_descriptor.h>
|
|
#include <fmha/hopper/smem_tile.h>
|
|
#include <fmha/utils.h>
|
|
#include <fmha/hopper/compute_tile.h>
|
|
|
|
#include <fmha/warpspec/kernel_traits.h>
|
|
#include <fmha/warpspec/dma.h>
|
|
#include <fmha/warpspec/compute.h>
|
|
|
|
{include_str}
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
{local_ns_open}
|
|
#if CUDA_VERSION >= {min_cuda_version}
|
|
|
|
static constexpr int DMA2COMPUTE_DEPTH = 1;
|
|
{num_compute_groups_str}
|
|
static constexpr bool USE_TMA_STORE = {use_tma_store_flag};
|
|
|
|
{bert_launch_params}
|
|
{attn_mask_type_str}
|
|
|
|
using Ktraits = {kernel_traits_header}
|
|
{loop_step},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{q_tile_buffers},
|
|
{kv_tile_buffers},
|
|
NUM_COMPUTE_GROUPS,
|
|
DMA2COMPUTE_DEPTH,
|
|
0,
|
|
{heads_interleaved_flag},
|
|
false,
|
|
{enable_mutex_flag},
|
|
{scheduling_mode},
|
|
{input_layout_flag},
|
|
USE_TMA_STORE,
|
|
{enable_attn_logit_softcapping_flag},
|
|
{return_softmax_stats_flag},
|
|
{enable_skip_softmax_flag},
|
|
{output_dtype_},
|
|
{sage_block_size_q},
|
|
{sage_block_size_k},
|
|
{sage_block_size_v}>;
|
|
|
|
using Ktraits_causal = {kernel_traits_header}
|
|
{loop_step},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{q_tile_buffers},
|
|
{kv_tile_buffers},
|
|
NUM_COMPUTE_GROUPS,
|
|
DMA2COMPUTE_DEPTH,
|
|
1,
|
|
{heads_interleaved_flag},
|
|
{has_alibi},
|
|
{enable_mutex_flag},
|
|
{scheduling_mode},
|
|
{input_layout_flag},
|
|
USE_TMA_STORE,
|
|
{enable_attn_logit_softcapping_flag},
|
|
{return_softmax_stats_flag},
|
|
{enable_skip_softmax_flag},
|
|
{output_dtype_}>;
|
|
|
|
using Ktraits_sliding_or_chunked_causal = {kernel_traits_header}
|
|
{loop_step},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{q_tile_buffers},
|
|
{kv_tile_buffers},
|
|
NUM_COMPUTE_GROUPS,
|
|
DMA2COMPUTE_DEPTH,
|
|
2,
|
|
{heads_interleaved_flag},
|
|
{has_alibi},
|
|
{enable_mutex_flag},
|
|
{scheduling_mode},
|
|
{input_layout_flag},
|
|
USE_TMA_STORE && false,
|
|
{enable_attn_logit_softcapping_flag},
|
|
{return_softmax_stats_flag},
|
|
{enable_skip_softmax_flag},
|
|
{output_dtype_}>;
|
|
|
|
using Ktraits_custom_mask = {kernel_traits_header}
|
|
{loop_step},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{q_tile_buffers},
|
|
{kv_tile_buffers},
|
|
NUM_COMPUTE_GROUPS,
|
|
DMA2COMPUTE_DEPTH,
|
|
3,
|
|
{heads_interleaved_flag},
|
|
{has_alibi},
|
|
{enable_mutex_flag},
|
|
{scheduling_mode},
|
|
{input_layout_flag},
|
|
USE_TMA_STORE && false,
|
|
{enable_attn_logit_softcapping_flag},
|
|
{return_softmax_stats_flag},
|
|
{enable_skip_softmax_flag},
|
|
{output_dtype_}>;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#if {padding_mask} // padding_mask
|
|
|
|
using Shared = typename Ktraits::Shared;
|
|
|
|
extern "C"
|
|
__global__ __launch_bounds__(Ktraits::THREADS, 1)
|
|
void {kernel_name}(
|
|
const __grid_constant__ {params_type} params){{
|
|
|
|
extern __shared__ char smem_[];
|
|
char *smem_aligned = fmha::align_1024(smem_);
|
|
|
|
Shared *shared = reinterpret_cast<Shared *>(&smem_aligned[0]);
|
|
shared->init(threadIdx.x == 0);
|
|
__syncthreads();
|
|
|
|
// special trick to avoid wrap_sync (leads to illegal instruction)
|
|
int warp_group = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
|
int tidx = threadIdx.x % 128;
|
|
|
|
if( warp_group == NUM_COMPUTE_GROUPS ) {{ // dma + sched
|
|
|
|
{setmaxnreg_dma_str}
|
|
uint32_t elect_one = tidx == 0;
|
|
|
|
// Need all threads involved when the dam group needs to transpose the v tile explicltly.
|
|
if constexpr ( Ktraits::DMA_GROUP_TRANSPOSE_V ) {{
|
|
fmha::ws::DMA<Ktraits>::Device dma_device(elect_one);
|
|
dma_device.{run_fct_name}(params, shared);
|
|
}} else {{
|
|
fmha::ws::DMA<Ktraits>::Device dma_device(elect_one);
|
|
if( tidx < 32 ) {{
|
|
dma_device.{run_fct_name}(params, shared);
|
|
}}
|
|
}}
|
|
|
|
}} else {{ // math
|
|
|
|
{setmaxnreg_compute_str}
|
|
|
|
fmha::ws::Compute<fmha::{instruction_traits}, Ktraits> compute;
|
|
compute.run(warp_group, tidx, shared, params);
|
|
}}
|
|
}}
|
|
|
|
#endif // padding mask
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#if {causal_mask} // causal_mask
|
|
|
|
using Shared_causal = typename Ktraits_causal::Shared;
|
|
|
|
extern "C"
|
|
__global__ __launch_bounds__(Ktraits_causal::THREADS, 1)
|
|
void {causal_kernel_name}(
|
|
const __grid_constant__ {params_type} params){{
|
|
|
|
extern __shared__ char smem_[];
|
|
char *smem_aligned = fmha::align_1024(smem_);
|
|
|
|
Shared_causal *shared = reinterpret_cast<Shared_causal *>(&smem_aligned[0]);
|
|
shared->init(threadIdx.x == 0);
|
|
__syncthreads();
|
|
|
|
// special trick to avoid wrap_sync (leads to illegal instruction)
|
|
int warp_group = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
|
int tidx = threadIdx.x % 128;
|
|
|
|
if( warp_group == NUM_COMPUTE_GROUPS ) {{ // dma + sched
|
|
|
|
{setmaxnreg_dma_str}
|
|
uint32_t elect_one = tidx == 0;
|
|
|
|
// Need all threads involved when the dam group needs to transpose the v tile explicltly.
|
|
if constexpr ( Ktraits_causal::DMA_GROUP_TRANSPOSE_V ) {{
|
|
fmha::ws::DMA<Ktraits_causal>::Device dma_device(elect_one);
|
|
dma_device.{run_fct_name}(params, shared);
|
|
}} else {{
|
|
fmha::ws::DMA<Ktraits_causal>::Device dma_device(elect_one);
|
|
if( tidx < 32 ) {{
|
|
dma_device.{run_fct_name}(params, shared);
|
|
}}
|
|
}}
|
|
|
|
}} else {{ // math
|
|
|
|
{setmaxnreg_compute_str}
|
|
|
|
fmha::ws::Compute<fmha::{instruction_traits}, Ktraits_causal> compute;
|
|
compute.run(warp_group, tidx, shared, params);
|
|
}}
|
|
}}
|
|
|
|
#endif // causal mask
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#if {sliding_or_chunked_causal_mask} // sliding_or_chunked_causal_mask
|
|
|
|
using Shared_sliding_or_chunked_causal = typename Ktraits_sliding_or_chunked_causal::Shared;
|
|
|
|
extern "C"
|
|
__global__ __launch_bounds__(Ktraits_sliding_or_chunked_causal::THREADS, 1)
|
|
void {sliding_or_chunked_causal_kernel_name}(
|
|
const __grid_constant__ {params_type} params){{
|
|
|
|
extern __shared__ char smem_[];
|
|
char *smem_aligned = fmha::align_1024(smem_);
|
|
|
|
Shared_sliding_or_chunked_causal *shared =
|
|
reinterpret_cast<Shared_sliding_or_chunked_causal *>(&smem_aligned[0]);
|
|
shared->init(threadIdx.x == 0);
|
|
__syncthreads();
|
|
|
|
// special trick to avoid wrap_sync (leads to illegal instruction)
|
|
int warp_group = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
|
int tidx = threadIdx.x % 128;
|
|
|
|
if( warp_group == NUM_COMPUTE_GROUPS ) {{ // dma + sched
|
|
|
|
{setmaxnreg_dma_str}
|
|
uint32_t elect_one = tidx == 0;
|
|
|
|
// Need all threads involved when the dam group needs to transpose the v tile explicltly.
|
|
if constexpr ( Ktraits_sliding_or_chunked_causal::DMA_GROUP_TRANSPOSE_V ) {{
|
|
fmha::ws::DMA<Ktraits_sliding_or_chunked_causal>::Device dma_device(elect_one);
|
|
dma_device.{run_fct_name}(params, shared);
|
|
}} else {{
|
|
fmha::ws::DMA<Ktraits_sliding_or_chunked_causal>::Device dma_device(elect_one);
|
|
if( tidx < 32 ) {{
|
|
dma_device.{run_fct_name}(params, shared);
|
|
}}
|
|
}}
|
|
|
|
}} else {{ // math
|
|
|
|
{setmaxnreg_compute_str}
|
|
|
|
fmha::ws::Compute<fmha::{instruction_traits}, Ktraits_sliding_or_chunked_causal> compute;
|
|
compute.run(warp_group, tidx, shared, params);
|
|
}}
|
|
}}
|
|
|
|
#endif // sliding_or_chunked_causal_mask
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#if {custom_mask} // custom_mask
|
|
|
|
using Shared_custom_mask = typename Ktraits_custom_mask::Shared;
|
|
|
|
extern "C"
|
|
__global__ __launch_bounds__(Ktraits_custom_mask::THREADS, 1)
|
|
void {custom_mask_kernel_name}(
|
|
const __grid_constant__ {params_type} params){{
|
|
|
|
extern __shared__ char smem_[];
|
|
char *smem_aligned = fmha::align_1024(smem_);
|
|
|
|
Shared_custom_mask *shared =
|
|
reinterpret_cast<Shared_custom_mask *>(&smem_aligned[0]);
|
|
shared->init(threadIdx.x == 0);
|
|
__syncthreads();
|
|
|
|
// special trick to avoid wrap_sync (leads to illegal instruction)
|
|
int warp_group = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
|
int tidx = threadIdx.x % 128;
|
|
|
|
if( warp_group == NUM_COMPUTE_GROUPS ) {{ // dma + sched
|
|
|
|
{setmaxnreg_dma_str}
|
|
uint32_t elect_one = tidx == 0;
|
|
|
|
// Need all threads involved when the dam group needs to transpose the v tile explicltly.
|
|
if constexpr ( Ktraits_custom_mask::DMA_GROUP_TRANSPOSE_V ) {{
|
|
fmha::ws::DMA<Ktraits_custom_mask>::Device dma_device(elect_one);
|
|
dma_device.{run_fct_name}(params, shared);
|
|
}} else {{
|
|
fmha::ws::DMA<Ktraits_custom_mask>::Device dma_device(elect_one);
|
|
if( tidx < 32 ) {{
|
|
dma_device.{run_fct_name}(params, shared);
|
|
}}
|
|
}}
|
|
|
|
}} else {{ // math
|
|
|
|
{setmaxnreg_compute_str}
|
|
|
|
fmha::ws::Compute<fmha::{instruction_traits}, Ktraits_custom_mask> compute;
|
|
compute.run(warp_group, tidx, shared, params);
|
|
}}
|
|
}}
|
|
|
|
#endif // custom_mask
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void {launcher_name}(
|
|
{fused_multihead_attention_params_v2_str} ¶ms,
|
|
const Launch_params &launch_params, cudaStream_t stream){{
|
|
|
|
{TMA_config}
|
|
if( Ktraits::SCHEDULING_MODE > 0 ) {{
|
|
FMHA_CHECK_CUDA(cudaMemsetAsync(params.tile_id_counter_ptr, 0, sizeof(uint32_t), stream));
|
|
}}
|
|
|
|
dim3 block_size;
|
|
|
|
if( Ktraits::SCHEDULING_MODE == 0 ) {{
|
|
block_size.y = std::min(params.b * params.h, launch_params.multi_processor_count);
|
|
// distribute m steps to multiple blocks (fully utilize SMs)
|
|
// block.x = blocks that handle single head, block.y = blocks that handle different heads
|
|
size_t sms_per_head = (launch_params.multi_processor_count) / block_size.y;
|
|
// Take multiple compute groups into consideration.
|
|
size_t m_steps = size_t((params.s + {loop_step} * NUM_COMPUTE_GROUPS - 1) / ({loop_step} * NUM_COMPUTE_GROUPS));
|
|
|
|
// 2 * {bytes_per_elt} stands for kv cache and {bytes_per_elt} bytes per element.
|
|
size_t size_in_bytes = block_size.y * params.s * params.d * 2 * {bytes_per_elt};
|
|
if( size_in_bytes <= launch_params.device_l2_cache_size ) {{
|
|
// strategy 1: limit to only 1 wave
|
|
block_size.x = std::min(m_steps, sms_per_head);
|
|
}} else {{
|
|
// strategy 2: fully unroll the q loops (contiguous blocks handle all q loops)
|
|
block_size.x = m_steps;
|
|
}}
|
|
params.num_tiles = params.b * params.h;
|
|
}} else if( Ktraits::SCHEDULING_MODE == 1 ) {{
|
|
// Get the max total M steps
|
|
// Take multiple compute groups into consideration.
|
|
size_t m_steps = size_t((params.s + {loop_step} * NUM_COMPUTE_GROUPS - 1) / ({loop_step} * NUM_COMPUTE_GROUPS));
|
|
params.num_tiles_per_head = static_cast<uint32_t>(m_steps);
|
|
params.num_tiles = static_cast<uint32_t>(m_steps * params.b * params.h);
|
|
if (launch_params.attention_mask_type == Attention_mask_type::CAUSAL) {{
|
|
// 2 * {bytes_per_elt} stands for kv cache and {bytes_per_elt} bytes per element.
|
|
size_t size_in_bytes = params.b * params.h * params.s * params.d * 2 * {bytes_per_elt};
|
|
params.use_balanced_scheduling = (size_in_bytes <= launch_params.device_l2_cache_size);
|
|
}}
|
|
|
|
block_size.x = 1;
|
|
block_size.y = std::min(static_cast<int>(params.num_tiles), launch_params.multi_processor_count);
|
|
}} else {{
|
|
assert(false && "Invalid SCHEDULING_MODE");
|
|
}}
|
|
|
|
// Reuse the same bytes_per_smem for launching kernels.
|
|
constexpr int SMEM_BYTES = Ktraits::BYTES_PER_SMEM;
|
|
if( launch_params.attention_mask_type == Attention_mask_type::PADDING ) {{
|
|
#if {padding_mask} // padding_mask
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({kernel_name},
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
SMEM_BYTES));
|
|
|
|
{kernel_name}
|
|
<<<block_size, Ktraits::THREADS, SMEM_BYTES, stream>>>({params_str});
|
|
#endif // padding_mask
|
|
}} else if( launch_params.attention_mask_type == Attention_mask_type::CAUSAL ) {{
|
|
#if {causal_mask} // causal_mask
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({causal_kernel_name},
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
SMEM_BYTES));
|
|
|
|
{causal_kernel_name}
|
|
<<<block_size, Ktraits::THREADS, SMEM_BYTES, stream>>>({params_str});
|
|
#endif // causal mask
|
|
}} else if( launch_params.attention_mask_type == Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL ) {{
|
|
#if {sliding_or_chunked_causal_mask} // sliding_or_chunked_causal_mask
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({sliding_or_chunked_causal_kernel_name},
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
SMEM_BYTES));
|
|
|
|
{sliding_or_chunked_causal_kernel_name}
|
|
<<<block_size, Ktraits::THREADS, SMEM_BYTES, stream>>>({params_str});
|
|
#endif // sliding_or_chunked_causal_mask
|
|
}} else if( launch_params.attention_mask_type == Attention_mask_type::CUSTOM_MASK ) {{
|
|
#if {custom_mask} // custom_mask
|
|
FMHA_CHECK_CUDA(cudaFuncSetAttribute({custom_mask_kernel_name},
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
SMEM_BYTES));
|
|
|
|
{custom_mask_kernel_name}
|
|
<<<block_size, Ktraits::THREADS, SMEM_BYTES, stream>>>({params_str});
|
|
#endif // custom mask
|
|
}}
|
|
|
|
}}
|
|
|
|
#endif
|
|
{local_ns_close}
|
|
'''
|
|
|
|
|
|
def encode_name(kernel_spec):
|
|
effective_sm, sm_name = get_effective_sm_and_name(kernel_spec)
|
|
# Is it a kernel for the interleaved NC/32HW32 INT8 layout?
|
|
il_tag = '_il' if kernel_spec.interleaved else ''
|
|
# Is it using the quantization scaling factor as an approximation of the max in softmax?
|
|
scale_max_tag = '_scale_max' if kernel_spec.has_scale_max else ''
|
|
# Deal with multi-CTA kernels for which the sequence length is seq_len per CTA * # of CTAs.
|
|
seqlen = kernel_spec.seq_len * kernel_spec.ctas_per_head
|
|
# The qkv layout.
|
|
qkv_layout_tag = ''
|
|
if kernel_spec.input_layout == InputLayout.PACKED_QKV:
|
|
qkv_layout_tag = '_qkv'
|
|
elif kernel_spec.input_layout == InputLayout.Q_PAGED_KV:
|
|
qkv_layout_tag = '_q_paged_kv'
|
|
elif kernel_spec.input_layout == InputLayout.SEPARATE_Q_K_V:
|
|
qkv_layout_tag = '_q_k_v'
|
|
else:
|
|
qkv_layout_tag = '_q_kv'
|
|
# for SM90 kernels, let's also differentiate ldgsts and tma kernels
|
|
feature_tags = ''
|
|
if (effective_sm == 90):
|
|
# let's think about where to insert tma/ldgsts in the string before MR. [Timmy]
|
|
if (kernel_spec.ldgsts_q == True):
|
|
tma_or_ldgsts = '_ldgsts'
|
|
else:
|
|
tma_or_ldgsts = '_tma'
|
|
if kernel_spec.warp_specialization:
|
|
warp_specialization_tag = '_ws'
|
|
# hopper warp-specialized kernels has specialized optimization for cases without alibi.
|
|
if kernel_spec.alibi:
|
|
feature_tags += '_alibi'
|
|
if kernel_spec.return_softmax_stats:
|
|
feature_tags += '_softmax'
|
|
else:
|
|
warp_specialization_tag = ''
|
|
else:
|
|
tma_or_ldgsts = ''
|
|
warp_specialization_tag = ''
|
|
|
|
if kernel_spec.enable_attn_logit_softcapping:
|
|
feature_tags += '_softcapping'
|
|
if kernel_spec.enable_skip_softmax:
|
|
feature_tags += '_skipSoftmax'
|
|
if kernel_spec.sage_block_sizes:
|
|
feature_tags += f"_sage_{'_'.join(map(str, kernel_spec.sage_block_sizes))}"
|
|
if kernel_spec.output_dtype:
|
|
feature_tags += f"_output_{kernel_spec.output_dtype}"
|
|
if kernel_spec.ctas_per_head > 1:
|
|
fmt = 'fmha_v{version}{il_tag}_{dtype}_' + str(
|
|
seqlen
|
|
) + '_{head_size}{attrib}{scale_max_tag}{tma_or_ldgsts}_sm{sm}'
|
|
elif kernel_spec.flash_attention:
|
|
fmt = 'fmha_v{version}{il_tag}_flash_attention_{dtype}_{loop_step}_{kv_loop_step}_S{qkv_layout_tag}_{head_size}{head_size_v_str}{attrib}{feature_tags}{scale_max_tag}{tma_or_ldgsts}{warp_specialization_tag}_sm{sm}'
|
|
elif kernel_spec.cross_mha:
|
|
fmt = 'fmha_mhca_{dtype}_{seq_len}_{head_size}{scale_max_tag}{tma_or_ldgsts}_sm{sm}'
|
|
else:
|
|
fmt = 'fmha_v{version}{il_tag}_{dtype}_{seq_len}_{head_size}{attrib}{scale_max_tag}{tma_or_ldgsts}_sm{sm}'
|
|
head_size_v_str = "" if kernel_spec.head_size_v == 0 else f"x{kernel_spec.head_size_v}"
|
|
# Assemble the name of the kernel.
|
|
name_base = fmt.format(**kernel_spec._asdict(),
|
|
head_size_v_str=head_size_v_str,
|
|
il_tag=il_tag,
|
|
qkv_layout_tag=qkv_layout_tag,
|
|
scale_max_tag=scale_max_tag,
|
|
tma_or_ldgsts=tma_or_ldgsts,
|
|
warp_specialization_tag=warp_specialization_tag,
|
|
feature_tags=feature_tags,
|
|
attrib='__placeholder__')
|
|
|
|
# Produce file, launch function and kernel names.
|
|
fname = name_base.replace('__placeholder__', '')
|
|
if seqlen >= 1024 and not kernel_spec.flash_attention:
|
|
fname += '.no_i2f_f2i'
|
|
fname += '.cu'
|
|
lname = ('run_' + name_base).replace('__placeholder__', '')
|
|
kname = name_base + '_kernel'
|
|
|
|
# remove causal
|
|
fname = fname.replace("causal_", "")
|
|
return fname, lname, kname
|
|
|
|
|
|
def get_GMMA_shape(instruction_traits, m, n, k, warps_n):
|
|
gmma_k = hopper_traits2shape[instruction_traits][-1]
|
|
|
|
# gmma shape is 64xgmma_nx16, gmma_n should be as big as possible, but not bigger than n
|
|
# gmma_n should also be smaller than 256
|
|
gmma_m = 64
|
|
gmma_n = 0
|
|
# find the largest supported n
|
|
n_supported = [(i + 1) * 8 for i in range(32)][::-1]
|
|
n_target = n // warps_n
|
|
assert n_target * warps_n == n
|
|
assert n_supported[0] == 256 and n_supported[-1] == 8
|
|
for cand_n in n_supported:
|
|
if n_target % cand_n == 0:
|
|
gmma_n = cand_n
|
|
break
|
|
assert gmma_n > 0, "No supported GMMA_N found!"
|
|
|
|
return gmma_m, gmma_n, gmma_k
|
|
|
|
|
|
def enable_mutex(kspec):
|
|
fp32_accu_dtype = kspec.dtype in ['fp16_fp32', 'bf16']
|
|
enable_mutex = 'false' if (fp32_accu_dtype
|
|
or kspec.head_size <= 64) else 'true'
|
|
return enable_mutex
|
|
|
|
|
|
def enable_tma_store(kspec):
|
|
output_dtype = kspec.output_dtype if kspec.output_dtype is not None else kspec.dtype
|
|
# TMA copies data in the 16B granularity.
|
|
return 'true' if (output_dtype in ['e4m3', 'e4m3_fp32']
|
|
and kspec.head_size % 16 == 0) else 'false'
|
|
|
|
|
|
def get_reg_count(kspec):
|
|
# if kspec.paged_kv_input and kspec.dtype in ['fp16', 'fp16_fp32', 'bf16']:
|
|
# dma_reg_count = 72
|
|
# compute_reg_count = 216
|
|
if kspec.input_layout == InputLayout.Q_PAGED_KV:
|
|
dma_reg_count = 56
|
|
compute_reg_count = 224
|
|
else:
|
|
dma_reg_count = 40
|
|
compute_reg_count = 232
|
|
return dma_reg_count, compute_reg_count
|
|
|
|
|
|
def get_hopper_instruction_traits(instruction_traits, kernel_spec):
|
|
gmma_shape_p = get_GMMA_shape(instruction_traits, kernel_spec.loop_step,
|
|
kernel_spec.seq_len, kernel_spec.head_size,
|
|
kernel_spec.warps_n)
|
|
|
|
instruction_traits_p = f'{instruction_traits}<{", ".join([str(x) for x in gmma_shape_p])}, false, false>'
|
|
|
|
gmma_shape_o = get_GMMA_shape(instruction_traits, kernel_spec.loop_step,
|
|
kernel_spec.head_size, kernel_spec.seq_len, 1)
|
|
instruction_traits_o = f'{instruction_traits}<{", ".join([str(x) for x in gmma_shape_o])}, true, false>'
|
|
|
|
return instruction_traits_p, instruction_traits_o
|
|
|
|
|
|
def get_effective_sm_and_name(kspec):
|
|
sm = kspec.sm
|
|
# Override the mma instruction with an older one.
|
|
if kspec.sm_mma in sm2name:
|
|
assert kspec.sm_mma <= kspec.sm, "Instruction version should be at most target arch"
|
|
sm = kspec.sm_mma
|
|
sm_name = sm2name[sm]
|
|
return sm, sm_name
|
|
|
|
|
|
def selected_mask_types(kspec):
|
|
# by default, we generate all combinations.
|
|
# '1' means true, '0' means false.
|
|
padding_mask = '1'
|
|
causal_mask = '1'
|
|
sliding_or_chunked_causal_mask = '1'
|
|
custom_mask = '1'
|
|
# only generate certain needed combinations of input_layout and mask types for trt-llm.
|
|
if "GENERATE_CUBIN" in os.environ:
|
|
if kspec.sage_block_sizes:
|
|
# SageAttention only needs padding mask now
|
|
causal_mask = '0'
|
|
sliding_or_chunked_causal_mask = '0'
|
|
custom_mask = '0'
|
|
elif (kspec.head_size, kspec.head_size_v) == (192, 128):
|
|
# MLA context phase only needs causal mask and padding mask (for chunked prefill) now
|
|
sliding_or_chunked_causal_mask = '0'
|
|
custom_mask = '0'
|
|
elif (kspec.head_size, kspec.head_size_v) == (576, 512):
|
|
# MLA generation phase only needs padding mask (MtpMask) now
|
|
causal_mask = '0'
|
|
sliding_or_chunked_causal_mask = '0'
|
|
custom_mask = '0'
|
|
# encoder models (head_size = 32 / 64 / 128) need packed_qkv input layout + padding mask.
|
|
elif kspec.input_layout == InputLayout.PACKED_QKV:
|
|
# NOTE: 72/80 are added for vision transformer
|
|
if kspec.head_size not in [32, 64, 72, 80, 128]:
|
|
padding_mask = '0'
|
|
# only cross attention (head_size = 32/64/128) needs contiguous_q_kv input layout + padding mask / custom_mask.
|
|
elif kspec.input_layout == InputLayout.CONTIGUOUS_Q_KV:
|
|
causal_mask = '0'
|
|
sliding_or_chunked_causal_mask = '0'
|
|
if kspec.head_size not in [32, 64, 72, 128]:
|
|
padding_mask = '0'
|
|
custom_mask = '0'
|
|
# paged kv cache is always needed in gpt variants.
|
|
# cross-attention also needs paged kv cache.
|
|
elif kspec.input_layout == InputLayout.Q_PAGED_KV:
|
|
if kspec.head_size not in [32, 64, 128]:
|
|
padding_mask = '0'
|
|
|
|
# alibi specialized kernels only need causal mask.
|
|
if (kspec.alibi and kspec.warp_specialization):
|
|
padding_mask = '0'
|
|
sliding_or_chunked_causal_mask = '0'
|
|
custom_mask = '0'
|
|
|
|
# enable_attn_logit_softcapping kernels only need causal mask or sliding_or_chunked_causal_mask.
|
|
if kspec.enable_attn_logit_softcapping:
|
|
padding_mask = '0'
|
|
custom_mask = '0'
|
|
|
|
return padding_mask, causal_mask, sliding_or_chunked_causal_mask, custom_mask
|
|
|
|
|
|
def get_kernel_code(kspec, kname, lname):
|
|
min_cuda_version = 0 # no restriction
|
|
|
|
# The architecture that determines the instruction.
|
|
effective_sm, sm_name = get_effective_sm_and_name(kspec)
|
|
|
|
if effective_sm >= 80:
|
|
min_cuda_version = 11000
|
|
|
|
launcher_name = lname
|
|
causal_kernel_name = kname.replace('__placeholder__', '_causal')
|
|
custom_mask_kernel_name = kname.replace('__placeholder__', '_custom_mask')
|
|
sliding_or_chunked_causal_kernel_name = kname.replace(
|
|
'__placeholder__', '_sliding_or_chunked_causal')
|
|
kernel_name = kname.replace('__placeholder__', '')
|
|
|
|
# FIXME: use separate parameters when generating cubins for trtllm.
|
|
if not kspec.cross_mha:
|
|
params_type = 'bert::Fused_multihead_attention_params_v{}'.format(
|
|
kspec.version)
|
|
else:
|
|
params_type = 'bert::Fused_multihead_attention_params_mhca'
|
|
|
|
if (effective_sm < 90):
|
|
instruction_traits = sm_name.capitalize() + '_' + dtype2traits[
|
|
kspec.dtype]
|
|
elif (effective_sm == 90):
|
|
instruction_traits = sm_name.capitalize() + '_' + hopper_dtype2traits[
|
|
kspec.dtype]
|
|
# for hopper, we differentiate instruction_traits_o and instruction_traits_p
|
|
instruction_traits_p, instruction_traits_o = get_hopper_instruction_traits(
|
|
instruction_traits, kspec)
|
|
#print(instruction_traits_p, instruction_traits_o)
|
|
|
|
if (effective_sm < 90):
|
|
if kspec.flash_attention:
|
|
kernel_variant = 'flash_attention'
|
|
else:
|
|
kernel_variant = '1xN' if kspec.warps_m == 1 else '2x2'
|
|
elif (effective_sm == 90):
|
|
if kspec.warps_n > 1:
|
|
# for hopper we slice the problem along the M dim.
|
|
kernel_variant = '4xN' + '_hopper'
|
|
else:
|
|
kernel_variant = '4x1' + '_hopper'
|
|
|
|
if (effective_sm < 90):
|
|
kernel_traits = 'Kernel_traits_'
|
|
elif (effective_sm == 90):
|
|
kernel_traits = 'FMHA_kernel_traits_hopper_'
|
|
|
|
if kspec.interleaved:
|
|
kernel_traits += 'interleaved_v2'
|
|
elif kspec.cross_mha:
|
|
kernel_traits += 'fmhca'
|
|
else:
|
|
kernel_traits += 'v{}'.format(kspec.version)
|
|
|
|
# decide whether to paged_kv kernel traits for ampere-style kernels.
|
|
if effective_sm < 90:
|
|
if kspec.input_layout == InputLayout.Q_PAGED_KV:
|
|
kernel_traits += '_paged_kv_cache'
|
|
elif kspec.input_layout == InputLayout.CONTIGUOUS_Q_KV:
|
|
kernel_traits += '_contiguous_kv_cache'
|
|
elif kspec.input_layout == InputLayout.SEPARATE_Q_K_V:
|
|
kernel_traits += '_q_k_v'
|
|
|
|
flags = 0
|
|
if kspec.ldgsts_q:
|
|
flags |= 1
|
|
if kspec.ldgsts_k:
|
|
flags |= 2
|
|
if kspec.ldgsts_v:
|
|
flags |= 4
|
|
if kspec.share_smem_k_v and not kspec.limit_qk_fragments:
|
|
flags |= 8
|
|
if kspec.has_scale_max:
|
|
flags |= 16
|
|
if not kspec.head_interleaved:
|
|
flags |= 32
|
|
if kspec.limit_qk_fragments:
|
|
flags |= 128
|
|
if kspec.limit_v_fragments:
|
|
flags |= 256
|
|
if kspec.has_noloop:
|
|
# NOTE do not use flags 512 = 0x200 as it is reserved; do not add to flags because it
|
|
# will be selectively added to no-loop kernel trait upon generating .cu templates
|
|
pass
|
|
if kspec.enable_attn_logit_softcapping:
|
|
flags |= 2048
|
|
if kspec.tiled:
|
|
flags |= 4096
|
|
if kspec.is_mtp:
|
|
flags |= 8192
|
|
|
|
# only generate certain needed combinations of input_layout and mask types for trt-llm.
|
|
padding_mask, causal_mask, sliding_or_chunked_causal_mask, custom_mask = \
|
|
selected_mask_types(kspec)
|
|
|
|
if any(selected_mask_flag == '1'
|
|
for selected_mask_flag in selected_mask_types(kspec)):
|
|
padding_mask, causal_mask, sliding_or_chunked_causal_mask, custom_mask = \
|
|
selected_mask_types(kspec)
|
|
else:
|
|
return None
|
|
|
|
kernel_flags = '0x{:02x}u'.format(flags)
|
|
|
|
heads_interleaved_flag = pythonBoolean2cpp[kspec.head_interleaved]
|
|
|
|
disable_fadd_trick = 1 if effective_sm >= 86 else 0 # this will force generating F2IP
|
|
|
|
enable_mutex_flag = enable_mutex(kspec)
|
|
|
|
has_alibi = pythonBoolean2cpp[kspec.alibi]
|
|
|
|
input_layout_flag = str(int(kspec.input_layout))
|
|
|
|
run_fct_name = 'run_packed_qkv' if kspec.input_layout == InputLayout.PACKED_QKV else \
|
|
'run_separate_q_and_kv'
|
|
|
|
dma_reg_count, compute_reg_count = get_reg_count(kspec)
|
|
|
|
use_tma_store_flag = enable_tma_store(kspec)
|
|
|
|
enable_attn_logit_softcapping_flag = pythonBoolean2cpp[
|
|
kspec.enable_attn_logit_softcapping]
|
|
|
|
return_softmax_stats_flag = pythonBoolean2cpp[kspec.return_softmax_stats]
|
|
|
|
enable_skip_softmax_flag = pythonBoolean2cpp[kspec.enable_skip_softmax]
|
|
|
|
# needed by warpspec kernels.
|
|
fp8_kernel = kspec.dtype in ["e4m3", "e4m3_fp32"]
|
|
kernel_traits_header = "fmha::ws::Kernel_traits_Hopper_qgmma_e4m3_fp32<" if fp8_kernel \
|
|
else f"fmha::ws::Kernel_traits<fmha::{instruction_traits},"
|
|
|
|
# output type.
|
|
output_dtype_ = f"fmha::{dtype2OutputType[kspec.output_dtype if kspec.output_dtype is not None else kspec.dtype]}"
|
|
|
|
# sage attention block sizes.
|
|
sage_block_size_q = 0
|
|
sage_block_size_k = 0
|
|
sage_block_size_v = 0
|
|
if fp8_kernel and kspec.sage_block_sizes:
|
|
assert kspec.output_dtype is not None, "output_dtype must be specified for fp8 sage attention kernels"
|
|
sage_block_size_q = kspec.sage_block_sizes[0]
|
|
sage_block_size_k = kspec.sage_block_sizes[1]
|
|
sage_block_size_v = kspec.sage_block_sizes[2]
|
|
|
|
TMA_config = r'''
|
|
// TMA configuration
|
|
// Note that this may only need to init once during inference (for different layers)
|
|
// Reuse the same traits for initializing tma descriptors.
|
|
fmha::ws::DMA<Ktraits>::Host dma_host;
|
|
dma_host.init_params(params, launch_params, stream);
|
|
''' if not generate_cu_trtllm else ''
|
|
params_str = 'reinterpret_cast<bert::Fused_multihead_attention_params_v2 &>(params)' if generate_cu_trtllm else 'params'
|
|
attn_mask_type_str = 'using Attention_mask_type = ContextAttentionMaskType;' if generate_cu_trtllm else 'using Attention_mask_type = fmha::Attention_mask_type;'
|
|
bert_launch_params = '' if generate_cu_trtllm else 'using Launch_params = bert::Fused_multihead_attention_launch_params;'
|
|
include_str = '#include "../fused_multihead_attention_common.h"\n' if generate_cu_trtllm else ''
|
|
include_str += '#include "tensorrt_llm/common/config.h"' if generate_cu_trtllm else ''
|
|
num_compute_groups_str = '' if generate_cu_trtllm else 'static constexpr int NUM_COMPUTE_GROUPS = 2;'
|
|
fused_multihead_attention_params_v2_str = 'Fused_multihead_attention_params_v2' if generate_cu_trtllm else f'{params_type}'
|
|
const_fused_multihead_attention_params_v2_str = 'Fused_multihead_attention_params_v2' if generate_cu_trtllm else f'const {params_type}'
|
|
setmaxnreg_dma_str = r'''
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900
|
|
const int DMA_REG_COUNT = {dma_reg_count};
|
|
asm volatile("{{setmaxnreg.dec.sync.aligned.u32 %0; \n\t}}" ::"n"(DMA_REG_COUNT));
|
|
#else
|
|
asm volatile("trap;\n");
|
|
#endif
|
|
'''.format(dma_reg_count=dma_reg_count) if generate_cu_trtllm else r'''
|
|
const int DMA_REG_COUNT = {dma_reg_count};
|
|
asm volatile("{{setmaxnreg.dec.sync.aligned.u32 %0; \n\t}}" ::"n"(DMA_REG_COUNT));'''.format(
|
|
dma_reg_count=dma_reg_count)
|
|
setmaxnreg_compute_str = r'''
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900
|
|
const int COMPUTE_REG_COUNT = {compute_reg_count};
|
|
asm volatile("{{setmaxnreg.inc.sync.aligned.u32 %0; \n\t}}" ::"n"(COMPUTE_REG_COUNT));
|
|
#else
|
|
asm volatile("trap;\n");
|
|
#endif
|
|
'''.format(compute_reg_count=compute_reg_count) if generate_cu_trtllm else r'''
|
|
const int COMPUTE_REG_COUNT = {compute_reg_count};
|
|
asm volatile("{{setmaxnreg.inc.sync.aligned.u32 %0; \n\t}}" ::"n"(COMPUTE_REG_COUNT));'''.format(
|
|
compute_reg_count=compute_reg_count)
|
|
abi_ns_open = r"""
|
|
TRTLLM_NAMESPACE_BEGIN
|
|
namespace kernels
|
|
{
|
|
// clang-format off
|
|
"""
|
|
abi_ns_close = r"""
|
|
// clang-format on
|
|
} // namespace kernels
|
|
TRTLLM_NAMESPACE_END
|
|
"""
|
|
local_ns_open = abi_ns_open if generate_cu_trtllm else ''
|
|
local_ns_close = abi_ns_close if generate_cu_trtllm else ''
|
|
|
|
tmp = dict(locals(), **kspec._asdict())
|
|
|
|
if (effective_sm < 90):
|
|
if kspec.flash_attention:
|
|
code = flash_attention_kernel_template.format(
|
|
**tmp,
|
|
copyright=copyright,
|
|
use_multi_cta=False,
|
|
MAX_STGS_PER_LOOP=MAX_STGS_PER_LOOP)
|
|
else:
|
|
use_multi_cta = 1 if kspec.ctas_per_head > 1 else 0
|
|
code = kernel_template.format(**tmp,
|
|
copyright=copyright,
|
|
use_multi_cta=use_multi_cta,
|
|
MAX_STGS_PER_LOOP=MAX_STGS_PER_LOOP)
|
|
elif (effective_sm == 90):
|
|
use_tma = 1
|
|
if (kspec.ldgsts_q == True):
|
|
use_tma = 0
|
|
if kspec.warp_specialization:
|
|
code = kernel_hopper_warp_specialization_template.format(
|
|
**tmp,
|
|
copyright=copyright,
|
|
use_tma=use_tma,
|
|
bytes_per_elt=dtype2bytes[kspec.dtype])
|
|
else:
|
|
code = kernel_hopper_template.format(**tmp,
|
|
copyright=copyright,
|
|
use_tma=use_tma)
|
|
return code
|
|
|
|
|
|
def get_api_code(specs_names):
|
|
|
|
def get_signature(lname, version, cross_mha, use_tma):
|
|
# The architecture that determines the instruction.
|
|
effective_sm, sm_name = get_effective_sm_and_name(kspec)
|
|
if cross_mha:
|
|
return 'void {}(const Params_mhca ¶ms, cudaStream_t stream);'.format(
|
|
lname)
|
|
elif effective_sm >= 90:
|
|
# need to set tma desc in params
|
|
return 'void {}(Params_v{} ¶ms, const Launch_params &launch_params, cudaStream_t stream);'.format(
|
|
lname, version)
|
|
else:
|
|
return 'void {}(const Params_v{} ¶ms, const Launch_params &launch_params, cudaStream_t stream);'.format(
|
|
lname, version)
|
|
|
|
signatures = []
|
|
for kspec, fname, lname, kname in specs_names:
|
|
effective_sm, _ = get_effective_sm_and_name(kspec)
|
|
use_tma = effective_sm == 90 and not kspec.ldgsts_q
|
|
signatures.append(
|
|
get_signature(lname, kspec.version, kspec.cross_mha, use_tma))
|
|
if kspec.has_noloop and not kspec.tiled:
|
|
signatures.append(
|
|
get_signature(lname + '_nl', kspec.version, kspec.cross_mha,
|
|
use_tma))
|
|
elif kspec.tiled:
|
|
signatures.append(
|
|
get_signature(lname + '_nl_tiled', kspec.version,
|
|
kspec.cross_mha, use_tma))
|
|
if not kspec.warp_specialization:
|
|
signatures.append(
|
|
'void {}_get_max_heads_per_wave(int*);'.format(lname))
|
|
signatures = '\n'.join(signatures)
|
|
|
|
#v1
|
|
# - normal
|
|
# - no loop
|
|
#v2
|
|
# - normal
|
|
# - no loop
|
|
# - normal interleaved
|
|
# - no loop interleaved
|
|
# - flash attention no loop
|
|
# - flash attention no loop tiled
|
|
# - flash attention warp_specialized (on Hopper)
|
|
|
|
def gen_unroll_check(kspec):
|
|
code = 'if (!{has_noloop} || (!force_unroll && (ignore_b1opt || b > {unroll_threshold})))'.format(
|
|
**kspec._asdict())
|
|
if kspec.flash_attention:
|
|
code = 'if (!{has_noloop} || (!force_unroll && (ignore_b1opt || b * h > {unroll_threshold})))'.format(
|
|
**kspec._asdict())
|
|
return code
|
|
|
|
def gen_call(kspec, lname):
|
|
effective_sm, _ = get_effective_sm_and_name(kspec)
|
|
data_type = dtype2typename[kspec.dtype]
|
|
output_data_type = data_type
|
|
if kspec.output_dtype:
|
|
output_data_type = dtype2typename[kspec.output_dtype]
|
|
il_check = ''
|
|
if kspec.version == 2 and kspec.dtype in ["fp16", "bf16"]:
|
|
il_check += "&& use_flash_attention " if kspec.flash_attention else '&& !use_flash_attention '
|
|
if kspec.version == 2:
|
|
# attention input layout.
|
|
il_check += f'&& attention_input_layout == {kspec.input_layout.value} '
|
|
# interleaved layout or not.
|
|
il_check += '&& interleaved ' if kspec.interleaved else '&& !interleaved '
|
|
if effective_sm == 90:
|
|
il_check += "&& !use_tma " if kspec.ldgsts_q else '&& use_tma '
|
|
il_check += "&& warp_specialization " if kspec.warp_specialization else '&& !warp_specialization '
|
|
else:
|
|
il_check += "&& !warp_specialization && !use_tma "
|
|
# Different accumulation types.
|
|
if '_fp32' in kspec.dtype or 'bf16' in kspec.dtype or kspec.dtype == 'e4m3':
|
|
il_check += '&& force_fp32_acc '
|
|
else:
|
|
il_check += '&& !force_fp32_acc '
|
|
# whether support alibi or not.
|
|
if kspec.warp_specialization:
|
|
il_check += '&& params.has_alibi ' if kspec.alibi else '&& !params.has_alibi '
|
|
il_check += '&& params.softmax_stats_ptr != nullptr ' if kspec.return_softmax_stats else '&& params.softmax_stats_ptr == nullptr '
|
|
# use enable_attn_logit_softcapping or not.
|
|
il_check += '&& enable_attn_logit_softcapping ' if kspec.enable_attn_logit_softcapping else '&& !enable_attn_logit_softcapping '
|
|
# check sage block sizes
|
|
sage_block_size_q = 0
|
|
sage_block_size_k = 0
|
|
sage_block_size_v = 0
|
|
if kspec.sage_block_sizes:
|
|
# override the data_type to output type, otherwise it is always E4M3
|
|
data_type = output_data_type
|
|
sage_block_size_q = kspec.sage_block_sizes[0]
|
|
sage_block_size_k = kspec.sage_block_sizes[1]
|
|
sage_block_size_v = kspec.sage_block_sizes[2]
|
|
il_check += f'&& sage_block_size_q == {sage_block_size_q} ' \
|
|
f'&& sage_block_size_k == {sage_block_size_k} ' \
|
|
f'&& sage_block_size_v == {sage_block_size_v} '
|
|
|
|
il_check += '&& enable_skip_softmax ' if kspec.enable_skip_softmax else '&& !enable_skip_softmax '
|
|
|
|
il_check += '&& params.use_int8_scale_max ' if kspec.has_scale_max else '&& !params.use_int8_scale_max '
|
|
|
|
slen = kspec.seq_len * kspec.ctas_per_head if not kspec.flash_attention else 0
|
|
|
|
## NOTE: need to tune here
|
|
if kspec.has_noloop and not kspec.flash_attention:
|
|
call_stmt = '''\
|
|
if( data_type == {data_type} && output_data_type == {output_data_type} && s == {slen} && d == {head_size} && sm == {sm}
|
|
{il_check}) {{
|
|
|
|
{unroll_check} {{
|
|
{lname}(params, launch_params, stream);
|
|
}} else {{
|
|
{lname}_nl(params, launch_params, stream);
|
|
}}
|
|
|
|
}} '''.format(**kspec._asdict(),
|
|
data_type=data_type,
|
|
output_data_type=output_data_type,
|
|
slen=slen,
|
|
lname=lname,
|
|
il_check=il_check,
|
|
unroll_check=gen_unroll_check(kspec))
|
|
|
|
elif kspec.flash_attention: #NOTE: flash attention uses no_loop as default
|
|
# TypeError: got multiple values for keyword argument if using key 'head_size_v', so 'dv' instead
|
|
dv = kspec.head_size_v or kspec.head_size
|
|
if kspec.tiled: # higher precedence; does not require bh_upper_thres
|
|
call_stmt = '''\
|
|
if( data_type == {data_type} && output_data_type == {output_data_type} && d == {head_size} && dv == {dv} && sm == {sm}
|
|
{il_check} && use_tiled) {{
|
|
|
|
{lname}_nl_tiled(params, launch_params, stream);
|
|
|
|
}} '''.format(**kspec._asdict(),
|
|
data_type=data_type,
|
|
output_data_type=output_data_type,
|
|
slen=slen,
|
|
lname=lname,
|
|
il_check=il_check,
|
|
dv=dv)
|
|
# warp specialization kernels need launch_params
|
|
elif kspec.warp_specialization:
|
|
call_stmt = '''\
|
|
if( data_type == {data_type} && output_data_type == {output_data_type} && d == {head_size} && dv == {dv} && sm == {sm}
|
|
{il_check}) {{
|
|
|
|
{lname}(params, launch_params, stream);
|
|
|
|
}} '''.format(**kspec._asdict(),
|
|
data_type=data_type,
|
|
output_data_type=output_data_type,
|
|
slen=slen,
|
|
lname=lname,
|
|
il_check=il_check,
|
|
dv=dv)
|
|
else:
|
|
call_stmt = '''\
|
|
if( data_type == {data_type} && output_data_type == {output_data_type} && d == {head_size} && dv == {dv} && sm == {sm}
|
|
&& !use_tiled {il_check}) {{
|
|
|
|
{lname}_nl(params, launch_params, stream);
|
|
|
|
}} '''.format(**kspec._asdict(),
|
|
data_type=data_type,
|
|
output_data_type=output_data_type,
|
|
slen=slen,
|
|
lname=lname,
|
|
il_check=il_check,
|
|
dv=dv)
|
|
else:
|
|
call_stmt = '''\
|
|
if( data_type == {data_type} && output_data_type == {output_data_type} && s == {slen} && d == {head_size} && sm == {sm}
|
|
{il_check}) {{
|
|
|
|
{lname}(params, launch_params, stream);
|
|
|
|
}} '''.format(**kspec._asdict(),
|
|
data_type=data_type,
|
|
output_data_type=output_data_type,
|
|
slen=slen,
|
|
lname=lname,
|
|
il_check=il_check)
|
|
return call_stmt
|
|
|
|
def gen_call_fmhca(kspec, lname):
|
|
effective_sm, _ = get_effective_sm_and_name(kspec)
|
|
data_type = dtype2typename[kspec.dtype]
|
|
il_check = ''
|
|
if kspec.version == 2:
|
|
il_check = '&& interleaved ' if kspec.interleaved else '&& !interleaved '
|
|
if effective_sm == 90:
|
|
il_check += "&& !use_tma " if kspec.ldgsts_q else '&& use_tma '
|
|
il_check += '&& params.use_int8_scale_max ' if kspec.has_scale_max else '&& !params.use_int8_scale_max '
|
|
|
|
s_kv_len = kspec.seq_len
|
|
if kspec.has_noloop:
|
|
call_stmt = '''\
|
|
if( data_type == {data_type} && s_kv == {s_kv_len} && d == {head_size} && sm == {sm} {il_check}) {{
|
|
|
|
{unroll_check} {{
|
|
{lname}(params, stream);
|
|
}} else {{
|
|
{lname}_nl(params, stream);
|
|
}}
|
|
|
|
}} '''.format(**kspec._asdict(),
|
|
data_type=data_type,
|
|
s_kv_len=s_kv_len,
|
|
lname=lname,
|
|
il_check=il_check,
|
|
unroll_check=gen_unroll_check(kspec))
|
|
|
|
else:
|
|
call_stmt = '''\
|
|
if( data_type == {data_type} && s_kv == {s_kv_len} && d == {head_size} && sm == {sm} {il_check}) {{
|
|
{lname}(params, stream);
|
|
}} '''.format(**kspec._asdict(),
|
|
data_type=data_type,
|
|
s_kv_len=s_kv_len,
|
|
lname=lname,
|
|
il_check=il_check)
|
|
return call_stmt
|
|
|
|
calls_v2 = [
|
|
gen_call(kspec, lname)
|
|
for kspec, fname, lname, kname in specs_names \
|
|
if kspec.version == 2 and kspec.cross_mha == 0
|
|
]
|
|
|
|
calls_v2 = 'else '.join(calls_v2) if len(calls_v2) > 0 else 'if( false ) {}'
|
|
|
|
calls_v1 = [
|
|
gen_call(kspec, lname) for kspec, fname, lname, kname in specs_names
|
|
if kspec.version == 1 and kspec.cross_mha == 0
|
|
]
|
|
|
|
calls_v1 = 'else '.join(calls_v1) if len(calls_v1) > 0 else 'if( false ) {}'
|
|
|
|
calls_mhca = [
|
|
gen_call_fmhca(kspec, lname)
|
|
for kspec, fname, lname, kname in specs_names if kspec.cross_mha == 1
|
|
]
|
|
|
|
calls_mhca = 'else '.join(calls_mhca) if len(
|
|
calls_mhca) > 0 else 'if( false ) {}'
|
|
|
|
def gen_warp_spec(kspec):
|
|
data_type = dtype2typename[kspec.dtype]
|
|
if kspec.sage_block_sizes is not None:
|
|
assert kspec.output_dtype is not None
|
|
# override the data_type to output type, otherwise it is always E4M3
|
|
data_type = dtype2typename[kspec.output_dtype]
|
|
slen = kspec.seq_len * kspec.ctas_per_head
|
|
effective_sm, _ = get_effective_sm_and_name(kspec)
|
|
warp_spec_check = ''
|
|
nl_warps_m = kspec.warps_m if effective_sm == 90 else 1
|
|
nl_warps_n = kspec.warps_n if effective_sm == 90 else kspec.warps_m * kspec.warps_n
|
|
if kspec.version == 2 and kspec.dtype in ["fp16", "bf16"]:
|
|
warp_spec_check += "&& use_flash_attention " if kspec.flash_attention else '&& !use_flash_attention '
|
|
if kspec.version == 2:
|
|
if effective_sm == 90:
|
|
warp_spec_check += "&& !use_tma " if kspec.ldgsts_q else '&& use_tma '
|
|
warp_spec_check += "&& warp_specialization " if kspec.warp_specialization else '&& !warp_specialization '
|
|
else:
|
|
warp_spec_check += '&& !use_tma && !warp_specialization '
|
|
|
|
if kspec.flash_attention: # NOTE support any sequence
|
|
return '''\
|
|
if( data_type == {data_type} && d == {head_size} && sm == {sm} {warp_spec_check}
|
|
&& version == {version} ) {{
|
|
warps_m = {warps_m};
|
|
warps_n = {warps_n};
|
|
}} '''.format(**locals(),
|
|
**kspec._asdict(),
|
|
unroll_check=gen_unroll_check(kspec))
|
|
return '''\
|
|
if( data_type == {data_type} && s == {slen} && d == {head_size} && sm == {sm} {warp_spec_check}
|
|
&& version == {version} ) {{
|
|
{unroll_check} {{
|
|
warps_m = {warps_m};
|
|
warps_n = {warps_n};
|
|
}} else {{
|
|
warps_m = {nl_warps_m};
|
|
warps_n = {nl_warps_n};
|
|
}}
|
|
}} '''.format(**locals(),
|
|
**kspec._asdict(),
|
|
unroll_check=gen_unroll_check(kspec))
|
|
|
|
warp_specs = 'else '.join([gen_warp_spec(spec[0]) for spec in specs_names])
|
|
if len(warp_specs) > 0:
|
|
warp_specs += 'else {\n\tassert(false && "Unsupported config");\n}'
|
|
|
|
# Generate the cta spec.
|
|
def gen_cta_spec(spec):
|
|
kspec, _, lname, _ = spec
|
|
slen = kspec.seq_len * kspec.ctas_per_head
|
|
return '''\
|
|
if( data_type == {data_type} && s == {slen} && d == {head_size} && use_multi_ctas
|
|
&& version == {version} ) {{
|
|
|
|
ctas_per_head = {ctas_per_head};
|
|
{lname}_get_max_heads_per_wave(&max_heads_per_wave);
|
|
|
|
}} '''.format(**locals(),
|
|
**kspec._asdict(),
|
|
data_type=dtype2typename[kspec.dtype])
|
|
|
|
cta_specs = 'else '.join([
|
|
gen_cta_spec(spec) for spec in specs_names if spec[0].ctas_per_head > 1
|
|
])
|
|
|
|
api_code = '''\
|
|
{copyright}
|
|
#pragma once
|
|
|
|
#include <cuda.h>
|
|
#include <fused_multihead_attention.h>
|
|
#include <fused_multihead_cross_attention.h>
|
|
#include <tuple>
|
|
|
|
using Params_v1 = bert::Fused_multihead_attention_params_v1;
|
|
using Params_v2 = bert::Fused_multihead_attention_params_v2;
|
|
using Params_mhca = bert::Fused_multihead_attention_params_mhca;
|
|
using Launch_params = bert::Fused_multihead_attention_launch_params;
|
|
|
|
{signatures}
|
|
|
|
inline void run_fmha_v1(Params_v1 ¶ms,
|
|
const Launch_params &launch_params,
|
|
Data_type data_type,
|
|
Data_type output_data_type,
|
|
int sm,
|
|
cudaStream_t stream=0){{
|
|
const size_t s = params.s;
|
|
const size_t b = params.b;
|
|
const size_t d = params.d;
|
|
const bool force_unroll = launch_params.force_unroll;
|
|
const bool ignore_b1opt = launch_params.ignore_b1opt;
|
|
|
|
const bool use_flash_attention = false;
|
|
|
|
{calls_v1}
|
|
else {{
|
|
assert(false && "Unsupported config.");
|
|
}}
|
|
|
|
}}
|
|
|
|
// Note: transitioning to moving kernel launch parameters into launch_params to reduce the
|
|
// occurrences the interface needs to be modified
|
|
inline void run_fmha_v2(Params_v2 ¶ms,
|
|
const Launch_params &launch_params,
|
|
Data_type data_type,
|
|
Data_type output_data_type,
|
|
int sm,
|
|
cudaStream_t stream=0) {{
|
|
|
|
const size_t s = params.s;
|
|
const size_t b = params.b;
|
|
const size_t h = params.h;
|
|
const size_t d = params.d;
|
|
const size_t dv = params.dv;
|
|
const size_t sage_block_size_q = params.sage.q.block_size;
|
|
const size_t sage_block_size_k = params.sage.k.block_size;
|
|
const size_t sage_block_size_v = params.sage.v.block_size;
|
|
|
|
const bool interleaved = launch_params.interleaved;
|
|
const bool force_unroll = launch_params.force_unroll;
|
|
const bool ignore_b1opt = launch_params.ignore_b1opt;
|
|
const bool force_fp32_acc = launch_params.force_fp32_acc;
|
|
const bool warp_specialization = launch_params.warp_specialization;
|
|
const bool use_tma = launch_params.use_tma;
|
|
const bool use_flash_attention = launch_params.flash_attention;
|
|
const bool enable_attn_logit_softcapping = launch_params.enable_attn_logit_softcapping;
|
|
const bool enable_skip_softmax = launch_params.enable_skip_softmax;
|
|
const int attention_input_layout = static_cast<int>(launch_params.attention_input_layout);
|
|
// tiled variant uses ldgsts
|
|
const bool use_tiled = launch_params.use_granular_tiling;
|
|
|
|
{calls_v2}
|
|
else {{
|
|
assert(false && "Unsupported config.");
|
|
}}
|
|
|
|
}}
|
|
|
|
#if __guard_fmhca_placeholder__ // fmhca api header
|
|
|
|
inline void run_fmhca(Params_mhca ¶ms,
|
|
const Launch_params &launch_params,
|
|
Data_type data_type,
|
|
int sm,
|
|
cudaStream_t stream=0) {{
|
|
|
|
const size_t s_kv = params.s;
|
|
const size_t b = params.b;
|
|
const size_t d = params.d_padded;
|
|
|
|
const bool interleaved = launch_params.interleaved;
|
|
const bool force_unroll = launch_params.force_unroll;
|
|
const bool ignore_b1opt = launch_params.ignore_b1opt;
|
|
|
|
{calls_mhca}
|
|
else {{
|
|
assert(false && "Unsupported config");
|
|
}}
|
|
|
|
}}
|
|
|
|
#endif // fmhca api header
|
|
|
|
inline std::tuple<size_t, size_t, size_t> get_warps(Launch_params& launch_params,
|
|
int sm,
|
|
Data_type data_type,
|
|
size_t s,
|
|
size_t b,
|
|
size_t d,
|
|
int version) {{
|
|
size_t warps_m, warps_n, warps_k = 1;
|
|
const bool interleaved = launch_params.interleaved;
|
|
const bool use_tma = launch_params.use_tma;
|
|
const bool force_unroll = launch_params.force_unroll;
|
|
const bool ignore_b1opt = launch_params.ignore_b1opt;
|
|
const bool use_flash_attention = launch_params.flash_attention;
|
|
// tiled variant uses ldgsts
|
|
const bool use_tiled = launch_params.use_granular_tiling;
|
|
const bool warp_specialization = launch_params.warp_specialization;
|
|
|
|
{warp_specs}
|
|
|
|
return std::make_tuple(warps_m, warps_n, warps_k);
|
|
}}
|
|
|
|
// The constant is defined in "setup.py".
|
|
constexpr int MAX_STGS_PER_LOOP = {MAX_STGS_PER_LOOP};
|
|
|
|
// The number of CTAs and threads per CTA to launch the kernel.
|
|
inline void get_grid_size(int &heads_per_wave,
|
|
int &ctas_per_head,
|
|
int sm,
|
|
Data_type data_type,
|
|
size_t b,
|
|
size_t s,
|
|
size_t h,
|
|
size_t d,
|
|
bool use_multi_ctas,
|
|
int version) {{
|
|
|
|
// Determine the number of CTAs per head (kernel constant).
|
|
int max_heads_per_wave = 0;
|
|
ctas_per_head = 1;
|
|
heads_per_wave = b*h;
|
|
{cta_specs}
|
|
|
|
// Adjust the number of heads per wave.
|
|
if( heads_per_wave > max_heads_per_wave ) {{
|
|
heads_per_wave = max_heads_per_wave;
|
|
}}
|
|
}}
|
|
|
|
'''.format(**locals(), copyright=copyright, MAX_STGS_PER_LOOP=MAX_STGS_PER_LOOP)
|
|
return api_code
|
|
|
|
|
|
ktraits_code_template = '''
|
|
#include "fused_multihead_attention_kernel.h"
|
|
#include "fmha/kernel_traits.h"
|
|
#include "fmha/hopper/kernel_traits.h"
|
|
#include <fmha/warpspec/kernel_traits.h>
|
|
|
|
using namespace fmha;
|
|
|
|
int main(){{
|
|
{print_kernel_specs}
|
|
}}
|
|
'''
|
|
|
|
|
|
def get_kernel_traits_code(specs_names):
|
|
print_kernel_specs = []
|
|
|
|
for kspec, fname, lname, kname in specs_names:
|
|
effective_sm, sm_name = get_effective_sm_and_name(kspec)
|
|
if (effective_sm < 90):
|
|
instruction_traits = sm_name.capitalize() + '_' + dtype2traits[
|
|
kspec.dtype]
|
|
elif (effective_sm == 90):
|
|
instruction_traits = sm_name.capitalize(
|
|
) + '_' + hopper_dtype2traits[kspec.dtype]
|
|
instruction_traits_p, instruction_traits_o = get_hopper_instruction_traits(
|
|
instruction_traits, kspec)
|
|
|
|
if (effective_sm < 90):
|
|
kernel_traits = 'Kernel_traits_'
|
|
elif (effective_sm == 90):
|
|
kernel_traits = 'FMHA_kernel_traits_hopper_'
|
|
|
|
if kspec.interleaved:
|
|
kernel_traits += 'interleaved_v2'
|
|
elif kspec.cross_mha:
|
|
kernel_traits += 'fmhca'
|
|
else:
|
|
kernel_traits += 'v{}'.format(kspec.version)
|
|
|
|
# needed by warpspec kernels.
|
|
fp8_kernel = kspec.dtype in ["e4m3", "e4m3_fp32"]
|
|
kernel_traits_header = "fmha::ws::Kernel_traits_Hopper_qgmma_e4m3_fp32<" if fp8_kernel \
|
|
else f"fmha::ws::Kernel_traits<fmha::{instruction_traits},"
|
|
|
|
flags = 0
|
|
if kspec.ldgsts_q:
|
|
flags |= 1
|
|
if kspec.ldgsts_k:
|
|
flags |= 2
|
|
if kspec.ldgsts_v:
|
|
flags |= 4
|
|
if kspec.share_smem_k_v:
|
|
flags |= 8
|
|
if kspec.has_scale_max:
|
|
flags |= 16
|
|
if not kspec.head_interleaved:
|
|
flags |= 32
|
|
if kspec.limit_qk_fragments:
|
|
flags |= 128
|
|
if kspec.limit_qk_fragments:
|
|
flags |= 256
|
|
if kspec.has_noloop:
|
|
# NOTE do not use flags 512 = 0x200 as it is reserved; do not add to flags because it
|
|
# will be selectively added to no-loop kernel trait upon generating .cu templates
|
|
pass
|
|
if kspec.enable_attn_logit_softcapping:
|
|
flags |= 2048
|
|
if kspec.tiled:
|
|
flags |= 4096
|
|
if kspec.is_mtp:
|
|
flags |= 8192
|
|
|
|
kernel_flags = '0x{:02x}u'.format(flags)
|
|
|
|
heads_interleaved_flag = pythonBoolean2cpp[kspec.head_interleaved]
|
|
|
|
enable_mutex_flag = enable_mutex(kspec)
|
|
|
|
has_alibi = pythonBoolean2cpp[kspec.alibi]
|
|
|
|
return_softmax_stats_flag = pythonBoolean2cpp[
|
|
kspec.return_softmax_stats]
|
|
|
|
input_layout_flag = str(int(kspec.input_layout))
|
|
|
|
enable_attn_logit_softcapping_flag = pythonBoolean2cpp[
|
|
kspec.enable_attn_logit_softcapping]
|
|
|
|
enable_skip_softmax_flag = pythonBoolean2cpp[kspec.enable_skip_softmax]
|
|
|
|
tmp = dict(locals(), **kspec._asdict())
|
|
|
|
if effective_sm < 90:
|
|
snippet = ''' {{
|
|
using Kernel_traits = {kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{seq_len},
|
|
{head_size},
|
|
{head_size_v},
|
|
{loop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags}>;
|
|
printf("%s %d %d %s %d %d\\n",
|
|
\"{kernel_name}\",
|
|
Kernel_traits::BYTES_PER_SMEM,
|
|
Kernel_traits::THREADS,
|
|
\"{fname}\",
|
|
{loop_step},
|
|
{unroll_threshold});
|
|
}}'''.format(**tmp, kernel_name=kname.replace('__placeholder__', ''))
|
|
snippet_nl = ''' {{
|
|
using Kernel_traits = {kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{seq_len},
|
|
{head_size},
|
|
{head_size_v},
|
|
{noloop_step},
|
|
1,
|
|
{warps_m} * {warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags} | 0x200 /* no_loop flag */>;
|
|
printf("%s %d %d %s %d %d\\n",
|
|
\"{kernel_name}_nl\",
|
|
Kernel_traits::BYTES_PER_SMEM,
|
|
Kernel_traits::THREADS,
|
|
\"{fname}\",
|
|
{noloop_step},
|
|
{unroll_threshold});
|
|
}}'''.format(**tmp, kernel_name=kname.replace('__placeholder__', ''))
|
|
snippet_flash = ''' {{
|
|
using Kernel_traits = {kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{loop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags}>;
|
|
printf("%s %d %d %s %d %d\\n",
|
|
\"{kernel_name}\",
|
|
Kernel_traits::BYTES_PER_SMEM,
|
|
Kernel_traits::THREADS,
|
|
\"{fname}\",
|
|
{loop_step},
|
|
{unroll_threshold});
|
|
}}'''.format(**tmp, kernel_name=kname.replace('__placeholder__', ''))
|
|
snippet_flash_nl_template = ''' {{
|
|
using Kernel_traits = {kernel_traits}<
|
|
fmha::{instruction_traits},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{noloop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
{ctas_per_head},
|
|
{kernel_flags} | 0x200 /* no_loop flag */>;
|
|
printf("%s %d %d %s %d %d\\n",
|
|
\"{kname}_nl\",
|
|
Kernel_traits::BYTES_PER_SMEM,
|
|
Kernel_traits::THREADS,
|
|
\"{fname}\",
|
|
{noloop_step},
|
|
{unroll_threshold});
|
|
}}'''.format(**tmp)
|
|
snippet_flash_nl = snippet_flash_nl_template.replace(
|
|
'__placeholder__', '')
|
|
snippet_flash_nl_tiled = snippet_flash_nl_template.replace(
|
|
'__placeholder__', '').replace('_nl', '_nl_tiled')
|
|
snippet_flash_nl_causal = snippet_flash_nl_template.replace(
|
|
'__placeholder__', '_causal')
|
|
snippet_flash_nl_tiled_causal = snippet_flash_nl_template.replace(
|
|
'__placeholder__', '_causal').replace('_nl', '_nl_tiled')
|
|
snippet_flash_nl_sliding_or_chunked_causal = snippet_flash_nl_template.replace(
|
|
'__placeholder__', '_sliding_or_chunked_causal')
|
|
snippet_flash_nl_tiled_sliding_or_chunked_causal = snippet_flash_nl_template.replace(
|
|
'__placeholder__',
|
|
'_sliding_or_chunked_causal').replace('_nl', '_nl_tiled')
|
|
snippet_flash_nl_custom_mask = snippet_flash_nl_template.replace(
|
|
'__placeholder__', '_custom_mask')
|
|
snippet_flash_nl_tiled_custom_mask = snippet_flash_nl_template.replace(
|
|
'__placeholder__', '_custom_mask').replace('_nl', '_nl_tiled')
|
|
elif effective_sm >= 90 and kspec.warp_specialization: #GMMA warpspec flash
|
|
snippet_ws_template = ''' {{
|
|
static constexpr int DMA2COMPUTE_DEPTH = 1;
|
|
static constexpr int NUM_COMPUTE_GROUPS = 2;
|
|
|
|
using Kernel_traits = {kernel_traits_header}
|
|
{loop_step},
|
|
{kv_loop_step},
|
|
{head_size},
|
|
{head_size_v},
|
|
{q_tile_buffers},
|
|
{kv_tile_buffers},
|
|
NUM_COMPUTE_GROUPS,
|
|
DMA2COMPUTE_DEPTH,
|
|
mask_type,
|
|
{heads_interleaved_flag},
|
|
{has_alibi},
|
|
{enable_mutex_flag},
|
|
{scheduling_mode},
|
|
{input_layout_flag},
|
|
__use_tma_store__ /* USE_TMA_STORE */,
|
|
{enable_attn_logit_softcapping_flag},
|
|
{return_softmax_stats_flag},
|
|
{enable_skip_softmax_flag}>;
|
|
|
|
printf("%s %d %d %s %d %d\\n",
|
|
\"{kname}\",
|
|
Kernel_traits::BYTES_PER_SMEM,
|
|
Kernel_traits::THREADS,
|
|
\"{fname}\",
|
|
{loop_step},
|
|
{unroll_threshold});
|
|
}}'''.format(**tmp)
|
|
snippet_ws = snippet_ws_template.replace('__placeholder__', '').\
|
|
replace('mask_type', '0').\
|
|
replace('__use_tma_store__', 'true')
|
|
snippet_ws_causal = snippet_ws_template.replace('__placeholder__', '_causal').\
|
|
replace('mask_type', '1').\
|
|
replace('__use_tma_store__', 'true')
|
|
snippet_ws_sliding_or_chunked_causal = \
|
|
snippet_ws_template.replace('__placeholder__', '_sliding_or_chunked_causal').\
|
|
replace('mask_type', '2').\
|
|
replace('__use_tma_store__', 'false')
|
|
snippet_ws_custom_mask = \
|
|
snippet_ws_template.replace('__placeholder__', '_custom_mask').\
|
|
replace('mask_type', '2').\
|
|
replace('__use_tma_store__', 'true')
|
|
elif effective_sm >= 90: #GMMA no flash yet
|
|
snippet_template = ''' {{
|
|
using Traits_p = fmha::{instruction_traits_p};
|
|
using Traits_o = fmha::{instruction_traits_o};
|
|
|
|
using Kernel_traits = {kernel_traits}<
|
|
Traits_p,
|
|
Traits_o,
|
|
{seq_len},
|
|
{head_size},
|
|
{loop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
2,
|
|
{kernel_flags}>;
|
|
printf("%s %d %d %s %d %d\\n",
|
|
\"{kname}\",
|
|
Kernel_traits::BYTES_PER_SMEM,
|
|
Kernel_traits::THREADS,
|
|
\"{fname}\",
|
|
{loop_step},
|
|
{unroll_threshold});
|
|
}}'''.format(**tmp)
|
|
snippet_nl_template = ''' {{
|
|
using Traits_p = fmha::{instruction_traits_p};
|
|
using Traits_o = fmha::{instruction_traits_o};
|
|
|
|
using Kernel_traits = {kernel_traits}<
|
|
Traits_p,
|
|
Traits_o,
|
|
{seq_len},
|
|
{head_size},
|
|
{noloop_step},
|
|
{warps_m},
|
|
{warps_n},
|
|
2,
|
|
{kernel_flags}>;
|
|
printf("%s %d %d %s %d %d\\n",
|
|
\"{kname}_nl\",
|
|
Kernel_traits::BYTES_PER_SMEM,
|
|
Kernel_traits::THREADS,
|
|
\"{fname}\",
|
|
{noloop_step},
|
|
{unroll_threshold});
|
|
}}'''.format(**tmp)
|
|
snippet = snippet_template.replace('__placeholder__', '')
|
|
snippet_causal = snippet_template.replace(
|
|
'__placeholder__', '_sliding_or_chunked_causal')
|
|
snippet_sliding_or_chunked_causal = snippet_template.replace(
|
|
'__placeholder__', '_causal')
|
|
snippet_nl = snippet_nl_template.replace('__placeholder__', '')
|
|
snippet_nl_causal = snippet_nl_template.replace(
|
|
'__placeholder__', '_causal')
|
|
snippet_nl_sliding_or_chunked_causal = snippet_nl_template.replace(
|
|
'__placeholder__', '_sliding_or_chunked_causal')
|
|
|
|
# only generate certain needed combinations of input_layout and mask types for trt-llm.
|
|
selected_types = selected_mask_types(kspec)
|
|
|
|
padding_mask = int(selected_types[0])
|
|
causal_mask = int(selected_types[1])
|
|
sliding_or_chunked_causal_mask = int(selected_types[2])
|
|
custom_mask = int(selected_types[3])
|
|
|
|
if not padding_mask:
|
|
snippet = None
|
|
snippet_nl = None
|
|
snippet_ws = None
|
|
snippet_flash_nl = None
|
|
snippet_flash_nl_tiled = None
|
|
if not causal_mask:
|
|
snippet_causal = None
|
|
snippet_nl_causal = None
|
|
snippet_ws_causal = None
|
|
snippet_flash_nl_causal = None
|
|
snippet_flash_nl_tiled_causal = None
|
|
if not sliding_or_chunked_causal_mask:
|
|
snippet_sliding_or_chunked_causal = None
|
|
snippet_nl_sliding_or_chunked_causal = None
|
|
snippet_ws_sliding_or_chunked_causal = None
|
|
snippet_flash_nl_sliding_or_chunked_causal = None
|
|
snippet_flash_nl_tiled_sliding_or_chunked_causal = None
|
|
if not custom_mask:
|
|
snippet_ws_custom_mask = None
|
|
snippet_flash_nl_custom_mask = None
|
|
snippet_flash_nl_tiled_custom_mask = None
|
|
|
|
if kspec.flash_attention:
|
|
pass
|
|
# print_kernel_specs.append(snippet_flash) # disabled as looped flash performs worse
|
|
else:
|
|
print_kernel_specs.append(snippet)
|
|
if 'snippet_causal' in locals():
|
|
print_kernel_specs.append(snippet_causal)
|
|
if 'snippet_sliding_or_chunked_causal' in locals():
|
|
print_kernel_specs.append(snippet_sliding_or_chunked_causal)
|
|
if kspec.has_noloop:
|
|
if kspec.flash_attention and kspec.tiled == 1:
|
|
print_kernel_specs.append(snippet_flash_nl_tiled)
|
|
print_kernel_specs.append(snippet_flash_nl_tiled_causal)
|
|
print_kernel_specs.append(
|
|
snippet_flash_nl_tiled_sliding_or_chunked_causal)
|
|
print_kernel_specs.append(snippet_flash_nl_tiled_custom_mask)
|
|
elif kspec.flash_attention and kspec.tiled == 0:
|
|
print_kernel_specs.append(snippet_flash_nl)
|
|
print_kernel_specs.append(snippet_flash_nl_causal)
|
|
print_kernel_specs.append(
|
|
snippet_flash_nl_sliding_or_chunked_causal)
|
|
print_kernel_specs.append(snippet_flash_nl_custom_mask)
|
|
else:
|
|
print_kernel_specs.append(snippet_nl)
|
|
if 'snippet_nl_causal' in locals():
|
|
print_kernel_specs.append(snippet_nl_causal)
|
|
if 'snippet_nl_sliding_or_chunked_causal' in locals():
|
|
print_kernel_specs.append(
|
|
snippet_nl_sliding_or_chunked_causal)
|
|
|
|
if kspec.warp_specialization:
|
|
print_kernel_specs.append(snippet_ws)
|
|
print_kernel_specs.append(snippet_ws_causal)
|
|
print_kernel_specs.append(snippet_ws_sliding_or_chunked_causal)
|
|
print_kernel_specs.append(snippet_ws_custom_mask)
|
|
# remove none.
|
|
print_kernel_specs = [
|
|
spec for spec in print_kernel_specs if spec is not None
|
|
]
|
|
print_kernel_specs = '\n'.join(print_kernel_specs)
|
|
|
|
code = ktraits_code_template.format(print_kernel_specs=print_kernel_specs)
|
|
return code
|
|
|
|
|
|
# For now:
|
|
# 1. Hopper head_size 128 kernel uses cubins for performance regressions.
|
|
# 2. Hopper sm89 with e4m3/e4m3_fp32 dtype uses cubins for accuracy regressions (will be fixed).
|
|
# 3. For skip-softmax attention feature, we force not to use cubins.
|
|
# You should set the condition `use_cubin_header` to false if you have modified the source codes of those kernels that use cubins.
|
|
# This ensures that the kernels will be recompiled using the updated source code rather than relying on precompiled cubins.
|
|
def use_cubin_header(sm,
|
|
head_size,
|
|
dtype,
|
|
output_dtype=None,
|
|
enable_skip_softmax=False):
|
|
if enable_skip_softmax:
|
|
return False
|
|
if 'e4m3' in dtype and output_dtype in ['bf16', 'fp16']:
|
|
return False
|
|
return (sm == 90 and head_size == 128) or (sm == 89 and 'e4m3' in dtype)
|
|
|
|
|
|
def get_cubin_header(kernel_traits, specs_names):
|
|
cubins = []
|
|
cubin_lens = []
|
|
launchers = []
|
|
cubins_dict = {}
|
|
cubin_lens_dict = {}
|
|
launchers_dict = {}
|
|
for kspec, fname, lname, kname in specs_names:
|
|
if generate_cu_trtllm and not use_cubin_header(
|
|
kspec.sm, kspec.head_size, kspec.dtype, kspec.output_dtype,
|
|
kspec.enable_skip_softmax):
|
|
continue
|
|
name = fname.replace('.', '_')
|
|
data = 'extern unsigned char cubin_{name}_cubin[];'.format(name=name)
|
|
size = 'extern uint32_t cubin_{name}_cubin_len;'.format(name=name)
|
|
if kspec.sm in cubins_dict:
|
|
cubins_dict[kspec.sm].append(data)
|
|
cubin_lens_dict[kspec.sm].append(size)
|
|
else:
|
|
cubins_dict[kspec.sm] = [data]
|
|
cubin_lens_dict[kspec.sm] = [size]
|
|
|
|
metadata_v1 = []
|
|
# Only metadata_v2 is used by TRT-LLM.
|
|
metadata_v2 = []
|
|
metadata_v2_dict = {}
|
|
unroll_config_v1 = []
|
|
unroll_config_v2 = []
|
|
for kname, smem, threads, fname, unroll_step, unroll_threshold in kernel_traits:
|
|
name = fname.replace('.', '_')
|
|
cubin_name = 'cubin_{name}_cubin'.format(name=name)
|
|
kname_remove_causal = kname.replace("_causal", "")
|
|
tname = (kname.replace('flash_attention_', '').replace(
|
|
'_scale_max', '').replace('_nl', '').replace('_tiled', '').replace(
|
|
'tma_',
|
|
'').replace('ldgsts_', '').replace('causal_', '').replace(
|
|
'alibi_', '').replace('softmax_', '').replace(
|
|
'sliding_or_chunked_', '').replace(
|
|
'custom_mask_', '').replace('qkv_', '').replace(
|
|
'q_kv_', '').replace('q_paged_kv_', '').replace(
|
|
'q_k_v_', '').replace('ws_', '').replace(
|
|
'softcapping_',
|
|
'').replace('sage_', '').replace(
|
|
'skipSoftmax_',
|
|
'').replace('output_', ''))
|
|
flash_attention = 'flash_attention' in kname
|
|
warp_specialization = 'tma_ws' in kname
|
|
toks = tname.split('_')
|
|
# 0 1 x -7 -6 -5-4-3 -2 -1 x
|
|
#fmha_v2(_flash_attention)_fp16(_fp32)_64_64_S_16_sm80_kernel(_nl)
|
|
# 0 1 2 -5 -4 -3 -2 -1
|
|
#fmha_v2_il_fp16(_fp32)_64_64_sm80_kernel
|
|
# print(kname)
|
|
version = toks[1][1]
|
|
sm = toks[-2][2:]
|
|
if '_output' in kname:
|
|
output_prec = toks[-3].upper()
|
|
toks.pop(-3)
|
|
else:
|
|
output_prec = None
|
|
if '_sage_' in kname:
|
|
# example:
|
|
# kname: fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_128_bf16_tma_ws_sm90_kernel
|
|
# tname: fmha_v2_e4m3_64_256_S_128_sage_64_64_128_bf16_sm90_kernel
|
|
sage_block_sizes = toks[-5:-2]
|
|
toks.pop(-5)
|
|
toks.pop(-4)
|
|
toks.pop(-3)
|
|
else:
|
|
sage_block_sizes = (0, 0, 0)
|
|
head_size = toks[-3]
|
|
if 'x' in head_size:
|
|
(head_size, head_size_v) = head_size.split('x')
|
|
else:
|
|
head_size_v = head_size
|
|
# flash attention kernel encodes variable seqlen as S, but only number 0 fits in the metadata struct
|
|
seq_len = 0 if toks[-4] == 'S' else toks[-4]
|
|
q_step = unroll_step
|
|
kv_step = seq_len
|
|
if flash_attention:
|
|
kv_step = toks[-5]
|
|
q_step = toks[-6]
|
|
prec = toks[-5].upper()
|
|
is_fp32_accu = 'false'
|
|
if flash_attention:
|
|
prec = toks[-7].upper()
|
|
# fp16_fp32 --> HMMA with FP32 accumulation
|
|
if toks[-8].upper() in ['E4M3', 'E5M2', 'FP16', 'BF16']:
|
|
if prec == 'FP32':
|
|
is_fp32_accu = 'true'
|
|
prec = toks[-8].upper()
|
|
|
|
elif toks[-6].upper() in ['E4M3', 'E5M2', 'FP16', 'BF16', 'FP32']:
|
|
# in this case, toks[-6] = data type, toks[-5] = acc data type
|
|
prec = toks[-5].upper()
|
|
# fp16_fp32 --> HMMA with FP32 accumulation
|
|
if toks[-6].upper() in ['E4M3', 'E5M2', 'FP16', 'BF16']:
|
|
if prec == 'FP32':
|
|
is_fp32_accu = 'true'
|
|
prec = toks[-6].upper()
|
|
|
|
# fp8 or bf16 always accumulates on fp32
|
|
if prec in ['E4M3', 'E5M2', 'BF16']:
|
|
is_fp32_accu = 'true'
|
|
if output_prec is None:
|
|
output_prec = prec
|
|
|
|
is_il = pythonBoolean2cpp['_il' in kname]
|
|
attention_mask_type = AttentionMaskType.PADDING
|
|
is_tiled = pythonBoolean2cpp['_tiled' in kname]
|
|
|
|
# Attention mask type:
|
|
# padding (0), causal_mask (1), sliding_or_chunked_causal_mask (2), custom_mask (3).
|
|
if '_custom_mask' in kname:
|
|
attention_mask_type = AttentionMaskType.CUSTOM_MASK
|
|
elif '_sliding_or_chunked_causal' in kname:
|
|
attention_mask_type = AttentionMaskType.SLIDING_OR_CHUNKED_CAUSAL
|
|
elif '_causal' in kname:
|
|
attention_mask_type = AttentionMaskType.CAUSAL
|
|
|
|
attention_mask_type_value = attention_mask_type.value
|
|
|
|
# Attention input layout:
|
|
# packed_qkv (0), contiguous_q_kv (1), q_paged_kv (2), separate_q_k_v (3).
|
|
attention_input_layout = InputLayout.PACKED_QKV
|
|
if '_q_kv' in kname:
|
|
attention_input_layout = InputLayout.CONTIGUOUS_Q_KV
|
|
elif '_q_paged_kv' in kname:
|
|
attention_input_layout = InputLayout.Q_PAGED_KV
|
|
elif '_q_k_v' in kname:
|
|
attention_input_layout = InputLayout.SEPARATE_Q_K_V
|
|
|
|
attention_input_layout_value = attention_input_layout.value
|
|
|
|
# hopper warpspecialized kernels have specialized ones for cases without alibi.
|
|
is_alibi_supported = pythonBoolean2cpp['_ws' not in kname
|
|
or '_alibi' in kname]
|
|
|
|
return_softmax_stats_flag = pythonBoolean2cpp[sm != '90' or (
|
|
sm == '90' and '_softmax' in kname)]
|
|
|
|
enable_skip_softmax_flag = pythonBoolean2cpp['_skipSoftmax' in kname]
|
|
|
|
# meta_unroll_step
|
|
meta_unroll_step = unroll_step if ('_nl' in kname
|
|
or '_ws' in kname) else '0'
|
|
|
|
is_flash_atten = pythonBoolean2cpp[flash_attention]
|
|
|
|
is_warp_specialization = pythonBoolean2cpp[warp_specialization]
|
|
|
|
has_softcapping_scale = 'true' if 'softcapping' in kname else 'false'
|
|
|
|
unroll_spec = '''\
|
|
{{ kSM_{sm}, DATA_TYPE_{prec}, {seq_len}, {head_size}, {unroll_threshold} }}\
|
|
'''.format(**locals())
|
|
|
|
if 'v1' in kname:
|
|
code = '''\
|
|
{{ DATA_TYPE_{prec}, {seq_len}, {head_size}, kSM_{sm}, {cubin_name}, {cubin_name}_len, \"{kname}\", {smem}, {threads} }}\
|
|
'''.format(**locals())
|
|
metadata_v1.append(code)
|
|
if '_nl' in kname:
|
|
unroll_config_v1.append(unroll_spec)
|
|
elif 'v2' in kname:
|
|
if generate_cu_trtllm:
|
|
|
|
def get_lname_from_kname(kname: str) -> str:
|
|
if use_cubin_header(int(sm), int(head_size), prec.lower(),
|
|
output_prec.lower(),
|
|
enable_skip_softmax_flag):
|
|
return 'nullptr'
|
|
lname = kname.replace('_kernel', '')
|
|
mask_types = [
|
|
'_sliding_or_chunked_causal', '_custom_mask', '_causal'
|
|
]
|
|
for mask_type in mask_types:
|
|
lname = lname.replace(mask_type, '')
|
|
lname = 'run_' + lname
|
|
|
|
return lname
|
|
|
|
lname = get_lname_from_kname(kname)
|
|
code = '''\
|
|
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
|
|
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, {cubin_name}, \
|
|
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
|
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
|
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}, {lname}}}\
|
|
'''.format(**locals()) if use_cubin_header(int(sm), int(head_size),
|
|
prec.lower(), output_prec.lower(),
|
|
enable_skip_softmax_flag) else '''\
|
|
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
|
|
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
|
|
0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
|
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
|
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}, {lname}}}\
|
|
'''.format(**locals())
|
|
else:
|
|
code = '''\
|
|
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
|
|
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, {cubin_name}, \
|
|
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
|
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
|
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {enable_skip_softmax_flag}}}\
|
|
'''.format(**locals())
|
|
if sm in metadata_v2_dict:
|
|
metadata_v2_dict[sm].append(code)
|
|
else:
|
|
metadata_v2_dict[sm] = [code]
|
|
if '_nl' in kname:
|
|
unroll_config_v2.append(unroll_spec)
|
|
if generate_cu_trtllm and lname != 'nullptr':
|
|
launcher = 'extern void {lname}(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);'.format(
|
|
lname=lname)
|
|
if int(sm) in launchers_dict:
|
|
if launcher not in launchers_dict[int(sm)]:
|
|
launchers_dict[int(sm)].append(launcher)
|
|
else:
|
|
launchers_dict[int(sm)] = [launcher]
|
|
elif 'mhca' in kname:
|
|
code = '''\
|
|
{{ DATA_TYPE_{prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, kSM_{sm}, {cubin_name}, {cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {is_il} }}\
|
|
'''.format(**locals())
|
|
metadata_v2.append(code)
|
|
else:
|
|
assert False, 'Something terrible happened'
|
|
|
|
metadata_v1 = ',\n'.join(metadata_v1)
|
|
# Add macros to only include needed cubins during compilation.
|
|
if bool(metadata_v2_dict):
|
|
metadata_v2 = ''
|
|
for sm in metadata_v2_dict.keys():
|
|
macro_begin = f"#ifndef EXCLUDE_SM_{sm}"
|
|
macro_end = f"#endif\n\n"
|
|
metadata_v2 += macro_begin + '\n' + (',\n'.join(
|
|
metadata_v2_dict[sm]))
|
|
last_key = list(metadata_v2_dict.keys())[-1]
|
|
metadata_v2 += ('' if sm == last_key else ',') + '\n' + macro_end
|
|
else:
|
|
metadata_v2 = ',\n'.join(metadata_v2)
|
|
# Add macros to only include needed cubins during compilation.
|
|
# Collect all SM versions from all dictionaries
|
|
all_sms = sorted(
|
|
set(
|
|
list(cubins_dict.keys()) + list(cubin_lens_dict.keys()) +
|
|
list(launchers_dict.keys())))
|
|
|
|
for sm in all_sms:
|
|
macro_begin = f"#ifndef EXCLUDE_SM_{sm}"
|
|
macro_end = f"#endif\n"
|
|
|
|
# Add cubin array declarations
|
|
if sm in cubins_dict:
|
|
cubins.extend([macro_begin] + cubins_dict[sm] + [macro_end])
|
|
|
|
# Add cubin length declarations
|
|
if sm in cubin_lens_dict:
|
|
cubin_lens.extend([macro_begin] + cubin_lens_dict[sm] + [macro_end])
|
|
|
|
# Add launcher declarations
|
|
if sm in launchers_dict:
|
|
launchers.extend([macro_begin] + launchers_dict[sm] + [macro_end])
|
|
|
|
unroll_config_v1 = ',\n'.join(unroll_config_v1)
|
|
unroll_config_v2 = ',\n'.join(unroll_config_v2)
|
|
cubins = '\n'.join(cubins)
|
|
cubin_lens = '\n'.join(cubin_lens)
|
|
launchers = '\n'.join(launchers)
|
|
local_ns_open = ns_open
|
|
local_ns_close = ns_close if generate_cu_trtllm else '}'
|
|
launcher_line = '''
|
|
void (*launcher)(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);''' if generate_cu_trtllm else ''
|
|
if "GENERATE_CUBIN" in os.environ:
|
|
code = '''\
|
|
{copyright}
|
|
#pragma once
|
|
|
|
{local_ns_open}
|
|
|
|
{cubins}
|
|
|
|
{cubin_lens}
|
|
|
|
static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
|
{{
|
|
Data_type mDataTypeIn;
|
|
Data_type mDataTypeOut;
|
|
unsigned int mS;
|
|
unsigned int mStepQ;
|
|
unsigned int mStepKV;
|
|
unsigned int mD;
|
|
unsigned int mDV;
|
|
unsigned int mSageBlockSizeQ;
|
|
unsigned int mSageBlockSizeK;
|
|
unsigned int mSageBlockSizeV;
|
|
unsigned int mSM;
|
|
const unsigned char* mCubin;
|
|
unsigned int mCubinSize;
|
|
const char* mFuncName;
|
|
unsigned int mSharedMemBytes;
|
|
unsigned int mThreadsPerCTA;
|
|
unsigned int mUnrollStep;
|
|
int mAttentionMaskType;
|
|
int mAttentionInputLayout;
|
|
bool mInterleaved;
|
|
bool mFlashAttention;
|
|
bool mWarpSpecialization;
|
|
bool mFP32Accumulation;
|
|
bool mAlibiSupported;
|
|
bool mTiled;
|
|
bool mEnableAttnLogitSoftcapping;
|
|
bool mReturnSoftmaxStats;
|
|
bool mEnableSkipSoftmax;{launcher_line}
|
|
}} sMhaKernelMetaInfosV2[] = {{
|
|
{metadata_v2}
|
|
}};
|
|
{local_ns_close}
|
|
|
|
'''.format(**locals(), copyright=copyright)
|
|
|
|
else:
|
|
code = '''\
|
|
{copyright}
|
|
#pragma once
|
|
|
|
{cubins}
|
|
|
|
{cubin_lens}
|
|
|
|
static const struct TestMetaV1
|
|
{{
|
|
Data_type mDataType;
|
|
unsigned int mS;
|
|
unsigned int mD;
|
|
unsigned int mSM;
|
|
const unsigned char* mCubin;
|
|
unsigned int mCubinSize;
|
|
const char* mFuncName;
|
|
unsigned int mSharedMemBytes;
|
|
unsigned int mThreadsPerCTA;
|
|
}} metaV1[] = {{
|
|
{metadata_v1}
|
|
}};
|
|
|
|
static const struct TestMetaV2
|
|
{{
|
|
Data_type mDataTypeIn;
|
|
Data_type mDataTypeOut;
|
|
unsigned int mS;
|
|
unsigned int mStepQ;
|
|
unsigned int mStepKV;
|
|
unsigned int mD;
|
|
unsigned int mDV;
|
|
unsigned int mSageBlockSizeQ;
|
|
unsigned int mSageBlockSizeK;
|
|
unsigned int mSageBlockSizeV;
|
|
unsigned int mSM;
|
|
const unsigned char* mCubin;
|
|
unsigned int mCubinSize;
|
|
const char* mFuncName;
|
|
unsigned int mSharedMemBytes;
|
|
unsigned int mThreadsPerCTA;
|
|
unsigned int mUnrollStep;
|
|
int mAttentionMaskType;
|
|
int mAttentionInputLayout;
|
|
bool mInterleaved;
|
|
bool mFlashAttention;
|
|
bool mWarpSpecialization;
|
|
bool mFP32Accumulation;
|
|
bool mAlibiSupported;
|
|
bool mTiled;
|
|
bool mEnableAttnLogitSoftcapping;
|
|
bool mReturnSoftmaxStats;
|
|
bool mEnableSkipSoftmax;
|
|
}} metaV2[] = {{
|
|
{metadata_v2}
|
|
}};
|
|
}}
|
|
|
|
'''.format(**locals(), copyright=copyright)
|
|
|
|
# Generate header content (.h file)
|
|
if "GENERATE_CUBIN" in os.environ:
|
|
header_content = '''\
|
|
{copyright}
|
|
#pragma once
|
|
|
|
#include "tensorrt_llm/common/config.h"
|
|
|
|
TRTLLM_NAMESPACE_BEGIN
|
|
namespace kernels{{
|
|
|
|
struct FusedMultiHeadAttentionKernelMetaInfoV2
|
|
{{
|
|
Data_type mDataTypeIn;
|
|
Data_type mDataTypeOut;
|
|
unsigned int mS;
|
|
unsigned int mStepQ;
|
|
unsigned int mStepKV;
|
|
unsigned int mD;
|
|
unsigned int mDV;
|
|
unsigned int mSageBlockSizeQ;
|
|
unsigned int mSageBlockSizeK;
|
|
unsigned int mSageBlockSizeV;
|
|
unsigned int mSM;
|
|
const unsigned char* mCubin;
|
|
unsigned int mCubinSize;
|
|
const char* mFuncName;
|
|
unsigned int mSharedMemBytes;
|
|
unsigned int mThreadsPerCTA;
|
|
unsigned int mUnrollStep;
|
|
int mAttentionMaskType;
|
|
int mAttentionInputLayout;
|
|
bool mInterleaved;
|
|
bool mFlashAttention;
|
|
bool mWarpSpecialization;
|
|
bool mFP32Accumulation;
|
|
bool mAlibiSupported;
|
|
bool mTiled;
|
|
bool mEnableAttnLogitSoftcapping;
|
|
bool mReturnSoftmaxStats;
|
|
bool mEnableSkipSoftmax;{launcher_line}
|
|
}};
|
|
|
|
extern const FusedMultiHeadAttentionKernelMetaInfoV2 sMhaKernelMetaInfosV2[];
|
|
extern const int sMhaKernelMetaInfosV2Size;
|
|
|
|
}} // namespace kernels
|
|
TRTLLM_NAMESPACE_END
|
|
'''.format(**locals(), copyright=copyright)
|
|
# Generate source content (.cpp file)
|
|
source_content = '''\
|
|
{copyright}
|
|
|
|
#include "tensorrt_llm/common/config.h"
|
|
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <cuda_runtime_api.h>
|
|
|
|
{local_ns_open}
|
|
|
|
//--- Cubin Arrays
|
|
{cubins}
|
|
|
|
//--- Cubin Lengths
|
|
{cubin_lens}
|
|
|
|
{local_ns_close}
|
|
|
|
using namespace tensorrt_llm::kernels;
|
|
|
|
namespace tensorrt_llm::TRTLLM_ABI_NAMESPACE::kernels {{
|
|
|
|
class Fused_multihead_attention_params_v2;
|
|
class Launch_params;
|
|
|
|
//--- Kernel Launchers
|
|
{launchers}
|
|
|
|
// FIXME: These are duplicated declarations, we should remove them in the future.
|
|
constexpr int32_t kSM_70 = 70;
|
|
constexpr int32_t kSM_72 = 72;
|
|
constexpr int32_t kSM_75 = 75;
|
|
constexpr int32_t kSM_80 = 80;
|
|
constexpr int32_t kSM_86 = 86;
|
|
constexpr int32_t kSM_89 = 89;
|
|
constexpr int32_t kSM_90 = 90;
|
|
constexpr int32_t kSM_100 = 100;
|
|
constexpr int32_t kSM_100f = 10100;
|
|
constexpr int32_t kSM_103 = 103;
|
|
constexpr int32_t kSM_120 = 120;
|
|
constexpr int32_t kSM_121 = 121;
|
|
|
|
// FIXME: These are duplicated declarations, we should remove them in the future.
|
|
enum Data_type
|
|
{{
|
|
DATA_TYPE_BOOL,
|
|
DATA_TYPE_FP16,
|
|
DATA_TYPE_FP32,
|
|
DATA_TYPE_INT4,
|
|
DATA_TYPE_INT8,
|
|
DATA_TYPE_INT32,
|
|
DATA_TYPE_BF16,
|
|
DATA_TYPE_E2M1,
|
|
DATA_TYPE_E4M3,
|
|
DATA_TYPE_E5M2
|
|
}};
|
|
|
|
struct FusedMultiHeadAttentionKernelMetaInfoV2
|
|
{{
|
|
Data_type mDataTypeIn;
|
|
Data_type mDataTypeOut;
|
|
unsigned int mS;
|
|
unsigned int mStepQ;
|
|
unsigned int mStepKV;
|
|
unsigned int mD;
|
|
unsigned int mDV;
|
|
unsigned int mSageBlockSizeQ;
|
|
unsigned int mSageBlockSizeK;
|
|
unsigned int mSageBlockSizeV;
|
|
unsigned int mSM;
|
|
const unsigned char* mCubin;
|
|
unsigned int mCubinSize;
|
|
const char* mFuncName;
|
|
unsigned int mSharedMemBytes;
|
|
unsigned int mThreadsPerCTA;
|
|
unsigned int mUnrollStep;
|
|
int mAttentionMaskType;
|
|
int mAttentionInputLayout;
|
|
bool mInterleaved;
|
|
bool mFlashAttention;
|
|
bool mWarpSpecialization;
|
|
bool mFP32Accumulation;
|
|
bool mAlibiSupported;
|
|
bool mTiled;
|
|
bool mEnableAttnLogitSoftcapping;
|
|
bool mReturnSoftmaxStats;
|
|
bool mEnableSkipSoftmax;{launcher_line}
|
|
}};
|
|
|
|
extern const FusedMultiHeadAttentionKernelMetaInfoV2 sMhaKernelMetaInfosV2[] = {{
|
|
{metadata_v2}
|
|
}};
|
|
|
|
extern const int sMhaKernelMetaInfosV2Size = sizeof(sMhaKernelMetaInfosV2) / sizeof(sMhaKernelMetaInfosV2[0]);
|
|
}} // namespace tensorrt_llm::TRTLLM_ABI_NAMESPACE::kernels
|
|
'''.format(**locals(), copyright=copyright)
|
|
else:
|
|
# Non-GENERATE_CUBIN mode: use old behavior
|
|
header_content = code
|
|
source_content = None
|
|
|
|
return header_content, source_content
|
|
|
|
|
|
# This is used to add some kernels running in cubins for passing CI cases.
|
|
def modify_cubin_header(cubin_header):
|
|
result = cubin_header
|
|
|
|
# for CI cases
|
|
def add_kernel_line(result, target, addition):
|
|
pos = result.find(target)
|
|
if pos != -1:
|
|
end_pos = result.find('\n', pos)
|
|
if end_pos == -1:
|
|
end_pos = len(result)
|
|
result = result[:end_pos + 1] + addition + result[end_pos:]
|
|
return result
|
|
|
|
target = "#ifndef EXCLUDE_SM_80"
|
|
addition_cubin_array = """
|
|
#ifndef EXCLUDE_SM_80
|
|
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin[];
|
|
#endif
|
|
"""
|
|
addition_cubin_length = """
|
|
#ifndef EXCLUDE_SM_80
|
|
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len;
|
|
#endif
|
|
"""
|
|
# Add cubin array and length into there corresponding sections.
|
|
result = add_kernel_line(result, "//--- Cubin Arrays", addition_cubin_array)
|
|
result = add_kernel_line(result, "//--- Cubin Lengths",
|
|
addition_cubin_length)
|
|
|
|
def modify_kernel_line(result, target, new_line):
|
|
lines = result.split('\n')
|
|
for i, line in enumerate(lines):
|
|
if target in line:
|
|
lines[i] = new_line
|
|
break
|
|
return '\n'.join(lines)
|
|
|
|
target = "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled"
|
|
new_line = '{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, false, true, true, false, true, false, nullptr},'
|
|
result = modify_kernel_line(result, target, new_line)
|
|
|
|
# make sure only one empty line at the end
|
|
lines = result.split('\n')
|
|
while lines and not lines[-1].strip():
|
|
lines.pop()
|
|
lines.append('')
|
|
|
|
return '\n'.join(lines)
|
|
|
|
|
|
def generate_files(specs_names):
|
|
|
|
kfiles = []
|
|
valid_specs_names = []
|
|
|
|
for kspec, fname, lname, kname in specs_names:
|
|
code = get_kernel_code(kspec, kname, lname)
|
|
# some kernels are skipped when generating cubins for trt-llm.
|
|
if code is None:
|
|
continue
|
|
# add valid specs names
|
|
valid_specs_names.append((kspec, fname, lname, kname))
|
|
path = os.path.join('./generated', fname)
|
|
# HACK: do not overwrite kernel file in case of collision; kernel selection logic can still be flaky
|
|
# TODO: allow profiling multiple kernel implementations satisfying the given problem size
|
|
if path not in kfiles:
|
|
with open(path, 'w') as f:
|
|
f.write(code)
|
|
kfiles.append(path)
|
|
|
|
api_code = get_api_code(valid_specs_names).replace(
|
|
'__guard_fmhca_placeholder__', 'false')
|
|
with open('./generated/fused_multihead_attention_api.h', 'w') as f:
|
|
f.write(api_code)
|
|
|
|
api_code = get_api_code(valid_specs_names).replace(
|
|
'__guard_fmhca_placeholder__', 'true')
|
|
with open('./generated/fused_multihead_cross_attention_api.h', 'w') as f:
|
|
f.write(api_code)
|
|
|
|
mk_code = get_makefile_code(valid_specs_names)
|
|
|
|
with open('./generated/makefile', 'w') as f:
|
|
f.write(mk_code)
|
|
|
|
print_kernel_traits_code = get_kernel_traits_code(valid_specs_names)
|
|
with open('./generated/print_kernel_traits.cu', 'w') as f:
|
|
f.write(print_kernel_traits_code)
|
|
|
|
# Make sure we have a bin directory.
|
|
if not os.path.exists('bin'):
|
|
os.mkdir('bin')
|
|
cmd = 'nvcc -I src -Xcompiler -Wno-enum-compare --std=c++17 -o bin/print_traits.exe generated/print_kernel_traits.cu'.split(
|
|
)
|
|
if 'CUDA_PATH' in os.environ:
|
|
cmd[0] = os.environ['CUDA_PATH'] + '/bin/' + cmd[0]
|
|
print('Running command "{}" to build "bin/print_traits.exe":'.format(
|
|
' '.join(cmd)))
|
|
process = subprocess.Popen(cmd,
|
|
stdin=subprocess.PIPE,
|
|
stdout=subprocess.PIPE)
|
|
output, error = process.communicate()
|
|
print('Running "bin/print_traits.exe":')
|
|
process = subprocess.Popen('bin/print_traits.exe',
|
|
stdin=subprocess.PIPE,
|
|
stdout=subprocess.PIPE)
|
|
output, error = process.communicate()
|
|
output = output.decode('utf-8').strip()
|
|
# this gives: kname, smem bytes, threads_per_cta, loop_step
|
|
kernel_traits = [traits.split() for traits in output.splitlines()]
|
|
# Use new function to generate both fmha_cubin.h and fmha_cubin.cpp files
|
|
# To switch back to old behavior, replace get_cubin_header_and_source with get_cubin_header
|
|
cubin_header, cubin_source = get_cubin_header(kernel_traits,
|
|
valid_specs_names)
|
|
if generate_cu_trtllm:
|
|
cubin_source = modify_cubin_header(cubin_source)
|
|
|
|
# Write fmha_cubin.h file
|
|
with open('./generated/fmha_cubin.h', 'w') as f:
|
|
f.write(cubin_header)
|
|
|
|
# Write fmha_cubin.cpp file (same directory as fmha_cubin.h file)
|
|
if cubin_source is not None:
|
|
with open('./generated/fmha_cubin.cpp', 'w') as f:
|
|
f.write(cubin_source)
|
|
|
|
|
|
def enumerate_hgmma_tma_kernels(specs, sm=90):
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype='fp16',
|
|
seq_len=[64, 128, 256],
|
|
head_size=64,
|
|
warps_m=4, #4x1 warpgroups
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=0,
|
|
noloop_step=64,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
|
|
# Note this will be used in TRT-LLM.
|
|
def enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype='fp16'):
|
|
|
|
for enable_attn_logit_softcapping in [False, True]:
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype=dtype,
|
|
seq_len=[64, 128, 256],
|
|
head_size=[32, 64],
|
|
warps_m=4, #4x1 warpgroups
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=
|
|
True, # for Hopper kernels, ldgsts = False signals TMA usage.
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=64,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping))
|
|
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype=dtype,
|
|
seq_len=[384, 512],
|
|
head_size=[32, 64],
|
|
warps_m=4, #4x1 warpgroups
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=
|
|
True, # for Hopper kernels, ldgsts = False signals TMA usage.
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=64,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping))
|
|
|
|
|
|
# Note this will be used in TRT-LLM.
|
|
def enumerate_hgmma_flash_warpspec_kernels(specs,
|
|
sm=90,
|
|
dtype='fp16',
|
|
enable_skip_softmax=False):
|
|
|
|
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
|
|
|
|
# use specialized kernels for cases without alibi scales.
|
|
# there is a numeric issues when applying the exp2f scale optimization and alibi scale at the same time.
|
|
combinations = product([False, True], [False, True], \
|
|
[InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV,
|
|
InputLayout.Q_PAGED_KV, InputLayout.SEPARATE_Q_K_V], [False, True])
|
|
for (alibi, return_softmax, input_layout,
|
|
enable_attn_logit_softcapping) in combinations:
|
|
# alibi and enable_attn_logit_softcapping shouldn't be used together.
|
|
if alibi and enable_attn_logit_softcapping:
|
|
continue
|
|
# for normal attention, we only need contiguous kv as input layout when returning softmax.
|
|
skip_combination = return_softmax and input_layout != InputLayout.CONTIGUOUS_Q_KV
|
|
# for context mla, we need separate qkv as input layout when returning softmax.
|
|
skip_mla_combination = return_softmax and input_layout != InputLayout.SEPARATE_Q_K_V
|
|
if not skip_combination:
|
|
# only specify
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype=dtype,
|
|
seq_len=0, # support any sequence length
|
|
head_size=[32, 40, 48, 64],
|
|
warps_m=4, #4x1 warpgroups
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=
|
|
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
q_tile_buffers=1, # only used by warp specialized kernels
|
|
has_noloop=0,
|
|
noloop_step=64,
|
|
kv_loop_step=256,
|
|
kv_tile_buffers=2, # only used by warp specialized kernels
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
flash_attention=True,
|
|
warp_specialization=True,
|
|
alibi=alibi,
|
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
|
return_softmax_stats=return_softmax,
|
|
scheduling_mode=scheduling_mode,
|
|
input_layout=input_layout,
|
|
enable_skip_softmax=enable_skip_softmax))
|
|
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype=dtype,
|
|
seq_len=0, # support any sequence length
|
|
head_size=[72, 80, 96, 104, 128],
|
|
warps_m=4, #4x1 warpgroups
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=
|
|
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
q_tile_buffers=1, # only used by warp specialized kernels
|
|
has_noloop=0,
|
|
noloop_step=64,
|
|
kv_loop_step=128,
|
|
kv_tile_buffers=2, # only used by warp specialized kernels
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
flash_attention=True,
|
|
warp_specialization=True,
|
|
alibi=alibi,
|
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
|
return_softmax_stats=return_softmax,
|
|
scheduling_mode=scheduling_mode,
|
|
input_layout=input_layout,
|
|
enable_skip_softmax=enable_skip_softmax))
|
|
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype=dtype,
|
|
seq_len=0, # support any sequence length
|
|
head_size=[160, 192, 256],
|
|
warps_m=4, #4x1 warpgroups
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=
|
|
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
q_tile_buffers=1, # only used by warp specialized kernels
|
|
has_noloop=0,
|
|
noloop_step=64,
|
|
kv_loop_step=64,
|
|
kv_tile_buffers=2, # only used by warp specialized kernels
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
flash_attention=True,
|
|
warp_specialization=True,
|
|
alibi=alibi,
|
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
|
return_softmax_stats=return_softmax,
|
|
scheduling_mode=scheduling_mode,
|
|
input_layout=input_layout,
|
|
enable_skip_softmax=enable_skip_softmax))
|
|
'''
|
|
smem size = (q_step * d * q_buffers * NUM_COMPUTE_GROUPS
|
|
+ (kv_step * d + kv_step * dv) * kv_buffers) * ele_size
|
|
Originally, head size is padded to next_power_of_2<d> and next_power_of_2<dv>.
|
|
For fp16/bf16 context MLA (d=192/dv=128), d is padded to 256, and dv remains 128,
|
|
if kv_step=64, then smem_size = 160 KB, it is OK but wastes much smem.
|
|
if kv_step=128, then smem_size = 256 KB, it is too big for Hopper (228KB smem per SM).
|
|
But in fact, 'next multiply of 128 bytes' is needed only, due to TMA 128B swizzle mode.
|
|
Then for fp16/bf16 context MLA, d remains 192 (192 * 2 = 128 * 3), and dv remains 128,
|
|
if kv_step = 128, then smem_size = 208 KB, smem is fully utilized.
|
|
'''
|
|
if not skip_mla_combination:
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype=dtype,
|
|
seq_len=0, # support any sequence length
|
|
head_size=192,
|
|
head_size_v=128,
|
|
warps_m=4, #4x1 warpgroups
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=
|
|
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
q_tile_buffers=1, # only used by warp specialized kernels
|
|
has_noloop=0,
|
|
noloop_step=64,
|
|
kv_loop_step=128,
|
|
kv_tile_buffers=2, # only used by warp specialized kernels
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
flash_attention=True,
|
|
warp_specialization=True,
|
|
alibi=alibi,
|
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
|
return_softmax_stats=return_softmax,
|
|
scheduling_mode=scheduling_mode,
|
|
input_layout=input_layout))
|
|
|
|
|
|
# Note this will be used in TRT-LLM.
|
|
def enumerate_qgmma_flash_warpspec_kernels(specs,
|
|
sm=90,
|
|
dtype='e4m3',
|
|
sage_block_sizes=None,
|
|
output_dtype=None,
|
|
enable_skip_softmax=False):
|
|
|
|
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
|
|
|
|
# use specialized kernels for cases without alibi scales.
|
|
# there is a numeric issues when applying the exp2f scale optimization and alibi scale at the same time.
|
|
combinations = product([False, True], \
|
|
[InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV,
|
|
InputLayout.Q_PAGED_KV, InputLayout.SEPARATE_Q_K_V],
|
|
[False, True], [False, True])
|
|
for (alibi, input_layout, enable_attn_logit_softcapping,
|
|
return_softmax) in combinations:
|
|
# alibi and bmm1_tanh_scale shouldn't be used together.
|
|
if alibi and enable_attn_logit_softcapping:
|
|
continue
|
|
# for normal attention, we do not need return softmax for ws fp8 kernels currently.
|
|
# also fp8 input and bf16 output is only needed for MLA kernel.
|
|
skip_combination = return_softmax
|
|
# for context mla, we need separate qkv as input layout when returning softmax.
|
|
skip_mla_combination = return_softmax and input_layout != InputLayout.SEPARATE_Q_K_V
|
|
if not skip_combination:
|
|
# D <= 64: KV_STEP = 256
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype=dtype,
|
|
seq_len=0, # support any sequence length
|
|
head_size=[32, 40, 48, 64],
|
|
warps_m=4, #4x1 warpgroups
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=
|
|
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
q_tile_buffers=1, # only used by warp specialized kernels
|
|
has_noloop=0,
|
|
noloop_step=64,
|
|
kv_loop_step=256,
|
|
kv_tile_buffers=4, # only used by warp specialized kernels
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
flash_attention=True,
|
|
warp_specialization=True,
|
|
alibi=alibi,
|
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
|
return_softmax_stats=return_softmax,
|
|
scheduling_mode=scheduling_mode,
|
|
input_layout=input_layout,
|
|
sage_block_sizes=sage_block_sizes,
|
|
output_dtype=output_dtype,
|
|
enable_skip_softmax=enable_skip_softmax))
|
|
|
|
# 64 < D <=128: KV_STEP = 128
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype=dtype,
|
|
seq_len=0, # support any sequence length
|
|
head_size=[80, 96, 104, 128],
|
|
warps_m=4, #4x1 warpgroups
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=
|
|
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
q_tile_buffers=1, # only used by warp specialized kernels
|
|
has_noloop=0,
|
|
noloop_step=64,
|
|
kv_loop_step=256,
|
|
kv_tile_buffers=2, # only used by warp specialized kernels
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
flash_attention=True,
|
|
warp_specialization=True,
|
|
alibi=alibi,
|
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
|
return_softmax_stats=return_softmax,
|
|
scheduling_mode=scheduling_mode,
|
|
input_layout=input_layout,
|
|
sage_block_sizes=sage_block_sizes,
|
|
output_dtype=output_dtype,
|
|
enable_skip_softmax=enable_skip_softmax))
|
|
|
|
# 128 < D <=256: KV_STEP = 128
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype=dtype,
|
|
seq_len=0, # support any sequence length
|
|
head_size=[160, 192, 256],
|
|
warps_m=4, #4x1 warpgroups
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=
|
|
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
q_tile_buffers=1, # only used by warp specialized kernels
|
|
has_noloop=0,
|
|
noloop_step=64,
|
|
kv_loop_step=
|
|
128, # use 128 kv step size to avoid register spilling
|
|
kv_tile_buffers=2, # only used by warp specialized kernels
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
flash_attention=True,
|
|
warp_specialization=True,
|
|
alibi=alibi,
|
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
|
return_softmax_stats=return_softmax,
|
|
scheduling_mode=scheduling_mode,
|
|
input_layout=input_layout,
|
|
sage_block_sizes=sage_block_sizes,
|
|
output_dtype=output_dtype,
|
|
enable_skip_softmax=enable_skip_softmax))
|
|
|
|
if not skip_mla_combination:
|
|
# context MLA (192x128)
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype=dtype,
|
|
seq_len=0, # support any sequence length
|
|
head_size=192,
|
|
head_size_v=128,
|
|
warps_m=4, #4x1 warpgroups
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=
|
|
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
q_tile_buffers=1, # only used by warp specialized kernels
|
|
has_noloop=0,
|
|
noloop_step=64,
|
|
kv_loop_step=128,
|
|
kv_tile_buffers=2, # only used by warp specialized kernels
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
flash_attention=True,
|
|
warp_specialization=True,
|
|
alibi=alibi,
|
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
|
return_softmax_stats=return_softmax,
|
|
scheduling_mode=scheduling_mode,
|
|
input_layout=input_layout,
|
|
sage_block_sizes=sage_block_sizes,
|
|
output_dtype=output_dtype))
|
|
|
|
|
|
def enumerate_igmma_kernels(specs, sm=90):
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype='int8',
|
|
seq_len=[64, 128, 256, 384],
|
|
head_size=64,
|
|
warps_m=4, #4x1 warpgroups
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=
|
|
True, # for Hopper kernels, ldgsts = False signals TMA usage.
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=64,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype='int8',
|
|
seq_len=[512],
|
|
head_size=64,
|
|
warps_m=4, #4x2 warpgroups
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=
|
|
True, # for Hopper kernels, ldgsts = False signals TMA usage.
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=64,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
|
|
def enumerate_hmma_kernels(specs, sm=80, dtype='fp16'):
|
|
# The following kernels are hmma-based kernels tuned for sm90
|
|
if sm == 90:
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype=dtype,
|
|
seq_len=[64, 128, 256],
|
|
head_size=[64, 72],
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype=dtype,
|
|
seq_len=[384, 512],
|
|
head_size=[64, 72],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=86,
|
|
dtype=dtype,
|
|
seq_len=384,
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=64,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=86,
|
|
dtype='fp16',
|
|
seq_len=384,
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=64,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# S=1024 split over 4 CTAs.
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='fp16',
|
|
seq_len=256,
|
|
head_size=[32, 64],
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=0,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
ctas_per_head=4))
|
|
|
|
#- S=512: STEP=32, STEP NL=-- FLAGS=0x9 (0x9 for SM86!)
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='fp16',
|
|
seq_len=512,
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype=dtype,
|
|
seq_len=512,
|
|
head_size=[16, 32, 64],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='fp16',
|
|
seq_len=512,
|
|
head_size=[16, 32, 64],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
#- S=384: STEP=48, STEP NL=-- FLAGS=0x9 (0x9 for SM86!)
|
|
# TODO warps_n=4 leads to 2 pred regs, which is not supported
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='fp16',
|
|
seq_len=384,
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=48,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype=dtype,
|
|
seq_len=384,
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=48,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
#- S=256: STEP=32, STEP NL=32 FLAGS=0x1
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='fp16',
|
|
seq_len=256,
|
|
head_size=[16, 32, 64],
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype=dtype,
|
|
seq_len=256,
|
|
head_size=[16, 32, 64],
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# #- S=128: STEP=NA, STEP NL=32 FLAGS=0x1
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=[16, 32, 64],
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=128,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype=dtype,
|
|
seq_len=128,
|
|
head_size=[16, 32, 64],
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=128,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
#- S=96: STEP=32, STEP NL=-- FLAGS=0x1 TODO noloop does not work - illegal memory access: we run LDSM.T x4 which is oob.
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='fp16',
|
|
seq_len=96,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=96,
|
|
has_noloop=0,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype=dtype,
|
|
seq_len=96,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=96,
|
|
has_noloop=0,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
#- S=64: STEP=32, STEP NL=-- FLAGS=0x1
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='fp16',
|
|
seq_len=64,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype=dtype,
|
|
seq_len=64,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
if sm == 75:
|
|
#- FP16
|
|
#- S=512: STEP=32, STEP NL=-- FLAGS=0x9 (0x9 for SM86!)
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='fp16',
|
|
seq_len=[384, 512],
|
|
head_size=[16, 32, 64],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='fp16',
|
|
seq_len=[384, 512],
|
|
head_size=[16, 32, 64],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='fp16',
|
|
seq_len=256,
|
|
head_size=[16, 32, 64],
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='fp16',
|
|
seq_len=256,
|
|
head_size=[16, 32, 64],
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=[16, 32, 64],
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=128,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=[16, 32, 64],
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=128,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='fp16',
|
|
seq_len=64,
|
|
head_size=[16, 32, 64],
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
#- S=384: STEP=32, STEP NL=32 FLAGS=0x8
|
|
#- S=256: STEP=32, STEP NL=32 FLAGS=0x8
|
|
#- S=128: STEP=32, STEP NL=32 FLAGS=0x8
|
|
#- S=128: STEP=NA, STEP NL=32 FLAGS=0x8
|
|
#- S=96: STEP=32, STEP NL=-- FLAGS=0x8
|
|
#- S=64: STEP=32, STEP NL=-- FLAGS=0x8
|
|
|
|
#SM 72
|
|
#- Int8 (same for interleaved)
|
|
#- S=384: STEP=32, STEP NL=-- FLAGS=0x0
|
|
#- S=256: STEP=64, STEP NL=-- FLAGS=0x0
|
|
#- S=192: STEP=64, STEP NL=-- FLAGS=0x0
|
|
#- S=128: STEP=NA, STEP NL=-- FLAGS=0x8
|
|
#- S=96
|
|
#- S=64
|
|
|
|
|
|
def enumerate_hmma884_kernels(specs, sm=70):
|
|
#- FP16
|
|
#- S=512: STEP=32, STEP NL=-- FLAGS=0x9 (0x9 for SM86!)
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=70,
|
|
dtype='fp16',
|
|
seq_len=[384, 512],
|
|
head_size=[64],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=16,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=70,
|
|
dtype='fp16',
|
|
seq_len=[384, 512],
|
|
head_size=[64],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=16,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=70,
|
|
dtype='fp16',
|
|
seq_len=[384, 512],
|
|
head_size=[32],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=70,
|
|
dtype='fp16',
|
|
seq_len=[384, 512],
|
|
head_size=[32],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
#- S=256: STEP=32, STEP NL=32 FLAGS=0x8
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=70,
|
|
dtype='fp16',
|
|
seq_len=[128, 256],
|
|
head_size=[32, 64],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=70,
|
|
dtype='fp16',
|
|
seq_len=[128, 256],
|
|
head_size=[32, 64],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# SEQLEN 96
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=70,
|
|
dtype='fp16',
|
|
seq_len=96,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=128,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# SEQLEN 64
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=70,
|
|
dtype='fp16',
|
|
seq_len=64,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=128,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# SEQLEN 32
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=70,
|
|
dtype='fp16',
|
|
seq_len=32,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=128,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
|
|
def enumerate_hmma_paged_kv_flash_kernels(specs, sm=80, dtype='fp16'):
|
|
for enable_attn_logit_softcapping in [False, True]:
|
|
enumerate_hmma_flash_kernels_base(specs, sm, dtype,
|
|
InputLayout.PACKED_QKV,
|
|
enable_attn_logit_softcapping)
|
|
|
|
|
|
def enumerate_hmma_flash_kernels(specs, sm=80, dtype='fp16', head_size_v=0):
|
|
input_layouts = [
|
|
InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV,
|
|
InputLayout.Q_PAGED_KV
|
|
]
|
|
# Deepseek MLA (context 192/128 separate-q-k-v)
|
|
if head_size_v == 128:
|
|
input_layouts.append(InputLayout.SEPARATE_Q_K_V)
|
|
for (input_layout,
|
|
enable_attn_logit_softcapping) in product(input_layouts,
|
|
[False, True]):
|
|
enumerate_hmma_flash_kernels_base(specs, sm, dtype, input_layout,
|
|
enable_attn_logit_softcapping,
|
|
head_size_v)
|
|
|
|
|
|
# Note this will be used in TRT-LLM.
|
|
def enumerate_hmma_flash_kernels_base(specs,
|
|
sm=80,
|
|
dtype='fp16',
|
|
input_layout=InputLayout.PACKED_QKV,
|
|
enable_attn_logit_softcapping=False,
|
|
head_size_v=0):
|
|
#- FP16 Flash Attention (use nl as default)
|
|
# Any Sequence Length H = 16/32/40/48/64/80/128/160/256/512 flash attention
|
|
|
|
# Note: sm70, sm72 are based on hmma8x8x4, while sm75+ is based on hmma16x8x16
|
|
# sm75 and sm80+ use the same underlying trait class; but for historical reasons we prefer not
|
|
# to change the appearance of the trait class. So:
|
|
# - Volta uses Volta_hmma_fp16_traits
|
|
# - Turing uses Turing_hmma_fp16_traits
|
|
# - Ampere uses Ampere_hmma_fp16_traits but is effectively an alias of Turing_hmma_fp16_traits
|
|
# - Ada and Hopper use Ampere_hmma_fp16_traits
|
|
sm_mma = 0
|
|
if sm in [70, 72]:
|
|
sm_mma = 70
|
|
elif sm in [75]:
|
|
sm_mma = 75
|
|
elif sm in [80, 86, 87, 89, 90, 100, 120]:
|
|
sm_mma = 80
|
|
|
|
# _nl_tiled kernels; higher precedence than _nl kernels
|
|
# params[head_size] = [q_step, kv_step]
|
|
tiled_params_q_kv_step = {
|
|
16: [128, 128],
|
|
32: [128, 128],
|
|
40: [128, 128],
|
|
48: [128, 128],
|
|
64: [128, 128],
|
|
72: [64, 128],
|
|
80: [64, 128],
|
|
96: [64, 128],
|
|
104: [64, 128],
|
|
128: [64, 128],
|
|
160: [64, 128],
|
|
192: [64, 128],
|
|
256: [64, 128],
|
|
512: [64, 64],
|
|
576: [64, 64]
|
|
}
|
|
for head_size, [q_loop_step,
|
|
kv_loop_step] in tiled_params_q_kv_step.items():
|
|
if sm_mma == 80:
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=sm_mma,
|
|
dtype=dtype,
|
|
flash_attention=True,
|
|
tiled=1,
|
|
seq_len=0, # means any sequence here
|
|
kv_loop_step=kv_loop_step,
|
|
limit_qk_fragments=False,
|
|
limit_v_fragments=False,
|
|
head_size=head_size,
|
|
head_size_v=head_size_v,
|
|
warps_m=4,
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=q_loop_step,
|
|
has_noloop=1,
|
|
noloop_step=q_loop_step,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
ctas_per_head=1,
|
|
input_layout=input_layout,
|
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
|
is_mtp=(head_size == 576 and head_size_v == 512)))
|
|
|
|
for head_size in [
|
|
16, 32, 40, 48, 64, 72, 80, 96, 104, 128, 160, 192, 256, 512
|
|
]:
|
|
if sm == 70 and (head_size > 256 or head_size == 16):
|
|
continue
|
|
# TODO: test head_size=512 on sm75
|
|
if sm == 75 and head_size > 256:
|
|
continue
|
|
|
|
# tune ldgsts
|
|
ldgsts_q = True
|
|
ldgsts_k = True
|
|
ldgsts_v = True
|
|
if head_size >= 256:
|
|
ldgsts_k = False
|
|
ldgsts_v = False
|
|
if head_size > 256:
|
|
ldgsts_q = False
|
|
if sm < 80:
|
|
ldgsts_q = False
|
|
ldgsts_k = False
|
|
ldgsts_v = False
|
|
|
|
# tune kv fragment double buffer
|
|
limit_qk_fragments = False
|
|
limit_v_fragments = False
|
|
if head_size >= 256:
|
|
limit_qk_fragments = True
|
|
limit_v_fragments = True
|
|
elif head_size >= 128 and sm == 70:
|
|
limit_qk_fragments = True
|
|
limit_v_fragments = True
|
|
|
|
# tune kv_loop step
|
|
q_loop_step = 64
|
|
kv_loop_step = 64
|
|
if head_size > 128:
|
|
kv_loop_step = 16
|
|
elif (head_size > 64 and sm == 70):
|
|
kv_loop_step = 16
|
|
elif head_size > 32:
|
|
kv_loop_step = 32
|
|
|
|
if sm < 80 or head_size > 128:
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=sm_mma,
|
|
dtype=dtype,
|
|
flash_attention=True,
|
|
seq_len=0, # means any sequence here
|
|
kv_loop_step=kv_loop_step,
|
|
limit_qk_fragments=limit_qk_fragments,
|
|
limit_v_fragments=limit_v_fragments,
|
|
head_size=head_size,
|
|
head_size_v=head_size_v,
|
|
warps_m=4,
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=ldgsts_q,
|
|
ldgsts_k=ldgsts_k,
|
|
ldgsts_v=ldgsts_v,
|
|
share_smem_k_v=False,
|
|
loop_step=q_loop_step,
|
|
has_noloop=1,
|
|
noloop_step=q_loop_step,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
ctas_per_head=1,
|
|
input_layout=input_layout,
|
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping)
|
|
)
|
|
elif head_size <= 128:
|
|
# q_step = 64, kv_step = 32
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=sm_mma,
|
|
dtype=dtype,
|
|
flash_attention=True,
|
|
seq_len=0, # means any sequence here
|
|
kv_loop_step=kv_loop_step,
|
|
limit_qk_fragments=limit_qk_fragments,
|
|
limit_v_fragments=limit_v_fragments,
|
|
head_size=head_size,
|
|
head_size_v=head_size_v,
|
|
warps_m=4,
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=ldgsts_q,
|
|
ldgsts_k=ldgsts_k,
|
|
ldgsts_v=ldgsts_v,
|
|
share_smem_k_v=False,
|
|
loop_step=q_loop_step,
|
|
has_noloop=1,
|
|
noloop_step=q_loop_step,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
ctas_per_head=1,
|
|
input_layout=input_layout,
|
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping)
|
|
)
|
|
|
|
|
|
def enumerate_qgmma_kernels(specs, sm=90):
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype='e4m3',
|
|
seq_len=[64, 128, 192, 256, 384],
|
|
head_size=64,
|
|
warps_m=4, #4x1 warpgroups
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=
|
|
True, # for Hopper kernels, ldgsts = False signals TMA usage.
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=64,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=90,
|
|
dtype='e4m3',
|
|
seq_len=[512],
|
|
head_size=64,
|
|
warps_m=4, #4x2 warpgroups
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=
|
|
True, # for Hopper kernels, ldgsts = False signals TMA usage.
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=64,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
|
|
def enumerate_qmma_kernels(specs, sm=89):
|
|
# SM89 (Ada) fp8
|
|
# Head Size 64
|
|
|
|
# generate fp16 acc first
|
|
# NOTE: generate only one acc type if it is used for cubin loading
|
|
# or modify the TestMetaV2 to have acc_type
|
|
for dtype in ['e4m3_fp16', 'e4m3_fp32']:
|
|
# SEQ 64
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=89,
|
|
dtype=dtype,
|
|
seq_len=64,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=0,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# SEQ 96
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=89,
|
|
dtype=dtype,
|
|
seq_len=96,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=96,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# SEQ 128
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=89,
|
|
dtype=dtype,
|
|
seq_len=128,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=128,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# SEQ 192/256/384
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=89,
|
|
dtype=dtype,
|
|
seq_len=[192, 256, 384],
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# SEQ 512
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=89,
|
|
dtype=dtype,
|
|
seq_len=512,
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
|
|
def enumerate_qmma_flash_kernels(specs,
|
|
sm=89,
|
|
dtype='e4m3_fp32',
|
|
head_sizes=None,
|
|
sage_block_sizes=None,
|
|
output_dtype=None):
|
|
# ((head_size, head_size_v), (q_loop_step, kv_loop_step), tiled).
|
|
params_q_kv_step = [
|
|
(32, (128, 128), 0),
|
|
(40, (128, 128), 0),
|
|
(48, (128, 128), 0),
|
|
(64, (128, 128), 0),
|
|
(72, (64, 32), 0),
|
|
(80, (64, 32), 0),
|
|
(96, (64, 32), 0),
|
|
(104, (64, 32), 0),
|
|
(128, (64, 32), 0),
|
|
(160, (64, 32), 0),
|
|
(192, (64, 32), 0),
|
|
(256, (64, 32), 0),
|
|
# MLA kernels.
|
|
((192, 128), (64, 64), 1),
|
|
((576, 512), (64, 64), 1),
|
|
]
|
|
input_layouts = [
|
|
InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV,
|
|
InputLayout.Q_PAGED_KV, InputLayout.SEPARATE_Q_K_V
|
|
]
|
|
for (head_size_params, (q_loop_step, kv_loop_step), tiled), input_layout in \
|
|
product(params_q_kv_step, input_layouts):
|
|
# head_size_v = 0 means head_size_v is the same as head_size
|
|
if isinstance(head_size_params, tuple):
|
|
head_size = head_size_params[0]
|
|
head_size_v = head_size_params[1]
|
|
else:
|
|
head_size = head_size_params
|
|
head_size_v = 0
|
|
# skip if head_size is not in head_sizes
|
|
if head_sizes is not None and head_size not in head_sizes:
|
|
continue
|
|
# skip if head_size_v is not 128 for separate-q-k-v
|
|
if input_layout == InputLayout.SEPARATE_Q_K_V and head_size_v != 128:
|
|
continue
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=89,
|
|
dtype=dtype,
|
|
seq_len=0,
|
|
head_size=head_size,
|
|
head_size_v=head_size_v,
|
|
warps_m=4,
|
|
warps_n=1,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=q_loop_step,
|
|
has_noloop=1,
|
|
noloop_step=q_loop_step,
|
|
kv_loop_step=kv_loop_step,
|
|
tiled=tiled,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
flash_attention=True,
|
|
limit_qk_fragments=False,
|
|
limit_v_fragments=False,
|
|
ctas_per_head=1,
|
|
input_layout=input_layout,
|
|
sage_block_sizes=sage_block_sizes,
|
|
output_dtype=output_dtype,
|
|
is_mtp=(head_size == 576 and head_size_v == 512)))
|
|
|
|
|
|
def enumerate_imma_kernels(specs, sm=80):
|
|
if sm == 90:
|
|
# The following kernels are imma-based kernels tuned for sm90
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=[64, 128, 256],
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=[384, 512],
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# # SM 80 / 86
|
|
# #- Int8 (same for interleaved)
|
|
|
|
#- S=1024 split over 4 CTAs.
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=256,
|
|
head_size=[32, 64],
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=0,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
ctas_per_head=4))
|
|
|
|
#- S=512: STEP=32, STEP NL=32 FLAGS=0x1
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=512,
|
|
head_size=[32, 64],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=512,
|
|
head_size=[32, 64],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# D=16: currently needs to run with Turing traits due to K=16 for BMM1.
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='int8',
|
|
seq_len=512,
|
|
head_size=16,
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True if sm >= 80 else False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# D=16: currently needs to run with Turing traits due to K=16 for BMM1.
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='int8',
|
|
seq_len=512,
|
|
head_size=16,
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True if sm >= 80 else False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
#- S=384: STEP=32, STEP NL=32 FLAGS=0x1
|
|
specs.append(
|
|
kernel_spec(
|
|
sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=384,
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=8, # required by pred packing.
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=[192, 256],
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=[192, 256, 384],
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=[192, 256, 384],
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=True,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
#- S=256: STEP=32, STEP NL=32 FLAGS=0x1
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=256,
|
|
head_size=32,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=256,
|
|
head_size=32,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='int8',
|
|
seq_len=256,
|
|
head_size=16,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True if sm >= 80 else False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='int8',
|
|
seq_len=256,
|
|
head_size=16,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True if sm >= 80 else False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# S=192: STEP=64, STEP NL=32 FLAGS=0x1
|
|
#- S=128: STEP=NA, STEP NL=16 FLAGS=0x1
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=128,
|
|
head_size=[32, 64],
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='int8',
|
|
seq_len=128,
|
|
head_size=16,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True if sm >= 80 else False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=128,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=128,
|
|
head_size=[32, 64],
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=128,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='int8',
|
|
seq_len=128,
|
|
head_size=16,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True if sm >= 80 else False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=128,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=128,
|
|
head_size=[32, 64],
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=True,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=128,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
#- S=96
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=96,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=96,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=96,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=96,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=96,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=True,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=96,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# #- S=64:
|
|
# TODO noloop doesn't work - need to adjust packing into registers for
|
|
# Mma_tile_p::MMAS_N == 1 => Mma_tile_o::MMAS_K == 1 (at least on SM8x)
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=64,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=0,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=64,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=0,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=80,
|
|
dtype='int8',
|
|
seq_len=64,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=True,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=0,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
# This config compiles IMMA 1x4 kernels for SM90
|
|
#specs.append(kernel_spec(sm=90,
|
|
# sm_mma=80,
|
|
# dtype='int8',
|
|
# seq_len=[128,192,256, 384],
|
|
# head_size=64,
|
|
# warps_m=1,
|
|
# warps_n=4,
|
|
# version=2,
|
|
# interleaved=False,
|
|
# ldgsts_q=True,
|
|
# ldgsts_k=False,
|
|
# ldgsts_v=False,
|
|
# share_smem_k_v=False,
|
|
# loop_step=32,
|
|
# has_noloop=0,
|
|
# noloop_step=32,
|
|
# unroll_threshold=1,
|
|
# has_scale_max=False))
|
|
|
|
#- Int8 (same for interleaved)
|
|
#- S=512: STEP=32, STEP NL=32 FLAGS=0x1
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='int8',
|
|
seq_len=[384, 512],
|
|
head_size=[32, 64],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=16,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='int8',
|
|
seq_len=256,
|
|
head_size=[32, 64],
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='int8',
|
|
seq_len=512,
|
|
head_size=[32, 64],
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='int8',
|
|
seq_len=[192, 256, 384],
|
|
head_size=[32, 64],
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=16,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
#- S=384: STEP=32, STEP NL=32 FLAGS=0x0
|
|
#- S=256: STEP=32, STEP NL=32 FLAGS=0x0
|
|
#- S=128: STEP=32, STEP NL=32 FLAGS=0x0
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='int8',
|
|
seq_len=128,
|
|
head_size=[32, 64],
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=1,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=128,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
specs.append(
|
|
kernel_spec(sm=sm,
|
|
sm_mma=75,
|
|
dtype='int8',
|
|
seq_len=128,
|
|
head_size=[32, 64],
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=128,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False))
|
|
|
|
#- S=192: STEP=64, STEP NL=64 FLAGS=0x0
|
|
#- S=128: STEP=NA, STEP NL=16 FLAGS=0x8
|
|
#- S=96
|
|
#- S=64
|
|
|
|
|
|
def enumerate_cross_mha_kernels(specs):
|
|
# TODO: combine cross_mha and mha kernel enumeration
|
|
#- S_Q=4096, S_KV=128: STEP=64, STEP NL=64
|
|
# HEAD_SIZE: 64
|
|
# SM 70
|
|
if 'ENABLE_SM70' in os.environ:
|
|
specs.append(
|
|
kernel_spec(sm=70,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
cross_mha=1))
|
|
|
|
# SM 75
|
|
specs.append(
|
|
kernel_spec(sm=75,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=64,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
cross_mha=1))
|
|
|
|
# SM 80
|
|
specs.append(
|
|
kernel_spec(sm=80,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=64,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
cross_mha=1))
|
|
|
|
# SM 86
|
|
specs.append(
|
|
kernel_spec(sm=86,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=64,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
cross_mha=1))
|
|
|
|
# SM 89
|
|
specs.append(
|
|
kernel_spec(sm=89,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=64,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=64,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
cross_mha=1))
|
|
|
|
#- S_Q=1024, S_KV=128: STEP=64, STEP NL=32
|
|
# HEAD_SIZE: 128
|
|
# SM 70
|
|
if 'ENABLE_SM70' in os.environ:
|
|
specs.append(
|
|
kernel_spec(sm=70,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=128,
|
|
warps_m=2,
|
|
warps_n=2,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
cross_mha=1))
|
|
|
|
# SM 75
|
|
specs.append(
|
|
kernel_spec(sm=75,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=128,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=False,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
cross_mha=1))
|
|
|
|
# SM 80
|
|
specs.append(
|
|
kernel_spec(sm=80,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=128,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
cross_mha=1))
|
|
|
|
# SM 86
|
|
specs.append(
|
|
kernel_spec(sm=86,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=128,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=64,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
cross_mha=1))
|
|
|
|
# SM 89
|
|
specs.append(
|
|
kernel_spec(sm=89,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=128,
|
|
warps_m=1,
|
|
warps_n=4,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=64,
|
|
has_noloop=1,
|
|
noloop_step=32,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
cross_mha=1))
|
|
|
|
#- S_KV=128: STEP=32, STEP NL=32
|
|
# HEAD_SIZE: 256
|
|
# SM 70
|
|
# specs.append(kernel_spec(sm=70,
|
|
# dtype='fp16',
|
|
# seq_len=128,
|
|
# head_size=256,
|
|
# warps_m=2,
|
|
# warps_n=2,
|
|
# version=2,
|
|
# interleaved=False,
|
|
# ldgsts_q=False,
|
|
# ldgsts_k=False,
|
|
# ldgsts_v=False,
|
|
# share_smem_k_v=True,
|
|
# loop_step= 32,
|
|
# has_noloop=1,
|
|
# noloop_step=16,
|
|
# unroll_threshold=1,
|
|
# has_scale_max=False,
|
|
# cross_mha=1))
|
|
|
|
# # SM 75
|
|
# specs.append(kernel_spec(sm=75,
|
|
# dtype='fp16',
|
|
# seq_len=128,
|
|
# head_size=256,
|
|
# warps_m=1,
|
|
# warps_n=8,
|
|
# version=2,
|
|
# interleaved=False,
|
|
# ldgsts_q=False,
|
|
# ldgsts_k=False,
|
|
# ldgsts_v=False,
|
|
# share_smem_k_v=True,
|
|
# loop_step= 32,
|
|
# has_noloop=1,
|
|
# noloop_step=16,
|
|
# unroll_threshold=1,
|
|
# has_scale_max=False,
|
|
# cross_mha=1))
|
|
|
|
# SM 80
|
|
specs.append(
|
|
kernel_spec(sm=80,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=256,
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=True,
|
|
ldgsts_v=True,
|
|
share_smem_k_v=False,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
cross_mha=1))
|
|
|
|
# SM 86
|
|
specs.append(
|
|
kernel_spec(sm=86,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=256,
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
cross_mha=1))
|
|
|
|
# SM 89
|
|
specs.append(
|
|
kernel_spec(sm=89,
|
|
dtype='fp16',
|
|
seq_len=128,
|
|
head_size=256,
|
|
warps_m=1,
|
|
warps_n=8,
|
|
version=2,
|
|
interleaved=False,
|
|
ldgsts_q=True,
|
|
ldgsts_k=False,
|
|
ldgsts_v=False,
|
|
share_smem_k_v=True,
|
|
loop_step=32,
|
|
has_noloop=1,
|
|
noloop_step=16,
|
|
unroll_threshold=1,
|
|
has_scale_max=False,
|
|
cross_mha=1))
|
|
|
|
|
|
def enumerate_kernels():
|
|
if not os.path.exists('./generated'):
|
|
os.mkdir('./generated')
|
|
|
|
specs = []
|
|
|
|
# TODO we have to select the unroll_threshold over a grid of b and h for each arch
|
|
|
|
# Current fp16 384 kernel does 1x8 (smem limit), STEP=48. FP16 does not currently have noloop.
|
|
|
|
# SM 90
|
|
enumerate_hgmma_tma_kernels(specs, sm=90)
|
|
enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype='fp16')
|
|
enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype='bf16')
|
|
if 'ENABLE_HMMA_FP32' in os.environ:
|
|
enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype='fp16_fp32')
|
|
enumerate_igmma_kernels(specs, sm=90)
|
|
enumerate_qgmma_kernels(specs, sm=90)
|
|
# need to add bf16 kernels if needed
|
|
for enable_skip_softmax in [False, True]:
|
|
if enable_skip_softmax and 'DISABLE_SKIP_SOFTMAX' in os.environ:
|
|
continue
|
|
enumerate_hgmma_flash_warpspec_kernels(
|
|
specs, sm=90, dtype='fp16', enable_skip_softmax=enable_skip_softmax)
|
|
enumerate_hgmma_flash_warpspec_kernels(
|
|
specs, sm=90, dtype='bf16', enable_skip_softmax=enable_skip_softmax)
|
|
enumerate_qgmma_flash_warpspec_kernels(
|
|
specs, sm=90, dtype='e4m3', enable_skip_softmax=enable_skip_softmax)
|
|
enumerate_qgmma_flash_warpspec_kernels(
|
|
specs,
|
|
sm=90,
|
|
dtype='e4m3',
|
|
output_dtype="bf16",
|
|
enable_skip_softmax=enable_skip_softmax)
|
|
|
|
# For now SageAttention only needs BF16
|
|
# block_size_q should be divisible by 64
|
|
# block_size_k should be divisible by 8
|
|
# block_size_v should be divisible by 32
|
|
for sage_block_sizes in [(64, 64, 64), (64, 64, 128), (64, 64, 256),
|
|
(64, 128, 64), (64, 128, 128), (64, 128, 256)]:
|
|
enumerate_qgmma_flash_warpspec_kernels(
|
|
specs,
|
|
sm=90,
|
|
dtype='e4m3',
|
|
sage_block_sizes=sage_block_sizes,
|
|
output_dtype="bf16")
|
|
|
|
if 'ENABLE_HMMA_FP32' in os.environ:
|
|
enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16_fp32')
|
|
# Optionally generate HMMA kernels on SM90 for comparison.
|
|
if 'SM90_USE_HMMA' in os.environ:
|
|
print("WARNING: GENERATING HMMA KERNELS INSTEAD OF HGMMA FOR SM90")
|
|
enumerate_hmma_kernels(specs, sm=90, dtype='fp16')
|
|
enumerate_hmma_kernels(specs, sm=90, dtype='bf16')
|
|
|
|
# SM90 IGMMA
|
|
if 'SM90_USE_IMMA' in os.environ:
|
|
print("WARNING: GENERATING IMMA KERNELS INSTEAD OF IGMMA FOR SM90")
|
|
enumerate_imma_kernels(specs, sm=90)
|
|
|
|
# SM 89
|
|
if 'ENABLE_SM89_QMMA' in os.environ:
|
|
enumerate_qmma_kernels(specs, sm=89)
|
|
enumerate_qmma_flash_kernels(specs, sm=89, dtype='e4m3_fp32')
|
|
# Add bf16 output MLA kernels.
|
|
enumerate_qmma_flash_kernels(specs,
|
|
sm=89,
|
|
dtype='e4m3_fp32',
|
|
head_sizes=[192, 576],
|
|
output_dtype="bf16")
|
|
# Sage Attention on Ada only supports block_size = (64, 32, 32)
|
|
enumerate_qmma_flash_kernels(specs,
|
|
sm=89,
|
|
dtype='e4m3_fp32',
|
|
sage_block_sizes=(64, 32, 32),
|
|
output_dtype="bf16")
|
|
enumerate_qmma_flash_kernels(specs,
|
|
sm=89,
|
|
dtype='e4m3_fp32',
|
|
sage_block_sizes=(64, 32, 32),
|
|
output_dtype="fp16")
|
|
|
|
enumerate_imma_kernels(specs, sm=89)
|
|
enumerate_hmma_kernels(specs, sm=89, dtype='fp16')
|
|
enumerate_hmma_kernels(specs, sm=89, dtype='bf16')
|
|
enumerate_hmma_flash_kernels(specs, sm=89, dtype='fp16')
|
|
enumerate_hmma_flash_kernels(specs, sm=89, dtype='bf16')
|
|
|
|
# SM 80 / 86
|
|
enumerate_imma_kernels(specs, sm=80)
|
|
enumerate_hmma_kernels(specs, sm=80, dtype='fp16')
|
|
enumerate_hmma_kernels(specs, sm=80, dtype='bf16')
|
|
enumerate_hmma_flash_kernels(specs, sm=80, dtype='fp16')
|
|
enumerate_hmma_flash_kernels(specs, sm=80, dtype='bf16')
|
|
|
|
enumerate_imma_kernels(specs, sm=86)
|
|
enumerate_hmma_kernels(specs, sm=86, dtype='fp16')
|
|
enumerate_hmma_kernels(specs, sm=86, dtype='bf16')
|
|
enumerate_hmma_flash_kernels(specs, sm=86, dtype='fp16')
|
|
enumerate_hmma_flash_kernels(specs, sm=86, dtype='bf16')
|
|
|
|
# SM 90 (only generate paged_kv_fmha hmma kernels)
|
|
enumerate_hmma_paged_kv_flash_kernels(specs, sm=90, dtype='fp16')
|
|
enumerate_hmma_paged_kv_flash_kernels(specs, sm=90, dtype='bf16')
|
|
|
|
if 'ENABLE_SM100' in os.environ:
|
|
# SM 100
|
|
enumerate_hmma_flash_kernels(specs, sm=100, dtype='fp16')
|
|
enumerate_hmma_flash_kernels(specs, sm=100, dtype='bf16')
|
|
enumerate_hmma_flash_kernels(specs,
|
|
sm=100,
|
|
dtype='bf16',
|
|
head_size_v=128)
|
|
enumerate_hmma_flash_kernels(specs,
|
|
sm=100,
|
|
dtype='bf16',
|
|
head_size_v=512)
|
|
|
|
if 'ENABLE_SM120' in os.environ:
|
|
# SM 120
|
|
enumerate_hmma_flash_kernels(specs, sm=120, dtype='fp16')
|
|
enumerate_hmma_flash_kernels(specs, sm=120, dtype='bf16')
|
|
enumerate_hmma_flash_kernels(specs,
|
|
sm=120,
|
|
dtype='bf16',
|
|
head_size_v=128)
|
|
enumerate_hmma_flash_kernels(specs,
|
|
sm=120,
|
|
dtype='bf16',
|
|
head_size_v=512)
|
|
enumerate_qmma_kernels(specs, sm=120)
|
|
enumerate_qmma_flash_kernels(specs, sm=120, dtype='e4m3_fp32')
|
|
# Add bf16 output MLA kernels.
|
|
enumerate_qmma_flash_kernels(specs,
|
|
sm=120,
|
|
dtype='e4m3_fp32',
|
|
head_sizes=[192, 576],
|
|
output_dtype="bf16")
|
|
|
|
if 'ENABLE_HMMA_FP32' in os.environ:
|
|
enumerate_hmma_flash_kernels(specs, sm=80, dtype='fp16_fp32')
|
|
enumerate_hmma_flash_kernels(specs, sm=86, dtype='fp16_fp32')
|
|
enumerate_hmma_flash_kernels(specs, sm=89, dtype='fp16_fp32')
|
|
# SM 90 (only generate paged_kv_fmha hmma kernels)
|
|
enumerate_hmma_paged_kv_flash_kernels(specs, sm=90, dtype='fp16_fp32')
|
|
if 'ENABLE_SM100' in os.environ:
|
|
# SM 100
|
|
enumerate_hmma_flash_kernels(specs, sm=100, dtype='fp16_fp32')
|
|
if 'ENABLE_SM120' in os.environ:
|
|
# SM 120
|
|
enumerate_hmma_flash_kernels(specs, sm=120, dtype='fp16_fp32')
|
|
|
|
for sm in [80, 86, 89, 90]:
|
|
if not (sm == 90 and "GENERATE_CUBIN" in os.environ):
|
|
# Hopper uses warp-specialized kernels instead (hasn't been merged yet).
|
|
enumerate_hmma_flash_kernels(specs,
|
|
sm=sm,
|
|
dtype='bf16',
|
|
head_size_v=128)
|
|
enumerate_hmma_flash_kernels(specs,
|
|
sm=sm,
|
|
dtype='bf16',
|
|
head_size_v=512)
|
|
|
|
# SM 75
|
|
enumerate_imma_kernels(specs, sm=75)
|
|
enumerate_hmma_kernels(specs, sm=75)
|
|
enumerate_hmma_flash_kernels(specs, sm=75)
|
|
|
|
# SM 70
|
|
if 'ENABLE_SM70' in os.environ:
|
|
enumerate_hmma884_kernels(specs, sm=70)
|
|
enumerate_hmma_flash_kernels(specs, sm=70)
|
|
|
|
# TODO: refactor this; maybe adding a option to enumerate_*mma_kernels()
|
|
enumerate_cross_mha_kernels(specs)
|
|
|
|
# Expand the cartesian product of the list fields "seq_len" and "head_size".
|
|
specs_expanded = []
|
|
list_like = lambda x: isinstance(x, list) or isinstance(x, tuple)
|
|
for kspec in specs:
|
|
tmp_s = kspec.seq_len
|
|
tmp_d = kspec.head_size
|
|
tmp_dtype = kspec.dtype
|
|
tmp_exp = [kspec._replace(seq_len=s)
|
|
for s in tmp_s] if list_like(tmp_s) else [kspec]
|
|
tmp_exp = [
|
|
tmp_ks._replace(head_size=d) for d in tmp_d for tmp_ks in tmp_exp
|
|
] if list_like(tmp_d) else tmp_exp
|
|
tmp_exp = [
|
|
tmp_ks._replace(dtype=dt) for dt in tmp_dtype for tmp_ks in tmp_exp
|
|
] if list_like(tmp_dtype) else tmp_exp
|
|
specs_expanded.extend(tmp_exp)
|
|
|
|
# Sanitize kernel specs
|
|
specs_expanded = [
|
|
kspec for kspec in specs_expanded if kspec.sm >= kspec.sm_mma
|
|
]
|
|
|
|
# Expand the list for the cross-MHA kernels.
|
|
# TRT-LLM uses the head_interleaved=False mode.
|
|
if 'GENERATE_CUBIN' in os.environ:
|
|
specs_expanded = [
|
|
kspec._replace(head_interleaved=False) for kspec in specs_expanded
|
|
]
|
|
# yapf: disable
|
|
specs_names = [(kspec, *encode_name(kspec)) for kspec in specs_expanded
|
|
# Volta is deprecated in TRT-LLM.
|
|
if (kspec.sm in [80, 86, 89, 90, 120]
|
|
and kspec.dtype in ['fp16', 'bf16', 'fp16_fp32', 'e4m3', 'e4m3_fp32']
|
|
and kspec.head_size <= 256
|
|
and kspec.head_size_v == 0
|
|
and kspec.sage_block_sizes is None
|
|
and kspec.version == 2
|
|
and kspec.cross_mha == False
|
|
and kspec.flash_attention == True
|
|
and kspec.input_layout != InputLayout.SEPARATE_Q_K_V
|
|
or (kspec.sm == 90
|
|
and kspec.dtype in ['fp16', 'bf16', 'fp16_fp32']
|
|
and kspec.head_size <= 256
|
|
and kspec.ldgsts_q == True
|
|
and kspec.version == 2
|
|
and kspec.cross_mha == False
|
|
and kspec.flash_attention == False)
|
|
# Clip/SigLip support.
|
|
or (kspec.sm == 100
|
|
and kspec.dtype in ['fp16', 'bf16', 'fp16_fp32', 'e4m3', 'e4m3_fp32']
|
|
and kspec.head_size == 80
|
|
and kspec.head_size_v == 0
|
|
and kspec.sage_block_sizes is None
|
|
and kspec.version == 2
|
|
and kspec.cross_mha == False
|
|
and kspec.flash_attention == True
|
|
and kspec.input_layout != InputLayout.SEPARATE_Q_K_V)
|
|
# Gemma3 VL support.
|
|
or (kspec.sm == 100
|
|
and kspec.dtype in ['fp16', 'bf16', 'fp16_fp32', 'e4m3', 'e4m3_fp32']
|
|
and kspec.head_size == 72
|
|
and kspec.head_size_v == 0
|
|
and kspec.sage_block_sizes is None
|
|
and kspec.version == 2
|
|
and kspec.cross_mha == False
|
|
and kspec.flash_attention == True
|
|
and kspec.input_layout != InputLayout.SEPARATE_Q_K_V)
|
|
# Deepseek MLA (generation 576/512 paged)
|
|
or (kspec.sm in [90, 100, 120]
|
|
and kspec.dtype in ['bf16', 'e4m3_fp32']
|
|
and kspec.head_size == 576
|
|
and kspec.head_size_v == 512
|
|
and kspec.input_layout == InputLayout.Q_PAGED_KV
|
|
and kspec.sage_block_sizes is None
|
|
and kspec.version == 2
|
|
and kspec.cross_mha == False
|
|
and kspec.flash_attention == True
|
|
and kspec.warp_specialization == False
|
|
and kspec.tiled == True)
|
|
# Deepseek MLA (context 192/128 separate-q-k-v)
|
|
or (kspec.sm in [90, 100, 120]
|
|
and kspec.dtype in ['bf16', 'e4m3', 'e4m3_fp32']
|
|
and kspec.head_size == 192
|
|
and kspec.head_size_v == 128
|
|
and kspec.input_layout == InputLayout.SEPARATE_Q_K_V
|
|
and kspec.sage_block_sizes is None
|
|
and kspec.version == 2
|
|
and kspec.cross_mha == False
|
|
and kspec.flash_attention == True
|
|
and ((kspec.warp_specialization == True and kspec.alibi == False) # sm90
|
|
or (kspec.warp_specialization == False and kspec.tiled == True)) # non-sm90
|
|
and kspec.enable_attn_logit_softcapping == False)
|
|
# SageAttention (warp_spec, head_size in (80, 128), packed QKV, padding mask)
|
|
or (kspec.sm == 90
|
|
and kspec.head_size in [80, 128]
|
|
and kspec.version == 2
|
|
and kspec.sage_block_sizes in [(64, 64, 256)]
|
|
and kspec.cross_mha == False
|
|
and kspec.flash_attention == True
|
|
and kspec.warp_specialization == True
|
|
and kspec.input_layout == InputLayout.PACKED_QKV
|
|
and kspec.alibi == False
|
|
and kspec.enable_attn_logit_softcapping == False)
|
|
# SageAttention on Ada (head_size in (80, 128), packed QKV, padding mask)
|
|
or (kspec.sm == 89
|
|
and kspec.head_size in [80, 128]
|
|
and kspec.sage_block_sizes in [(64, 32, 32)]
|
|
and kspec.output_dtype in ['fp16', 'bf16']
|
|
and kspec.version == 2
|
|
and kspec.cross_mha == False
|
|
and kspec.flash_attention == True
|
|
and kspec.warp_specialization == False
|
|
and kspec.input_layout == InputLayout.PACKED_QKV))
|
|
# only generate head_size = 128/256 for attn_logit_softcapping operation.
|
|
and (kspec.head_size == 128 or kspec.head_size == 256 or not kspec.enable_attn_logit_softcapping)]
|
|
# yapf: enable
|
|
|
|
generate_files(specs_names)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
enumerate_kernels()
|
|
|
|
# General restrictions
|
|
# FP16: no s=192
|
|
# FP16: no Volta
|
|
# Interleaved only for Int8
|
|
|
|
# v1:
|
|
# 384 should have 1x8 kernels not to exceed xmmas_n = 4
|
|
# No support for interleaved
|
|
|
|
# v2:
|
|
#
|
|
|
|
# TODO record all step and smem configs
|