mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Yuxin <yuxinz@nvidia.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Yiqing Yan <yiqingy@nvidia.com> Signed-off-by: qqiao <qqiao@nvidia.com> Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com> Signed-off-by: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com> Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Signed-off-by: Rashid K <rkaleem@nvidia.com> Signed-off-by: Zhenhuan Chen <chenzhh3671@gmail.com> Signed-off-by: Po-Wei Wang (Vincent) <poweiw@nvidia.com> Signed-off-by: Netanel Haber <nhaber@nvidia.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: Clay <ccs96307@gmail.com> Signed-off-by: Venky <23023424+venkywonka@users.noreply.github.com> Signed-off-by: Xin He (SW-GPU) <200704525+xinhe-nv@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com> Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Signed-off-by: Frank Di Natale <3429989+FrankD412@users.noreply.github.com> Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com> Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com> Signed-off-by: Tailing Yuan <yuantailing@gmail.com> Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com> Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> Signed-off-by: Hui Gao <huig@nvidia.com> Signed-off-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Signed-off-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com> Signed-off-by: jthomson04 <jwillthomson19@gmail.com> Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com> Signed-off-by: Xianjie Qiao <5410381+qiaoxj07@users.noreply.github.com> Signed-off-by: Julien Debache <julien.debache@hotmail.com> Signed-off-by: Yanchao Lu <yanchaol@nvidia.com> Signed-off-by: Yiteng Niu <6831097+niukuo@users.noreply.github.com> Signed-off-by: Daniel Stokes <40156487+djns99@users.noreply.github.com> Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com> Signed-off-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com> Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com> Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Signed-off-by: David Clark <215764518+davidclark-nv@users.noreply.github.com> Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com> Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> Signed-off-by: JieXin Liang <Alcanderian@users.noreply.github.com> Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com> Signed-off-by: Yegor <75512761+Wokzy@users.noreply.github.com> Signed-off-by: Yegor Yershov <yegor6741@gmail.com> Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> Signed-off-by: raayandhar <rdhar@nvidia.com> Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com> Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Signed-off-by: xsimmons <xsimmons@nvidia.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com> Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com> Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com> Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com> Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com> Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> Signed-off-by: Ubuntu <ubuntu@ip-10-0-20-146.us-west-2.compute.internal> Signed-off-by: Hanjun Cho <46752251+gkswns0531@users.noreply.github.com> Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com> Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com> Signed-off-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com> Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> Signed-off-by: narutolhy <582909902@qq.com> Signed-off-by: ZhanruiSunCh <184402041+ZhanruiSunCh@users.noreply.github.com> Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> Signed-off-by: Frank <3429989+FrankD412@users.noreply.github.com> Signed-off-by: Yilin Zhang <18275976+yilin-void@users.noreply.github.com> Signed-off-by: William Tambellini <wtambellini@sdl.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Yiqing Yan <yiqingy@nvidia.com> Co-authored-by: Emma Qiao <qqiao@nvidia.com> Co-authored-by: WeiHaocheng <20514172+WeiHaocheng@users.noreply.github.com> Co-authored-by: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com> Co-authored-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Co-authored-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Co-authored-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Co-authored-by: Rashid Kaleem <4079439+arekay@users.noreply.github.com> Co-authored-by: Zhihan Jiang <68881590+nvzhihanj@users.noreply.github.com> Co-authored-by: Zhenhuan Chen <chenzhh3671@gmail.com> Co-authored-by: Po-Wei (Vincent) <poweiw@nvidia.com> Co-authored-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Co-authored-by: Neta Zmora <nzmora@nvidia.com> Co-authored-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Co-authored-by: Clay <ccs96307@gmail.com> Co-authored-by: Venky <23023424+venkywonka@users.noreply.github.com> Co-authored-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com> Co-authored-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Co-authored-by: Zheng Duan <200704041+zhengd-nv@users.noreply.github.com> Co-authored-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Co-authored-by: Frank <3429989+FrankD412@users.noreply.github.com> Co-authored-by: brb-nv <169953907+brb-nv@users.noreply.github.com> Co-authored-by: Linda <57756729+Linda-Stadter@users.noreply.github.com> Co-authored-by: Shunkangz <182541032+Shunkangz@users.noreply.github.com> Co-authored-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com> Co-authored-by: Tailing Yuan <yuantailing@gmail.com> Co-authored-by: Faraz <58580514+farazkh80@users.noreply.github.com> Co-authored-by: peaceh-nv <103117813+peaceh-nv@users.noreply.github.com> Co-authored-by: ixlmar <206748156+ixlmar@users.noreply.github.com> Co-authored-by: HuiGao-NV <huig@nvidia.com> Co-authored-by: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Co-authored-by: ShiXiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Co-authored-by: Stefan Niebler <82932102+stnie@users.noreply.github.com> Co-authored-by: jthomson04 <jwillthomson19@gmail.com> Co-authored-by: Xianjie Qiao <5410381+qiaoxj07@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Julien Debache <jdebache@nvidia.com> Co-authored-by: Yanchao Lu <yanchaol@nvidia.com> Co-authored-by: Yiteng Niu <6831097+niukuo@users.noreply.github.com> Co-authored-by: Daniel Stokes <40156487+djns99@users.noreply.github.com> Co-authored-by: bhsueh_NV <11360707+byshiue@users.noreply.github.com> Co-authored-by: Bo Li <22713281+bobboli@users.noreply.github.com> Co-authored-by: ChristinaZ <83400082+ChristinaZ@users.noreply.github.com> Co-authored-by: Larry <197874197+LarryXFly@users.noreply.github.com> Co-authored-by: DylanChen-NV <191843203+DylanChen-NV@users.noreply.github.com> Co-authored-by: Daniel Cámpora <961215+dcampora@users.noreply.github.com> Co-authored-by: davidclark-nv <215764518+davidclark-nv@users.noreply.github.com> Co-authored-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com> Co-authored-by: Yechan Kim <161688079+yechank-nvidia@users.noreply.github.com> Co-authored-by: liji-nv <59594262+liji-nv@users.noreply.github.com> Co-authored-by: JieXin Liang <Alcanderian@users.noreply.github.com> Co-authored-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> Co-authored-by: xiweny <13230610+VALLIS-NERIA@users.noreply.github.com> Co-authored-by: Yegor <75512761+Wokzy@users.noreply.github.com> Co-authored-by: Yukun He <23156053+hyukn@users.noreply.github.com> Co-authored-by: Raayan Dhar <58057652+raayandhar@users.noreply.github.com> Co-authored-by: Dom Brown <3886319+DomBrown@users.noreply.github.com> Co-authored-by: Chang Liu <9713593+chang-l@users.noreply.github.com> Co-authored-by: Pamela Peng <179191831+pamelap-nvidia@users.noreply.github.com> Co-authored-by: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com> Co-authored-by: xavier-nvidia <xsimmons@nvidia.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Jhao-Ting Chen <jhaotingc@nvidia.com> Co-authored-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com> Co-authored-by: Erin <14718778+hchings@users.noreply.github.com> Co-authored-by: chenfeiz0326 <chenfeiz@nvidia.com> Co-authored-by: dongxuy04 <78518666+dongxuy04@users.noreply.github.com> Co-authored-by: 2ez4bz <133824995+2ez4bz@users.noreply.github.com> Co-authored-by: Hanjun Cho <46752251+gkswns0531@users.noreply.github.com> Co-authored-by: Ubuntu <ubuntu@ip-10-0-20-146.us-west-2.compute.internal> Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com> Co-authored-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com> Co-authored-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com> Co-authored-by: CarstyYou <186021327+CarstyYou@users.noreply.github.com> Co-authored-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com> Co-authored-by: narutolhy <582909902@qq.com> Co-authored-by: Zhanrui Sun <184402041+ZhanruiSunCh@users.noreply.github.com> Co-authored-by: wili <98001977+wili-65535@users.noreply.github.com> Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com> Co-authored-by: Void <18275976+yilin-void@users.noreply.github.com> Co-authored-by: William Tambellini <wtambellini@sdl.com>
322 lines
11 KiB
C++
322 lines
11 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
|
|
*
|
|
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
* property and proprietary rights in and to this material, related
|
|
* documentation and any modifications thereto. Any use, reproduction,
|
|
* disclosure or distribution of this material and related documentation
|
|
* without an express license agreement from NVIDIA CORPORATION or
|
|
* its affiliates is strictly prohibited.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <cuda.h>
|
|
#include <fmha/alibi_params.h>
|
|
#include <fmha/hopper/tma_types.h>
|
|
#include <fmha/paged_kv_cache.h>
|
|
#include <fused_multihead_attention_utils.h>
|
|
#include <vector>
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace fmha
|
|
{
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Make sure the mask input is padded to 128 x 256 tile size in order to
|
|
// match all Ampere/Hopper kernels.
|
|
static constexpr int FLASH_ATTEN_MASK_M_ALIGNMENT = 128;
|
|
static constexpr int FLASH_ATTEN_MASK_N_ALIGNMENT = 256;
|
|
// The packed mask's MMA tile size is 64 x 64.
|
|
static constexpr int FLASH_ATTEN_MASK_MMA_M = 64;
|
|
static constexpr int FLASH_ATTEN_MASK_MMA_N = 64;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
enum class Attention_mask_type
|
|
{
|
|
// Mask the padded tokens.
|
|
PADDING = 0,
|
|
// Mask the padded tokens and all the tokens that come after in a sequence.
|
|
CAUSAL,
|
|
// Causal mask + attend to the specific sliding window or chunk.
|
|
SLIDING_OR_CHUNKED_CAUSAL,
|
|
// The custom mask input.
|
|
CUSTOM_MASK,
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline std::string mask_type_to_string(Attention_mask_type mask_type)
|
|
{
|
|
switch (mask_type)
|
|
{
|
|
case Attention_mask_type::PADDING: return "padding";
|
|
case Attention_mask_type::CAUSAL: return "causal";
|
|
case Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL: return "sliding_or_chunked_causal";
|
|
case Attention_mask_type::CUSTOM_MASK: return "custom_mask";
|
|
default: assert(false); return "";
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
enum class Attention_input_layout
|
|
{
|
|
// QKV are packed into [B, S, 3, H, D] layout.
|
|
PACKED_QKV = 0,
|
|
// Q has contiguous [B, S, H, D] layout, while KV has contiguous [B, 2, H, S, D] layout.
|
|
CONTIGUOUS_Q_KV,
|
|
// Q has contiguous [B, S, H, D] layout, while paged KV layout are blocks of indices with shape
|
|
// of [B, 2, Blocks_per_Seq], and the indice indicates the block distance to the pool ptr in
|
|
// global memory.
|
|
Q_PAGED_KV,
|
|
// Q has [B, S, H, D] layout,
|
|
// K has [B, S, H_kv, D] layout,
|
|
// V has [B, S, H_kv, Dv] layout,
|
|
SEPARATE_Q_K_V,
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
static inline std::string attention_input_layout_to_string(Attention_input_layout layout)
|
|
{
|
|
switch (layout)
|
|
{
|
|
case Attention_input_layout::PACKED_QKV: return "packed_qkv";
|
|
case Attention_input_layout::CONTIGUOUS_Q_KV: return "contiguous_q_kv";
|
|
case Attention_input_layout::Q_PAGED_KV: return "contiguous_q_paged_kv";
|
|
case Attention_input_layout::SEPARATE_Q_K_V: return "separate_q_k_v";
|
|
default: assert(false); return "";
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace fmha
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace bert
|
|
{
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#if USE_DEMO_BERT_PARAMS
|
|
|
|
// TODO TRT plugins use a different parameter struct taken from the old XMMA fork.
|
|
// Until all cubins in the plugin are replaced with new kernels, we need to conform to that.
|
|
#include <fused_multihead_attention_demo_bert_params.h>
|
|
|
|
#else
|
|
struct Fused_multihead_attention_params_base
|
|
{
|
|
// The QKV matrices.
|
|
void* qkv_ptr;
|
|
// The O matrix (output).
|
|
void* o_ptr;
|
|
|
|
// The stride between rows of O.
|
|
int64_t o_stride_in_bytes;
|
|
|
|
#if defined(STORE_P)
|
|
// The pointer to the P matrix (for debugging).
|
|
void* p_ptr;
|
|
// The stride between rows of the P matrix (for debugging).
|
|
int64_t p_stride_in_bytes;
|
|
#endif // defined(STORE_P)
|
|
|
|
#if defined(STORE_S)
|
|
// The pointer to the S matrix (for debugging).
|
|
void* s_ptr;
|
|
// The stride between rows of the S matrix (for debugging).
|
|
int64_t s_stride_in_bytes;
|
|
#endif // defined(STORE_S)
|
|
|
|
#if defined(DEBUG_HAS_PRINT_BUFFER)
|
|
void* print_ptr;
|
|
#endif
|
|
|
|
// The dimensions.
|
|
int b, h, s, d;
|
|
// The scaling factors for the kernel.
|
|
uint32_t scale_bmm1, scale_softmax, scale_bmm2;
|
|
// The bmm2 scaling factors in the device.
|
|
uint32_t* scale_bmm1_d;
|
|
uint32_t* scale_bmm2_d;
|
|
|
|
// Do we use Niall's trick to avoid I2F/F2I in the INT8 kernel.
|
|
bool enable_i2f_trick;
|
|
|
|
// true: for int8, instead of doing max reduce, use max value encoded in scale factor
|
|
bool use_int8_scale_max = false;
|
|
|
|
// If the kernel is using alibi or not
|
|
bool has_alibi = false;
|
|
fmha::AlibiParams alibi_params;
|
|
|
|
// The number of heads computed by one iteration of the wave.
|
|
int heads_per_wave;
|
|
// Buffers to perform a global sync and a critical section.
|
|
int *counters, *max_barriers, *sum_barriers, *locks;
|
|
// Scratch buffers to finalize softmax.
|
|
float *max_scratch_ptr, *sum_scratch_ptr;
|
|
// Scratch buffer to finalize the output (not needed for FP16).
|
|
int* o_scratch_ptr;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
struct Fused_multihead_attention_params_v1 : Fused_multihead_attention_params_base
|
|
{
|
|
// The stride between rows of the Q, K and V matrices.
|
|
int64_t qkv_stride_in_bytes;
|
|
// The mask to implement drop-out.
|
|
void* packed_mask_ptr;
|
|
|
|
// The stride between matrices of packed mask.
|
|
int64_t packed_mask_stride_in_bytes;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_base
|
|
{
|
|
// The dimension of V. If unset, dv = d.
|
|
int dv = 0;
|
|
|
|
// The input to support any mask patterns.
|
|
void* packed_mask_ptr;
|
|
// The mask input's stride in the N (K-seq) dimension.
|
|
int64_t packed_mask_stride_in_bytes;
|
|
// The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max
|
|
void* softmax_stats_ptr;
|
|
// The stride between rows of softmax_stats_ptr
|
|
int64_t softmax_stats_stride_in_bytes;
|
|
|
|
// array of length b+1 holding prefix sum of actual q sequence lengths.
|
|
int* cu_q_seqlens;
|
|
// array of length b+1 holding prefix sum of actual kv sequence lengths.
|
|
int* cu_kv_seqlens;
|
|
// array of length b+1 holding prefix sum of actual mask sequence lengths.
|
|
// it might not be the same as cu_q_seqlens as the mask seqlens will be padded.
|
|
int* cu_mask_rows;
|
|
|
|
// tma descriptors on device.
|
|
// Either q in packed qkv [B, S, 3, H, D] of separate q layout [B, S, H, D].
|
|
fmha::cudaTmaDesc tma_desc_q;
|
|
// Tma descriptors for packed/contiguous/paged kv cache.
|
|
// Kv in packed qkv layout: [B, S, 3, H, D]
|
|
// Contiguous kv layout: [B, 2, H, S, D].
|
|
// Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D].
|
|
fmha::cudaTmaDesc tma_desc_k;
|
|
fmha::cudaTmaDesc tma_desc_v;
|
|
// Tma descriptor for o
|
|
fmha::cudaTmaDesc tma_desc_o;
|
|
|
|
// Contiguous Q buffer pointer [B, S, H, D].
|
|
void* q_ptr;
|
|
// The separate K matrice.
|
|
void* k_ptr;
|
|
// The separate V matrice.
|
|
void* v_ptr;
|
|
// Contiguous KV buffer pointer [B, 2, H, S, D].
|
|
void* kv_ptr;
|
|
// Paged KV Cache buffer.
|
|
fmha::Kv_block_array paged_kv_cache;
|
|
// Q and KV stride (used by LDGSTS).
|
|
int64_t q_stride_in_bytes;
|
|
int64_t k_stride_in_bytes;
|
|
int64_t v_stride_in_bytes;
|
|
|
|
// Paged KV load.
|
|
int blocks_per_tma_load;
|
|
int blocks_per_tma_load_log2;
|
|
|
|
// M tile id counter for dynamic scheduling
|
|
uint32_t* tile_id_counter_ptr;
|
|
uint32_t num_tiles;
|
|
uint32_t num_tiles_per_head;
|
|
bool use_balanced_scheduling;
|
|
|
|
// In multi-query or grouped-query attention (MQA/GQA), several Q heads are associated with one KV head
|
|
int h_kv = 0;
|
|
// h_q_per_kv is sometimes rematerialized in the kernel by formula h / h_kv to reclaim one register
|
|
int h_q_per_kv = 1;
|
|
|
|
// The number of grouped heads in the seqlen dimension.
|
|
int num_grouped_heads = 1;
|
|
|
|
// Sliding Window Attention
|
|
// Only pay attention to [max(0, query_idx - sliding_window_size), query_idx].
|
|
int sliding_window_size = INT_MAX;
|
|
|
|
// The chunked attention size (<= 0 means no chunked attention).
|
|
int log2_chunked_attention_size = 0;
|
|
|
|
// The softcapping scale (scale * tanh (x / scale)) applied to bmm1 output.
|
|
float softcapping_scale_bmm1 = 0.0f;
|
|
|
|
// is input/output padded
|
|
bool is_s_padded = false;
|
|
|
|
struct SageAttention
|
|
{
|
|
struct Scales
|
|
{
|
|
// this field is only used in bin/fmha.exe, will be omitted in exported cubin
|
|
int block_size;
|
|
// ceil(max_seqlen / block_size)
|
|
int max_nblock;
|
|
// The scale of each block, layout: (B, H, max_nblock)
|
|
float* scales;
|
|
} q, k, v;
|
|
} sage;
|
|
};
|
|
|
|
#endif
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// flags to control kernel choice
|
|
struct Fused_multihead_attention_launch_params
|
|
{
|
|
// flags to control small batch kernel choice
|
|
// true: never unroll
|
|
bool ignore_b1opt = false;
|
|
// true: always unroll
|
|
bool force_unroll = false;
|
|
// use fp32 accumulation
|
|
bool force_fp32_acc = false;
|
|
// the C/32 format
|
|
bool interleaved = false;
|
|
// by default TMA is not used.
|
|
bool use_tma = false;
|
|
// total number of q tokens to set tma descriptors
|
|
int total_q_seqlen = 0;
|
|
// total number of kv tokens to set tma descriptors
|
|
int total_kv_seqlen = 0;
|
|
// if flash attention is used (only FP16)
|
|
bool flash_attention = false;
|
|
// if warp_specialized kernels are used (only SM90 HGMMA + TMA)
|
|
bool warp_specialization = false;
|
|
// granular tiling flash attention kernels
|
|
bool use_granular_tiling = false;
|
|
// causal masking or sliding_or_chunked_causal masking or dense(padding) mask.
|
|
fmha::Attention_mask_type attention_mask_type = fmha::Attention_mask_type::PADDING;
|
|
// the attention input layout.
|
|
fmha::Attention_input_layout attention_input_layout = fmha::Attention_input_layout::PACKED_QKV;
|
|
// enable_attn_logit_softcapping (choose kernels with softcapping_scale_bmm1).
|
|
bool enable_attn_logit_softcapping = false;
|
|
// harward properties to determine how to launch blocks
|
|
int multi_processor_count = 0;
|
|
int device_l2_cache_size = 0;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace bert
|