[None][feat] Use XQA JIT impl by default and mitigate perf loss with sliding window (#10335)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
This commit is contained in:
Pengbo Wang 2026-01-15 15:47:00 +08:00 committed by GitHub
parent 71ccc07d2b
commit 683515b1bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 79 additions and 42 deletions

View File

@ -89,7 +89,8 @@ cpp_file_prefix_text = R"""/*
#include "tensorrt_llm/common/config.h"
TRTLLM_NAMESPACE_BEGIN
namespace tensorrt_llm
{
namespace kernels
{
// clang-format off
@ -98,7 +99,7 @@ namespace kernels
cpp_file_suffex_text = R"""
// clang-format on
} // namespace kernels
TRTLLM_NAMESPACE_END
}
"""
cubin_meta_info_struct_prefix_text = R"""
@ -438,8 +439,9 @@ if __name__ == "__main__":
CompileMacroOption('HEAD_ELEMS', 'd', [128]),
CompileMacroOption('BEAM_WIDTH', 'beam', [1]),
CompileMacroOption('CACHE_ELEM_ENUM', 'kvt', [0, 1, 2]),
CompileMacroOption('TOKENS_PER_PAGE', 'pagedKV',
[0, 64, 128]), # 0 denotes contiguous kv cache.
CompileMacroOption(
'TOKENS_PER_PAGE', 'pagedKV',
[0, 32, 64, 128]), # 0 denotes contiguous kv cache.
CompileMacroOption('HEAD_GRP_SIZE', 'nqpkv', [0]),
CompileMacroOption('M_TILESIZE', 'm', [16, 32]),
]]

View File

@ -465,13 +465,10 @@ using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
#if SPEC_DEC
#define MMAS_N_PER_MASK 2
__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset,
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
,
int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg
#endif
)
__device__ inline void applyMaskFromInputSlidingAndSpecDec(Warp const& warp, WarpAcc& acc, MaskType const* mask,
uint32_t rowOffset, uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize,
int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg)
{
uint32_t const idxInQuad = laneId() % 4;
uint32_t const idxQuad = laneId() / 4;
@ -479,7 +476,6 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
constexpr uint64_t fullMask = ~uint64_t{0};
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x};
Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1)};
bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end;
@ -487,11 +483,6 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(warpTileTokenBeg);
uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask);
#else
constexpr bool ctaNeedBegMask = false;
bool const ctaNeedSpecDecMask = true;
int32_t const tok0NbMaskOut = -2147483648;
#endif
bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask;
if (!needMask)
@ -559,6 +550,61 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
}
#endif
__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset,
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize)
{
uint32_t const idxInQuad = laneId() % 4;
uint32_t const idxQuad = laneId() / 4;
// Packed mask is aligned with 32 bits (2 uint16_t).
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
#pragma unroll
for (uint32_t m = 0; m < acc.rows; m++)
{
#pragma unroll
for (uint32_t i = 0; i < InstAcc::rows; i++)
{
uint32_t const tokenRow = min((rowOffset + instM * m + idxQuad + i * 8) / headGrpSize, actualQSeqLen - 1);
#pragma unroll
for (uint32_t mask_n = 0; mask_n < acc.cols / MMAS_N_PER_MASK; mask_n++)
{
uint32_t const firstCol = instN * mask_n * MMAS_N_PER_MASK + InstAcc::cols * idxInQuad;
uint32_t const lastCol = firstCol + instN * (MMAS_N_PER_MASK - 1) + InstAcc::cols - 1;
uint32_t const maskPos0 = firstCol + actualQSeqLen < nbValidCols
? 0u
: min(firstCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1);
uint32_t const maskPos1 = lastCol + actualQSeqLen < nbValidCols
? 0u
: min(lastCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1);
uint32_t packedMask = 0u;
uint32_t const maskPosStart = (maskPos0 / 16) * 16;
reinterpret_cast<uint16_t*>(&packedMask)[0]
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)];
reinterpret_cast<uint16_t*>(&packedMask)[1]
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)];
#pragma unroll
for (uint32_t nj = 0; nj < MMAS_N_PER_MASK; nj++)
{
#pragma unroll
for (uint32_t j = 0; j < InstAcc::cols; j++)
{
uint32_t const n = (mask_n * MMAS_N_PER_MASK + nj);
uint32_t const col = instN * n + InstAcc::cols * idxInQuad + j;
// bool const maskFlag = col + qSeqLen < nbValidCols ? true : mask[tokenRow * qSeqLen + (col +
// qSeqLen - nbValidCols)];
bool const maskFlag = col + actualQSeqLen < nbValidCols
? true
: packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
}
}
}
}
}
}
#endif
__device__ inline QuadRegRowMax warpTileOnlineSoftmax(Warp const& warp, QuadRegRowMax const& rowMaxHint, WarpAcc& acc)
{
QuadRegRowMax rowMax = rowMaxHint;
@ -1655,7 +1701,7 @@ CUBIN_EXPORT __global__
uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset;
int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize);
uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg);
bool const rtIsReallySliding = (cacheSeqLen + actualQSeqLen > slidingWinSize);
#elif SLIDING_WINDOW
bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
assert(!SPEC_DEC || !rtIsReallySliding);
@ -1673,7 +1719,8 @@ CUBIN_EXPORT __global__
uint32_t const nbSeqIters = useKVCache ? divUp(cacheSeqLen, ctaTile.x) : 0;
#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
uint32_t const nbSeqItersWithoutMask = nbSkipLeadingTiles;
uint32_t const nbSeqItersWithoutMask
= rtIsReallySliding ? nbSkipLeadingTiles : (cacheSeqLen - actualQSeqLen) / ctaTile.x;
#elif SPEC_DEC
uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
#endif
@ -1960,12 +2007,18 @@ CUBIN_EXPORT __global__
if (seqIter >= nbSeqItersWithoutMask)
{
uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U);
applyMaskFromInput(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
,
tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg
if (rtIsReallySliding)
{
applyMaskFromInputSlidingAndSpecDec(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen,
actualQSeqLen, headGrpSize, tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg);
}
else
#endif
);
{
applyMaskFromInput(
warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize);
}
}
#else
bool const isFirstIter = (seqIter == nbSkipLeadingTiles);

View File

@ -84,26 +84,8 @@ DecoderXQAImpl* DecoderXQARunner::getImplFromXQAParams(XQAParams const& xqaParam
}
else
{
if (xqaParams.multi_query_tokens)
{
// Some multi_query kernels are not ported to JIT yet.
auto const grpSize = xqaParams.num_q_heads / xqaParams.num_kv_heads;
// Hopper XQA supports spec dec with JIT, but only for E4M3 kv cache data type. Only allow 64%grpSize==0 for
// now.
bool const supportedByHopperXqa
= (smVersion == 90 && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3 && grpSize <= 64);
bool const supportedBySm120Mla = (smVersion == 120 && xqaParams.isMLA()
&& xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3);
bool const supportedByAmpereXqa = (!xqaParams.isMLA() && (64 % grpSize == 0));
return (supportedByHopperXqa || supportedBySm120Mla || supportedByAmpereXqa) ? mJITImpl.get()
: mPrecompiledImpl.get();
}
else
{
// regular decoding kernels uses JIT by default
return mJITImpl.get();
}
// uses JIT by default
return mJITImpl.get();
}
}