mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* fp8 kv + bf16 ctx MLA + fp8 gen MLA
Use BF16 for context MLA.
mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together.
Allow mSM==90 for mFP8GenerationMLA==true.
For FMHA, dataTypeKv should be FP8.
For FP8 MLA generation, the output is still in BF16.
Refine debug info for FMHA kernel metadata.
Use inputType, outputType, SM together to hash kernel list.
Add FP8 MLA generation FMHA kernel.
Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel.
Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails.
Refine debug info in fused_multihead_attention_v2.cpp
Correct FP8 MLA metadata.
New kernel provided by Yuxin, which outputs BF16.
smem size is not set correctly, which will lead to illegal mem access.
Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched.
There are two bmm1 scales that should be set correctly.
New kernel generated by Yuxin.
Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA.
Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false.
Skip a check in fmhaDispatcher.
Modifications in fmhaRunner:
- Debug dump.
- if (!isFP8GenerationMLA) skips a lot of flag setting.
- TMA descriptor modification for qo (by Yuxin).
Cleanup debug output.
Clean up o tma descriptor modifications.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Resolve conflicts.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Apply the patch of FP8 FlashMLA and resolve conflicts.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Fix compilation error.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Fix compile error.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* pick blackwell support
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
* Add copyright notice to fused_multihead_attention_v2.cpp.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Add license.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Add missing license.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Exclude building flashMLA kernels under sm90.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Revert "Exclude building flashMLA kernels under sm90."
This reverts commit f0c859d459.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
* Use macro to skip compiling FlashMLA for non sm90 targets.
Signed-off-by: Bo Li <bobboli0202@gmail.com>
---------
Signed-off-by: Bo Li <bobboli0202@gmail.com>
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: Dylan Chen <ziqingc@nvidia.com>
Co-authored-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
858 lines
36 KiB
C++
858 lines
36 KiB
C++
/*
|
|
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "fmhaRunner.h"
|
|
#include "tensorrt_llm/common/envUtils.h"
|
|
#include "tensorrt_llm/common/mathUtils.h"
|
|
#include <cassert>
|
|
#include <cstdio>
|
|
#include <cstring>
|
|
#include <cuda_runtime.h>
|
|
#include <iostream>
|
|
#include <math.h>
|
|
#include <tuple>
|
|
#include <vector>
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace kernels
|
|
{
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
union __half2_uint32_t_union
|
|
{
|
|
half2 fp162;
|
|
uint32_t u32;
|
|
};
|
|
|
|
union __float_uint32_t_union
|
|
{
|
|
float fp32;
|
|
uint32_t u32;
|
|
};
|
|
|
|
static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype)
|
|
{
|
|
if (dtype == DATA_TYPE_FP16)
|
|
{
|
|
__half2_uint32_t_union temp;
|
|
temp.fp162 = __float2half2_rn(norm);
|
|
alpha = temp.u32;
|
|
}
|
|
else if (dtype == DATA_TYPE_FP32)
|
|
{
|
|
__float_uint32_t_union temp;
|
|
temp.fp32 = norm;
|
|
alpha = temp.u32;
|
|
}
|
|
else if (dtype == DATA_TYPE_INT32)
|
|
{
|
|
int32_t inorm = static_cast<int32_t>(norm);
|
|
alpha = reinterpret_cast<uint32_t const&>(inorm);
|
|
}
|
|
else if (dtype == DATA_TYPE_BF16)
|
|
{
|
|
// TODO HACK!! BF16 Outputs are computed in FP32 for FP8.
|
|
// This is because cublas does not allow current FP32 output.
|
|
alpha = reinterpret_cast<uint32_t const&>(norm);
|
|
}
|
|
else
|
|
{
|
|
assert(false);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
FusedMHARunnerV2::FusedMHARunnerV2(MHARunnerFixedParams fixedParams)
|
|
: mFixedParams(fixedParams)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(
|
|
(mSM == kSM_80 || mSM == kSM_86 || mSM == kSM_89 || mSM == kSM_90 || mSM == kSM_100 || mSM == kSM_120),
|
|
"Unsupported architecture");
|
|
TLLM_CHECK_WITH_INFO((mFixedParams.dataType == DATA_TYPE_FP16 || mFixedParams.dataType == DATA_TYPE_BF16
|
|
|| mFixedParams.dataType == DATA_TYPE_E4M3),
|
|
"Unsupported data type");
|
|
xmmaKernel = getXMMAKernelsV2(mFixedParams.dataType, mFixedParams.dataTypeOut, mSM);
|
|
|
|
if (mFixedParams.headSizeV == 0)
|
|
{
|
|
mFixedParams.headSizeV = mFixedParams.headSize;
|
|
}
|
|
// Get device attributes.
|
|
int device_id;
|
|
cudaGetDevice(&device_id);
|
|
cudaDeviceGetAttribute(&mMultiProcessorCount, cudaDevAttrMultiProcessorCount, device_id);
|
|
cudaDeviceGetAttribute(&mDeviceL2CacheSize, cudaDevAttrL2CacheSize, device_id);
|
|
auto const [free_memory, total_memory] = tensorrt_llm::common::getDeviceMemoryInfo(false);
|
|
mTotalDeviceMemory = total_memory;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// For debugging purposes.
|
|
inline void dumpFmhaParams(Fused_multihead_attention_params_v2 const& p, FILE* fd = stdout)
|
|
{
|
|
fprintf(fd, "d = %d\ndv = %d\ns = %d\nb = %d\nh = %d\nh_kv = %d\nh_q_per_kv = %d\n", p.d, p.dv, p.s, p.b, p.h,
|
|
p.h_kv, p.h_q_per_kv);
|
|
auto dump_scale_bmm = [&](char const* name, uint32_t const* ptr, uint32_t val)
|
|
{
|
|
if (ptr)
|
|
{
|
|
cudaError err = cudaMemcpy(&val, ptr, sizeof(uint32_t), cudaMemcpyDeviceToHost);
|
|
if (err != cudaSuccess)
|
|
{
|
|
throw std::runtime_error("failed to cudaMemcpy()");
|
|
}
|
|
}
|
|
fprintf(fd, "%s = %f\n", name, reinterpret_cast<float&>(val));
|
|
};
|
|
dump_scale_bmm("scale_bmm1", p.scale_bmm1_d, p.scale_bmm1);
|
|
dump_scale_bmm("scale_bmm2", p.scale_bmm2_d, p.scale_bmm2);
|
|
fprintf(fd, "softcapping_scale_bmm1 = %f\nscale_softmax = %f\n",
|
|
reinterpret_cast<float const&>(p.softcapping_scale_bmm1), reinterpret_cast<float const&>(p.scale_softmax));
|
|
auto to_bool = [](bool v) -> char const* { return v ? "true" : "false"; };
|
|
fprintf(fd, "sliding_window_size = %d\nhas_alibi = %s\nis_s_padded = %s\n", p.sliding_window_size,
|
|
to_bool(p.has_alibi), to_bool(p.is_s_padded));
|
|
auto dump_cu_array = [&](char const* name, int const* cu_array)
|
|
{
|
|
if (!cu_array)
|
|
{
|
|
fprintf(fd, "%s = None\n", name);
|
|
return;
|
|
}
|
|
size_t sz = (p.b + 1) * sizeof(int);
|
|
int* array = (int*) malloc(sz);
|
|
if (!array)
|
|
{
|
|
throw std::runtime_error("failed to malloc()");
|
|
}
|
|
cudaError err = cudaMemcpy(array, cu_array, sz, cudaMemcpyDeviceToHost);
|
|
if (err != cudaSuccess)
|
|
{
|
|
throw std::runtime_error("failed to cudaMemcpy()");
|
|
}
|
|
fprintf(fd, "%s = [", name);
|
|
for (int i = 0; i <= p.b; i++)
|
|
{
|
|
fprintf(fd, i == 0 ? "%d" : ", %d", array[i]);
|
|
}
|
|
fprintf(fd, "]\n");
|
|
free(array);
|
|
};
|
|
dump_cu_array("cu_q_seqlens", p.cu_q_seqlens);
|
|
dump_cu_array("cu_kv_seqlens", p.cu_kv_seqlens);
|
|
dump_cu_array("cu_mask_rows", p.cu_mask_rows);
|
|
fprintf(fd,
|
|
"tile_id_counter_ptr = %p\nnum_tiles = %u\nnum_tiles_per_head = %u\n"
|
|
"use_balanced_scheduling = %s\n",
|
|
p.tile_id_counter_ptr, p.num_tiles, p.num_tiles_per_head, to_bool(p.use_balanced_scheduling));
|
|
fprintf(fd,
|
|
"qkv_stride_in_bytes = %ld\nq_stride_in_bytes = %ld\nkv_stride_in_bytes = %ld\n"
|
|
"v_stride_in_bytes = %ld\npacked_mask_stride_in_bytes = %ld\no_stride_in_bytes = %ld\n",
|
|
p.qkv_stride_in_bytes, p.q_stride_in_bytes, p.kv_stride_in_bytes, p.v_stride_in_bytes,
|
|
p.packed_mask_stride_in_bytes, p.o_stride_in_bytes);
|
|
auto& kv_cache = p.paged_kv_cache;
|
|
fprintf(fd,
|
|
"# paged_kv_cache\nmMaxSeqs = %d\nmMaxBlocksPerSeq = %d\nmTokensPerBlock = %d\n"
|
|
"mTokensPerBlockLog2 = %d\nmBytesPerBlock = %d\n\n",
|
|
kv_cache.mMaxSeqs, kv_cache.mMaxBlocksPerSeq, kv_cache.mTokensPerBlock, kv_cache.mTokensPerBlockLog2,
|
|
kv_cache.mBytesPerBlock);
|
|
}
|
|
|
|
// Shared setup function.
|
|
void FusedMHARunnerV2::setupKernelParams(MHARunnerParams runnerParams)
|
|
{
|
|
// Reinit kernel params.
|
|
mKernelParams = {};
|
|
|
|
// Set the batch size, and sequence length.
|
|
mKernelParams.b = runnerParams.b;
|
|
mKernelParams.s = runnerParams.qSeqLen;
|
|
mKernelParams.sliding_window_size = runnerParams.slidingWindowSize;
|
|
// Set the head size and number of heads.
|
|
mKernelParams.d = mFixedParams.headSize;
|
|
mKernelParams.dv = mFixedParams.headSizeV;
|
|
TLLM_CHECK_WITH_INFO(mFixedParams.numQHeads % mFixedParams.numKvHeads == 0,
|
|
"number of Query heads should be multiple of KV heads !");
|
|
mKernelParams.h = mFixedParams.numQHeads;
|
|
mKernelParams.h_kv = mFixedParams.numKvHeads;
|
|
mKernelParams.h_q_per_kv = mFixedParams.numQHeads / mFixedParams.numKvHeads;
|
|
// Are the input sequences padded ?
|
|
mKernelParams.is_s_padded = mFixedParams.isSPadded;
|
|
|
|
mKernelParams.softmax_stats_ptr = runnerParams.softmaxStatsPtr;
|
|
mKernelParams.softmax_stats_stride_in_bytes = sizeof(float) * mFixedParams.numQHeads;
|
|
|
|
// Packed QKV input layout.
|
|
mKernelParams.qkv_stride_in_bytes = get_size_in_bytes(mFixedParams.numQHeads * mFixedParams.headSize
|
|
+ mFixedParams.numKvHeads * mFixedParams.headSize + mFixedParams.numKvHeads * mFixedParams.headSizeV,
|
|
mFixedParams.dataType);
|
|
// Contiguous Q input layout.
|
|
mKernelParams.q_stride_in_bytes
|
|
= get_size_in_bytes(mFixedParams.numQHeads * mFixedParams.headSize, mFixedParams.dataType);
|
|
// Set the kv_stride_in_bytes when separate kv buffer is used.
|
|
if (mFixedParams.attentionInputLayout == AttentionInputLayout::Q_PAGED_KV)
|
|
{
|
|
// Paged kv cache layout.
|
|
mKernelParams.kv_stride_in_bytes = get_size_in_bytes(
|
|
runnerParams.pagedKvCache.mTokensPerBlock * mFixedParams.headSize, mFixedParams.dataType);
|
|
// only for deepseek
|
|
mKernelParams.v_stride_in_bytes = mKernelParams.kv_stride_in_bytes;
|
|
}
|
|
else if (mFixedParams.attentionInputLayout == AttentionInputLayout::Q_CONTIGUOUS_KV)
|
|
{
|
|
// Contiguous kv input layout.
|
|
mKernelParams.kv_stride_in_bytes
|
|
= get_size_in_bytes(2 * mFixedParams.numKvHeads * mFixedParams.headSize, mFixedParams.dataType);
|
|
}
|
|
// Set the output buffer stride in bytes.
|
|
mKernelParams.o_stride_in_bytes
|
|
= get_size_in_bytes(mFixedParams.numQHeads * mFixedParams.headSizeV, mFixedParams.dataTypeOut);
|
|
// Set the packed_mask_stride_in_bytes.
|
|
if (mFixedParams.attentionMaskType == ContextAttentionMaskType::CUSTOM_MASK)
|
|
{
|
|
// The packed mask col (n) dimension has to be padded to multiple of 256.
|
|
mKernelParams.packed_mask_stride_in_bytes
|
|
= (tensorrt_llm::common::divUp(int64_t(runnerParams.kvSeqLen), int64_t(FLASH_ATTEN_PACKED_MASK_N_ALIGNMENT))
|
|
* FLASH_ATTEN_PACKED_MASK_N_ALIGNMENT)
|
|
/ 8;
|
|
}
|
|
|
|
float const inv_sqrt_scale = (1.f / (sqrtf(mFixedParams.headSize) * mFixedParams.qScaling));
|
|
// Note that we apply scales and bias in the order of
|
|
// (bmm1_output * scale_bmm1 + alibi) * scale_after_alibi
|
|
float const scale_after_alibi = mFixedParams.scaleAlibi ? inv_sqrt_scale : 1.0f;
|
|
float scale_bmm1 = mFixedParams.scaleAlibi ? 1.0f : inv_sqrt_scale;
|
|
// Fuse 1.0f / attn_logit_softcapping_scale into scale_bmm1.
|
|
scale_bmm1 = mFixedParams.attnLogitSoftcappingScale != 0.f ? scale_bmm1 / mFixedParams.attnLogitSoftcappingScale
|
|
: scale_bmm1;
|
|
// The softmax output scale (not used).
|
|
float const scale_softmax = 1.f;
|
|
// FP8 FMHA kernels load the scale_bmm2 from the device memory.
|
|
float const scale_bmm2 = 1.f;
|
|
|
|
Data_type scale_type = mLaunchParams.force_fp32_acc ? DATA_TYPE_FP32 : mFixedParams.dataType;
|
|
// Use exp2f optimization for warp-specialized ws kernels on Hopper.
|
|
if (mLaunchParams.useBase2ExpTrick)
|
|
{
|
|
// The kernel adopts the log2f optimization.
|
|
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
|
set_alpha(mKernelParams.scale_bmm1, scale_bmm1 * float(kLog2e), DATA_TYPE_FP32);
|
|
}
|
|
else
|
|
{
|
|
set_alpha(mKernelParams.scale_bmm1, scale_bmm1, scale_type);
|
|
}
|
|
set_alpha(mKernelParams.scale_softmax, scale_softmax, scale_type);
|
|
// Host scale_bmm2 will not be used.
|
|
set_alpha(mKernelParams.scale_bmm2, scale_bmm2, scale_type);
|
|
// The attention logit softcapping scale after bmm1 (always float32).
|
|
mKernelParams.softcapping_scale_bmm1 = mFixedParams.attnLogitSoftcappingScale;
|
|
|
|
// alibi.
|
|
if (mFixedParams.hasAlibi && mSM > kSM_70)
|
|
{
|
|
mKernelParams.has_alibi = true;
|
|
mKernelParams.alibi_params = AlibiParams(
|
|
mFixedParams.numQHeads, runnerParams.kvSeqLen, mFixedParams.tpSize, mFixedParams.tpRank, scale_after_alibi);
|
|
}
|
|
|
|
// Set device pointers.
|
|
mKernelParams.qkv_ptr = runnerParams.qkvPtr;
|
|
mKernelParams.q_ptr = runnerParams.qPtr;
|
|
mKernelParams.kv_ptr = runnerParams.kvPtr;
|
|
mKernelParams.o_ptr = runnerParams.outputPtr;
|
|
if (mFixedParams.attentionMaskType == ContextAttentionMaskType::CUSTOM_MASK)
|
|
{
|
|
mKernelParams.packed_mask_ptr = runnerParams.packedMaskPtr;
|
|
mKernelParams.cu_mask_rows = reinterpret_cast<int const*>(runnerParams.cuMaskRowsPtr);
|
|
}
|
|
mKernelParams.cu_q_seqlens = reinterpret_cast<int const*>(runnerParams.cuQSeqLenPtr);
|
|
mKernelParams.tile_id_counter_ptr = reinterpret_cast<uint32_t*>(runnerParams.tileCounterPtr);
|
|
// TRT doesn't support host scales. Use device scales instead.
|
|
// The scaleBmm1Ptr offset.
|
|
// 2 scales prepared for scaleBmm1 in the device memory: float scale, float (scale with log2e).
|
|
int64_t scaleBmm1PtrOffset = (mLaunchParams.useBase2ExpTrick ? kIdxScaleSoftmaxLog2Ptr : kIdxScaleSoftmaxPtr);
|
|
// Only fp8 kernels need to load scales from the device memory.
|
|
if (mFixedParams.dataType == DATA_TYPE_E4M3)
|
|
{
|
|
mKernelParams.scale_bmm1_d = reinterpret_cast<uint32_t const*>(runnerParams.scaleBmm1Ptr + scaleBmm1PtrOffset);
|
|
mKernelParams.scale_bmm2_d = reinterpret_cast<uint32_t const*>(runnerParams.scaleBmm2Ptr);
|
|
}
|
|
|
|
// Separate q and kv buffers may have different q and kv sequence lengths.
|
|
if (mFixedParams.attentionInputLayout != AttentionInputLayout::PACKED_QKV)
|
|
{
|
|
mKernelParams.cu_kv_seqlens = reinterpret_cast<int const*>(runnerParams.cuKvSeqLenPtr);
|
|
}
|
|
|
|
// Paged kv fmha.
|
|
if (mFixedParams.attentionInputLayout == AttentionInputLayout::Q_PAGED_KV)
|
|
{
|
|
mKernelParams.paged_kv_cache = runnerParams.pagedKvCache.copyKVBlockArrayForContextFMHA();
|
|
}
|
|
|
|
// for sage attention
|
|
mKernelParams.sage.q.scales = runnerParams.qScalePtr;
|
|
mKernelParams.sage.k.scales = runnerParams.kScalePtr;
|
|
mKernelParams.sage.v.scales = runnerParams.vScalePtr;
|
|
mKernelParams.sage.q.max_nblock = runnerParams.qMaxNBlock;
|
|
mKernelParams.sage.k.max_nblock = runnerParams.kMaxNBlock;
|
|
mKernelParams.sage.v.max_nblock = runnerParams.vMaxNBlock;
|
|
|
|
// For debugging purposes.
|
|
// if (mFixedParams.dataType == DATA_TYPE_E4M3) {
|
|
// dumpFmhaParams(mKernelParams);
|
|
// }
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Set the launch params to select kernels.
|
|
void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams)
|
|
{
|
|
|
|
// Determine launch parameters.
|
|
// Reset launch params to default.
|
|
mLaunchParams = {};
|
|
|
|
// Device properties.
|
|
mLaunchParams.multi_processor_count = mMultiProcessorCount;
|
|
mLaunchParams.device_l2_cache_size = mDeviceL2CacheSize;
|
|
mLaunchParams.total_device_memory = mTotalDeviceMemory;
|
|
|
|
// Do we use attnLogitSoftcappingScale ?
|
|
TLLM_CHECK_WITH_INFO(
|
|
(mFixedParams.headSize == 128 || mFixedParams.headSize == 256) || !mFixedParams.attnLogitSoftcappingScale,
|
|
"FMHA only supports head_size = 128 or 256 with attention logit softcapping scale currently.");
|
|
mLaunchParams.enableAttnLogitSoftcapping = mFixedParams.attnLogitSoftcappingScale != 0.f;
|
|
// BF16 FMHA only accumulates on FP32.
|
|
// E4M3 FMHA only supports fp32 accumulation currently.
|
|
mLaunchParams.force_fp32_acc = mFixedParams.dataType == DATA_TYPE_BF16 || mFixedParams.dataType == DATA_TYPE_E4M3
|
|
|| mFixedParams.forceFp32Acc || runnerParams.forceFp32Acc;
|
|
// The attention mask type.
|
|
mLaunchParams.attention_mask_type = mFixedParams.attentionMaskType;
|
|
// The input layout type.
|
|
mLaunchParams.attention_input_layout = mFixedParams.attentionInputLayout;
|
|
|
|
// The total sequence length used to set the tma descriptors.
|
|
mLaunchParams.total_q_seqlen
|
|
= mFixedParams.isSPadded ? runnerParams.b * runnerParams.qSeqLen : runnerParams.totalQSeqLen;
|
|
mLaunchParams.total_kv_seqlen
|
|
= mFixedParams.isSPadded ? runnerParams.b * runnerParams.kvSeqLen : runnerParams.totalKvSeqLen;
|
|
|
|
// Next power of 2 head size.
|
|
TLLM_CHECK_WITH_INFO(mFixedParams.headSize > 0, "Head size should be greater than 0.");
|
|
mLaunchParams.padded_d = (mFixedParams.headSize & (mFixedParams.headSize - 1)) == 0
|
|
? mFixedParams.headSize
|
|
: pow(2, int(log2(mFixedParams.headSize)) + 1);
|
|
|
|
bool const isSm70 = (mSM == kSM_70);
|
|
bool const isSm90 = (mSM == kSM_90);
|
|
bool const isSm8x = (mSM == kSM_86 || mSM == kSM_89);
|
|
bool const isSm80 = (mSM == kSM_80);
|
|
bool const isSm89 = (mSM == kSM_89);
|
|
bool const isSm100 = (mSM == kSM_100);
|
|
bool const isSm120 = (mSM == kSM_120);
|
|
|
|
// Sliding_window_causal mask.
|
|
if (runnerParams.kvSeqLen > runnerParams.slidingWindowSize
|
|
&& mLaunchParams.attention_mask_type == ContextAttentionMaskType::CAUSAL)
|
|
{
|
|
mLaunchParams.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
|
|
}
|
|
|
|
// Is the input layout separate q + kv input ?
|
|
bool const separateQKvInput = mFixedParams.attentionInputLayout != AttentionInputLayout::PACKED_QKV;
|
|
// Is the mask type padding or causal mask ?
|
|
bool const paddingOrCausalMask = mFixedParams.attentionMaskType == ContextAttentionMaskType::PADDING
|
|
|| mFixedParams.attentionMaskType == ContextAttentionMaskType::CAUSAL;
|
|
|
|
// Only warp-specialized FMHA kernels support FP8 on Hopper.
|
|
// Separate Q + KV input layout: enable warp-specialization kernels when s > 512, otherwise use ampere-style flash
|
|
// attention kernels.
|
|
if (isSm90 && (mFixedParams.dataType == DATA_TYPE_E4M3 || (separateQKvInput && runnerParams.kvSeqLen > 512)))
|
|
{
|
|
mLaunchParams.flash_attention = true;
|
|
mLaunchParams.force_unroll = true;
|
|
}
|
|
else if (isSm70)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(false, "Unsupported architecture");
|
|
}
|
|
// Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256
|
|
// Only supports packed_qkv input + padding/causal mask.
|
|
else if (isSm90 && !separateQKvInput && paddingOrCausalMask
|
|
&& (mFixedParams.headSize == 32 || mFixedParams.headSize == 64) && runnerParams.qSeqLen <= 256
|
|
&& !common::getEnvForceDeterministicAttention())
|
|
{
|
|
mLaunchParams.flash_attention = false;
|
|
// get max sequence length for non-flash-attention.
|
|
// this doesn't support different q and kv sequence lengths.
|
|
mLaunchParams.kernel_s = getSFromMaxSeqLen(runnerParams.qSeqLen);
|
|
}
|
|
else
|
|
{ // always use flash attention kernels for Ampere/Ada
|
|
mLaunchParams.flash_attention = true;
|
|
// flash attention kernles s = 0 (support any seq length)
|
|
mLaunchParams.kernel_s = 0;
|
|
mLaunchParams.force_unroll = true;
|
|
// enable tiled kernels on Ampere/Ada
|
|
if (isSm89 && mFixedParams.dataType == DATA_TYPE_E4M3)
|
|
{
|
|
// so far Ada QMMA only supports non-tiled kernels.
|
|
mLaunchParams.granular_tiling = false;
|
|
}
|
|
else if (mLaunchParams.flash_attention && runnerParams.kvSeqLen <= 64)
|
|
{
|
|
// flash attention tiled kernels allows larger free dim tile size (M, N) with flexibility
|
|
// in unroll dimension tile size (K). for short sequence length (s<=128), tiled kernels
|
|
// can suffer from tile quantization loss therefore use flash attention non-tiled instead
|
|
mLaunchParams.granular_tiling = false;
|
|
}
|
|
else if (isSm8x && mFixedParams.headSize < 256)
|
|
{
|
|
// flash attention tiled kernel is faster on Ada and Ampere derivatives when head_size>=256
|
|
mLaunchParams.granular_tiling = false;
|
|
}
|
|
else if (isSm80 || isSm8x || isSm100 || isSm120)
|
|
{
|
|
// otherwise, choose tiled kernel for Ampere/Ada/Gb20x
|
|
mLaunchParams.granular_tiling = true;
|
|
}
|
|
}
|
|
|
|
// when flash attention is enabled on Hopper, we need to set the tma descriptors
|
|
if (isSm90 && mLaunchParams.flash_attention)
|
|
{
|
|
mLaunchParams.warp_specialization = true;
|
|
mLaunchParams.use_tma = true;
|
|
// Enable dynamic tile scheduling for hopper ws kernel.
|
|
mLaunchParams.dynamic_scheduler = true;
|
|
}
|
|
|
|
// Use specialized ws kernels on Hopper for cases without alibi.
|
|
if (mLaunchParams.warp_specialization && !mFixedParams.hasAlibi)
|
|
{
|
|
// Use specialized ws kernels for cases without alibi.
|
|
mLaunchParams.useKernelWithoutAlibi = true;
|
|
// Enable exp2f optimization (which helps improve performance).
|
|
// - note that this is not compatible with alibi bias due to the accuracy issues.
|
|
// - only hopper warp-specialized kernels have this optimization.
|
|
// - it doesn't work with attention logit softcapping.
|
|
mLaunchParams.useBase2ExpTrick = !mLaunchParams.enableAttnLogitSoftcapping;
|
|
}
|
|
|
|
// TODO: Refactor these dirty hacks.
|
|
// For Deepseek-v2(MLA), all of SM80, SM89 and SM90 kernels use tiled flash attention
|
|
// in both context (192/128 dimensions) and generation (576/512 dimensions)
|
|
if (mFixedParams.headSize == mFixedParams.headSizeV + 64)
|
|
{
|
|
mLaunchParams.flash_attention = true;
|
|
mLaunchParams.force_unroll = true;
|
|
mLaunchParams.kernel_s = 0;
|
|
|
|
// Now we have SM90 generation MLA kernels. These treatments are only for context MLA and non SM90 generation
|
|
// MLA.
|
|
bool isFP8GenerationMLA = mFixedParams.dataType == DATA_TYPE_E4M3
|
|
&& (mFixedParams.headSize == 576 && mFixedParams.headSizeV == 512);
|
|
if (!isFP8GenerationMLA)
|
|
{
|
|
mLaunchParams.granular_tiling = true;
|
|
// Even on SM90, we use ampere-style kernel, will be optimized later
|
|
mLaunchParams.warp_specialization = false;
|
|
mLaunchParams.useKernelWithoutAlibi = false;
|
|
// Deepseek-V2 kernel is not hooper style right now.
|
|
mLaunchParams.useBase2ExpTrick = false;
|
|
mLaunchParams.use_tma = false;
|
|
mLaunchParams.dynamic_scheduler = false;
|
|
}
|
|
}
|
|
|
|
mLaunchParams.sage_block_size_q = mFixedParams.sageBlockSizeQ;
|
|
mLaunchParams.sage_block_size_k = mFixedParams.sageBlockSizeK;
|
|
mLaunchParams.sage_block_size_v = mFixedParams.sageBlockSizeV;
|
|
// for not (sm90 + warp_specialization + flash attention kernel) kernel:
|
|
// all kernels enable saving softmaxStatsPtr, just let softmaxStatsPtr != null
|
|
// for (sm90 + warp_specialization + flash attention) kernel:
|
|
// we need to explicitly set supportReturnSoftmaxStats to true when
|
|
// satisfying the following constrains
|
|
if (!isSm90)
|
|
{
|
|
mLaunchParams.supportReturnSoftmaxStats = true;
|
|
}
|
|
else
|
|
{
|
|
mLaunchParams.supportReturnSoftmaxStats = (runnerParams.softmaxStatsPtr != nullptr
|
|
&& mLaunchParams.flash_attention && mLaunchParams.warp_specialization
|
|
&& mLaunchParams.attention_input_layout == AttentionInputLayout::Q_CONTIGUOUS_KV);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// TMA descriptors are used as grid_constant parameters (remove MemCpyH2D operations)
|
|
void FusedMHARunnerV2::setPackedQkvTmaDescriptors(MHARunnerParams runnerParams)
|
|
{
|
|
// split D into multiple groups in order to match the TMA swizzle mode (128B)
|
|
uint32_t const d_in_bytes = get_size_in_bytes(mLaunchParams.padded_d, mFixedParams.dataType);
|
|
uint32_t const d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
|
|
|
|
// separate q, k, v and o tma descriptors
|
|
Multiple_tma_descriptor<4> qkv_tma_descriptor;
|
|
|
|
// tensor size
|
|
uint32_t tensor_size_qkv[4];
|
|
if (mKernelParams.h_kv < mKernelParams.h)
|
|
{
|
|
// if multi-query or grouped-query
|
|
tensor_size_qkv[2] = 1;
|
|
tensor_size_qkv[1] = (mKernelParams.h + 2 * mKernelParams.h_kv);
|
|
tensor_size_qkv[0] = mKernelParams.d; // mKernelParams.d;
|
|
}
|
|
else
|
|
{
|
|
tensor_size_qkv[2] = 3;
|
|
tensor_size_qkv[1] = mKernelParams.h;
|
|
tensor_size_qkv[0] = mKernelParams.d; // mKernelParams.d;
|
|
}
|
|
|
|
// O : [TOTAL, 1, h, d]
|
|
uint32_t tensor_size_o[4];
|
|
tensor_size_o[0] = mKernelParams.d;
|
|
tensor_size_o[1] = mKernelParams.h;
|
|
tensor_size_o[2] = 1;
|
|
|
|
// box size for k and v
|
|
uint32_t box_size[4];
|
|
// Update this on device?
|
|
box_size[2] = 1;
|
|
box_size[1] = 1;
|
|
box_size[0] = mLaunchParams.padded_d / d_groups;
|
|
|
|
// stride size in bytes. Assumes least significant dim is 1 (?)
|
|
uint64_t tensor_stride_qkv[3];
|
|
tensor_stride_qkv[0] = get_size_in_bytes(tensor_size_qkv[0], mFixedParams.dataType); // d
|
|
tensor_stride_qkv[1] = tensor_size_qkv[1] * tensor_stride_qkv[0]; // d*h
|
|
tensor_stride_qkv[2] = tensor_size_qkv[2] * tensor_stride_qkv[1]; // d*h*3
|
|
|
|
uint64_t tensor_stride_o[3];
|
|
tensor_stride_o[0] = get_size_in_bytes(tensor_size_o[0], mFixedParams.dataTypeOut); // d
|
|
tensor_stride_o[1] = tensor_size_o[1] * tensor_stride_o[0]; // d*h
|
|
tensor_stride_o[2] = tensor_size_o[2] * tensor_stride_o[1]; // d*h*1
|
|
|
|
// traversal stride
|
|
uint32_t traversal_stride_qkv[4] = {1, 1, 1, 1};
|
|
uint32_t traversal_stride_o[4] = {1, 1, 1, 1};
|
|
|
|
// OOB fill zeros
|
|
uint32_t oob_fill = 0;
|
|
|
|
// FP32 to TF32 conversion disabled
|
|
uint32_t fp32_to_tf32 = 0;
|
|
|
|
// gmma descriptor mode
|
|
uint32_t const d_bytes_per_group = d_in_bytes / d_groups;
|
|
cudaTmaDescSwizzle const swizzle_mode = (d_bytes_per_group > 64
|
|
? cudaTmaDescSwizzle::SWIZZLE_128B
|
|
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
|
|
|
|
uint32_t q_step = 0, kv_step = 0;
|
|
xmmaKernel->getStepSize(q_step, kv_step, mKernelParams, mLaunchParams);
|
|
|
|
// QKV [TOTAL, 3, h, d]
|
|
// NOTE: we may need to use actual seqlen to set oob_value
|
|
auto const* qkv_ptr = static_cast<char const*>(mKernelParams.qkv_ptr);
|
|
tensor_size_qkv[3] = mLaunchParams.total_q_seqlen;
|
|
// O [TOTAL, 1, h, d]
|
|
auto* o_ptr = static_cast<char*>(mKernelParams.o_ptr);
|
|
tensor_size_o[3] = mLaunchParams.total_q_seqlen;
|
|
|
|
// Q: STEP_Q
|
|
box_size[3] = q_step;
|
|
// Desc Format (data type).
|
|
cudaTmaDescFormat const desc_format
|
|
= (get_size_in_bytes(mFixedParams.dataType) == 1) ? cudaTmaDescFormat::U8 : cudaTmaDescFormat::F16_RN;
|
|
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format, cudaTmaDescInterleave::INTERLEAVE_DISABLED,
|
|
swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_stride_qkv,
|
|
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, &mKernelParams.tma_desc_q);
|
|
|
|
// K/V: STEP_KV
|
|
box_size[3] = kv_step;
|
|
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format, cudaTmaDescInterleave::INTERLEAVE_DISABLED,
|
|
swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_stride_qkv,
|
|
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, &mKernelParams.tma_desc_kv);
|
|
|
|
// O: 16
|
|
// Note: sliding window causal kernel currently has reg spill when TMA store is enabled
|
|
box_size[3] = 16;
|
|
if ((get_size_in_bytes(mFixedParams.dataTypeOut) == 1)
|
|
&& mLaunchParams.attention_mask_type != ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL)
|
|
{
|
|
qkv_tma_descriptor.set_tma_desctriptor(o_ptr, desc_format, cudaTmaDescInterleave::INTERLEAVE_DISABLED,
|
|
swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_o, tensor_stride_o, traversal_stride_o,
|
|
box_size, oob_fill, fp32_to_tf32, &mKernelParams.tma_desc_o);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Contiguous in the shape of [B, S, H, D].
|
|
// Contiguous KV in the shape of [B, S, 2, H, D].
|
|
// Paged KV has [B, 2, NumBlocksPerSequence] buffers,
|
|
// and each points to the contiguous buffer with shape [H, TokensPerBlock, D]
|
|
// TMA descriptors need cudaMemcpyAsync since we need multiple tma descriptors in device memory.
|
|
void FusedMHARunnerV2::setSeparateQKvTmaDescriptors(MHARunnerParams runnerParams)
|
|
{
|
|
// split D into multiple groups in order to match the TMA swizzle mode (128B)
|
|
uint32_t const d_in_bytes = get_size_in_bytes(mLaunchParams.padded_d, mFixedParams.dataType);
|
|
uint32_t const d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
|
|
|
|
uint32_t q_step = 0, kv_step = 0;
|
|
xmmaKernel->getStepSize(q_step, kv_step, mKernelParams, mLaunchParams);
|
|
|
|
// Separate q, and paged kv tma descriptors.
|
|
Multiple_tma_descriptor<4> qo_tma_descriptor;
|
|
Multiple_tma_descriptor<4> kv_tma_descriptor;
|
|
// Contiguous Q
|
|
// query tensor size [B x S, 1, H, D]
|
|
uint32_t tensor_size_qo[4];
|
|
tensor_size_qo[3] = mLaunchParams.total_q_seqlen;
|
|
tensor_size_qo[2] = 1;
|
|
tensor_size_qo[1] = mKernelParams.h;
|
|
tensor_size_qo[0] = mKernelParams.d;
|
|
|
|
// box size for q and o
|
|
uint32_t box_size_qo[4];
|
|
box_size_qo[3] = q_step;
|
|
box_size_qo[2] = 1;
|
|
box_size_qo[1] = 1;
|
|
box_size_qo[0] = mLaunchParams.padded_d / d_groups;
|
|
|
|
// stride size in bytes.
|
|
uint64_t tensor_stride_qo[3];
|
|
tensor_stride_qo[0] = get_size_in_bytes(tensor_size_qo[0], mFixedParams.dataType);
|
|
tensor_stride_qo[1] = tensor_size_qo[1] * tensor_stride_qo[0];
|
|
tensor_stride_qo[2] = tensor_size_qo[2] * tensor_stride_qo[1];
|
|
|
|
// traversal stride
|
|
uint32_t traversal_stride[4] = {1, 1, 1, 1};
|
|
|
|
// OOB fill zeros
|
|
uint32_t oob_fill = 0;
|
|
|
|
// FP32 to TF32 conversion disabled
|
|
uint32_t fp32_to_tf32 = 0;
|
|
|
|
// Desc Format (data type).
|
|
cudaTmaDescFormat const desc_format
|
|
= (get_size_in_bytes(mFixedParams.dataType) == 1) ? cudaTmaDescFormat::U8 : cudaTmaDescFormat::F16_RN;
|
|
|
|
// gmma descriptor mode
|
|
uint32_t const d_bytes_per_group = d_in_bytes / d_groups;
|
|
cudaTmaDescSwizzle const swizzle_mode = (d_bytes_per_group > 64
|
|
? cudaTmaDescSwizzle::SWIZZLE_128B
|
|
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
|
|
|
|
// Q ptr.
|
|
auto const* q_ptr = static_cast<char const*>(mKernelParams.q_ptr);
|
|
|
|
// Q: STEP_Q.
|
|
qo_tma_descriptor.set_tma_desctriptor(q_ptr, desc_format, cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode,
|
|
cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qo, tensor_stride_qo, traversal_stride, box_size_qo,
|
|
oob_fill, fp32_to_tf32, &mKernelParams.tma_desc_q);
|
|
|
|
// O ptr.
|
|
auto const* o_ptr = static_cast<char const*>(mKernelParams.o_ptr);
|
|
// Note (added by Yuxin): TMA descriptor for o here might be problematic if d and dv are different.
|
|
|
|
// O: 16. Reuse
|
|
box_size_qo[3] = 16;
|
|
if ((get_size_in_bytes(mFixedParams.dataTypeOut) == 1)
|
|
&& mLaunchParams.attention_mask_type != ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL)
|
|
{
|
|
qo_tma_descriptor.set_tma_desctriptor(o_ptr, desc_format, cudaTmaDescInterleave::INTERLEAVE_DISABLED,
|
|
swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qo, tensor_stride_qo, traversal_stride,
|
|
box_size_qo, oob_fill, fp32_to_tf32, &mKernelParams.tma_desc_o);
|
|
}
|
|
|
|
// Contiguous KV layout [B, S, 2, H, D].
|
|
if (mFixedParams.attentionInputLayout == AttentionInputLayout::Q_CONTIGUOUS_KV)
|
|
{
|
|
// Per batch tensor size.
|
|
uint32_t tensor_size_kv[4];
|
|
// Maximum number of blocks in this device.
|
|
tensor_size_kv[3] = mLaunchParams.total_kv_seqlen;
|
|
tensor_size_kv[2] = 2;
|
|
tensor_size_kv[1] = mKernelParams.h_kv;
|
|
tensor_size_kv[0] = mKernelParams.d;
|
|
|
|
// Box size for k and v.
|
|
uint32_t box_size_kv[4];
|
|
box_size_kv[3] = kv_step;
|
|
box_size_kv[2] = 1;
|
|
box_size_kv[1] = 1;
|
|
box_size_kv[0] = mLaunchParams.padded_d / d_groups;
|
|
|
|
// Stride size in bytes.
|
|
uint64_t tensor_stride_kv[3];
|
|
tensor_stride_kv[0] = get_size_in_bytes(tensor_size_kv[0], mFixedParams.dataType);
|
|
tensor_stride_kv[1] = tensor_size_kv[1] * tensor_stride_kv[0];
|
|
tensor_stride_kv[2] = tensor_size_kv[2] * tensor_stride_kv[1];
|
|
|
|
// Set the paged_kv tma descriptor.
|
|
kv_tma_descriptor.set_tma_desctriptor(runnerParams.kvPtr, desc_format,
|
|
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
|
|
tensor_size_kv, tensor_stride_kv, traversal_stride, box_size_kv, oob_fill, fp32_to_tf32,
|
|
&mKernelParams.tma_desc_kv);
|
|
}
|
|
else if (mFixedParams.attentionInputLayout == AttentionInputLayout::Q_PAGED_KV)
|
|
{
|
|
// Paged KV
|
|
// Per batch tensor size.
|
|
uint32_t tokens_per_block = uint32_t(mKernelParams.paged_kv_cache.mTokensPerBlock);
|
|
uint32_t tensor_size_kv[4];
|
|
// Maximum number of blocks in this device.
|
|
tensor_size_kv[3] = mLaunchParams.total_device_memory / mKernelParams.paged_kv_cache.mBytesPerBlock;
|
|
tensor_size_kv[2] = mKernelParams.h_kv;
|
|
tensor_size_kv[1] = tokens_per_block;
|
|
tensor_size_kv[0] = mKernelParams.d;
|
|
|
|
// Box size for k and v.
|
|
uint32_t box_size_kv[4];
|
|
box_size_kv[3] = 1;
|
|
box_size_kv[2] = 1;
|
|
box_size_kv[1] = std::min(tokens_per_block, kv_step);
|
|
box_size_kv[0] = mLaunchParams.padded_d / d_groups;
|
|
|
|
TLLM_CHECK_WITH_INFO(
|
|
tokens_per_block % 2 == 0, "FMHA with paged kv cache needs tokens_per_block to be power of 2 !");
|
|
mKernelParams.blocks_per_tma_load = std::max(1, int32_t(kv_step / tokens_per_block));
|
|
mKernelParams.blocks_per_tma_load_log2 = log2(mKernelParams.blocks_per_tma_load);
|
|
|
|
// Stride size in bytes.
|
|
uint64_t tensor_stride_kv[3];
|
|
tensor_stride_kv[0] = get_size_in_bytes(tensor_size_kv[0], mFixedParams.dataType);
|
|
tensor_stride_kv[1] = tensor_size_kv[1] * tensor_stride_kv[0];
|
|
tensor_stride_kv[2] = tensor_size_kv[2] * tensor_stride_kv[1];
|
|
|
|
// Set the paged_kv tma descriptor.
|
|
kv_tma_descriptor.set_tma_desctriptor(runnerParams.pagedKvCache.mPrimaryPoolPtr, desc_format,
|
|
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
|
|
tensor_size_kv, tensor_stride_kv, traversal_stride, box_size_kv, oob_fill, fp32_to_tf32,
|
|
&mKernelParams.tma_desc_kv);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
void FusedMHARunnerV2::run(MHARunnerParams runnerParams)
|
|
{
|
|
// Note that we must set the launch params first.
|
|
// Set the launch params.
|
|
setupLaunchParams(runnerParams);
|
|
// Set the kernel params.
|
|
setupKernelParams(runnerParams);
|
|
// Need to set tma descriptors additionally.
|
|
if (mSM == kSM_90 && mLaunchParams.use_tma)
|
|
{
|
|
switch (mFixedParams.attentionInputLayout)
|
|
{
|
|
case AttentionInputLayout::PACKED_QKV: setPackedQkvTmaDescriptors(runnerParams); break;
|
|
case AttentionInputLayout::Q_CONTIGUOUS_KV:
|
|
case AttentionInputLayout::Q_PAGED_KV: setSeparateQKvTmaDescriptors(runnerParams); break;
|
|
default: TLLM_CHECK_WITH_INFO(false, "Unsupported attention input layout.");
|
|
}
|
|
}
|
|
// Check if the sliding window size is valid or not.
|
|
if (mFixedParams.attentionInputLayout == AttentionInputLayout::Q_PAGED_KV
|
|
&& mLaunchParams.attention_mask_type == ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL)
|
|
{
|
|
uint32_t q_step = 0, kv_step = 0;
|
|
xmmaKernel->getStepSize(q_step, kv_step, mKernelParams, mLaunchParams);
|
|
// The sliding window size needs to be multiple of kv_step, so that the paged context fmha can read the cyclic
|
|
// kv cache correctly.
|
|
TLLM_CHECK_WITH_INFO(mKernelParams.sliding_window_size % kv_step == 0,
|
|
"The sliding window size doesn't work with paged context fmha kv_step_size = %d.", kv_step);
|
|
}
|
|
|
|
// Select the kernel and run it.
|
|
xmmaKernel->run(mKernelParams, mLaunchParams, runnerParams.stream);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
bool FusedMHARunnerV2::isValidS(int s) const
|
|
{
|
|
return xmmaKernel->isValid(s);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
int FusedMHARunnerV2::getSFromMaxSeqLen(int const max_seq_len) const
|
|
{
|
|
int S = 1024;
|
|
|
|
if (max_seq_len <= 64)
|
|
{
|
|
S = 64;
|
|
}
|
|
else if (max_seq_len <= 128)
|
|
{
|
|
S = 128;
|
|
}
|
|
else if (max_seq_len <= 256)
|
|
{
|
|
S = 256;
|
|
}
|
|
else if (max_seq_len <= 384)
|
|
{
|
|
S = 384;
|
|
}
|
|
else if (max_seq_len <= 512)
|
|
{
|
|
S = 512;
|
|
}
|
|
// for bert and vit, use flash attention when s >= 512
|
|
else if (max_seq_len > 512)
|
|
{
|
|
S = max_seq_len;
|
|
}
|
|
|
|
return S;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Function to check if fmha is supported when building plugins.
|
|
// If any kernel in the map meets the requirements, then return true.
|
|
bool FusedMHARunnerV2::isFmhaSupported()
|
|
{
|
|
bool is_supported = xmmaKernel->checkIfKernelExist(mFixedParams);
|
|
if (!is_supported)
|
|
{
|
|
std::string msg = "FMHA Kernel doesn't exist for mFixedParams:\n" + mFixedParams.convertToStrOutput();
|
|
TLLM_LOG_WARNING("%s\n", msg.c_str());
|
|
}
|
|
return is_supported;
|
|
}
|
|
|
|
} // namespace kernels
|
|
} // namespace tensorrt_llm
|