TensorRT-LLMs/cpp/kernels/fmha_v2/src/fmha/kernel_traits.h
zhhuang-nv 7e135d2ea7
[None][feat] Use Separate QKV Input Layout for Context MLA (#6538)
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
2025-08-19 22:04:48 +08:00

1116 lines
34 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include <fmha/alibi_params.h>
#include <fmha/fragment.h>
#include <fmha/gemm.h>
#include <fmha/gmem_tile_o.h>
#include <fmha/gmem_tile_o_packed.h>
#include <fmha/gmem_tile_ps.h>
#include <fmha/gmem_tile_qkv.h>
#include <fmha/gmem_tile_qkv_packed.h>
#include <fmha/smem_tile_o.h>
#include <fmha/smem_tile_qkv.h>
#include <fmha/smem_tile_v.h>
#include <fmha/softmax.h>
#include <fmha/traits.h>
#include <fmha/utils.h>
namespace fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
// Ada hmma/imma reuses Ampere
template <typename Traits_>
struct Traits_reuse
{
using Traits = Traits_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct Traits_reuse<fmha::Ada_hmma_fp16_traits>
{
using Traits = fmha::Ampere_hmma_fp16_traits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct Traits_reuse<fmha::Ada_hmma_fp32_traits>
{
using Traits = fmha::Ampere_hmma_fp32_traits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct Traits_reuse<fmha::Ada_imma_int8_int32_traits>
{
using Traits = fmha::Ampere_imma_int8_int32_traits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits_p, bool FORCE_EPILOGUE_FP16>
struct Traits_o_adapter
{
using Traits = Traits_p;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool FORCE_EPILOGUE_FP16>
struct Traits_o_adapter<fmha::Volta_hmma_fp16_traits, FORCE_EPILOGUE_FP16>
{
using Traits = fmha::Volta_hmma_fp16_16x16x16_traits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// convert to fp16 before smem_o store
template <>
struct Traits_o_adapter<fmha::Ampere_hmma_fp32_traits, true>
{
using Traits = fmha::Ampere_hmma_fp16_traits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// convert to fp16 before smem_o store
template <>
struct Traits_o_adapter<fmha::Turing_hmma_fp32_traits, true>
{
using Traits = fmha::Turing_hmma_fp16_traits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// convert to bf16 before smem_o store
template <>
struct Traits_o_adapter<fmha::Ampere_hmma_bf16_traits, true>
{
using Traits = fmha::Ampere_hmma_bf16_bf16_traits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// Instruction traits.
typename Traits_,
// The global memory tile for Q, K and V.
template <typename, typename, int, int, int, int, bool, bool, int, bool> class Gmem_tile_q_,
template <typename, typename, int, int, int, int, bool, bool, int, bool> class Gmem_tile_k_,
template <typename, typename, int, int, int, int, bool, bool, int, bool> class Gmem_tile_v_,
// The global memory tile for the output.
template <typename, typename, int> class Gmem_tile_o_,
// Sequence length.
int S,
// The valid hidden dimension.
int VALID_D_,
// The valid hidden dimension of V.
int VALID_DV_,
// The iteration step of the outer loop.
int STEP,
// The number of vertical warps.
int WARPS_M,
// The number of horizontal warps.
int WARPS_N,
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
int CTAS_PER_HEAD_,
// The flags to control the behaviour of LDGs.
uint32_t FLAGS,
// The version of the kernel.
int VERSION_,
// The mask version of the kernel
int MASK_VERSION_,
// Do we use half epilogue for the 2nd GEMM (hmma_fp32)
bool BMM2_FP16_EPILOGUE = true,
// non-positive means disabled
int SAGE_BLOCK_SIZE_Q_ = 0, int SAGE_BLOCK_SIZE_K_ = 0, int SAGE_BLOCK_SIZE_V_ = 0>
struct Kernel_traits_
{
// The instruction traits for the Q*K product.
using Traits_p = typename Traits_reuse<Traits_>::Traits;
// The instruction traits for the P*V product. Hack to change the traits for Volta HMMA.
using Traits_o = typename Traits_o_adapter<Traits_p, false>::Traits;
// The instruction traits for the epilogue of the 2nd GEMM. Always use FP16.
using Traits_e = typename Traits_o_adapter<Traits_p, BMM2_FP16_EPILOGUE>::Traits;
// The padded D dimension
enum
{
VALID_D = VALID_D_
};
enum
{
D = Next_power_of_two<VALID_D>::VALUE
};
enum
{
VALID_DV = VALID_DV_ > 0 ? VALID_DV_ : VALID_D
};
enum
{
DV = Next_power_of_two<VALID_DV>::VALUE
};
enum
{
SAGE_ATTENTION = SAGE_BLOCK_SIZE_Q_ > 0 || SAGE_BLOCK_SIZE_K_ > 0 || SAGE_BLOCK_SIZE_V_ > 0
};
enum
{
SAGE_BLOCK_SIZE_Q = SAGE_BLOCK_SIZE_Q_
};
enum
{
SAGE_BLOCK_SIZE_K = SAGE_BLOCK_SIZE_K_
};
enum
{
SAGE_BLOCK_SIZE_V = SAGE_BLOCK_SIZE_V_
};
// TODO: expose these tiling params to the interface
enum
{
USE_GRANULAR_TILING = (FLAGS & 0x1000) != 0u
}; // TODO ANT: check FLAGS
using Traits_tile_size = Traits_tile_size<(bool) USE_GRANULAR_TILING, STEP, S, D, DV, Traits_o::K_PER_MMA>;
enum
{
CTA_P_TILE_M = Traits_tile_size::CTA_P_TILE_M
};
enum
{
CTA_P_TILE_N = Traits_tile_size::CTA_P_TILE_N
};
enum
{
CTA_P_TILE_K = Traits_tile_size::CTA_P_TILE_K
};
enum
{
CTA_O_TILE_M = Traits_tile_size::CTA_O_TILE_M
};
enum
{
CTA_O_TILE_N = Traits_tile_size::CTA_O_TILE_N
};
enum
{
CTA_O_TILE_K = Traits_tile_size::CTA_O_TILE_K
};
// Do we need to reload Q due to splitting the D ?
enum
{
RELOAD_Q = static_cast<int>(CTA_P_TILE_K) != static_cast<int>(D)
};
// The CTA description for the 1st GEMM.
using Cta_tile_p = typename Traits_p::template Cta_tile_extd<CTA_P_TILE_M, CTA_P_TILE_N, CTA_P_TILE_K, S, VALID_D,
WARPS_M, WARPS_N, 1>;
// The CTA description for the 2nd GEMM.
using Cta_tile_o = typename Traits_o::template Cta_tile_extd<CTA_O_TILE_M, CTA_O_TILE_N, CTA_O_TILE_K, VALID_DV, S,
WARPS_M, 1, WARPS_N>;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = typename Traits_p::template Mma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = typename Traits_o::template Mma_tile<Cta_tile_o>;
// Compute the total BMM2_MMAS_K (might not the same as Mma_tile_o::MMAS_K if the granular tiling is used).
static_assert(S % CTA_O_TILE_K == 0, "");
enum
{
TOTAL_BMM2_MMAS_K = Mma_tile_o::MMAS_K * (S / CTA_O_TILE_K)
};
// Constraints on the K dimension.
static_assert(Mma_tile_p::K_PER_MMA <= static_cast<int>(D));
static_assert(Mma_tile_o::K_PER_MMA <= S);
// The version.
enum
{
VERSION = VERSION_
};
// The mask version: padding (2), causal (3), sliding_window_causal (4), custom_mask (5).
enum
{
MASK_VERSION = MASK_VERSION_
};
// Whether use causal mask or not.
enum
{
CAUSAL_MASK = MASK_VERSION_ == 3 || MASK_VERSION_ == 4
};
// Whether use the sliding window attention or not.
enum
{
SLIDING_WINDOW_ATTENTION = MASK_VERSION_ == 4
};
// Whether use the custom mask or not.
enum
{
CUSTOM_MASK = MASK_VERSION_ == 5
};
// Do we use LDGSTS for Q, K or V.
enum
{
USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u
};
enum
{
USE_LDGSTS_K = (FLAGS & 0x2u) != 0u
};
enum
{
USE_LDGSTS_V = (FLAGS & 0x4u) != 0u
};
// Do we use one buffer for K and V.
enum
{
SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u
};
// Do we use the scale max trick.
enum
{
USE_SCALE_MAX = (FLAGS & 0x10u) != 0u
};
// Are heads in QKV interleaved, i.e. total x h x 3 x d or total x 3 x h x d.
enum
{
HEADS_INTERLEAVED = (FLAGS & 0x20u) == 0u
};
// Keep full K matrix in registers.
enum
{
K_IN_REGS = (FLAGS & 0x40) == 0u
};
// Do we use only 2 fragments or full fragments for frag_q/k (only used by flash attention)
enum
{
LIMIT_QK_FRAGMENTS = ((FLAGS & 0x80u) != 0u && !SHARE_SMEM_FOR_K_AND_V)
};
// Do we use only 2 fragments or full fragments for frag_v (only used by flash attention)
enum
{
LIMIT_V_FRAGMENTS = ((FLAGS & 0x100u) != 0u && !SHARE_SMEM_FOR_K_AND_V)
};
// Limiting QK fragments implies SMEM_K has to reside in SMEM
static_assert(!(LIMIT_QK_FRAGMENTS && SHARE_SMEM_FOR_K_AND_V), "");
// Indicates that kernel does not loop over Q tensor, usually kernel name has _nl suffix
enum
{
NO_LOOP = (FLAGS & 0x200u) != 0u
};
// Are sequences in one batch interleaved. i.e. s x b x ..., or b x s x ...
enum
{
SEQUENCES_INTERLEAVED = (FLAGS & 0x400) != 0u
};
// Use BMM1 softcapping scale or not.
enum
{
ENABLE_BMM1_SOFTCAPPING_SCALE = (FLAGS & 0x800) != 0u
};
// Use MTP (multi-token prediction for MLA kernels) or not.
enum
{
IS_MTP = (FLAGS & 0x2000) != 0u
};
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
enum
{
CTAS_PER_HEAD = CTAS_PER_HEAD_
};
// The number of shared memory buffers to build a software pipeline for Q, K and V.
enum
{
BUFFERS_PER_TILE_SMEM_Q = (USE_GRANULAR_TILING && D > 64) || (USE_LDGSTS_Q && !NO_LOOP) ? 2 : 1
};
enum
{
BUFFERS_PER_TILE_SMEM_K = USE_GRANULAR_TILING ? 2 : 1
};
enum
{
BUFFERS_PER_TILE_SMEM_V = USE_GRANULAR_TILING ? 2 : 1
};
// The global memory tile to load Q.
using Gmem_tile_q = Gmem_tile_q_<Traits_p, Cta_tile_p, Traits_p::BITS_PER_ELEMENT_A, CTA_P_TILE_M, CTA_P_TILE_K,
VALID_D, USE_LDGSTS_Q, HEADS_INTERLEAVED,
3, // NUM_MATS
SLIDING_WINDOW_ATTENTION // Not used.
>;
// The shared memory tile to swizzle Q.
using Smem_tile_q
= fmha::Smem_tile_a<Traits_p, Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, BUFFERS_PER_TILE_SMEM_Q>;
// The global memory tile to load K.
using Gmem_tile_k = Gmem_tile_k_<Traits_p, Cta_tile_p, Traits_p::BITS_PER_ELEMENT_B, CTA_P_TILE_N, CTA_P_TILE_K,
VALID_D, USE_LDGSTS_K, HEADS_INTERLEAVED,
3, // NUM_MATS
SLIDING_WINDOW_ATTENTION>;
// The shared memory tile to swizzle K.
using Smem_tile_k
= fmha::Smem_tile_b<Traits_p, Cta_tile_p, fmha::Col, Gmem_tile_k::BYTES_PER_LDG, BUFFERS_PER_TILE_SMEM_K>;
// The global memory tile to load V.
using Gmem_tile_v = Gmem_tile_v_<Traits_o, Cta_tile_o, Traits_o::BITS_PER_ELEMENT_B, CTA_O_TILE_K, CTA_O_TILE_N,
VALID_DV, USE_LDGSTS_V, HEADS_INTERLEAVED,
3, // NUM_MATS
SLIDING_WINDOW_ATTENTION>;
// The shared memory tile to swizzle V.
using Smem_tile_v = fmha::Smem_tile_v<Traits_o, Cta_tile_o, BUFFERS_PER_TILE_SMEM_V>;
// The global memory tile to store O.
using Gmem_tile_o = Gmem_tile_o_<Traits_e, Cta_tile_o, CTAS_PER_HEAD>;
// The shared memory tile for O.
using Smem_tile_o = fmha::Smem_tile_o<Traits_e, Cta_tile_o>;
// Make sure the number of threads match.
static_assert((int) Gmem_tile_o::THREADS_PER_ROW == (int) Smem_tile_o::THREADS_PER_ROW, "");
// The number of threads.
enum
{
THREADS = Cta_tile_p::THREADS_PER_CTA
};
// Make sure the number of threads matches both CTAs.
static_assert((int) THREADS == (int) Cta_tile_o::THREADS_PER_CTA, "");
// The amount of shared memory needed to load Q and K.
enum
{
BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE
};
// The extra amount of shared memory needed to load V.
enum
{
BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE
};
// The amount of shared memory needed for Q, K and V..
enum
{
BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V
};
// The amount of shared memory needed to load/store O.
enum
{
BYTES_PER_SMEM_O = Smem_tile_o::BYTES_PER_TILE
};
// The amount of shared memory needed to load Q and store O.
enum
{
BYTES_PER_SMEM_QO = NO_LOOP ? Smem_tile_o::BYTES_PER_TILE : Smem_tile_q::BYTES_PER_TILE + BYTES_PER_SMEM_O
};
// The amount of shared memory needed for Q, K, V and O.
enum
{
BYTES_PER_SMEM = fmha::Max<BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO>::VALUE
};
// Make sure we have enough shared memory.
static_assert((NO_LOOP ? Smem_tile_o::BYTES_PER_TILE : Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE)
<= BYTES_PER_SMEM,
"");
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// Instruction traits.
typename Traits_,
// The global memory tile for Q, K and V.
template <typename, typename, int, int, int, bool, bool, int> class Gmem_tile_q_,
// The global memory tile for the output.
template <typename, typename, int> class Gmem_tile_o_,
// Sequence length for K/V.
int S_KV,
// The hidden dimension.
int D,
// The iteration step of the outer loop.
int STEP,
// The number of vertical warps.
int WARPS_M,
// The number of horizontal warps.
int WARPS_N,
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
int CTAS_PER_HEAD_,
// The flags to control the behaviour of LDGs.
uint32_t FLAGS,
// The version of the kernel.
int VERSION_,
// Do we use half epilogue for the 2nd GEMM (hmma_fp32)
bool BMM2_FP16_EPILOGUE = true>
struct Kernel_traits_fmhca_
{
// The instruction traits for the Q*K product.
using Traits_p = typename Traits_reuse<Traits_>::Traits;
// The instruction traits for the P*V product. Hack to change the traits for Volta HMMA.
using Traits_o = typename Traits_o_adapter<Traits_p, false>::Traits;
// The instruction traits for the epilogue of the 2nd GEMM. Always use FP16.
using Traits_e = typename Traits_o_adapter<Traits_p, BMM2_FP16_EPILOGUE>::Traits;
// The CTA description for the 1st GEMM.
using Cta_tile_p = typename Traits_p::template Cta_tile_extd<STEP, S_KV, D, S_KV, D, WARPS_M, WARPS_N, 1>;
// The CTA description for the 2nd GEMM.
using Cta_tile_o = typename Traits_o::template Cta_tile_extd<STEP, D, S_KV, D, S_KV, WARPS_M, 1, WARPS_N>;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = typename Traits_p::template Mma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = typename Traits_o::template Mma_tile<Cta_tile_o>;
// Constraints on the K dimension.
static_assert(Mma_tile_p::K_PER_MMA <= D, "");
static_assert(Mma_tile_o::K_PER_MMA <= S_KV, "");
// The version.
enum
{
VERSION = VERSION_
};
// The mask version
enum
{
MASK_VERSION = VERSION_
};
// Whether use causal mask or not.
enum
{
CAUSAL_MASK = MASK_VERSION >= 3
};
// Whether use the sliding window attention or not.
enum
{
SLIDING_WINDOW_ATTENTION = MASK_VERSION == 4
};
// Do we use LDGSTS for Q, K or V.
enum
{
USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u
};
enum
{
USE_LDGSTS_K = (FLAGS & 0x2u) != 0u
};
enum
{
USE_LDGSTS_V = (FLAGS & 0x4u) != 0u
};
// Do we use one buffer for K and V.
enum
{
SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u
};
// Do we use the scale max trick.
enum
{
USE_SCALE_MAX = (FLAGS & 0x10u) != 0u
};
// Are heads in QKV interleaved, i.e. total x h x 3 x d or total x 3 x h x d.
enum
{
HEADS_INTERLEAVED = (FLAGS & 0x20u) == 0u
};
// Keep full K matrix in registers.
enum
{
K_IN_REGS = (FLAGS & 0x40) == 0u
};
// Use BMM1 softcapping scale or not.
enum
{
ENABLE_BMM1_SOFTCAPPING_SCALE = 0
};
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
enum
{
CTAS_PER_HEAD = CTAS_PER_HEAD_
};
// The global memory tile to load Q.
using Gmem_tile_q
= Gmem_tile_q_<Traits_p, Cta_tile_p, Traits_p::BITS_PER_ELEMENT_A, STEP, D, USE_LDGSTS_Q, HEADS_INTERLEAVED,
1 // NUM_MATS
>;
// The shared memory tile to swizzle Q.
using Smem_tile_q
= fmha::Smem_tile_a<Traits_p, Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, USE_LDGSTS_Q ? 2 : 1>;
// The global memory tile to load K.
using Gmem_tile_k
= Gmem_tile_q_<Traits_p, Cta_tile_p, Traits_p::BITS_PER_ELEMENT_B, S_KV, D, USE_LDGSTS_K, HEADS_INTERLEAVED,
2 // NUM_MATS
>;
// The shared memory tile to swizzle K.
using Smem_tile_k = fmha::Smem_tile_b<Traits_p, Cta_tile_p, fmha::Col>;
// The global memory tile to load V.
using Gmem_tile_v
= Gmem_tile_q_<Traits_o, Cta_tile_o, Traits_o::BITS_PER_ELEMENT_B, S_KV, D, USE_LDGSTS_V, HEADS_INTERLEAVED,
2 // NUM_MATS
>;
// The shared memory tile to swizzle V.
using Smem_tile_v = fmha::Smem_tile_v<Traits_o, Cta_tile_o>;
// The global memory tile to store O.
using Gmem_tile_o = Gmem_tile_o_<Traits_e, Cta_tile_o, CTAS_PER_HEAD>;
// The shared memory tile for O.
using Smem_tile_o = fmha::Smem_tile_o<Traits_e, Cta_tile_o>;
// Make sure the number of threads match.
static_assert((int) Gmem_tile_o::THREADS_PER_ROW == (int) Smem_tile_o::THREADS_PER_ROW, "");
// The number of threads.
enum
{
THREADS = Cta_tile_p::THREADS_PER_CTA
};
// Make sure the number of threads matches both CTAs.
static_assert((int) THREADS == (int) Cta_tile_o::THREADS_PER_CTA, "");
// The amount of shared memory needed to load Q and K.
enum
{
BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE
};
// The extra amount of shared memory needed to load V.
enum
{
BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE
};
// The amount of shared memory needed for Q, K and V..
enum
{
BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V
};
// The amount of shared memory needed to load Q and store O.
enum
{
BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE
};
// The amount of shared memory needed for Q, K, V and O.
enum
{
BYTES_PER_SMEM = fmha::Max<BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO>::VALUE
};
// Make sure we have enough shared memory.
static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, "");
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits_,
// The sequence length.
int S,
// The hidden size per head.
int VALID_D,
// The number of timesteps per iteration of the main loop.
int STEP,
// The number of vertical warps.
int WARPS_M,
// The number of horizontal warps.
int WARPS_N,
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
int CTAS_PER_HEAD_,
// The flags.
uint32_t FLAGS = 0x8,
// The mask version of the kernel
int MASK_VERSION_ = 2>
struct Kernel_traits_interleaved_v2_
{
// The instruction traits.
using Traits = typename Traits_reuse<Traits_>::Traits;
using Traits_p = Traits;
using Traits_o = Traits;
// The padded D dimension
enum
{
D = Next_power_of_two<VALID_D>::VALUE
};
// The CTA description for the 1st GEMM.
using Cta_tile_p = typename Traits::template Cta_tile_extd<STEP, S, D, S, VALID_D, WARPS_M, WARPS_N, 1>;
// The CTA description for the 2nd GEMM.
using Cta_tile_o = typename Traits::template Cta_tile_extd<STEP, D, S, VALID_D, S, WARPS_M, 1, WARPS_N>;
// The version.
enum
{
VERSION = 2
};
enum
{
MASK_VERSION = MASK_VERSION_
};
// Whether use causal mask or not.
enum
{
CAUSAL_MASK = MASK_VERSION_ >= 3
};
// Whether use the sliding window attention or not.
enum
{
SLIDING_WINDOW_ATTENTION = MASK_VERSION_ == 4
};
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
enum
{
CTAS_PER_HEAD = CTAS_PER_HEAD_
};
// Do we use LDGSTS for Q, K or V.
enum
{
USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u
};
enum
{
USE_LDGSTS_K = (FLAGS & 0x2u) != 0u
};
enum
{
USE_LDGSTS_V = (FLAGS & 0x4u) != 0u
};
// Do we use one buffer for K and V.
enum
{
SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u
};
// Do we use the scale max trick.
enum
{
USE_SCALE_MAX = (FLAGS & 16) != 0u
};
// The global memory tile to load Q.
using Gmem_tile_q
= fmha::v2::Gmem_tile_qkv_interleaved<Traits, Cta_tile_p, Traits::BITS_PER_ELEMENT_A, STEP, D, USE_LDGSTS_Q>;
// The shared memory tile to swizzle Q.
using Smem_tile_q = fmha::Smem_tile_qk_interleaved_a<Traits, Cta_tile_p>;
// The global memory tile to load K.
using Gmem_tile_k
= fmha::v2::Gmem_tile_qkv_interleaved<Traits, Cta_tile_p, Traits::BITS_PER_ELEMENT_B, S, D, USE_LDGSTS_K>;
// The shared memory tile to swizzle K.
using Smem_tile_k = fmha::Smem_tile_qk_interleaved_b<Traits, Cta_tile_p>;
// The global memory tile to load V.
using Gmem_tile_v
= fmha::v2::Gmem_tile_qkv_interleaved<Traits, Cta_tile_o, Traits::BITS_PER_ELEMENT_B, S, D, USE_LDGSTS_V>;
// The shared memory tile to swizzle V.
using Smem_tile_v = fmha::Smem_tile_v_interleaved_b<Traits, Cta_tile_o>;
// The global memory tile to store O.
using Gmem_tile_o = fmha::v2::Imma_gmem_tile_o_interleaved<Traits, Cta_tile_o, CTAS_PER_HEAD>;
// The shared memory tile for O.
using Smem_tile_o = fmha::Smem_tile_o_interleaved<Traits, Cta_tile_o>;
// Make sure the number of threads match.
static_assert((int) Gmem_tile_o::THREADS_PER_ROW == (int) Smem_tile_o::THREADS_PER_ROW, "");
// The number of threads.
enum
{
THREADS = Cta_tile_p::THREADS_PER_CTA
};
// Make sure the number of threads matches both CTAs.
static_assert((int) THREADS == (int) Cta_tile_o::THREADS_PER_CTA, "");
// The amount of shared memory needed to load Q and K.
enum
{
BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE
};
// The extra amount of shared memory needed to load V.
enum
{
BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE
};
// The amount of shared memory needed for Q, K and V..
enum
{
BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V
};
// The amount of shared memory needed to load Q and store O.
enum
{
BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE
};
// The amount of shared memory needed for Q, K, V and O.
enum
{
BYTES_PER_SMEM = fmha::Max<BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO>::VALUE
};
// Make sure we have enough shared memory.
static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, "");
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits_,
// The sequence length.
int S,
// The hidden size per head.
int VALID_D,
// The number of timesteps per iteration of the main loop.
int STEP,
// The number of vertical warps.
int WARPS_M,
// The number of horizontal warps.
int WARPS_N,
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
int CTAS_PER_HEAD_,
// The flags.
uint32_t FLAGS = 0x8,
// The mask version of the kernel
int MASK_VERSION_ = 2>
using Kernel_traits_interleaved_v2
= Kernel_traits_interleaved_v2_<Traits_, S, VALID_D, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD_, FLAGS, MASK_VERSION_>;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The sequence length.
int S,
// The hidden size per head.
int D,
// The number of timesteps per iteration of the main loop.
int STEP,
// The number of vertical warps.
int WARPS_M,
// The number of horizontal warps.
int WARPS_N,
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
int CTAS_PER_HEAD,
// The flags.
uint32_t FLAGS = 0x8>
using Kernel_traits_v1 = Kernel_traits_<Traits, fmha::v1::Gmem_tile_qkv, fmha::v1::Gmem_tile_qkv,
fmha::v1::Gmem_tile_qkv, fmha::v1::Gmem_tile_o, S, D, 0, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS, 1, 1>;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The sequence length.
int S,
// The hidden size per head.
int D,
// The number of timesteps per iteration of the main loop.
int STEP,
// The number of vertical warps.
int WARPS_M,
// The number of horizontal warps.
int WARPS_N,
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
int CTAS_PER_HEAD,
// The flags.
uint32_t FLAGS = 0x8>
using Kernel_traits_v1_causal_mask = Kernel_traits_<Traits, fmha::v1::Gmem_tile_qkv, fmha::v1::Gmem_tile_qkv,
fmha::v1::Gmem_tile_qkv, fmha::v1::Gmem_tile_o, S, D, 0, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS,
1, // VERSION_
3>; // MASK_VERSION_
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Traits_, typename OutputType>
struct Gmem_tile_o_dispatcher
{
template <typename Traits, typename Cta_tile, int CTAS_PER_HEAD>
using Gmem_tile_o = fmha::v2::Gmem_tile_o<Traits, Cta_tile, CTAS_PER_HEAD>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct Gmem_tile_o_dispatcher<fmha::Ada_qmma_e4m3_fp32_traits, uint16_t>
{
template <typename Traits, typename Cta_tile, int CTAS_PER_HEAD>
using Gmem_tile_o = fmha::v2::Gmem_tile_o_uint16<Traits, Cta_tile, CTAS_PER_HEAD>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct Gmem_tile_o_dispatcher<fmha::Ada_qmma_e4m3_fp32_traits, nv_bfloat16>
{
template <typename Traits, typename Cta_tile, int CTAS_PER_HEAD>
using Gmem_tile_o = fmha::v2::Gmem_tile_o_bfloat16<Traits, Cta_tile, CTAS_PER_HEAD>;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The sequence length.
int S,
// The hidden size per head.
int D,
// The hidden dimension of V.
int DV,
// The number of timesteps per iteration of the main loop.
int STEP,
// The number of vertical warps.
int WARPS_M,
// The number of horizontal warps.
int WARPS_N,
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
int CTAS_PER_HEAD,
// The flags.
uint32_t FLAGS = 0x8,
// The attention mask version (see src/mask.h).
int MASK_VERSION = 2,
// Do we use half epilogue for the 2nd GEMM (hmma_fp32)
bool BMM2_FP16_EPILOGUE = true,
// The output type.
typename OutputType = typename Traits::A_type,
// The sage attention block size for Q, K and V
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0>
using Kernel_traits_v2 = Kernel_traits_<Traits, fmha::v2::Gmem_tile_qkv, fmha::v2::Gmem_tile_qkv,
fmha::v2::Gmem_tile_qkv, Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, DV, STEP, WARPS_M, WARPS_N,
CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The sequence length.
int S,
// The hidden size per head.
int D,
// The hidden dimension of V.
int DV,
// The number of timesteps per iteration of the main loop.
int STEP,
// The number of vertical warps.
int WARPS_M,
// The number of horizontal warps.
int WARPS_N,
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
int CTAS_PER_HEAD,
// The flags.
uint32_t FLAGS = 0x8,
// The attention mask version (see src/mask.h).
int MASK_VERSION = 2,
// Do we use half epilogue for the 2nd GEMM (hmma_fp32)
bool BMM2_FP16_EPILOGUE = true,
// The output type.
typename OutputType = typename Traits::A_type,
// The sage attention block size for Q, K and V
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0>
using Kernel_traits_v2_q_k_v
= Kernel_traits_<Traits, fmha::v2::Gmem_tile_q_k_v, fmha::v2::Gmem_tile_q_k_v, fmha::v2::Gmem_tile_q_k_v,
Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, DV, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS,
2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The sequence length.
int S,
// The hidden size per head.
int D,
// The hidden dimension of V.
int DV,
// The number of timesteps per iteration of the main loop.
int STEP,
// The number of vertical warps.
int WARPS_M,
// The number of horizontal warps.
int WARPS_N,
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
int CTAS_PER_HEAD,
// The flags.
uint32_t FLAGS = 0x8,
// The attention mask version (see src/mask.h).
int MASK_VERSION = 2,
// Do we use half epilogue for the 2nd GEMM (hmma_fp32)
bool BMM2_FP16_EPILOGUE = true,
// The output type.
typename OutputType = typename Traits::A_type,
// The sage attention block size for Q, K and V
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0>
using Kernel_traits_v2_paged_kv_cache
= Kernel_traits_<Traits, fmha::v2::Gmem_tile_q_k_v, fmha::v2::Gmem_tile_paged_kv, fmha::v2::Gmem_tile_paged_kv,
Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, DV, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS,
2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The sequence length.
int S,
// The hidden size per head.
int D,
// The hidden dimension of V.
int DV,
// The number of timesteps per iteration of the main loop.
int STEP,
// The number of vertical warps.
int WARPS_M,
// The number of horizontal warps.
int WARPS_N,
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
int CTAS_PER_HEAD,
// The flags.
uint32_t FLAGS = 0x8,
// The attention mask version (see src/mask.h).
int MASK_VERSION = 2,
// Do we use half epilogue for the 2nd GEMM (hmma_fp32)
bool BMM2_FP16_EPILOGUE = true,
// The output type.
typename OutputType = typename Traits::A_type,
// The sage attention block size for Q, K and V
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0>
using Kernel_traits_v2_contiguous_kv_cache = Kernel_traits_<Traits, fmha::v2::Gmem_tile_q_k_v,
fmha::v2::Gmem_tile_contiguous_kv, fmha::v2::Gmem_tile_contiguous_kv,
Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, 0, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS, 2,
MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The sequence length for K and V.
int S_KV,
// The hidden size per head.
int D,
// The number of timesteps per iteration of the main loop.
int STEP,
// The number of vertical warps.
int WARPS_M,
// The number of horizontal warps.
int WARPS_N,
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
int CTAS_PER_HEAD,
// The flags.
uint32_t FLAGS = 0x8>
using Kernel_traits_fmhca = Kernel_traits_fmhca_<Traits, fmha::v2::Gmem_tile_q_kv, fmha::v2::Gmem_tile_o, S_KV, D, STEP,
WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS, 2>;
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha