mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-06 19:21:52 +08:00
Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> Co-authored-by: Bo Li <22713281+bobboli@users.noreply.github.com>
379 lines
14 KiB
C++
379 lines
14 KiB
C++
/*
|
|
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "tensorrt_llm/common/assert.h"
|
|
#include "tensorrt_llm/common/config.h"
|
|
#include <cstdint>
|
|
#include <cuda_runtime.h>
|
|
|
|
TRTLLM_NAMESPACE_BEGIN
|
|
|
|
namespace kernels
|
|
{
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// The attention mask types.
|
|
enum class TrtllmGenAttentionMaskType
|
|
{
|
|
// Dense mask.
|
|
Dense = 0,
|
|
// Causal mask.
|
|
Causal,
|
|
// Sliding window or chunked causal mask.
|
|
SlidingOrChunkedCausal,
|
|
// Custom mask.
|
|
Custom
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Helper functions to check the mask type.
|
|
|
|
#define ATTENTION_MASK_TYPE_FUNCTION(MaskType) \
|
|
inline bool is##MaskType##Mask(TrtllmGenAttentionMaskType maskType) \
|
|
{ \
|
|
return (maskType == TrtllmGenAttentionMaskType::MaskType); \
|
|
}
|
|
|
|
ATTENTION_MASK_TYPE_FUNCTION(Dense)
|
|
ATTENTION_MASK_TYPE_FUNCTION(Causal)
|
|
ATTENTION_MASK_TYPE_FUNCTION(SlidingOrChunkedCausal)
|
|
ATTENTION_MASK_TYPE_FUNCTION(Custom)
|
|
|
|
#undef ATTENTION_MASK_TYPE_FUNCTION
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
enum class FmhaKernelType
|
|
{
|
|
// The context-phase kernels.
|
|
Context = 0,
|
|
// Choose the best generation kernel based on the heuristic:
|
|
// use SwapsMmaAbForGeneration kernels when numHeadsQPerKv <= 16, otherwise KeepsMmaAbForGeneration.
|
|
Generation,
|
|
// Swap tensor A and tensor B of Mma, which only supports numHeadsQPerKv <= 16.
|
|
SwapsMmaAbForGeneration,
|
|
// Keep tensor A and tensor B of Mma.
|
|
KeepsMmaAbForGeneration,
|
|
// Speculative decoding (Medusa and Eagle) generation-phase attention kernels, where seqLenQ > 1.
|
|
SpecDecodingGeneration
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Helper functions to check the fmha kernel type.
|
|
|
|
#define FMHA_KERNEL_TYPE_FUNCTION(KernelType) \
|
|
inline bool is##KernelType##Kernel(FmhaKernelType kernelType) \
|
|
{ \
|
|
return (kernelType == FmhaKernelType::KernelType); \
|
|
}
|
|
|
|
FMHA_KERNEL_TYPE_FUNCTION(Context)
|
|
FMHA_KERNEL_TYPE_FUNCTION(Generation)
|
|
FMHA_KERNEL_TYPE_FUNCTION(SwapsMmaAbForGeneration)
|
|
FMHA_KERNEL_TYPE_FUNCTION(KeepsMmaAbForGeneration)
|
|
FMHA_KERNEL_TYPE_FUNCTION(SpecDecodingGeneration)
|
|
|
|
#undef QKV_LAYOUT_FUNCTION
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Note that (batchSize, seqLen) dimensions will be packed as sumOfSeqLens without paddings for
|
|
// variable sequence lengths.
|
|
enum class QkvLayout
|
|
{
|
|
// SeparateQkv: separate Q, K and V buffers.
|
|
// Each has the shape: [batchSize, seqLen, numHeads, headDim].
|
|
SeparateQkv = 0,
|
|
// PackedQkv: single buffer for Q, K and V.
|
|
// Shape: [batchSize, seqLen, numHeadsQ + 2*numHeadsKv, headDim].
|
|
PackedQkv,
|
|
// Paged buffer for K and V. Its shape is [batchSize, 2, maxNumPagesPerSeq]. The 2 corresponds to
|
|
// K
|
|
// and V. That buffer stores the logical page index of the paged-KV memory pool. Each "page" of
|
|
// that
|
|
// pool is a contiguous buffer of shape [numHeadsKv, pageSize, headDim].
|
|
PagedKv,
|
|
// ContiguousKv:
|
|
// Contiguous buffer for Q with shape [batchSize, seqLen, numHeads, headDim].
|
|
// Contiguous buffer for Kv with shape [batchSize, seqLen, 2 * numHeads, headDim].
|
|
ContiguousKv,
|
|
};
|
|
|
|
// Helper functions to check the QkvLayout type.
|
|
|
|
#define QKV_LAYOUT_FUNCTION(LayoutType) \
|
|
inline bool is##LayoutType(QkvLayout qkvLayout) \
|
|
{ \
|
|
return (qkvLayout == QkvLayout::LayoutType); \
|
|
}
|
|
|
|
QKV_LAYOUT_FUNCTION(SeparateQkv)
|
|
QKV_LAYOUT_FUNCTION(PackedQkv)
|
|
QKV_LAYOUT_FUNCTION(PagedKv)
|
|
QKV_LAYOUT_FUNCTION(ContiguousKv)
|
|
|
|
#undef QKV_LAYOUT_FUNCTION
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
enum class TileScheduler
|
|
{
|
|
// Static scheduler (Non-persistent).
|
|
Static = 0,
|
|
// Persistent scheduler.
|
|
Persistent
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
enum class MultiCtasKvMode
|
|
{
|
|
// Disable the multiCtasKvMode.
|
|
Disabled = 0,
|
|
// Do the reduction through the global memory and atomic counters.
|
|
GmemReduction,
|
|
// Same as GmemReduction, but use a separate kernel for the reduction.
|
|
// It is only supported/needed for 2-CTA or 1-CTA keepsMmaAbForGeneration MLA kernels with large
|
|
// reduction tiles.
|
|
GmemReductionWithSeparateKernel,
|
|
// Do the reduction through the CGA remote shared memory.
|
|
CgaSmemReduction
|
|
};
|
|
|
|
// Helper function to check if the multiCtasKv is enabled.
|
|
inline bool isMultiCtasKvEnabled(MultiCtasKvMode multiCtasKvMode)
|
|
{
|
|
return multiCtasKvMode != MultiCtasKvMode::Disabled;
|
|
}
|
|
|
|
// Helper function to check the multiCtasKvMode type.
|
|
|
|
#define MULTI_CTAS_KV_MODE_FUNCTION(Type) \
|
|
inline bool is##Type(MultiCtasKvMode multiCtasKvMode) \
|
|
{ \
|
|
return (multiCtasKvMode == MultiCtasKvMode::Type); \
|
|
}
|
|
|
|
MULTI_CTAS_KV_MODE_FUNCTION(Disabled)
|
|
MULTI_CTAS_KV_MODE_FUNCTION(GmemReduction)
|
|
MULTI_CTAS_KV_MODE_FUNCTION(GmemReductionWithSeparateKernel)
|
|
MULTI_CTAS_KV_MODE_FUNCTION(CgaSmemReduction)
|
|
|
|
#undef MULTI_CTAS_KV_MODE_FUNCTION
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
struct TllmGenFmhaRunnerParams
|
|
{
|
|
// Input layout.
|
|
QkvLayout mQkvLayout;
|
|
// Attention mask type.
|
|
TrtllmGenAttentionMaskType mMaskType;
|
|
// The kernel type.
|
|
FmhaKernelType mKernelType;
|
|
// The tile scheduler.
|
|
TileScheduler mTileScheduler;
|
|
// The multiCtasKvMode (i.e. multiBlockMode).
|
|
bool mMultiCtasKvMode;
|
|
// Use block sparse attention.
|
|
bool mUseBlockSparseAttention;
|
|
|
|
// Input QKV buffers.
|
|
void const* qPtr;
|
|
void const* kPtr;
|
|
void const* vPtr;
|
|
// Packed KV buffer
|
|
void const* kvPtr;
|
|
// Packed KV scaling factor buffer
|
|
void const* kvSfPtr;
|
|
// Packed QKV buffer
|
|
void const* qkvPtr;
|
|
// The attention sinks pointer (additional value per head in the denominator of the softmax).
|
|
float const* attentionSinksPtr;
|
|
// The general packed custom mask ptr which does not meet specific format for trtllm gen kernels.
|
|
int32_t const* generalPackedCustoMaskPtr;
|
|
// The custom mask ptr.
|
|
uint32_t* customMaskPtr;
|
|
// The packed custom mask's offsets of each sequence.
|
|
int64_t* customMaskOffsetsPtr;
|
|
// The first sparseMask offsets in the Kv sequence dimension.
|
|
int32_t* firstSparseMaskOffsetsKvPtr;
|
|
// The counter for the multiCtasKv mode.
|
|
int32_t* multiCtasKvCounterPtr;
|
|
// The sequence length buffer for K/V.
|
|
int const* seqLensKvPtr;
|
|
// The cumulative sequence length buffer for Q and K/V
|
|
int const* cumSeqLensQPtr;
|
|
int const* cumSeqLensKvPtr;
|
|
// The kv page idx
|
|
int const* kvPageIdxPtr;
|
|
// The device output scale for FP8 quantization.
|
|
float const* outputScalePtr;
|
|
// The device scaling factor for softmax (multiplied by log2 to use faster exp2)
|
|
float const* scaleSoftmaxLog2Ptr;
|
|
// The device scale for KV scaling factor.
|
|
float const* kvSfScalePtr;
|
|
// The device scale for O scaling factor.
|
|
float const* oSfScalePtr;
|
|
// The scratch space for each CtaKv when the multiCtasKv mode is enabled.
|
|
// PartialO, partialMax and partialSum will be stored to the scratch space.
|
|
void* multiCtasKvScratchPtr;
|
|
// The softmax stats buffer.
|
|
// The softmax max/sum values will be stored to the buffer if it is not nullptr.
|
|
float2* softmaxStatsPtr;
|
|
// The output buffer.
|
|
void* oPtr;
|
|
// The output scaling factor buffer.
|
|
void* oSfPtr;
|
|
// The sequence lengths for Q.
|
|
int const* seqlensQPtr;
|
|
|
|
// Head dimension for Q and K.
|
|
int mHeadDimQk;
|
|
// Head dimension for V.
|
|
int mHeadDimV;
|
|
// Head dimension for Q/K non-RoPE part, only used for MLA now.
|
|
int mHeadDimQkNope;
|
|
// Number of heads for Q and K/V.
|
|
int mNumHeadsQ, mNumHeadsKv, mNumHeadsQPerKv;
|
|
// The batch size.
|
|
int mBatchSize;
|
|
// The max sequence length in the contiguous Kv cache.
|
|
int mMaxSeqLenCacheKv;
|
|
// The max q sequence length.
|
|
int mMaxSeqLenQ;
|
|
// The max kv sequence length.
|
|
int mMaxSeqLenKv;
|
|
// The attention window size for sliding window attention (sliding-window-attention is enabled when seqLenKv >
|
|
// mAttentionWindowSize).
|
|
int mAttentionWindowSize;
|
|
// The chunked attention size (chunked-context is enabled when seqLenKv > mChunkedAttentionSize).
|
|
int mChunkedAttentionSize;
|
|
// The sum of sequence lengths for Q and K/V. (Only used when mSupportsVarSeqLens = true)
|
|
int mSumOfSeqLensQ;
|
|
int mSumOfSeqLensKv;
|
|
// The maximum number of pages per sequence in the paged-kv buffer.
|
|
int mMaxNumPagesPerSeqKv;
|
|
// The number of tokens per pageKv.
|
|
int mNumTokensPerPage;
|
|
// The number of pages in memory pool.
|
|
int mNumPagesInMemPool;
|
|
// The number of multiProcessor for the GPU.
|
|
int mMultiProcessorCount;
|
|
// Scaling factor for Q.
|
|
float mScaleQ;
|
|
// The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase kernel when inflight
|
|
// batching is enabled.
|
|
int mSfStartTokenIdx;
|
|
// Skip softmax threshold scale factor.
|
|
float mSkipSoftmaxThresholdScaleFactor;
|
|
// Whether to use sparse MLA.
|
|
bool mSparseMla;
|
|
// The top k value for sparse MLA.
|
|
int mSparseMlaTopK;
|
|
// The cuda stream.
|
|
cudaStream_t stream;
|
|
// The layer index.
|
|
int32_t mLayerIdx = 0;
|
|
// Whether the spec-dec tree is used.
|
|
bool mIsSpecDecTree = false;
|
|
|
|
// set the attention mask type
|
|
TllmGenFmhaRunnerParams& setAttentionMaskType(std::int8_t maskType)
|
|
{
|
|
// maskType is the enum of tensorrt_llm::kernels::ContextAttentionMaskType
|
|
// convert ContextAttentionMaskType to TrtllmGenAttentionMaskType
|
|
switch (maskType)
|
|
{
|
|
case 0: // tensorrt_llm::kernels::ContextAttentionMaskType::PADDING
|
|
mMaskType = TrtllmGenAttentionMaskType::Dense;
|
|
break;
|
|
case 1: // tensorrt_llm::kernels::ContextAttentionMaskType::CAUSAL
|
|
mMaskType = TrtllmGenAttentionMaskType::Causal;
|
|
break;
|
|
case 2: // tensorrt_llm::kernels::ContextAttentionMaskType::SLIDING_OR_CHUNKED_CAUSAL
|
|
mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal;
|
|
break;
|
|
case 3: // tensorrt_llm::kernels::ContextAttentionMaskType::CUSTOM_MASK
|
|
mMaskType = TrtllmGenAttentionMaskType::Custom;
|
|
break;
|
|
default:
|
|
TLLM_THROW("ContextAttentionMaskType %d cannot be mapped to TrtllmGenAttentionMaskType",
|
|
static_cast<int>(maskType));
|
|
}
|
|
return *this;
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Parameters that might be updated when selecting kernels.
|
|
struct TllmGenSelectKernelParams
|
|
{
|
|
// The FMHA kernel type.
|
|
FmhaKernelType mKernelType;
|
|
// The headDimV per CTA, which is only used by MLA generation kernels currently.
|
|
int mHeadDimPerCtaV;
|
|
// The multiCtasKvMode.
|
|
MultiCtasKvMode mMultiCtasKvMode;
|
|
// Force using GmemRedution for the multiCtasKvMode.
|
|
bool mForceGmemReduction;
|
|
// The mask type.
|
|
TrtllmGenAttentionMaskType mMaskType;
|
|
// The number of tokens per page.
|
|
int mNumTokensPerPage;
|
|
// Reuse smemK for V or not (only work with MLA generation kernels).
|
|
bool mReuseSmemKForV;
|
|
// Do we need to select a new kernel as the parameters have been updated.
|
|
bool mSelectNewKernel;
|
|
// The tile scheduler.
|
|
TileScheduler mTileScheduler;
|
|
// The tile size for Q.
|
|
int mTileSizeQ;
|
|
// The tile size for Kv.
|
|
int mTileSizeKv;
|
|
// Use 2 CTA MMA or not.
|
|
bool mUses2CtaMma;
|
|
// Skips softmax or not.
|
|
bool mSkipsSoftmaxWhenPossible;
|
|
|
|
// The constructor.
|
|
TllmGenSelectKernelParams(TllmGenFmhaRunnerParams params)
|
|
: mKernelType(params.mKernelType)
|
|
, mHeadDimPerCtaV(params.mHeadDimV)
|
|
// Note the CgaSmemReduction will be enabled based on the heuristic.
|
|
, mMultiCtasKvMode(params.mMultiCtasKvMode ? MultiCtasKvMode::GmemReduction : MultiCtasKvMode::Disabled)
|
|
, mForceGmemReduction(false)
|
|
, mMaskType(params.mMaskType)
|
|
, mNumTokensPerPage(params.mNumTokensPerPage)
|
|
, mReuseSmemKForV(false)
|
|
, mSelectNewKernel(false)
|
|
, mTileScheduler(params.mTileScheduler)
|
|
, mTileSizeQ(128)
|
|
, mTileSizeKv(128)
|
|
, mUses2CtaMma(false)
|
|
, mSkipsSoftmaxWhenPossible(params.mSkipSoftmaxThresholdScaleFactor != 0.0f){};
|
|
};
|
|
|
|
} // namespace kernels
|
|
|
|
TRTLLM_NAMESPACE_END
|