From 683515b1bdcf4c54f50f0599982c6347cd9f0617 Mon Sep 17 00:00:00 2001 From: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com> Date: Thu, 15 Jan 2026 15:47:00 +0800 Subject: [PATCH] [None][feat] Use XQA JIT impl by default and mitigate perf loss with sliding window (#10335) Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com> --- cpp/kernels/xqa/gen_cubins.py | 10 ++- cpp/kernels/xqa/mha.cu | 89 +++++++++++++++---- .../decoderXQARunner.cpp | 22 +---- 3 files changed, 79 insertions(+), 42 deletions(-) diff --git a/cpp/kernels/xqa/gen_cubins.py b/cpp/kernels/xqa/gen_cubins.py index a345861fb7..9230cf94cf 100755 --- a/cpp/kernels/xqa/gen_cubins.py +++ b/cpp/kernels/xqa/gen_cubins.py @@ -89,7 +89,8 @@ cpp_file_prefix_text = R"""/* #include "tensorrt_llm/common/config.h" -TRTLLM_NAMESPACE_BEGIN +namespace tensorrt_llm +{ namespace kernels { // clang-format off @@ -98,7 +99,7 @@ namespace kernels cpp_file_suffex_text = R""" // clang-format on } // namespace kernels -TRTLLM_NAMESPACE_END +} """ cubin_meta_info_struct_prefix_text = R""" @@ -438,8 +439,9 @@ if __name__ == "__main__": CompileMacroOption('HEAD_ELEMS', 'd', [128]), CompileMacroOption('BEAM_WIDTH', 'beam', [1]), CompileMacroOption('CACHE_ELEM_ENUM', 'kvt', [0, 1, 2]), - CompileMacroOption('TOKENS_PER_PAGE', 'pagedKV', - [0, 64, 128]), # 0 denotes contiguous kv cache. + CompileMacroOption( + 'TOKENS_PER_PAGE', 'pagedKV', + [0, 32, 64, 128]), # 0 denotes contiguous kv cache. CompileMacroOption('HEAD_GRP_SIZE', 'nqpkv', [0]), CompileMacroOption('M_TILESIZE', 'm', [16, 32]), ]] diff --git a/cpp/kernels/xqa/mha.cu b/cpp/kernels/xqa/mha.cu index 5881a93fc4..e215cb172d 100644 --- a/cpp/kernels/xqa/mha.cu +++ b/cpp/kernels/xqa/mha.cu @@ -465,13 +465,10 @@ using WarpAcc = WarpAccT; #if SPEC_DEC #define MMAS_N_PER_MASK 2 -__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset, - uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE - , - int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg -#endif -) +__device__ inline void applyMaskFromInputSlidingAndSpecDec(Warp const& warp, WarpAcc& acc, MaskType const* mask, + uint32_t rowOffset, uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize, + int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg) { uint32_t const idxInQuad = laneId() % 4; uint32_t const idxQuad = laneId() / 4; @@ -479,7 +476,6 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u; uint16_t const* uint16Mask = reinterpret_cast(mask); constexpr uint64_t fullMask = ~uint64_t{0}; -#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x}; Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1)}; bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end; @@ -487,11 +483,6 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(warpTileTokenBeg); uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x; bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask); -#else - constexpr bool ctaNeedBegMask = false; - bool const ctaNeedSpecDecMask = true; - int32_t const tok0NbMaskOut = -2147483648; -#endif bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask; if (!needMask) @@ -559,6 +550,61 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy } #endif +__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset, + uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize) +{ + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; + // Packed mask is aligned with 32 bits (2 uint16_t). + uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u; + uint16_t const* uint16Mask = reinterpret_cast(mask); +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) + { + uint32_t const tokenRow = min((rowOffset + instM * m + idxQuad + i * 8) / headGrpSize, actualQSeqLen - 1); +#pragma unroll + for (uint32_t mask_n = 0; mask_n < acc.cols / MMAS_N_PER_MASK; mask_n++) + { + uint32_t const firstCol = instN * mask_n * MMAS_N_PER_MASK + InstAcc::cols * idxInQuad; + uint32_t const lastCol = firstCol + instN * (MMAS_N_PER_MASK - 1) + InstAcc::cols - 1; + uint32_t const maskPos0 = firstCol + actualQSeqLen < nbValidCols + ? 0u + : min(firstCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1); + uint32_t const maskPos1 = lastCol + actualQSeqLen < nbValidCols + ? 0u + : min(lastCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1); + uint32_t packedMask = 0u; + uint32_t const maskPosStart = (maskPos0 / 16) * 16; + reinterpret_cast(&packedMask)[0] + = uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)]; + reinterpret_cast(&packedMask)[1] + = uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)]; +#pragma unroll + for (uint32_t nj = 0; nj < MMAS_N_PER_MASK; nj++) + { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j++) + { + uint32_t const n = (mask_n * MMAS_N_PER_MASK + nj); + uint32_t const col = instN * n + InstAcc::cols * idxInQuad + j; + // bool const maskFlag = col + qSeqLen < nbValidCols ? true : mask[tokenRow * qSeqLen + (col + + // qSeqLen - nbValidCols)]; + bool const maskFlag = col + actualQSeqLen < nbValidCols + ? true + : packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart)); + acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax; + } + } + } + } + } +} + +#endif + __device__ inline QuadRegRowMax warpTileOnlineSoftmax(Warp const& warp, QuadRegRowMax const& rowMaxHint, WarpAcc& acc) { QuadRegRowMax rowMax = rowMaxHint; @@ -1655,7 +1701,7 @@ CUBIN_EXPORT __global__ uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset; int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize); uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg); - + bool const rtIsReallySliding = (cacheSeqLen + actualQSeqLen > slidingWinSize); #elif SLIDING_WINDOW bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize); assert(!SPEC_DEC || !rtIsReallySliding); @@ -1673,7 +1719,8 @@ CUBIN_EXPORT __global__ uint32_t const nbSeqIters = useKVCache ? divUp(cacheSeqLen, ctaTile.x) : 0; #if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE - uint32_t const nbSeqItersWithoutMask = nbSkipLeadingTiles; + uint32_t const nbSeqItersWithoutMask + = rtIsReallySliding ? nbSkipLeadingTiles : (cacheSeqLen - actualQSeqLen) / ctaTile.x; #elif SPEC_DEC uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x; #endif @@ -1960,12 +2007,18 @@ CUBIN_EXPORT __global__ if (seqIter >= nbSeqItersWithoutMask) { uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U); - applyMaskFromInput(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE - , - tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg + if (rtIsReallySliding) + { + applyMaskFromInputSlidingAndSpecDec(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, + actualQSeqLen, headGrpSize, tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg); + } + else #endif - ); + { + applyMaskFromInput( + warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize); + } } #else bool const isFirstIter = (seqIter == nbSkipLeadingTiles); diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp index 165ffc2848..d9dc79662f 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp @@ -84,26 +84,8 @@ DecoderXQAImpl* DecoderXQARunner::getImplFromXQAParams(XQAParams const& xqaParam } else { - if (xqaParams.multi_query_tokens) - { - // Some multi_query kernels are not ported to JIT yet. - auto const grpSize = xqaParams.num_q_heads / xqaParams.num_kv_heads; - // Hopper XQA supports spec dec with JIT, but only for E4M3 kv cache data type. Only allow 64%grpSize==0 for - // now. - bool const supportedByHopperXqa - = (smVersion == 90 && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3 && grpSize <= 64); - bool const supportedBySm120Mla = (smVersion == 120 && xqaParams.isMLA() - && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3); - bool const supportedByAmpereXqa = (!xqaParams.isMLA() && (64 % grpSize == 0)); - - return (supportedByHopperXqa || supportedBySm120Mla || supportedByAmpereXqa) ? mJITImpl.get() - : mPrecompiledImpl.get(); - } - else - { - // regular decoding kernels uses JIT by default - return mJITImpl.get(); - } + // uses JIT by default + return mJITImpl.get(); } }