mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-03 17:52:19 +08:00
[None][feat] Use XQA JIT impl by default and mitigate perf loss with sliding window (#10335)
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
This commit is contained in:
parent
71ccc07d2b
commit
683515b1bd
@ -89,7 +89,8 @@ cpp_file_prefix_text = R"""/*
|
||||
|
||||
#include "tensorrt_llm/common/config.h"
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
// clang-format off
|
||||
@ -98,7 +99,7 @@ namespace kernels
|
||||
cpp_file_suffex_text = R"""
|
||||
// clang-format on
|
||||
} // namespace kernels
|
||||
TRTLLM_NAMESPACE_END
|
||||
}
|
||||
"""
|
||||
|
||||
cubin_meta_info_struct_prefix_text = R"""
|
||||
@ -438,8 +439,9 @@ if __name__ == "__main__":
|
||||
CompileMacroOption('HEAD_ELEMS', 'd', [128]),
|
||||
CompileMacroOption('BEAM_WIDTH', 'beam', [1]),
|
||||
CompileMacroOption('CACHE_ELEM_ENUM', 'kvt', [0, 1, 2]),
|
||||
CompileMacroOption('TOKENS_PER_PAGE', 'pagedKV',
|
||||
[0, 64, 128]), # 0 denotes contiguous kv cache.
|
||||
CompileMacroOption(
|
||||
'TOKENS_PER_PAGE', 'pagedKV',
|
||||
[0, 32, 64, 128]), # 0 denotes contiguous kv cache.
|
||||
CompileMacroOption('HEAD_GRP_SIZE', 'nqpkv', [0]),
|
||||
CompileMacroOption('M_TILESIZE', 'm', [16, 32]),
|
||||
]]
|
||||
|
||||
@ -465,13 +465,10 @@ using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
|
||||
#if SPEC_DEC
|
||||
#define MMAS_N_PER_MASK 2
|
||||
|
||||
__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset,
|
||||
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize
|
||||
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
|
||||
,
|
||||
int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg
|
||||
#endif
|
||||
)
|
||||
__device__ inline void applyMaskFromInputSlidingAndSpecDec(Warp const& warp, WarpAcc& acc, MaskType const* mask,
|
||||
uint32_t rowOffset, uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize,
|
||||
int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg)
|
||||
{
|
||||
uint32_t const idxInQuad = laneId() % 4;
|
||||
uint32_t const idxQuad = laneId() / 4;
|
||||
@ -479,7 +476,6 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
|
||||
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
|
||||
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
|
||||
constexpr uint64_t fullMask = ~uint64_t{0};
|
||||
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
|
||||
Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x};
|
||||
Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1)};
|
||||
bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end;
|
||||
@ -487,11 +483,6 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
|
||||
int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(warpTileTokenBeg);
|
||||
uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
|
||||
bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask);
|
||||
#else
|
||||
constexpr bool ctaNeedBegMask = false;
|
||||
bool const ctaNeedSpecDecMask = true;
|
||||
int32_t const tok0NbMaskOut = -2147483648;
|
||||
#endif
|
||||
bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask;
|
||||
|
||||
if (!needMask)
|
||||
@ -559,6 +550,61 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
|
||||
}
|
||||
#endif
|
||||
|
||||
__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset,
|
||||
uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize)
|
||||
{
|
||||
uint32_t const idxInQuad = laneId() % 4;
|
||||
uint32_t const idxQuad = laneId() / 4;
|
||||
// Packed mask is aligned with 32 bits (2 uint16_t).
|
||||
uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u;
|
||||
uint16_t const* uint16Mask = reinterpret_cast<uint16_t const*>(mask);
|
||||
#pragma unroll
|
||||
for (uint32_t m = 0; m < acc.rows; m++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < InstAcc::rows; i++)
|
||||
{
|
||||
uint32_t const tokenRow = min((rowOffset + instM * m + idxQuad + i * 8) / headGrpSize, actualQSeqLen - 1);
|
||||
#pragma unroll
|
||||
for (uint32_t mask_n = 0; mask_n < acc.cols / MMAS_N_PER_MASK; mask_n++)
|
||||
{
|
||||
uint32_t const firstCol = instN * mask_n * MMAS_N_PER_MASK + InstAcc::cols * idxInQuad;
|
||||
uint32_t const lastCol = firstCol + instN * (MMAS_N_PER_MASK - 1) + InstAcc::cols - 1;
|
||||
uint32_t const maskPos0 = firstCol + actualQSeqLen < nbValidCols
|
||||
? 0u
|
||||
: min(firstCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1);
|
||||
uint32_t const maskPos1 = lastCol + actualQSeqLen < nbValidCols
|
||||
? 0u
|
||||
: min(lastCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1);
|
||||
uint32_t packedMask = 0u;
|
||||
uint32_t const maskPosStart = (maskPos0 / 16) * 16;
|
||||
reinterpret_cast<uint16_t*>(&packedMask)[0]
|
||||
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)];
|
||||
reinterpret_cast<uint16_t*>(&packedMask)[1]
|
||||
= uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)];
|
||||
#pragma unroll
|
||||
for (uint32_t nj = 0; nj < MMAS_N_PER_MASK; nj++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < InstAcc::cols; j++)
|
||||
{
|
||||
uint32_t const n = (mask_n * MMAS_N_PER_MASK + nj);
|
||||
uint32_t const col = instN * n + InstAcc::cols * idxInQuad + j;
|
||||
// bool const maskFlag = col + qSeqLen < nbValidCols ? true : mask[tokenRow * qSeqLen + (col +
|
||||
// qSeqLen - nbValidCols)];
|
||||
bool const maskFlag = col + actualQSeqLen < nbValidCols
|
||||
? true
|
||||
: packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
|
||||
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
__device__ inline QuadRegRowMax warpTileOnlineSoftmax(Warp const& warp, QuadRegRowMax const& rowMaxHint, WarpAcc& acc)
|
||||
{
|
||||
QuadRegRowMax rowMax = rowMaxHint;
|
||||
@ -1655,7 +1701,7 @@ CUBIN_EXPORT __global__
|
||||
uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset;
|
||||
int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize);
|
||||
uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg);
|
||||
|
||||
bool const rtIsReallySliding = (cacheSeqLen + actualQSeqLen > slidingWinSize);
|
||||
#elif SLIDING_WINDOW
|
||||
bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize);
|
||||
assert(!SPEC_DEC || !rtIsReallySliding);
|
||||
@ -1673,7 +1719,8 @@ CUBIN_EXPORT __global__
|
||||
|
||||
uint32_t const nbSeqIters = useKVCache ? divUp(cacheSeqLen, ctaTile.x) : 0;
|
||||
#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE
|
||||
uint32_t const nbSeqItersWithoutMask = nbSkipLeadingTiles;
|
||||
uint32_t const nbSeqItersWithoutMask
|
||||
= rtIsReallySliding ? nbSkipLeadingTiles : (cacheSeqLen - actualQSeqLen) / ctaTile.x;
|
||||
#elif SPEC_DEC
|
||||
uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x;
|
||||
#endif
|
||||
@ -1960,12 +2007,18 @@ CUBIN_EXPORT __global__
|
||||
if (seqIter >= nbSeqItersWithoutMask)
|
||||
{
|
||||
uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U);
|
||||
applyMaskFromInput(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize
|
||||
#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE
|
||||
,
|
||||
tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg
|
||||
if (rtIsReallySliding)
|
||||
{
|
||||
applyMaskFromInputSlidingAndSpecDec(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen,
|
||||
actualQSeqLen, headGrpSize, tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
);
|
||||
{
|
||||
applyMaskFromInput(
|
||||
warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize);
|
||||
}
|
||||
}
|
||||
#else
|
||||
bool const isFirstIter = (seqIter == nbSkipLeadingTiles);
|
||||
|
||||
@ -84,26 +84,8 @@ DecoderXQAImpl* DecoderXQARunner::getImplFromXQAParams(XQAParams const& xqaParam
|
||||
}
|
||||
else
|
||||
{
|
||||
if (xqaParams.multi_query_tokens)
|
||||
{
|
||||
// Some multi_query kernels are not ported to JIT yet.
|
||||
auto const grpSize = xqaParams.num_q_heads / xqaParams.num_kv_heads;
|
||||
// Hopper XQA supports spec dec with JIT, but only for E4M3 kv cache data type. Only allow 64%grpSize==0 for
|
||||
// now.
|
||||
bool const supportedByHopperXqa
|
||||
= (smVersion == 90 && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3 && grpSize <= 64);
|
||||
bool const supportedBySm120Mla = (smVersion == 120 && xqaParams.isMLA()
|
||||
&& xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3);
|
||||
bool const supportedByAmpereXqa = (!xqaParams.isMLA() && (64 % grpSize == 0));
|
||||
|
||||
return (supportedByHopperXqa || supportedBySm120Mla || supportedByAmpereXqa) ? mJITImpl.get()
|
||||
: mPrecompiledImpl.get();
|
||||
}
|
||||
else
|
||||
{
|
||||
// regular decoding kernels uses JIT by default
|
||||
return mJITImpl.get();
|
||||
}
|
||||
// uses JIT by default
|
||||
return mJITImpl.get();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user