/* * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/config.h" #include #include TRTLLM_NAMESPACE_BEGIN namespace kernels { //////////////////////////////////////////////////////////////////////////////////////////////////// // The attention mask types. enum class TrtllmGenAttentionMaskType { // Dense mask. Dense = 0, // Causal mask. Causal, // Sliding window or chunked causal mask. SlidingOrChunkedCausal, // Custom mask. Custom }; //////////////////////////////////////////////////////////////////////////////////////////////////// // Helper functions to check the mask type. #define ATTENTION_MASK_TYPE_FUNCTION(MaskType) \ inline bool is##MaskType##Mask(TrtllmGenAttentionMaskType maskType) \ { \ return (maskType == TrtllmGenAttentionMaskType::MaskType); \ } ATTENTION_MASK_TYPE_FUNCTION(Dense) ATTENTION_MASK_TYPE_FUNCTION(Causal) ATTENTION_MASK_TYPE_FUNCTION(SlidingOrChunkedCausal) ATTENTION_MASK_TYPE_FUNCTION(Custom) #undef ATTENTION_MASK_TYPE_FUNCTION //////////////////////////////////////////////////////////////////////////////////////////////////// enum class FmhaKernelType { // The context-phase kernels. Context = 0, // Choose the best generation kernel based on the heuristic: // use SwapsMmaAbForGeneration kernels when numHeadsQPerKv <= 16, otherwise KeepsMmaAbForGeneration. Generation, // Swap tensor A and tensor B of Mma, which only supports numHeadsQPerKv <= 16. SwapsMmaAbForGeneration, // Keep tensor A and tensor B of Mma. KeepsMmaAbForGeneration, // Speculative decoding (Medusa and Eagle) generation-phase attention kernels, where seqLenQ > 1. SpecDecodingGeneration }; //////////////////////////////////////////////////////////////////////////////////////////////////// // Helper functions to check the fmha kernel type. #define FMHA_KERNEL_TYPE_FUNCTION(KernelType) \ inline bool is##KernelType##Kernel(FmhaKernelType kernelType) \ { \ return (kernelType == FmhaKernelType::KernelType); \ } FMHA_KERNEL_TYPE_FUNCTION(Context) FMHA_KERNEL_TYPE_FUNCTION(Generation) FMHA_KERNEL_TYPE_FUNCTION(SwapsMmaAbForGeneration) FMHA_KERNEL_TYPE_FUNCTION(KeepsMmaAbForGeneration) FMHA_KERNEL_TYPE_FUNCTION(SpecDecodingGeneration) #undef QKV_LAYOUT_FUNCTION //////////////////////////////////////////////////////////////////////////////////////////////////// // Note that (batchSize, seqLen) dimensions will be packed as sumOfSeqLens without paddings for // variable sequence lengths. enum class QkvLayout { // SeparateQkv: separate Q, K and V buffers. // Each has the shape: [batchSize, seqLen, numHeads, headDim]. SeparateQkv = 0, // PackedQkv: single buffer for Q, K and V. // Shape: [batchSize, seqLen, numHeadsQ + 2*numHeadsKv, headDim]. PackedQkv, // Paged buffer for K and V. Its shape is [batchSize, 2, maxNumPagesPerSeq]. The 2 corresponds to // K // and V. That buffer stores the logical page index of the paged-KV memory pool. Each "page" of // that // pool is a contiguous buffer of shape [numHeadsKv, pageSize, headDim]. PagedKv, // ContiguousKv: // Contiguous buffer for Q with shape [batchSize, seqLen, numHeads, headDim]. // Contiguous buffer for Kv with shape [batchSize, seqLen, 2 * numHeads, headDim]. ContiguousKv, }; // Helper functions to check the QkvLayout type. #define QKV_LAYOUT_FUNCTION(LayoutType) \ inline bool is##LayoutType(QkvLayout qkvLayout) \ { \ return (qkvLayout == QkvLayout::LayoutType); \ } QKV_LAYOUT_FUNCTION(SeparateQkv) QKV_LAYOUT_FUNCTION(PackedQkv) QKV_LAYOUT_FUNCTION(PagedKv) QKV_LAYOUT_FUNCTION(ContiguousKv) #undef QKV_LAYOUT_FUNCTION //////////////////////////////////////////////////////////////////////////////////////////////////// enum class TileScheduler { // Static scheduler (Non-persistent). Static = 0, // Persistent scheduler. Persistent }; //////////////////////////////////////////////////////////////////////////////////////////////////// enum class MultiCtasKvMode { // Disable the multiCtasKvMode. Disabled = 0, // Do the reduction through the global memory and atomic counters. GmemReduction, // Same as GmemReduction, but use a separate kernel for the reduction. // It is only supported/needed for 2-CTA or 1-CTA keepsMmaAbForGeneration MLA kernels with large // reduction tiles. GmemReductionWithSeparateKernel, // Do the reduction through the CGA remote shared memory. CgaSmemReduction }; // Helper function to check if the multiCtasKv is enabled. inline bool isMultiCtasKvEnabled(MultiCtasKvMode multiCtasKvMode) { return multiCtasKvMode != MultiCtasKvMode::Disabled; } // Helper function to check the multiCtasKvMode type. #define MULTI_CTAS_KV_MODE_FUNCTION(Type) \ inline bool is##Type(MultiCtasKvMode multiCtasKvMode) \ { \ return (multiCtasKvMode == MultiCtasKvMode::Type); \ } MULTI_CTAS_KV_MODE_FUNCTION(Disabled) MULTI_CTAS_KV_MODE_FUNCTION(GmemReduction) MULTI_CTAS_KV_MODE_FUNCTION(GmemReductionWithSeparateKernel) MULTI_CTAS_KV_MODE_FUNCTION(CgaSmemReduction) #undef MULTI_CTAS_KV_MODE_FUNCTION //////////////////////////////////////////////////////////////////////////////////////////////////// struct TllmGenFmhaRunnerParams { // Input layout. QkvLayout mQkvLayout; // Attention mask type. TrtllmGenAttentionMaskType mMaskType; // The kernel type. FmhaKernelType mKernelType; // The tile scheduler. TileScheduler mTileScheduler; // The multiCtasKvMode (i.e. multiBlockMode). bool mMultiCtasKvMode; // Use block sparse attention. bool mUseBlockSparseAttention; // Input QKV buffers. void const* qPtr; void const* kPtr; void const* vPtr; // Packed KV buffer void const* kvPtr; // Packed KV scaling factor buffer void const* kvSfPtr; // Packed QKV buffer void const* qkvPtr; // The attention sinks pointer (additional value per head in the denominator of the softmax). float const* attentionSinksPtr; // The general packed custom mask ptr which does not meet specific format for trtllm gen kernels. int32_t const* generalPackedCustoMaskPtr; // The custom mask ptr. uint32_t* customMaskPtr; // The packed custom mask's offsets of each sequence. int64_t* customMaskOffsetsPtr; // The first sparseMask offsets in the Kv sequence dimension. int32_t* firstSparseMaskOffsetsKvPtr; // The counter for the multiCtasKv mode. int32_t* multiCtasKvCounterPtr; // The sequence length buffer for K/V. int const* seqLensKvPtr; // The cumulative sequence length buffer for Q and K/V int const* cumSeqLensQPtr; int const* cumSeqLensKvPtr; // The kv page idx int const* kvPageIdxPtr; // The device output scale for FP8 quantization. float const* outputScalePtr; // The device scaling factor for softmax (multiplied by log2 to use faster exp2) float const* scaleSoftmaxLog2Ptr; // The device scale for KV scaling factor. float const* kvSfScalePtr; // The device scale for O scaling factor. float const* oSfScalePtr; // The scratch space for each CtaKv when the multiCtasKv mode is enabled. // PartialO, partialMax and partialSum will be stored to the scratch space. void* multiCtasKvScratchPtr; // The softmax stats buffer. // The softmax max/sum values will be stored to the buffer if it is not nullptr. float2* softmaxStatsPtr; // The output buffer. void* oPtr; // The output scaling factor buffer. void* oSfPtr; // The sequence lengths for Q. int const* seqlensQPtr; // Head dimension for Q and K. int mHeadDimQk; // Head dimension for V. int mHeadDimV; // Head dimension for Q/K non-RoPE part, only used for MLA now. int mHeadDimQkNope; // Number of heads for Q and K/V. int mNumHeadsQ, mNumHeadsKv, mNumHeadsQPerKv; // The batch size. int mBatchSize; // The max sequence length in the contiguous Kv cache. int mMaxSeqLenCacheKv; // The max q sequence length. int mMaxSeqLenQ; // The max kv sequence length. int mMaxSeqLenKv; // The attention window size for sliding window attention (sliding-window-attention is enabled when seqLenKv > // mAttentionWindowSize). int mAttentionWindowSize; // The chunked attention size (chunked-context is enabled when seqLenKv > mChunkedAttentionSize). int mChunkedAttentionSize; // The sum of sequence lengths for Q and K/V. (Only used when mSupportsVarSeqLens = true) int mSumOfSeqLensQ; int mSumOfSeqLensKv; // The maximum number of pages per sequence in the paged-kv buffer. int mMaxNumPagesPerSeqKv; // The number of tokens per pageKv. int mNumTokensPerPage; // The number of pages in memory pool. int mNumPagesInMemPool; // The number of multiProcessor for the GPU. int mMultiProcessorCount; // Scaling factor for Q. float mScaleQ; // The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase kernel when inflight // batching is enabled. int mSfStartTokenIdx; // Skip softmax threshold scale factor. float mSkipSoftmaxThresholdScaleFactor; // Whether to use sparse MLA. bool mSparseMla; // The top k value for sparse MLA. int mSparseMlaTopK; // The cuda stream. cudaStream_t stream; // The layer index. int32_t mLayerIdx = 0; // Whether the spec-dec tree is used. bool mIsSpecDecTree = false; // set the attention mask type TllmGenFmhaRunnerParams& setAttentionMaskType(std::int8_t maskType) { // maskType is the enum of tensorrt_llm::kernels::ContextAttentionMaskType // convert ContextAttentionMaskType to TrtllmGenAttentionMaskType switch (maskType) { case 0: // tensorrt_llm::kernels::ContextAttentionMaskType::PADDING mMaskType = TrtllmGenAttentionMaskType::Dense; break; case 1: // tensorrt_llm::kernels::ContextAttentionMaskType::CAUSAL mMaskType = TrtllmGenAttentionMaskType::Causal; break; case 2: // tensorrt_llm::kernels::ContextAttentionMaskType::SLIDING_OR_CHUNKED_CAUSAL mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal; break; case 3: // tensorrt_llm::kernels::ContextAttentionMaskType::CUSTOM_MASK mMaskType = TrtllmGenAttentionMaskType::Custom; break; default: TLLM_THROW("ContextAttentionMaskType %d cannot be mapped to TrtllmGenAttentionMaskType", static_cast(maskType)); } return *this; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// // Parameters that might be updated when selecting kernels. struct TllmGenSelectKernelParams { // The FMHA kernel type. FmhaKernelType mKernelType; // The headDimV per CTA, which is only used by MLA generation kernels currently. int mHeadDimPerCtaV; // The multiCtasKvMode. MultiCtasKvMode mMultiCtasKvMode; // Force using GmemRedution for the multiCtasKvMode. bool mForceGmemReduction; // The mask type. TrtllmGenAttentionMaskType mMaskType; // Reuse smemK for V or not (only work with MLA generation kernels). bool mReuseSmemKForV; // Do we need to select a new kernel as the parameters have been updated. bool mSelectNewKernel; // The tile scheduler. TileScheduler mTileScheduler; // The tile size for Kv. int mTileSizeKv; // Use 2 CTA MMA or not. bool mUses2CtaMma; // Skips softmax or not. bool mSkipsSoftmaxWhenPossible; // The constructor. TllmGenSelectKernelParams(TllmGenFmhaRunnerParams params) : mKernelType(params.mKernelType) , mHeadDimPerCtaV(params.mHeadDimV) // Note the CgaSmemReduction will be enabled based on the heuristic. , mMultiCtasKvMode(params.mMultiCtasKvMode ? MultiCtasKvMode::GmemReduction : MultiCtasKvMode::Disabled) , mForceGmemReduction(false) , mMaskType(params.mMaskType) , mReuseSmemKForV(false) , mSelectNewKernel(false) , mTileScheduler(params.mTileScheduler) , mTileSizeKv(128) , mUses2CtaMma(false) , mSkipsSoftmaxWhenPossible(false){}; }; } // namespace kernels TRTLLM_NAMESPACE_END