/* * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// // Ada hmma/imma reuses Ampere template struct Traits_reuse { using Traits = Traits_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Traits_reuse { using Traits = fmha::Ampere_hmma_fp16_traits; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Traits_reuse { using Traits = fmha::Ampere_hmma_fp32_traits; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Traits_reuse { using Traits = fmha::Ampere_imma_int8_int32_traits; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Traits_o_adapter { using Traits = Traits_p; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Traits_o_adapter { using Traits = fmha::Volta_hmma_fp16_16x16x16_traits; }; //////////////////////////////////////////////////////////////////////////////////////////////////// // convert to fp16 before smem_o store template <> struct Traits_o_adapter { using Traits = fmha::Ampere_hmma_fp16_traits; }; //////////////////////////////////////////////////////////////////////////////////////////////////// // convert to fp16 before smem_o store template <> struct Traits_o_adapter { using Traits = fmha::Turing_hmma_fp16_traits; }; //////////////////////////////////////////////////////////////////////////////////////////////////// // convert to bf16 before smem_o store template <> struct Traits_o_adapter { using Traits = fmha::Ampere_hmma_bf16_bf16_traits; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // Instruction traits. typename Traits_, // The global memory tile for Q, K and V. template class Gmem_tile_q_, template class Gmem_tile_k_, template class Gmem_tile_v_, // The global memory tile for the output. template class Gmem_tile_o_, // Sequence length. int S, // The valid hidden dimension. int VALID_D_, // The valid hidden dimension of V. int VALID_DV_, // The iteration step of the outer loop. int STEP, // The number of vertical warps. int WARPS_M, // The number of horizontal warps. int WARPS_N, // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K int CTAS_PER_HEAD_, // The flags to control the behaviour of LDGs. uint32_t FLAGS, // The version of the kernel. int VERSION_, // The mask version of the kernel int MASK_VERSION_, // Do we use half epilogue for the 2nd GEMM (hmma_fp32) bool BMM2_FP16_EPILOGUE = true, // non-positive means disabled int SAGE_BLOCK_SIZE_Q_ = 0, int SAGE_BLOCK_SIZE_K_ = 0, int SAGE_BLOCK_SIZE_V_ = 0> struct Kernel_traits_ { // The instruction traits for the Q*K product. using Traits_p = typename Traits_reuse::Traits; // The instruction traits for the P*V product. Hack to change the traits for Volta HMMA. using Traits_o = typename Traits_o_adapter::Traits; // The instruction traits for the epilogue of the 2nd GEMM. Always use FP16. using Traits_e = typename Traits_o_adapter::Traits; // The padded D dimension enum { VALID_D = VALID_D_ }; enum { D = Next_power_of_two::VALUE }; enum { VALID_DV = VALID_DV_ > 0 ? VALID_DV_ : VALID_D }; enum { DV = Next_power_of_two::VALUE }; enum { SAGE_ATTENTION = SAGE_BLOCK_SIZE_Q_ > 0 || SAGE_BLOCK_SIZE_K_ > 0 || SAGE_BLOCK_SIZE_V_ > 0 }; enum { SAGE_BLOCK_SIZE_Q = SAGE_BLOCK_SIZE_Q_ }; enum { SAGE_BLOCK_SIZE_K = SAGE_BLOCK_SIZE_K_ }; enum { SAGE_BLOCK_SIZE_V = SAGE_BLOCK_SIZE_V_ }; // TODO: expose these tiling params to the interface enum { USE_GRANULAR_TILING = (FLAGS & 0x1000) != 0u }; // TODO ANT: check FLAGS using Traits_tile_size = Traits_tile_size<(bool) USE_GRANULAR_TILING, STEP, S, D, DV, Traits_o::K_PER_MMA>; enum { CTA_P_TILE_M = Traits_tile_size::CTA_P_TILE_M }; enum { CTA_P_TILE_N = Traits_tile_size::CTA_P_TILE_N }; enum { CTA_P_TILE_K = Traits_tile_size::CTA_P_TILE_K }; enum { CTA_O_TILE_M = Traits_tile_size::CTA_O_TILE_M }; enum { CTA_O_TILE_N = Traits_tile_size::CTA_O_TILE_N }; enum { CTA_O_TILE_K = Traits_tile_size::CTA_O_TILE_K }; // Do we need to reload Q due to splitting the D ? enum { RELOAD_Q = static_cast(CTA_P_TILE_K) != static_cast(D) }; // The CTA description for the 1st GEMM. using Cta_tile_p = typename Traits_p::template Cta_tile_extd; // The CTA description for the 2nd GEMM. using Cta_tile_o = typename Traits_o::template Cta_tile_extd; // The MMA tile for the 1st GEMM. using Mma_tile_p = typename Traits_p::template Mma_tile; // The MMA tile for the 2nd GEMM. using Mma_tile_o = typename Traits_o::template Mma_tile; // Compute the total BMM2_MMAS_K (might not the same as Mma_tile_o::MMAS_K if the granular tiling is used). static_assert(S % CTA_O_TILE_K == 0, ""); enum { TOTAL_BMM2_MMAS_K = Mma_tile_o::MMAS_K * (S / CTA_O_TILE_K) }; // Constraints on the K dimension. static_assert(Mma_tile_p::K_PER_MMA <= static_cast(D)); static_assert(Mma_tile_o::K_PER_MMA <= S); // The version. enum { VERSION = VERSION_ }; // The mask version: padding (2), causal (3), sliding_window_causal (4), custom_mask (5). enum { MASK_VERSION = MASK_VERSION_ }; // Whether use causal mask or not. enum { CAUSAL_MASK = MASK_VERSION_ == 3 || MASK_VERSION_ == 4 }; // Whether use the sliding window attention or not. enum { SLIDING_WINDOW_ATTENTION = MASK_VERSION_ == 4 }; // Whether use the custom mask or not. enum { CUSTOM_MASK = MASK_VERSION_ == 5 }; // Do we use LDGSTS for Q, K or V. enum { USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u }; enum { USE_LDGSTS_K = (FLAGS & 0x2u) != 0u }; enum { USE_LDGSTS_V = (FLAGS & 0x4u) != 0u }; // Do we use one buffer for K and V. enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u }; // Do we use the scale max trick. enum { USE_SCALE_MAX = (FLAGS & 0x10u) != 0u }; // Are heads in QKV interleaved, i.e. total x h x 3 x d or total x 3 x h x d. enum { HEADS_INTERLEAVED = (FLAGS & 0x20u) == 0u }; // Keep full K matrix in registers. enum { K_IN_REGS = (FLAGS & 0x40) == 0u }; // Do we use only 2 fragments or full fragments for frag_q/k (only used by flash attention) enum { LIMIT_QK_FRAGMENTS = ((FLAGS & 0x80u) != 0u && !SHARE_SMEM_FOR_K_AND_V) }; // Do we use only 2 fragments or full fragments for frag_v (only used by flash attention) enum { LIMIT_V_FRAGMENTS = ((FLAGS & 0x100u) != 0u && !SHARE_SMEM_FOR_K_AND_V) }; // Limiting QK fragments implies SMEM_K has to reside in SMEM static_assert(!(LIMIT_QK_FRAGMENTS && SHARE_SMEM_FOR_K_AND_V), ""); // Indicates that kernel does not loop over Q tensor, usually kernel name has _nl suffix enum { NO_LOOP = (FLAGS & 0x200u) != 0u }; // Are sequences in one batch interleaved. i.e. s x b x ..., or b x s x ... enum { SEQUENCES_INTERLEAVED = (FLAGS & 0x400) != 0u }; // Use BMM1 softcapping scale or not. enum { ENABLE_BMM1_SOFTCAPPING_SCALE = (FLAGS & 0x800) != 0u }; // Use MTP (multi-token prediction for MLA kernels) or not. enum { IS_MTP = (FLAGS & 0x2000) != 0u }; // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K enum { CTAS_PER_HEAD = CTAS_PER_HEAD_ }; // The number of shared memory buffers to build a software pipeline for Q, K and V. enum { BUFFERS_PER_TILE_SMEM_Q = (USE_GRANULAR_TILING && D > 64) || (USE_LDGSTS_Q && !NO_LOOP) ? 2 : 1 }; enum { BUFFERS_PER_TILE_SMEM_K = USE_GRANULAR_TILING ? 2 : 1 }; enum { BUFFERS_PER_TILE_SMEM_V = USE_GRANULAR_TILING ? 2 : 1 }; // The global memory tile to load Q. using Gmem_tile_q = Gmem_tile_q_; // The shared memory tile to swizzle Q. using Smem_tile_q = fmha::Smem_tile_a; // The global memory tile to load K. using Gmem_tile_k = Gmem_tile_k_; // The shared memory tile to swizzle K. using Smem_tile_k = fmha::Smem_tile_b; // The global memory tile to load V. using Gmem_tile_v = Gmem_tile_v_; // The shared memory tile to swizzle V. using Smem_tile_v = fmha::Smem_tile_v; // The global memory tile to store O. using Gmem_tile_o = Gmem_tile_o_; // The shared memory tile for O. using Smem_tile_o = fmha::Smem_tile_o; // Make sure the number of threads match. static_assert((int) Gmem_tile_o::THREADS_PER_ROW == (int) Smem_tile_o::THREADS_PER_ROW, ""); // The number of threads. enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; // Make sure the number of threads matches both CTAs. static_assert((int) THREADS == (int) Cta_tile_o::THREADS_PER_CTA, ""); // The amount of shared memory needed to load Q and K. enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; // The extra amount of shared memory needed to load V. enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; // The amount of shared memory needed for Q, K and V.. enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; // The amount of shared memory needed to load/store O. enum { BYTES_PER_SMEM_O = Smem_tile_o::BYTES_PER_TILE }; // The amount of shared memory needed to load Q and store O. enum { BYTES_PER_SMEM_QO = NO_LOOP ? Smem_tile_o::BYTES_PER_TILE : Smem_tile_q::BYTES_PER_TILE + BYTES_PER_SMEM_O }; // The amount of shared memory needed for Q, K, V and O. enum { BYTES_PER_SMEM = fmha::Max::VALUE }; // Make sure we have enough shared memory. static_assert((NO_LOOP ? Smem_tile_o::BYTES_PER_TILE : Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE) <= BYTES_PER_SMEM, ""); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // Instruction traits. typename Traits_, // The global memory tile for Q, K and V. template class Gmem_tile_q_, // The global memory tile for the output. template class Gmem_tile_o_, // Sequence length for K/V. int S_KV, // The hidden dimension. int D, // The iteration step of the outer loop. int STEP, // The number of vertical warps. int WARPS_M, // The number of horizontal warps. int WARPS_N, // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K int CTAS_PER_HEAD_, // The flags to control the behaviour of LDGs. uint32_t FLAGS, // The version of the kernel. int VERSION_, // Do we use half epilogue for the 2nd GEMM (hmma_fp32) bool BMM2_FP16_EPILOGUE = true> struct Kernel_traits_fmhca_ { // The instruction traits for the Q*K product. using Traits_p = typename Traits_reuse::Traits; // The instruction traits for the P*V product. Hack to change the traits for Volta HMMA. using Traits_o = typename Traits_o_adapter::Traits; // The instruction traits for the epilogue of the 2nd GEMM. Always use FP16. using Traits_e = typename Traits_o_adapter::Traits; // The CTA description for the 1st GEMM. using Cta_tile_p = typename Traits_p::template Cta_tile_extd; // The CTA description for the 2nd GEMM. using Cta_tile_o = typename Traits_o::template Cta_tile_extd; // The MMA tile for the 1st GEMM. using Mma_tile_p = typename Traits_p::template Mma_tile; // The MMA tile for the 2nd GEMM. using Mma_tile_o = typename Traits_o::template Mma_tile; // Constraints on the K dimension. static_assert(Mma_tile_p::K_PER_MMA <= D, ""); static_assert(Mma_tile_o::K_PER_MMA <= S_KV, ""); // The version. enum { VERSION = VERSION_ }; // The mask version enum { MASK_VERSION = VERSION_ }; // Whether use causal mask or not. enum { CAUSAL_MASK = MASK_VERSION >= 3 }; // Whether use the sliding window attention or not. enum { SLIDING_WINDOW_ATTENTION = MASK_VERSION == 4 }; // Do we use LDGSTS for Q, K or V. enum { USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u }; enum { USE_LDGSTS_K = (FLAGS & 0x2u) != 0u }; enum { USE_LDGSTS_V = (FLAGS & 0x4u) != 0u }; // Do we use one buffer for K and V. enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u }; // Do we use the scale max trick. enum { USE_SCALE_MAX = (FLAGS & 0x10u) != 0u }; // Are heads in QKV interleaved, i.e. total x h x 3 x d or total x 3 x h x d. enum { HEADS_INTERLEAVED = (FLAGS & 0x20u) == 0u }; // Keep full K matrix in registers. enum { K_IN_REGS = (FLAGS & 0x40) == 0u }; // Use BMM1 softcapping scale or not. enum { ENABLE_BMM1_SOFTCAPPING_SCALE = 0 }; // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K enum { CTAS_PER_HEAD = CTAS_PER_HEAD_ }; // The global memory tile to load Q. using Gmem_tile_q = Gmem_tile_q_; // The shared memory tile to swizzle Q. using Smem_tile_q = fmha::Smem_tile_a; // The global memory tile to load K. using Gmem_tile_k = Gmem_tile_q_; // The shared memory tile to swizzle K. using Smem_tile_k = fmha::Smem_tile_b; // The global memory tile to load V. using Gmem_tile_v = Gmem_tile_q_; // The shared memory tile to swizzle V. using Smem_tile_v = fmha::Smem_tile_v; // The global memory tile to store O. using Gmem_tile_o = Gmem_tile_o_; // The shared memory tile for O. using Smem_tile_o = fmha::Smem_tile_o; // Make sure the number of threads match. static_assert((int) Gmem_tile_o::THREADS_PER_ROW == (int) Smem_tile_o::THREADS_PER_ROW, ""); // The number of threads. enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; // Make sure the number of threads matches both CTAs. static_assert((int) THREADS == (int) Cta_tile_o::THREADS_PER_CTA, ""); // The amount of shared memory needed to load Q and K. enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; // The extra amount of shared memory needed to load V. enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; // The amount of shared memory needed for Q, K and V.. enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; // The amount of shared memory needed to load Q and store O. enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE }; // The amount of shared memory needed for Q, K, V and O. enum { BYTES_PER_SMEM = fmha::Max::VALUE }; // Make sure we have enough shared memory. static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, ""); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The instruction traits. typename Traits_, // The sequence length. int S, // The hidden size per head. int VALID_D, // The number of timesteps per iteration of the main loop. int STEP, // The number of vertical warps. int WARPS_M, // The number of horizontal warps. int WARPS_N, // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K int CTAS_PER_HEAD_, // The flags. uint32_t FLAGS = 0x8, // The mask version of the kernel int MASK_VERSION_ = 2> struct Kernel_traits_interleaved_v2_ { // The instruction traits. using Traits = typename Traits_reuse::Traits; using Traits_p = Traits; using Traits_o = Traits; // The padded D dimension enum { D = Next_power_of_two::VALUE }; // The CTA description for the 1st GEMM. using Cta_tile_p = typename Traits::template Cta_tile_extd; // The CTA description for the 2nd GEMM. using Cta_tile_o = typename Traits::template Cta_tile_extd; // The version. enum { VERSION = 2 }; enum { MASK_VERSION = MASK_VERSION_ }; // Whether use causal mask or not. enum { CAUSAL_MASK = MASK_VERSION_ >= 3 }; // Whether use the sliding window attention or not. enum { SLIDING_WINDOW_ATTENTION = MASK_VERSION_ == 4 }; // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K enum { CTAS_PER_HEAD = CTAS_PER_HEAD_ }; // Do we use LDGSTS for Q, K or V. enum { USE_LDGSTS_Q = (FLAGS & 0x1u) != 0u }; enum { USE_LDGSTS_K = (FLAGS & 0x2u) != 0u }; enum { USE_LDGSTS_V = (FLAGS & 0x4u) != 0u }; // Do we use one buffer for K and V. enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u }; // Do we use the scale max trick. enum { USE_SCALE_MAX = (FLAGS & 16) != 0u }; // The global memory tile to load Q. using Gmem_tile_q = fmha::v2::Gmem_tile_qkv_interleaved; // The shared memory tile to swizzle Q. using Smem_tile_q = fmha::Smem_tile_qk_interleaved_a; // The global memory tile to load K. using Gmem_tile_k = fmha::v2::Gmem_tile_qkv_interleaved; // The shared memory tile to swizzle K. using Smem_tile_k = fmha::Smem_tile_qk_interleaved_b; // The global memory tile to load V. using Gmem_tile_v = fmha::v2::Gmem_tile_qkv_interleaved; // The shared memory tile to swizzle V. using Smem_tile_v = fmha::Smem_tile_v_interleaved_b; // The global memory tile to store O. using Gmem_tile_o = fmha::v2::Imma_gmem_tile_o_interleaved; // The shared memory tile for O. using Smem_tile_o = fmha::Smem_tile_o_interleaved; // Make sure the number of threads match. static_assert((int) Gmem_tile_o::THREADS_PER_ROW == (int) Smem_tile_o::THREADS_PER_ROW, ""); // The number of threads. enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; // Make sure the number of threads matches both CTAs. static_assert((int) THREADS == (int) Cta_tile_o::THREADS_PER_CTA, ""); // The amount of shared memory needed to load Q and K. enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; // The extra amount of shared memory needed to load V. enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; // The amount of shared memory needed for Q, K and V.. enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; // The amount of shared memory needed to load Q and store O. enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE }; // The amount of shared memory needed for Q, K, V and O. enum { BYTES_PER_SMEM = fmha::Max::VALUE }; // Make sure we have enough shared memory. static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, ""); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The instruction traits. typename Traits_, // The sequence length. int S, // The hidden size per head. int VALID_D, // The number of timesteps per iteration of the main loop. int STEP, // The number of vertical warps. int WARPS_M, // The number of horizontal warps. int WARPS_N, // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K int CTAS_PER_HEAD_, // The flags. uint32_t FLAGS = 0x8, // The mask version of the kernel int MASK_VERSION_ = 2> using Kernel_traits_interleaved_v2 = Kernel_traits_interleaved_v2_; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The instruction traits. typename Traits, // The sequence length. int S, // The hidden size per head. int D, // The number of timesteps per iteration of the main loop. int STEP, // The number of vertical warps. int WARPS_M, // The number of horizontal warps. int WARPS_N, // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K int CTAS_PER_HEAD, // The flags. uint32_t FLAGS = 0x8> using Kernel_traits_v1 = Kernel_traits_; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The instruction traits. typename Traits, // The sequence length. int S, // The hidden size per head. int D, // The number of timesteps per iteration of the main loop. int STEP, // The number of vertical warps. int WARPS_M, // The number of horizontal warps. int WARPS_N, // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K int CTAS_PER_HEAD, // The flags. uint32_t FLAGS = 0x8> using Kernel_traits_v1_causal_mask = Kernel_traits_; // MASK_VERSION_ //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Gmem_tile_o_dispatcher { template using Gmem_tile_o = fmha::v2::Gmem_tile_o; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Gmem_tile_o_dispatcher { template using Gmem_tile_o = fmha::v2::Gmem_tile_o_uint16; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template <> struct Gmem_tile_o_dispatcher { template using Gmem_tile_o = fmha::v2::Gmem_tile_o_bfloat16; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The instruction traits. typename Traits, // The sequence length. int S, // The hidden size per head. int D, // The hidden dimension of V. int DV, // The number of timesteps per iteration of the main loop. int STEP, // The number of vertical warps. int WARPS_M, // The number of horizontal warps. int WARPS_N, // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K int CTAS_PER_HEAD, // The flags. uint32_t FLAGS = 0x8, // The attention mask version (see src/mask.h). int MASK_VERSION = 2, // Do we use half epilogue for the 2nd GEMM (hmma_fp32) bool BMM2_FP16_EPILOGUE = true, // The output type. typename OutputType = typename Traits::A_type, // The sage attention block size for Q, K and V int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0> using Kernel_traits_v2 = Kernel_traits_::Gmem_tile_o, S, D, DV, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The instruction traits. typename Traits, // The sequence length. int S, // The hidden size per head. int D, // The hidden dimension of V. int DV, // The number of timesteps per iteration of the main loop. int STEP, // The number of vertical warps. int WARPS_M, // The number of horizontal warps. int WARPS_N, // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K int CTAS_PER_HEAD, // The flags. uint32_t FLAGS = 0x8, // The attention mask version (see src/mask.h). int MASK_VERSION = 2, // Do we use half epilogue for the 2nd GEMM (hmma_fp32) bool BMM2_FP16_EPILOGUE = true, // The output type. typename OutputType = typename Traits::A_type, // The sage attention block size for Q, K and V int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0> using Kernel_traits_v2_q_k_v = Kernel_traits_::Gmem_tile_o, S, D, DV, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The instruction traits. typename Traits, // The sequence length. int S, // The hidden size per head. int D, // The hidden dimension of V. int DV, // The number of timesteps per iteration of the main loop. int STEP, // The number of vertical warps. int WARPS_M, // The number of horizontal warps. int WARPS_N, // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K int CTAS_PER_HEAD, // The flags. uint32_t FLAGS = 0x8, // The attention mask version (see src/mask.h). int MASK_VERSION = 2, // Do we use half epilogue for the 2nd GEMM (hmma_fp32) bool BMM2_FP16_EPILOGUE = true, // The output type. typename OutputType = typename Traits::A_type, // The sage attention block size for Q, K and V int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0> using Kernel_traits_v2_paged_kv_cache = Kernel_traits_::Gmem_tile_o, S, D, DV, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The instruction traits. typename Traits, // The sequence length. int S, // The hidden size per head. int D, // The hidden dimension of V. int DV, // The number of timesteps per iteration of the main loop. int STEP, // The number of vertical warps. int WARPS_M, // The number of horizontal warps. int WARPS_N, // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K int CTAS_PER_HEAD, // The flags. uint32_t FLAGS = 0x8, // The attention mask version (see src/mask.h). int MASK_VERSION = 2, // Do we use half epilogue for the 2nd GEMM (hmma_fp32) bool BMM2_FP16_EPILOGUE = true, // The output type. typename OutputType = typename Traits::A_type, // The sage attention block size for Q, K and V int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0> using Kernel_traits_v2_contiguous_kv_cache = Kernel_traits_::Gmem_tile_o, S, D, 0, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS, 2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>; //////////////////////////////////////////////////////////////////////////////////////////////////// template < // The instruction traits. typename Traits, // The sequence length for K and V. int S_KV, // The hidden size per head. int D, // The number of timesteps per iteration of the main loop. int STEP, // The number of vertical warps. int WARPS_M, // The number of horizontal warps. int WARPS_N, // The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K int CTAS_PER_HEAD, // The flags. uint32_t FLAGS = 0x8> using Kernel_traits_fmhca = Kernel_traits_fmhca_; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha