TensorRT-LLMs/cpp/kernels/fmha_v2/src/fused_multihead_attention.h
Zhou Yuxin fca13b8c95
hopper-style context MLA (#5713)
Signed-off-by: Yuxin <yuxinz@nvidia.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Yiqing Yan <yiqingy@nvidia.com>
Signed-off-by: qqiao <qqiao@nvidia.com>
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
Signed-off-by: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com>
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Signed-off-by: Rashid K <rkaleem@nvidia.com>
Signed-off-by: Zhenhuan Chen <chenzhh3671@gmail.com>
Signed-off-by: Po-Wei Wang (Vincent) <poweiw@nvidia.com>
Signed-off-by: Netanel Haber <nhaber@nvidia.com>
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Clay <ccs96307@gmail.com>
Signed-off-by: Venky <23023424+venkywonka@users.noreply.github.com>
Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com>
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com>
Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com>
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
Signed-off-by: Hui Gao <huig@nvidia.com>
Signed-off-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
Signed-off-by: Xianjie Qiao <5410381+qiaoxj07@users.noreply.github.com>
Signed-off-by: Julien Debache <julien.debache@hotmail.com>
Signed-off-by: Yanchao Lu <yanchaol@nvidia.com>
Signed-off-by: Yiteng Niu <6831097+niukuo@users.noreply.github.com>
Signed-off-by: Daniel Stokes <40156487+djns99@users.noreply.github.com>
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com>
Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
Signed-off-by: David Clark <215764518+davidclark-nv@users.noreply.github.com>
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Signed-off-by: JieXin Liang <Alcanderian@users.noreply.github.com>
Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com>
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
Signed-off-by: Yegor <75512761+Wokzy@users.noreply.github.com>
Signed-off-by: Yegor Yershov <yegor6741@gmail.com>
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Signed-off-by: raayandhar <rdhar@nvidia.com>
Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
Signed-off-by: xsimmons <xsimmons@nvidia.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com>
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
Signed-off-by: Ubuntu <ubuntu@ip-10-0-20-146.us-west-2.compute.internal>
Signed-off-by: Hanjun Cho <46752251+gkswns0531@users.noreply.github.com>
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Signed-off-by: narutolhy <582909902@qq.com>
Signed-off-by: ZhanruiSunCh <184402041+ZhanruiSunCh@users.noreply.github.com>
Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>
Signed-off-by: Frank <3429989+FrankD412@users.noreply.github.com>
Signed-off-by: Yilin Zhang <18275976+yilin-void@users.noreply.github.com>
Signed-off-by: William Tambellini <wtambellini@sdl.com>
Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: Yiqing Yan <yiqingy@nvidia.com>
Co-authored-by: Emma Qiao <qqiao@nvidia.com>
Co-authored-by: WeiHaocheng <20514172+WeiHaocheng@users.noreply.github.com>
Co-authored-by: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com>
Co-authored-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
Co-authored-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Co-authored-by: Rashid Kaleem <4079439+arekay@users.noreply.github.com>
Co-authored-by: Zhihan Jiang <68881590+nvzhihanj@users.noreply.github.com>
Co-authored-by: Zhenhuan Chen <chenzhh3671@gmail.com>
Co-authored-by: Po-Wei (Vincent) <poweiw@nvidia.com>
Co-authored-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
Co-authored-by: Neta Zmora <nzmora@nvidia.com>
Co-authored-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Co-authored-by: Clay <ccs96307@gmail.com>
Co-authored-by: Venky <23023424+venkywonka@users.noreply.github.com>
Co-authored-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Co-authored-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
Co-authored-by: Zheng Duan <200704041+zhengd-nv@users.noreply.github.com>
Co-authored-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
Co-authored-by: Frank <3429989+FrankD412@users.noreply.github.com>
Co-authored-by: brb-nv <169953907+brb-nv@users.noreply.github.com>
Co-authored-by: Linda <57756729+Linda-Stadter@users.noreply.github.com>
Co-authored-by: Shunkangz <182541032+Shunkangz@users.noreply.github.com>
Co-authored-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
Co-authored-by: Tailing Yuan <yuantailing@gmail.com>
Co-authored-by: Faraz <58580514+farazkh80@users.noreply.github.com>
Co-authored-by: peaceh-nv <103117813+peaceh-nv@users.noreply.github.com>
Co-authored-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
Co-authored-by: HuiGao-NV <huig@nvidia.com>
Co-authored-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com>
Co-authored-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
Co-authored-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
Co-authored-by: jthomson04 <jwillthomson19@gmail.com>
Co-authored-by: Xianjie Qiao <5410381+qiaoxj07@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Julien Debache <jdebache@nvidia.com>
Co-authored-by: Yanchao Lu <yanchaol@nvidia.com>
Co-authored-by: Yiteng Niu <6831097+niukuo@users.noreply.github.com>
Co-authored-by: Daniel Stokes <40156487+djns99@users.noreply.github.com>
Co-authored-by: bhsueh_NV <11360707+byshiue@users.noreply.github.com>
Co-authored-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Co-authored-by: ChristinaZ <83400082+ChristinaZ@users.noreply.github.com>
Co-authored-by: Larry <197874197+LarryXFly@users.noreply.github.com>
Co-authored-by: DylanChen-NV <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: Daniel Cámpora <961215+dcampora@users.noreply.github.com>
Co-authored-by: davidclark-nv <215764518+davidclark-nv@users.noreply.github.com>
Co-authored-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com>
Co-authored-by: Yechan Kim <161688079+yechank-nvidia@users.noreply.github.com>
Co-authored-by: liji-nv <59594262+liji-nv@users.noreply.github.com>
Co-authored-by: JieXin Liang <Alcanderian@users.noreply.github.com>
Co-authored-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Co-authored-by: xiweny <13230610+VALLIS-NERIA@users.noreply.github.com>
Co-authored-by: Yegor <75512761+Wokzy@users.noreply.github.com>
Co-authored-by: Yukun He <23156053+hyukn@users.noreply.github.com>
Co-authored-by: Raayan Dhar <58057652+raayandhar@users.noreply.github.com>
Co-authored-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
Co-authored-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Co-authored-by: Pamela Peng <179191831+pamelap-nvidia@users.noreply.github.com>
Co-authored-by: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com>
Co-authored-by: xavier-nvidia <xsimmons@nvidia.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Co-authored-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
Co-authored-by: Erin <14718778+hchings@users.noreply.github.com>
Co-authored-by: chenfeiz0326 <chenfeiz@nvidia.com>
Co-authored-by: dongxuy04 <78518666+dongxuy04@users.noreply.github.com>
Co-authored-by: 2ez4bz <133824995+2ez4bz@users.noreply.github.com>
Co-authored-by: Hanjun Cho <46752251+gkswns0531@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-10-0-20-146.us-west-2.compute.internal>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
Co-authored-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
Co-authored-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Co-authored-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com>
Co-authored-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Co-authored-by: narutolhy <582909902@qq.com>
Co-authored-by: Zhanrui Sun <184402041+ZhanruiSunCh@users.noreply.github.com>
Co-authored-by: wili <98001977+wili-65535@users.noreply.github.com>
Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com>
Co-authored-by: Void <18275976+yilin-void@users.noreply.github.com>
Co-authored-by: William Tambellini <wtambellini@sdl.com>
2025-07-23 14:37:20 +08:00

322 lines
11 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include <cuda.h>
#include <fmha/alibi_params.h>
#include <fmha/hopper/tma_types.h>
#include <fmha/paged_kv_cache.h>
#include <fused_multihead_attention_utils.h>
#include <vector>
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
// Make sure the mask input is padded to 128 x 256 tile size in order to
// match all Ampere/Hopper kernels.
static constexpr int FLASH_ATTEN_MASK_M_ALIGNMENT = 128;
static constexpr int FLASH_ATTEN_MASK_N_ALIGNMENT = 256;
// The packed mask's MMA tile size is 64 x 64.
static constexpr int FLASH_ATTEN_MASK_MMA_M = 64;
static constexpr int FLASH_ATTEN_MASK_MMA_N = 64;
////////////////////////////////////////////////////////////////////////////////////////////////////
enum class Attention_mask_type
{
// Mask the padded tokens.
PADDING = 0,
// Mask the padded tokens and all the tokens that come after in a sequence.
CAUSAL,
// Causal mask + attend to the specific sliding window or chunk.
SLIDING_OR_CHUNKED_CAUSAL,
// The custom mask input.
CUSTOM_MASK,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline std::string mask_type_to_string(Attention_mask_type mask_type)
{
switch (mask_type)
{
case Attention_mask_type::PADDING: return "padding";
case Attention_mask_type::CAUSAL: return "causal";
case Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL: return "sliding_or_chunked_causal";
case Attention_mask_type::CUSTOM_MASK: return "custom_mask";
default: assert(false); return "";
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
enum class Attention_input_layout
{
// QKV are packed into [B, S, 3, H, D] layout.
PACKED_QKV = 0,
// Q has contiguous [B, S, H, D] layout, while KV has contiguous [B, 2, H, S, D] layout.
CONTIGUOUS_Q_KV,
// Q has contiguous [B, S, H, D] layout, while paged KV layout are blocks of indices with shape
// of [B, 2, Blocks_per_Seq], and the indice indicates the block distance to the pool ptr in
// global memory.
Q_PAGED_KV,
// Q has [B, S, H, D] layout,
// K has [B, S, H_kv, D] layout,
// V has [B, S, H_kv, Dv] layout,
SEPARATE_Q_K_V,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline std::string attention_input_layout_to_string(Attention_input_layout layout)
{
switch (layout)
{
case Attention_input_layout::PACKED_QKV: return "packed_qkv";
case Attention_input_layout::CONTIGUOUS_Q_KV: return "contiguous_q_kv";
case Attention_input_layout::Q_PAGED_KV: return "contiguous_q_paged_kv";
case Attention_input_layout::SEPARATE_Q_K_V: return "separate_q_k_v";
default: assert(false); return "";
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace bert
{
////////////////////////////////////////////////////////////////////////////////////////////////////
#if USE_DEMO_BERT_PARAMS
// TODO TRT plugins use a different parameter struct taken from the old XMMA fork.
// Until all cubins in the plugin are replaced with new kernels, we need to conform to that.
#include <fused_multihead_attention_demo_bert_params.h>
#else
struct Fused_multihead_attention_params_base
{
// The QKV matrices.
void* qkv_ptr;
// The O matrix (output).
void* o_ptr;
// The stride between rows of O.
int64_t o_stride_in_bytes;
#if defined(STORE_P)
// The pointer to the P matrix (for debugging).
void* p_ptr;
// The stride between rows of the P matrix (for debugging).
int64_t p_stride_in_bytes;
#endif // defined(STORE_P)
#if defined(STORE_S)
// The pointer to the S matrix (for debugging).
void* s_ptr;
// The stride between rows of the S matrix (for debugging).
int64_t s_stride_in_bytes;
#endif // defined(STORE_S)
#if defined(DEBUG_HAS_PRINT_BUFFER)
void* print_ptr;
#endif
// The dimensions.
int b, h, s, d;
// The scaling factors for the kernel.
uint32_t scale_bmm1, scale_softmax, scale_bmm2;
// The bmm2 scaling factors in the device.
uint32_t* scale_bmm1_d;
uint32_t* scale_bmm2_d;
// Do we use Niall's trick to avoid I2F/F2I in the INT8 kernel.
bool enable_i2f_trick;
// true: for int8, instead of doing max reduce, use max value encoded in scale factor
bool use_int8_scale_max = false;
// If the kernel is using alibi or not
bool has_alibi = false;
fmha::AlibiParams alibi_params;
// The number of heads computed by one iteration of the wave.
int heads_per_wave;
// Buffers to perform a global sync and a critical section.
int *counters, *max_barriers, *sum_barriers, *locks;
// Scratch buffers to finalize softmax.
float *max_scratch_ptr, *sum_scratch_ptr;
// Scratch buffer to finalize the output (not needed for FP16).
int* o_scratch_ptr;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Fused_multihead_attention_params_v1 : Fused_multihead_attention_params_base
{
// The stride between rows of the Q, K and V matrices.
int64_t qkv_stride_in_bytes;
// The mask to implement drop-out.
void* packed_mask_ptr;
// The stride between matrices of packed mask.
int64_t packed_mask_stride_in_bytes;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_base
{
// The dimension of V. If unset, dv = d.
int dv = 0;
// The input to support any mask patterns.
void* packed_mask_ptr;
// The mask input's stride in the N (K-seq) dimension.
int64_t packed_mask_stride_in_bytes;
// The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max
void* softmax_stats_ptr;
// The stride between rows of softmax_stats_ptr
int64_t softmax_stats_stride_in_bytes;
// array of length b+1 holding prefix sum of actual q sequence lengths.
int* cu_q_seqlens;
// array of length b+1 holding prefix sum of actual kv sequence lengths.
int* cu_kv_seqlens;
// array of length b+1 holding prefix sum of actual mask sequence lengths.
// it might not be the same as cu_q_seqlens as the mask seqlens will be padded.
int* cu_mask_rows;
// tma descriptors on device.
// Either q in packed qkv [B, S, 3, H, D] of separate q layout [B, S, H, D].
fmha::cudaTmaDesc tma_desc_q;
// Tma descriptors for packed/contiguous/paged kv cache.
// Kv in packed qkv layout: [B, S, 3, H, D]
// Contiguous kv layout: [B, 2, H, S, D].
// Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D].
fmha::cudaTmaDesc tma_desc_k;
fmha::cudaTmaDesc tma_desc_v;
// Tma descriptor for o
fmha::cudaTmaDesc tma_desc_o;
// Contiguous Q buffer pointer [B, S, H, D].
void* q_ptr;
// The separate K matrice.
void* k_ptr;
// The separate V matrice.
void* v_ptr;
// Contiguous KV buffer pointer [B, 2, H, S, D].
void* kv_ptr;
// Paged KV Cache buffer.
fmha::Kv_block_array paged_kv_cache;
// Q and KV stride (used by LDGSTS).
int64_t q_stride_in_bytes;
int64_t k_stride_in_bytes;
int64_t v_stride_in_bytes;
// Paged KV load.
int blocks_per_tma_load;
int blocks_per_tma_load_log2;
// M tile id counter for dynamic scheduling
uint32_t* tile_id_counter_ptr;
uint32_t num_tiles;
uint32_t num_tiles_per_head;
bool use_balanced_scheduling;
// In multi-query or grouped-query attention (MQA/GQA), several Q heads are associated with one KV head
int h_kv = 0;
// h_q_per_kv is sometimes rematerialized in the kernel by formula h / h_kv to reclaim one register
int h_q_per_kv = 1;
// The number of grouped heads in the seqlen dimension.
int num_grouped_heads = 1;
// Sliding Window Attention
// Only pay attention to [max(0, query_idx - sliding_window_size), query_idx].
int sliding_window_size = INT_MAX;
// The chunked attention size (<= 0 means no chunked attention).
int log2_chunked_attention_size = 0;
// The softcapping scale (scale * tanh (x / scale)) applied to bmm1 output.
float softcapping_scale_bmm1 = 0.0f;
// is input/output padded
bool is_s_padded = false;
struct SageAttention
{
struct Scales
{
// this field is only used in bin/fmha.exe, will be omitted in exported cubin
int block_size;
// ceil(max_seqlen / block_size)
int max_nblock;
// The scale of each block, layout: (B, H, max_nblock)
float* scales;
} q, k, v;
} sage;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// flags to control kernel choice
struct Fused_multihead_attention_launch_params
{
// flags to control small batch kernel choice
// true: never unroll
bool ignore_b1opt = false;
// true: always unroll
bool force_unroll = false;
// use fp32 accumulation
bool force_fp32_acc = false;
// the C/32 format
bool interleaved = false;
// by default TMA is not used.
bool use_tma = false;
// total number of q tokens to set tma descriptors
int total_q_seqlen = 0;
// total number of kv tokens to set tma descriptors
int total_kv_seqlen = 0;
// if flash attention is used (only FP16)
bool flash_attention = false;
// if warp_specialized kernels are used (only SM90 HGMMA + TMA)
bool warp_specialization = false;
// granular tiling flash attention kernels
bool use_granular_tiling = false;
// causal masking or sliding_or_chunked_causal masking or dense(padding) mask.
fmha::Attention_mask_type attention_mask_type = fmha::Attention_mask_type::PADDING;
// the attention input layout.
fmha::Attention_input_layout attention_input_layout = fmha::Attention_input_layout::PACKED_QKV;
// enable_attn_logit_softcapping (choose kernels with softcapping_scale_bmm1).
bool enable_attn_logit_softcapping = false;
// harward properties to determine how to launch blocks
int multi_processor_count = 0;
int device_l2_cache_size = 0;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace bert