TensorRT-LLMs/cpp/kernels/fmha_v2/src/fused_multihead_attention.h
Bo Li a66eeab537
[TRTLLM-9805][feat] Skip Softmax Attention. (#9821)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Co-authored-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
2025-12-21 02:52:42 -05:00

342 lines
12 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda.h>
#include <fmha/alibi_params.h>
#include <fmha/hopper/tma_types.h>
#include <fmha/paged_kv_cache.h>
#include <fused_multihead_attention_utils.h>
#include <vector>
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
// Make sure the mask input is padded to 128 x 256 tile size in order to
// match all Ampere/Hopper kernels.
static constexpr int FLASH_ATTEN_MASK_M_ALIGNMENT = 128;
static constexpr int FLASH_ATTEN_MASK_N_ALIGNMENT = 256;
// The packed mask's MMA tile size is 64 x 64.
static constexpr int FLASH_ATTEN_MASK_MMA_M = 64;
static constexpr int FLASH_ATTEN_MASK_MMA_N = 64;
////////////////////////////////////////////////////////////////////////////////////////////////////
enum class Attention_mask_type
{
// Mask the padded tokens.
PADDING = 0,
// Mask the padded tokens and all the tokens that come after in a sequence.
CAUSAL,
// Causal mask + attend to the specific sliding window or chunk.
SLIDING_OR_CHUNKED_CAUSAL,
// The custom mask input.
CUSTOM_MASK,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline std::string mask_type_to_string(Attention_mask_type mask_type)
{
switch (mask_type)
{
case Attention_mask_type::PADDING: return "padding";
case Attention_mask_type::CAUSAL: return "causal";
case Attention_mask_type::SLIDING_OR_CHUNKED_CAUSAL: return "sliding_or_chunked_causal";
case Attention_mask_type::CUSTOM_MASK: return "custom_mask";
default: assert(false); return "";
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
enum class Attention_input_layout
{
// QKV are packed into [B, S, 3, H, D] layout.
PACKED_QKV = 0,
// Q has contiguous [B, S, H, D] layout, while KV has contiguous [B, 2, H, S, D] layout.
CONTIGUOUS_Q_KV,
// Q has contiguous [B, S, H, D] layout, while paged KV layout are blocks of indices with shape
// of [B, 2, Blocks_per_Seq], and the indice indicates the block distance to the pool ptr in
// global memory.
Q_PAGED_KV,
// Q has [B, S, H, D] layout,
// K has [B, S, H_kv, D] layout,
// V has [B, S, H_kv, Dv] layout,
SEPARATE_Q_K_V,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline std::string attention_input_layout_to_string(Attention_input_layout layout)
{
switch (layout)
{
case Attention_input_layout::PACKED_QKV: return "packed_qkv";
case Attention_input_layout::CONTIGUOUS_Q_KV: return "contiguous_q_kv";
case Attention_input_layout::Q_PAGED_KV: return "contiguous_q_paged_kv";
case Attention_input_layout::SEPARATE_Q_K_V: return "separate_q_k_v";
default: assert(false); return "";
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace bert
{
////////////////////////////////////////////////////////////////////////////////////////////////////
#if USE_DEMO_BERT_PARAMS
// TODO TRT plugins use a different parameter struct taken from the old XMMA fork.
// Until all cubins in the plugin are replaced with new kernels, we need to conform to that.
#include <fused_multihead_attention_demo_bert_params.h>
#else
struct Fused_multihead_attention_params_base
{
// The QKV matrices.
void* qkv_ptr;
// The O matrix (output).
void* o_ptr;
// The stride between rows of O.
int64_t o_stride_in_bytes;
#if defined(STORE_P)
// The pointer to the P matrix (for debugging).
void* p_ptr;
// The stride between rows of the P matrix (for debugging).
int64_t p_stride_in_bytes;
#endif // defined(STORE_P)
#if defined(STORE_S)
// The pointer to the S matrix (for debugging).
void* s_ptr;
// The stride between rows of the S matrix (for debugging).
int64_t s_stride_in_bytes;
#endif // defined(STORE_S)
#if defined(DEBUG_HAS_PRINT_BUFFER)
void* print_ptr;
#endif
// The dimensions.
int b, h, s, d;
// The scaling factors for the kernel.
uint32_t scale_bmm1, scale_softmax, scale_bmm2;
// The bmm2 scaling factors in the device.
uint32_t* scale_bmm1_d;
uint32_t* scale_bmm2_d;
// Do we use Niall's trick to avoid I2F/F2I in the INT8 kernel.
bool enable_i2f_trick;
// true: for int8, instead of doing max reduce, use max value encoded in scale factor
bool use_int8_scale_max = false;
// If the kernel is using alibi or not
bool has_alibi = false;
fmha::AlibiParams alibi_params;
// The number of heads computed by one iteration of the wave.
int heads_per_wave;
// Buffers to perform a global sync and a critical section.
int *counters, *max_barriers, *sum_barriers, *locks;
// Scratch buffers to finalize softmax.
float *max_scratch_ptr, *sum_scratch_ptr;
// Scratch buffer to finalize the output (not needed for FP16).
int* o_scratch_ptr;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Fused_multihead_attention_params_v1 : Fused_multihead_attention_params_base
{
// The stride between rows of the Q, K and V matrices.
int64_t qkv_stride_in_bytes;
// The mask to implement drop-out.
void* packed_mask_ptr;
// The stride between matrices of packed mask.
int64_t packed_mask_stride_in_bytes;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_base
{
// The dimension of V. If unset, dv = d.
int dv = 0;
// The input to support any mask patterns.
void* packed_mask_ptr;
// The mask input's stride in the N (K-seq) dimension.
int64_t packed_mask_stride_in_bytes;
// The Softmax stats vector of layout [total_tokens_q, h, 2], including softmax_max and softmax_sum
void* softmax_stats_ptr;
// The stride between rows of softmax_stats_ptr, default: h * sizeof(float2)
int64_t softmax_stats_stride_in_bytes;
// The attention sinks (per head).
float* attention_sinks;
// array of length b+1 holding prefix sum of actual q sequence lengths.
int* cu_q_seqlens;
// array of length b+1 holding prefix sum of actual kv sequence lengths.
int* cu_kv_seqlens;
// array of length b+1 holding prefix sum of actual mask sequence lengths.
// it might not be the same as cu_q_seqlens as the mask seqlens will be padded.
int* cu_mask_rows;
// tma descriptors on device.
// Either q in packed qkv [B, S, 3, H, D] of separate q layout [B, S, H, D].
fmha::cudaTmaDesc tma_desc_q;
// Tma descriptors for packed/contiguous/paged kv cache.
// Kv in packed qkv layout: [B, S, 3, H, D]
// Contiguous kv layout: [B, 2, H, S, D].
// Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D].
fmha::cudaTmaDesc tma_desc_k;
fmha::cudaTmaDesc tma_desc_v;
// Tma descriptor for o
fmha::cudaTmaDesc tma_desc_o;
// Contiguous Q buffer pointer [B, S, H, D].
void* q_ptr;
// The separate K matrice.
void* k_ptr;
// The separate V matrice.
void* v_ptr;
// Contiguous KV buffer pointer [B, 2, H, S, D].
void* kv_ptr;
// Paged KV Cache buffer.
fmha::Kv_block_array paged_kv_cache;
// Q and KV stride (used by LDGSTS).
int64_t q_stride_in_bytes;
int64_t k_stride_in_bytes;
int64_t v_stride_in_bytes;
// Paged KV load.
int blocks_per_tma_load;
int blocks_per_tma_load_log2;
// M tile id counter for dynamic scheduling
uint32_t* tile_id_counter_ptr;
uint32_t num_tiles;
uint32_t num_tiles_per_head;
bool use_balanced_scheduling;
// In multi-query or grouped-query attention (MQA/GQA), several Q heads are associated with one KV head
int h_kv = 0;
// h_q_per_kv is sometimes rematerialized in the kernel by formula h / h_kv to reclaim one register
int h_q_per_kv = 1;
// The number of grouped heads in the seqlen dimension.
int num_grouped_heads = 1;
// Sliding Window Attention
// Only pay attention to [max(0, query_idx - sliding_window_size), query_idx].
int sliding_window_size = INT_MAX;
// The chunked attention size (<= 0 means no chunked attention).
int log2_chunked_attention_size = 0;
// The softcapping scale (scale * tanh (x / scale)) applied to bmm1 output.
float softcapping_scale_bmm1 = 0.0f;
// is input/output padded
bool is_s_padded = false;
struct SageAttention
{
struct Scales
{
// this field is only used in bin/fmha.exe, will be omitted in exported cubin
int block_size;
// ceil(max_seqlen / block_size)
int max_nblock;
// The scale of each block, layout: (B, H, max_nblock)
float* scales;
} q, k, v;
} sage;
// Skip softmax when exp(local_max - global_max) < skip_softmax_threshold_scale_factor / seqlen.
// A positive value means skip-softmax is enabled.
float skip_softmax_threshold_scale_factor = 0;
#ifdef SKIP_SOFTMAX_STAT
// Statistics of skip-softmax, pointers of device memory for output
uint32_t* skip_softmax_total_blocks;
uint32_t* skip_softmax_skipped_blocks;
#endif
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// flags to control kernel choice
struct Fused_multihead_attention_launch_params
{
// flags to control small batch kernel choice
// true: never unroll
bool ignore_b1opt = false;
// true: always unroll
bool force_unroll = false;
// use fp32 accumulation
bool force_fp32_acc = false;
// the C/32 format
bool interleaved = false;
// by default TMA is not used.
bool use_tma = false;
// total number of q tokens to set tma descriptors
int total_q_seqlen = 0;
// total number of kv tokens to set tma descriptors
int total_kv_seqlen = 0;
// if flash attention is used (only FP16)
bool flash_attention = false;
// if warp_specialized kernels are used (only SM90 HGMMA + TMA)
bool warp_specialization = false;
// granular tiling flash attention kernels
bool use_granular_tiling = false;
// causal masking or sliding_or_chunked_causal masking or dense(padding) mask.
fmha::Attention_mask_type attention_mask_type = fmha::Attention_mask_type::PADDING;
// the attention input layout.
fmha::Attention_input_layout attention_input_layout = fmha::Attention_input_layout::PACKED_QKV;
// enable_attn_logit_softcapping (choose kernels with softcapping_scale_bmm1).
bool enable_attn_logit_softcapping = false;
// harward properties to determine how to launch blocks
int multi_processor_count = 0;
int device_l2_cache_size = 0;
// skip softmax attention
bool enable_skip_softmax = false;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace bert