TensorRT-LLMs/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h
Perkz Zheng 71ccc07d2b
[None][feat] update trtllm-gen to support groupsTokensHeadsQ (#10261)
Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Co-authored-by: Bo Li <22713281+bobboli@users.noreply.github.com>
2026-01-15 02:24:25 -05:00

379 lines
14 KiB
C++

/*
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/config.h"
#include <cstdint>
#include <cuda_runtime.h>
TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
////////////////////////////////////////////////////////////////////////////////////////////////////
// The attention mask types.
enum class TrtllmGenAttentionMaskType
{
// Dense mask.
Dense = 0,
// Causal mask.
Causal,
// Sliding window or chunked causal mask.
SlidingOrChunkedCausal,
// Custom mask.
Custom
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Helper functions to check the mask type.
#define ATTENTION_MASK_TYPE_FUNCTION(MaskType) \
inline bool is##MaskType##Mask(TrtllmGenAttentionMaskType maskType) \
{ \
return (maskType == TrtllmGenAttentionMaskType::MaskType); \
}
ATTENTION_MASK_TYPE_FUNCTION(Dense)
ATTENTION_MASK_TYPE_FUNCTION(Causal)
ATTENTION_MASK_TYPE_FUNCTION(SlidingOrChunkedCausal)
ATTENTION_MASK_TYPE_FUNCTION(Custom)
#undef ATTENTION_MASK_TYPE_FUNCTION
////////////////////////////////////////////////////////////////////////////////////////////////////
enum class FmhaKernelType
{
// The context-phase kernels.
Context = 0,
// Choose the best generation kernel based on the heuristic:
// use SwapsMmaAbForGeneration kernels when numHeadsQPerKv <= 16, otherwise KeepsMmaAbForGeneration.
Generation,
// Swap tensor A and tensor B of Mma, which only supports numHeadsQPerKv <= 16.
SwapsMmaAbForGeneration,
// Keep tensor A and tensor B of Mma.
KeepsMmaAbForGeneration,
// Speculative decoding (Medusa and Eagle) generation-phase attention kernels, where seqLenQ > 1.
SpecDecodingGeneration
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Helper functions to check the fmha kernel type.
#define FMHA_KERNEL_TYPE_FUNCTION(KernelType) \
inline bool is##KernelType##Kernel(FmhaKernelType kernelType) \
{ \
return (kernelType == FmhaKernelType::KernelType); \
}
FMHA_KERNEL_TYPE_FUNCTION(Context)
FMHA_KERNEL_TYPE_FUNCTION(Generation)
FMHA_KERNEL_TYPE_FUNCTION(SwapsMmaAbForGeneration)
FMHA_KERNEL_TYPE_FUNCTION(KeepsMmaAbForGeneration)
FMHA_KERNEL_TYPE_FUNCTION(SpecDecodingGeneration)
#undef QKV_LAYOUT_FUNCTION
////////////////////////////////////////////////////////////////////////////////////////////////////
// Note that (batchSize, seqLen) dimensions will be packed as sumOfSeqLens without paddings for
// variable sequence lengths.
enum class QkvLayout
{
// SeparateQkv: separate Q, K and V buffers.
// Each has the shape: [batchSize, seqLen, numHeads, headDim].
SeparateQkv = 0,
// PackedQkv: single buffer for Q, K and V.
// Shape: [batchSize, seqLen, numHeadsQ + 2*numHeadsKv, headDim].
PackedQkv,
// Paged buffer for K and V. Its shape is [batchSize, 2, maxNumPagesPerSeq]. The 2 corresponds to
// K
// and V. That buffer stores the logical page index of the paged-KV memory pool. Each "page" of
// that
// pool is a contiguous buffer of shape [numHeadsKv, pageSize, headDim].
PagedKv,
// ContiguousKv:
// Contiguous buffer for Q with shape [batchSize, seqLen, numHeads, headDim].
// Contiguous buffer for Kv with shape [batchSize, seqLen, 2 * numHeads, headDim].
ContiguousKv,
};
// Helper functions to check the QkvLayout type.
#define QKV_LAYOUT_FUNCTION(LayoutType) \
inline bool is##LayoutType(QkvLayout qkvLayout) \
{ \
return (qkvLayout == QkvLayout::LayoutType); \
}
QKV_LAYOUT_FUNCTION(SeparateQkv)
QKV_LAYOUT_FUNCTION(PackedQkv)
QKV_LAYOUT_FUNCTION(PagedKv)
QKV_LAYOUT_FUNCTION(ContiguousKv)
#undef QKV_LAYOUT_FUNCTION
////////////////////////////////////////////////////////////////////////////////////////////////////
enum class TileScheduler
{
// Static scheduler (Non-persistent).
Static = 0,
// Persistent scheduler.
Persistent
};
////////////////////////////////////////////////////////////////////////////////////////////////////
enum class MultiCtasKvMode
{
// Disable the multiCtasKvMode.
Disabled = 0,
// Do the reduction through the global memory and atomic counters.
GmemReduction,
// Same as GmemReduction, but use a separate kernel for the reduction.
// It is only supported/needed for 2-CTA or 1-CTA keepsMmaAbForGeneration MLA kernels with large
// reduction tiles.
GmemReductionWithSeparateKernel,
// Do the reduction through the CGA remote shared memory.
CgaSmemReduction
};
// Helper function to check if the multiCtasKv is enabled.
inline bool isMultiCtasKvEnabled(MultiCtasKvMode multiCtasKvMode)
{
return multiCtasKvMode != MultiCtasKvMode::Disabled;
}
// Helper function to check the multiCtasKvMode type.
#define MULTI_CTAS_KV_MODE_FUNCTION(Type) \
inline bool is##Type(MultiCtasKvMode multiCtasKvMode) \
{ \
return (multiCtasKvMode == MultiCtasKvMode::Type); \
}
MULTI_CTAS_KV_MODE_FUNCTION(Disabled)
MULTI_CTAS_KV_MODE_FUNCTION(GmemReduction)
MULTI_CTAS_KV_MODE_FUNCTION(GmemReductionWithSeparateKernel)
MULTI_CTAS_KV_MODE_FUNCTION(CgaSmemReduction)
#undef MULTI_CTAS_KV_MODE_FUNCTION
////////////////////////////////////////////////////////////////////////////////////////////////////
struct TllmGenFmhaRunnerParams
{
// Input layout.
QkvLayout mQkvLayout;
// Attention mask type.
TrtllmGenAttentionMaskType mMaskType;
// The kernel type.
FmhaKernelType mKernelType;
// The tile scheduler.
TileScheduler mTileScheduler;
// The multiCtasKvMode (i.e. multiBlockMode).
bool mMultiCtasKvMode;
// Use block sparse attention.
bool mUseBlockSparseAttention;
// Input QKV buffers.
void const* qPtr;
void const* kPtr;
void const* vPtr;
// Packed KV buffer
void const* kvPtr;
// Packed KV scaling factor buffer
void const* kvSfPtr;
// Packed QKV buffer
void const* qkvPtr;
// The attention sinks pointer (additional value per head in the denominator of the softmax).
float const* attentionSinksPtr;
// The general packed custom mask ptr which does not meet specific format for trtllm gen kernels.
int32_t const* generalPackedCustoMaskPtr;
// The custom mask ptr.
uint32_t* customMaskPtr;
// The packed custom mask's offsets of each sequence.
int64_t* customMaskOffsetsPtr;
// The first sparseMask offsets in the Kv sequence dimension.
int32_t* firstSparseMaskOffsetsKvPtr;
// The counter for the multiCtasKv mode.
int32_t* multiCtasKvCounterPtr;
// The sequence length buffer for K/V.
int const* seqLensKvPtr;
// The cumulative sequence length buffer for Q and K/V
int const* cumSeqLensQPtr;
int const* cumSeqLensKvPtr;
// The kv page idx
int const* kvPageIdxPtr;
// The device output scale for FP8 quantization.
float const* outputScalePtr;
// The device scaling factor for softmax (multiplied by log2 to use faster exp2)
float const* scaleSoftmaxLog2Ptr;
// The device scale for KV scaling factor.
float const* kvSfScalePtr;
// The device scale for O scaling factor.
float const* oSfScalePtr;
// The scratch space for each CtaKv when the multiCtasKv mode is enabled.
// PartialO, partialMax and partialSum will be stored to the scratch space.
void* multiCtasKvScratchPtr;
// The softmax stats buffer.
// The softmax max/sum values will be stored to the buffer if it is not nullptr.
float2* softmaxStatsPtr;
// The output buffer.
void* oPtr;
// The output scaling factor buffer.
void* oSfPtr;
// The sequence lengths for Q.
int const* seqlensQPtr;
// Head dimension for Q and K.
int mHeadDimQk;
// Head dimension for V.
int mHeadDimV;
// Head dimension for Q/K non-RoPE part, only used for MLA now.
int mHeadDimQkNope;
// Number of heads for Q and K/V.
int mNumHeadsQ, mNumHeadsKv, mNumHeadsQPerKv;
// The batch size.
int mBatchSize;
// The max sequence length in the contiguous Kv cache.
int mMaxSeqLenCacheKv;
// The max q sequence length.
int mMaxSeqLenQ;
// The max kv sequence length.
int mMaxSeqLenKv;
// The attention window size for sliding window attention (sliding-window-attention is enabled when seqLenKv >
// mAttentionWindowSize).
int mAttentionWindowSize;
// The chunked attention size (chunked-context is enabled when seqLenKv > mChunkedAttentionSize).
int mChunkedAttentionSize;
// The sum of sequence lengths for Q and K/V. (Only used when mSupportsVarSeqLens = true)
int mSumOfSeqLensQ;
int mSumOfSeqLensKv;
// The maximum number of pages per sequence in the paged-kv buffer.
int mMaxNumPagesPerSeqKv;
// The number of tokens per pageKv.
int mNumTokensPerPage;
// The number of pages in memory pool.
int mNumPagesInMemPool;
// The number of multiProcessor for the GPU.
int mMultiProcessorCount;
// Scaling factor for Q.
float mScaleQ;
// The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase kernel when inflight
// batching is enabled.
int mSfStartTokenIdx;
// Skip softmax threshold scale factor.
float mSkipSoftmaxThresholdScaleFactor;
// Whether to use sparse MLA.
bool mSparseMla;
// The top k value for sparse MLA.
int mSparseMlaTopK;
// The cuda stream.
cudaStream_t stream;
// The layer index.
int32_t mLayerIdx = 0;
// Whether the spec-dec tree is used.
bool mIsSpecDecTree = false;
// set the attention mask type
TllmGenFmhaRunnerParams& setAttentionMaskType(std::int8_t maskType)
{
// maskType is the enum of tensorrt_llm::kernels::ContextAttentionMaskType
// convert ContextAttentionMaskType to TrtllmGenAttentionMaskType
switch (maskType)
{
case 0: // tensorrt_llm::kernels::ContextAttentionMaskType::PADDING
mMaskType = TrtllmGenAttentionMaskType::Dense;
break;
case 1: // tensorrt_llm::kernels::ContextAttentionMaskType::CAUSAL
mMaskType = TrtllmGenAttentionMaskType::Causal;
break;
case 2: // tensorrt_llm::kernels::ContextAttentionMaskType::SLIDING_OR_CHUNKED_CAUSAL
mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal;
break;
case 3: // tensorrt_llm::kernels::ContextAttentionMaskType::CUSTOM_MASK
mMaskType = TrtllmGenAttentionMaskType::Custom;
break;
default:
TLLM_THROW("ContextAttentionMaskType %d cannot be mapped to TrtllmGenAttentionMaskType",
static_cast<int>(maskType));
}
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Parameters that might be updated when selecting kernels.
struct TllmGenSelectKernelParams
{
// The FMHA kernel type.
FmhaKernelType mKernelType;
// The headDimV per CTA, which is only used by MLA generation kernels currently.
int mHeadDimPerCtaV;
// The multiCtasKvMode.
MultiCtasKvMode mMultiCtasKvMode;
// Force using GmemRedution for the multiCtasKvMode.
bool mForceGmemReduction;
// The mask type.
TrtllmGenAttentionMaskType mMaskType;
// The number of tokens per page.
int mNumTokensPerPage;
// Reuse smemK for V or not (only work with MLA generation kernels).
bool mReuseSmemKForV;
// Do we need to select a new kernel as the parameters have been updated.
bool mSelectNewKernel;
// The tile scheduler.
TileScheduler mTileScheduler;
// The tile size for Q.
int mTileSizeQ;
// The tile size for Kv.
int mTileSizeKv;
// Use 2 CTA MMA or not.
bool mUses2CtaMma;
// Skips softmax or not.
bool mSkipsSoftmaxWhenPossible;
// The constructor.
TllmGenSelectKernelParams(TllmGenFmhaRunnerParams params)
: mKernelType(params.mKernelType)
, mHeadDimPerCtaV(params.mHeadDimV)
// Note the CgaSmemReduction will be enabled based on the heuristic.
, mMultiCtasKvMode(params.mMultiCtasKvMode ? MultiCtasKvMode::GmemReduction : MultiCtasKvMode::Disabled)
, mForceGmemReduction(false)
, mMaskType(params.mMaskType)
, mNumTokensPerPage(params.mNumTokensPerPage)
, mReuseSmemKForV(false)
, mSelectNewKernel(false)
, mTileScheduler(params.mTileScheduler)
, mTileSizeQ(128)
, mTileSizeKv(128)
, mUses2CtaMma(false)
, mSkipsSoftmaxWhenPossible(params.mSkipSoftmaxThresholdScaleFactor != 0.0f){};
};
} // namespace kernels
TRTLLM_NAMESPACE_END