diff --git a/cpp/kernels/xqa/defines.h b/cpp/kernels/xqa/defines.h index efc5c4ec65..b369b43045 100644 --- a/cpp/kernels/xqa/defines.h +++ b/cpp/kernels/xqa/defines.h @@ -129,6 +129,18 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena #define SLIDING_WINDOW 0 #endif +#ifndef SKIP_SOFTMAX_ATTN +#define SKIP_SOFTMAX_ATTN 0 +#endif + +#ifndef SKIP_SOFTMAX_ATTN_BLOCK_STATS +#define SKIP_SOFTMAX_ATTN_BLOCK_STATS 0 +#endif + +#ifndef SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE +#define SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE 1 +#endif + // 0 - no PDL // 1 - naive PDL // 2 - aggressive PDL (implemented only in mha_sm90.cu for now) diff --git a/cpp/kernels/xqa/gmma.cuh b/cpp/kernels/xqa/gmma.cuh index f5f29c73e7..7f5a843865 100644 --- a/cpp/kernels/xqa/gmma.cuh +++ b/cpp/kernels/xqa/gmma.cuh @@ -106,6 +106,7 @@ __device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset, asm volatile("trap;\n"); return 0; }(); + assert(__cvta_generic_to_shared(data) % baseAlign == 0); uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7)); return MatDesc{ /*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)), diff --git a/cpp/kernels/xqa/mha.cu b/cpp/kernels/xqa/mha.cu index 330364ee88..5881a93fc4 100644 --- a/cpp/kernels/xqa/mha.cu +++ b/cpp/kernels/xqa/mha.cu @@ -2734,6 +2734,25 @@ static constexpr auto kernel_mha = kernel_mha_impl; #endif #ifndef GENERATE_CUBIN +uint32_t computeNbSubSeqPerSeqMHA(cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen) +{ + if (!allowMultiBlockMode) + { + return 1; + } + auto const env = std::getenv("XQA_NB_SUB_SEQ"); + if (env != nullptr) + { + int32_t const val = std::stoi(env); + if (val > 0) + { + return val; + } + } + return std::min( + std::max(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x)); +} + void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #if SLIDING_WINDOW uint32_t slidingWinSize, @@ -2771,6 +2790,13 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads, // int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, +#endif +#if SKIP_SOFTMAX_ATTN + float const skipSoftmaxThresholdScaleFactor, // for compatibility with mha_sm90.cu only +#if SKIP_SOFTMAX_ATTN_BLOCK_STATS + uint32_t* __restrict__ skippedBlockCount, // for compatibility with mha_sm90.cu only + uint32_t* __restrict__ totalBlockCount, // for compatibility with mha_sm90.cu only +#endif #endif uint32_t* semaphores, void* scratch, cudaStream_t stream) { @@ -2793,24 +2819,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads, uint32_t const nbQHeads = nbKHeads * headGrpSize; // const uint32_t nbSubSeqPerSeq = allowMultiBlockMode ? DBG_NB_CTAS_PER_SEQ : 1; - uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t - { - if (!allowMultiBlockMode) - { - return 1; - } - auto const env = std::getenv("XQA_NB_SUB_SEQ"); - if (env != nullptr) - { - int32_t const val = std::stoi(env); - if (val > 0) - { - return val; - } - } - return std::min( - std::max(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x)); - }(); + uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqMHA(prop, batchSize, nbKHeads, maxSeqLen); // gridDim.z == batchSize && gridDim.y == nbKHeads && gridDim.x == nbSubSeqPerSeq #if SPEC_DEC const uint32_t nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, rowsPerBlock); diff --git a/cpp/kernels/xqa/mha.h b/cpp/kernels/xqa/mha.h index a40a5e6c0d..2c7ef50a83 100644 --- a/cpp/kernels/xqa/mha.h +++ b/cpp/kernels/xqa/mha.h @@ -90,6 +90,9 @@ struct BeamSearchParams // match trt-llm API. }; +uint32_t computeNbSubSeqPerSeqMHA( + cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen); + void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads, #if SLIDING_WINDOW uint32_t slidingWinSize, @@ -127,9 +130,18 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads, // int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, +#endif +#if SKIP_SOFTMAX_ATTN + float const skipSoftmaxThresholdScaleFactor, +#if SKIP_SOFTMAX_ATTN_BLOCK_STATS + uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount, +#endif #endif uint32_t* semaphores, void* scratch, cudaStream_t stream); +uint32_t computeNbSubSeqPerSeqHopperF8MHA( + cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen); + void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #if SLIDING_WINDOW uint32_t slidingWinSize, @@ -167,6 +179,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, // int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, +#endif +#if SKIP_SOFTMAX_ATTN + float const skipSoftmaxThresholdScaleFactor, +#if SKIP_SOFTMAX_ATTN_BLOCK_STATS + uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount, +#endif #endif uint32_t* semaphores, void* scratch, cudaStream_t stream); diff --git a/cpp/kernels/xqa/mha_sm90.cu b/cpp/kernels/xqa/mha_sm90.cu index 457b106891..dc21872ac6 100644 --- a/cpp/kernels/xqa/mha_sm90.cu +++ b/cpp/kernels/xqa/mha_sm90.cu @@ -49,6 +49,10 @@ static_assert(specDecQLen * headGrpSize <= 32, "SPEC_Q_SEQ_LEN macro value is to #define SWAP_AB (!SPEC_DEC) #endif +#if SKIP_SOFTMAX_ATTN +static_assert(SWAP_AB && USE_PAGED_KV_CACHE && !SPEC_DEC && BEAM_WIDTH == 1, "SKIP_SOFTMAX_ATTN is not supported."); +#endif + #define IS_SUPPORTED_F16_CASE (CACHE_ELEM_ENUM == 0 && !SPEC_DEC && SWAP_AB && !USE_INPUT_KV && !LOW_PREC_OUTPUT) inline constexpr bool swapAB = SWAP_AB; @@ -138,26 +142,38 @@ using PaddedOutHead = PaddedInputHead; struct alignas(128) SharedMem { + using QBuffer = Vec, nbQParts>; using KBuffer = Array2D; - static constexpr uint32_t nbKBuf = 2; - KBuffer k[nbKBuf]; // as is loaded from global mem. using XBuffer = Vec, nbXParts>; - static constexpr uint32_t nbXBuf - = 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens)); using VBuffer = Vec, cacheHeadNbParts>; #if !SWAP_AB using VTBuffer = Array2D; #endif - static constexpr uint32_t nbVBuf = 2; #if CACHE_ELEM_ENUM == 0 using OutSwizzleBuf = Array2D; #elif CACHE_ELEM_ENUM == 2 using OutSwizzleBuf = Array2D, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>; #endif + +#if SKIP_SOFTMAX_ATTN + static constexpr uint32_t nbKBuf = 2; + static constexpr uint32_t nbVBuf = 3; // @fixme: skip_softmax_attn: for skip softmax attn, an extra VBuffer is used + static constexpr uint32_t nbXBuf + = 3 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens)); +#else + static constexpr uint32_t nbKBuf = 2; + static constexpr uint32_t nbVBuf = 2; + static constexpr uint32_t nbXBuf + = 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens)); +#endif static_assert(nbXBuf == nbVBuf); + // note: buffers used for GMMA may have additional alignment requirements + KBuffer k[nbKBuf]; // as is loaded from global mem. + QBuffer q; // For gmma math. Conversion done if needed. + union ReusedXVOutSwizzleBuf { struct XV @@ -196,9 +212,6 @@ struct alignas(128) SharedMem return reusedXVOutSwizzleBuf[i].outSwizzle; } - using QBuffer = Vec, nbQParts>; - QBuffer q; // For gmma math. Conversion done if needed. - // @fixme: move these into reusedXVOutSwizzleBuf #if SWAP_AB ShmQWiseVec xColMax[nbXBuf]; @@ -220,6 +233,11 @@ struct alignas(128) SharedMem Vec pages[2]; // one for K and one for V #endif +#if SKIP_SOFTMAX_ATTN + uint32_t skipSoftmaxVotesGemm0ToV[nbXBuf]; // guarded by skipSoftmaxXBar + uint32_t skipSoftmaxVotesGemm0ToGemm1[nbXBuf]; // guarded by xBar +#endif + // mem barriers CtaBarrierPair qBar; @@ -229,6 +247,9 @@ struct alignas(128) SharedMem CtaBarrierPair vtBar[nbVBuf]; #endif CtaBarrierPair xBar[nbXBuf]; +#if SKIP_SOFTMAX_ATTN + CtaBarrierPair skipSoftmaxXBar[nbXBuf]; // for V to wait for X to be ready +#endif // used internally in the gemm0 warp group // @fixme: use separate arrive and wait for all usage @@ -425,8 +446,13 @@ __device__ void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, #endif #if SWAP_AB +#if SKIP_SOFTMAX_ATTN +__device__ RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src, + float skipSoftmaxThreshold, uint32_t* smemSkipVote, bool maybeSkip); +#else __device__ RegColWiseVec computeWarpGrpColMax_sync( CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src); +#endif __device__ void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd); __device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax); __device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src); @@ -675,6 +701,12 @@ CUBIN_EXPORT __global__ #endif #if SPEC_DEC SpecDecParams const specDecParams, +#endif +#if SKIP_SOFTMAX_ATTN + float const skipSoftmaxThresholdScaleFactor, +#if SKIP_SOFTMAX_ATTN_BLOCK_STATS + uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount, +#endif #endif uint32_t* __restrict__ const semaphores = nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)] @@ -753,6 +785,10 @@ CUBIN_EXPORT __global__ uint32_t const nbSubSeq = isMultiBlockMode ? mha::min(nbTilesInUse / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1; static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2); assert(isMultiBlockMode == (nbSubSeq > 1)); +#if SKIP_SOFTMAX_ATTN + bool const disableSkipForShortSeq = (cacheSeqLen < skipSoftmaxThresholdScaleFactor); + float const skipSoftmaxThreshold = disableSkipForShortSeq ? 0.0f : skipSoftmaxThresholdScaleFactor / cacheSeqLen; +#endif if (idxSubSeq >= nbSubSeq) { return; @@ -776,21 +812,34 @@ CUBIN_EXPORT __global__ assert(dynamicSmemSize() >= sizeof(SharedMem)); SharedMem& smem = *reinterpret_cast(&smemByteBuf[0]); - constexpr uint32_t nbBuffers = 2; - static_assert(nbBuffers == SharedMem::nbKBuf && nbBuffers == SharedMem::nbVBuf && nbBuffers == SharedMem::nbXBuf); - if (wid < nbBuffers) + constexpr uint32_t maxNbBuffers = (SharedMem::nbXBuf > SharedMem::nbVBuf) ? SharedMem::nbXBuf : SharedMem::nbVBuf; + static_assert( + maxNbBuffers >= SharedMem::nbKBuf && maxNbBuffers >= SharedMem::nbVBuf && maxNbBuffers >= SharedMem::nbXBuf); + if (wid < maxNbBuffers) { if (warpElectSync()) { - smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size); - smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size); -#if !SWAP_AB - smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2); + if (wid < SharedMem::nbKBuf) + { + smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size); + } + if (wid < SharedMem::nbXBuf) + { +#if SKIP_SOFTMAX_ATTN + smem.skipSoftmaxXBar[wid].initialize(gemm0NbThrds + warp_size, gemm0NbThrds + warp_size); + smem.vBar[wid].initialize(gemm1NbThrds + warp_size, gemm1NbThrds + warp_size); +#else + smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size); #endif - smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds); + +#if !SWAP_AB + smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2); +#endif + smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds); + } } } - else if (wid == nbBuffers) + else if (wid == maxNbBuffers) { if (warpElectSync()) { @@ -819,6 +868,10 @@ CUBIN_EXPORT __global__ SpecDec const specDec{specDecParams, idxReq, idxInputSubSeq, cacheSeqLen}; #endif +#if SKIP_SOFTMAX_ATTN_BLOCK_STATS + uint32_t localSkippedBlockCount = 0; +#endif + // QK gemm constexpr uint32_t nbGmmaInstM = exactDiv(gemm0CtaTileNbTokens, gmma::instM); using Acc = GmmaAcc; @@ -940,10 +993,39 @@ CUBIN_EXPORT __global__ } } #endif + + uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf; + auto& xBar = smem.xBar[idxXBuf]; // update colMax in shared mem and get a register copy #if SWAP_AB +#if SKIP_SOFTMAX_ATTN + auto& skipSoftmaxXBar = smem.skipSoftmaxXBar[idxXBuf]; + skipSoftmaxXBar.consumed.arrive_and_wait(); + + bool const maybeSkip = !disableSkipForShortSeq && idxIter != 0; + RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc, + skipSoftmaxThreshold, &smem.skipSoftmaxVotesGemm0ToV[idxXBuf], maybeSkip); + bool const shouldSkipSoftmaxAttn = static_cast(smem.skipSoftmaxVotesGemm0ToV[idxXBuf]); + unused(skipSoftmaxXBar.produced.arrive()); + warpGrpOnlineSoftmax(acc, colMax); + if (shouldSkipSoftmaxAttn) + { + xBar.consumed.arrive_and_wait(); + if (threadIdx.x == 0) + { + smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf] = 1U; +#if SKIP_SOFTMAX_ATTN_BLOCK_STATS + localSkippedBlockCount++; +#endif + } + asm volatile("fence.proxy.async.shared::cta;\n"); // maybe not used + unused(xBar.produced.arrive()); + continue; + } +#else RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc); warpGrpOnlineSoftmax(acc, colMax); +#endif #else RegRowWiseVec const rowMax = computeWarpGrpRowMax_sync(warpRank, smem.gemm0CurrentSeqMax, acc); warpGrpOnlineSoftmax(acc, rowMax); @@ -959,8 +1041,6 @@ CUBIN_EXPORT __global__ // map 1 to fp8_max before conversion to fp8 acc = acc * kE4M3_MAX; - uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf; - auto& xBar = smem.xBar[idxXBuf]; // @fixme: for fp16/bf16, try not to transpose acc here, and leave it to the next GEMM. #if SWAP_AB storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc); @@ -989,13 +1069,25 @@ CUBIN_EXPORT __global__ storeShmRowWiseVec(warpRank, smem.xRowMax[idxXBuf], rowMax); storeShmRowWiseVec(warpRank, smem.xRowSum[idxXBuf], rowSum); #endif - +#if SKIP_SOFTMAX_ATTN + if (threadIdx.x == 0) + { + smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf] = 0; + } +#endif __syncwarp(); // the release semantics of arrive does not work for async consumers like gmma. additional fence is // needed. asm volatile("fence.proxy.async.shared::cta;\n"); unused(xBar.produced.arrive()); } +#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS + if (threadIdx.x == 0 && skippedBlockCount != nullptr && totalBlockCount != nullptr) + { + atomicAdd(skippedBlockCount, localSkippedBlockCount); + atomicAdd(totalBlockCount, nbIters); + } +#endif unused(smem.qBar.consumed.arrive()); } else if (warpIdx.z == 1) @@ -1043,216 +1135,231 @@ CUBIN_EXPORT __global__ uint32_t idxVTile = idxVTileInit + idxIter * nbSubSeq; auto const idxVBuf = idxIter % SharedMem::nbVBuf; auto const idxXBuf = idxVBuf; - auto& vBar = smem.vBar[idxVBuf]; - arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds)); - auto const& vBuf = smem.vBuf(idxVBuf); -#if !SWAP_AB - CtaBarrierPair& vtBar = smem.vtBar[idxVBuf]; - auto& vtBuf = smem.vtBuf(idxVBuf); - vtBar.consumed.arrive_and_wait(); - transposeVTile(warpRank, laneId(), vtBuf, vBuf); - vBar.consumed.arrive(); - vtBar.produced.arrive(); -#endif auto& xBar = smem.xBar[idxXBuf]; + auto& vBar = smem.vBar[idxVBuf]; + auto const& vBuf = smem.vBuf(idxVBuf); xBar.produced.arrive_and_wait(); +#if SKIP_SOFTMAX_ATTN + bool shouldSkipSoftmaxAttn = smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf]; // guarded by xBar + if (shouldSkipSoftmaxAttn) + { + vBar.produced.arrive_and_wait(); + } +#endif + +#if SKIP_SOFTMAX_ATTN + if (!shouldSkipSoftmaxAttn) // skip XVGemm +#endif + { + arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds)); +#if !SWAP_AB + CtaBarrierPair& vtBar = smem.vtBar[idxVBuf]; + auto& vtBuf = smem.vtBuf(idxVBuf); + vtBar.consumed.arrive_and_wait(); + transposeVTile(warpRank, laneId(), vtBuf, vBuf); + vBar.consumed.arrive(); + vtBar.produced.arrive(); +#endif #if !defined(NDEBUG) && DBG_PRINT #if SWAP_AB - if (threadIdx.x == 0) - { - printf("colMax:\n"); - for (int i = 0; i < ctaNbQHeads; i++) - { - printf("%f, ", smem.xColMax[idxXBuf][i]); - } - printf("\n"); - printf("colSum:\n"); - for (int n = 0; n < 4; n++) + if (threadIdx.x == 0) { + printf("colMax:\n"); for (int i = 0; i < ctaNbQHeads; i++) { - printf("%f, ", smem.xColSum[idxXBuf][n][i]); + printf("%f, ", smem.xColMax[idxXBuf][i]); + } + printf("\n"); + printf("colSum:\n"); + for (int n = 0; n < 4; n++) + { + for (int i = 0; i < ctaNbQHeads; i++) + { + printf("%f, ", smem.xColSum[idxXBuf][n][i]); + } + printf("\n"); + } + printf("\n"); + printf("X:\n"); + for (int i = 0; i < ctaNbQHeads; i++) + { + for (int j = 0; j < gemm0CtaTileNbTokens; j++) + { + auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart); + auto const e = reinterpret_cast&>( + smem.xBuf(idxXBuf)[j / elemsPerXPart].template at( + i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain]; + printf("%.2f, ", float(e)); + if (j % 16 == 15) + { + printf("| "); + } + } + printf("\n\n"); + } + } + smem.gemm1WarpGrpBar.arrive_and_wait(); +#else + if (blockIdx.y == 1 && threadIdx.x == 0) + { + printf("rowMax:\n"); + for (int i = 0; i < ctaNbQHeads; i++) + { + printf("%f, ", smem.xRowMax[idxXBuf][i]); + } + printf("\n"); + printf("rowSum:\n"); + for (int i = 0; i < ctaNbQHeads; i++) + { + printf("%f, ", smem.xRowSum[idxXBuf][i]); } printf("\n"); } - printf("\n"); - printf("X:\n"); - for (int i = 0; i < ctaNbQHeads; i++) - { - for (int j = 0; j < gemm0CtaTileNbTokens; j++) - { - auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart); - auto const e = reinterpret_cast&>( - smem.xBuf(idxXBuf)[j / elemsPerXPart].template at( - i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain]; - printf("%.2f, ", float(e)); - if (j % 16 == 15) - { - printf("| "); - } - } - printf("\n\n"); - } - } - smem.gemm1WarpGrpBar.arrive_and_wait(); -#else - if (blockIdx.y == 1 && threadIdx.x == 0) - { - printf("rowMax:\n"); - for (int i = 0; i < ctaNbQHeads; i++) - { - printf("%f, ", smem.xRowMax[idxXBuf][i]); - } - printf("\n"); - printf("rowSum:\n"); - for (int i = 0; i < ctaNbQHeads; i++) - { - printf("%f, ", smem.xRowSum[idxXBuf][i]); - } - printf("\n"); - } - smem.gemm1WarpGrpBar.arrive_and_wait(); + smem.gemm1WarpGrpBar.arrive_and_wait(); #endif #endif #if SWAP_AB - // @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead. - rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf], - smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar); + // @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead. + rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf], + smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar); #else - rescaleGemm1AccForNewRowMax_sync( - warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf], smem.gemm1AccColMax, acc, smem.gemm1AccColSum); + rescaleGemm1AccForNewRowMax_sync(warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf], + smem.gemm1AccColMax, acc, smem.gemm1AccColSum); #endif - auto& xBuf = smem.xBuf(idxXBuf); + auto& xBuf = smem.xBuf(idxXBuf); - auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8, - gmma::getSwizzleMode(SharedMem::XBuffer::Elem{})) - .raw(); + auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::XBuffer::Elem{})) + .raw(); #if CACHE_ELEM_ENUM == 0 - auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8, - gmma::getSwizzleMode(SharedMem::VBuffer::Elem{})) - .raw(); + auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::VBuffer::Elem{})) + .raw(); #endif #if SWAP_AB //@fixme: to reduce code size, we can disable unroll and use double-buffer for LDSM in loadVTileTransposed. #pragma unroll - for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++) - { + for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++) + { #if CACHE_ELEM_ENUM == 2 - Vec const fragA - = loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK); + Vec const fragA + = loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK); #if !defined(NDEBUG) && DBG_PRINT - if (threadIdx.x == 0) - { - printf("fragA:\nidxInstK == %u\n", idxInstK); - } - smem.gemm1WarpGrpBar.arrive_and_wait(); - for (int m = 0; m < 2; m++) - { - for (int w = 0; w < 4; w++) + if (threadIdx.x == 0) { - if (warpRank == w) + printf("fragA:\nidxInstK == %u\n", idxInstK); + } + smem.gemm1WarpGrpBar.arrive_and_wait(); + for (int m = 0; m < 2; m++) + { + for (int w = 0; w < 4; w++) { - if (laneId() == 0) + if (warpRank == w) { - printf(" warpRank = %u\n", warpRank); - } - __syncwarp(); - for (int a = 0; a < 2; a++) - { - for (int b = 0; b < 8; b++) + if (laneId() == 0) { - for (int c = 0; c < 2; c++) + printf(" warpRank = %u\n", warpRank); + } + __syncwarp(); + for (int a = 0; a < 2; a++) + { + for (int b = 0; b < 8; b++) { - for (int d = 0; d < 4; d++) + for (int c = 0; c < 2; c++) { - if (laneId() == b * 4 + d) + for (int d = 0; d < 4; d++) { - for (int e = 0; e < 4; e++) + if (laneId() == b * 4 + d) { - auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>( - fragA[m](0, c)(a, 0)); - printf("%.2f, ", float(elem4[e])); + for (int e = 0; e < 4; e++) + { + auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>( + fragA[m](0, c)(a, 0)); + printf("%.2f, ", float(elem4[e])); + } } + __syncwarp(); } - __syncwarp(); } + if (laneId() == 0) + { + printf("\n"); + } + __syncwarp(); } - if (laneId() == 0) + if (laneId() == 0 && a == 0) { - printf("\n"); + printf("----------------------\n"); } __syncwarp(); } - if (laneId() == 0 && a == 0) - { - printf("----------------------\n"); - } - __syncwarp(); } + smem.gemm1WarpGrpBar.arrive_and_wait(); } - smem.gemm1WarpGrpBar.arrive_and_wait(); } - } #endif #endif - BoundedVal const kOffsetInGrains{grainsPerInstK * idxInstK}; - auto const descX = addAddr(descXBase, - &xBuf[kOffsetInGrains.template divBy().get()]( - 0, kOffsetInGrains.template mod().get())); + BoundedVal const kOffsetInGrains{grainsPerInstK * idxInstK}; + auto const descX = addAddr(descXBase, + &xBuf[kOffsetInGrains.template divBy().get()]( + 0, kOffsetInGrains.template mod().get())); #if CACHE_ELEM_ENUM == 2 - gmma::fence(); + gmma::fence(); #endif #pragma unroll - for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++) - { + for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++) + { #if CACHE_ELEM_ENUM == 0 - auto const descV - = addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0)); - gmma::mma_async_shmA( - reinterpret_cast(acc(idxInstM, 0)), - descV, descX, true); + auto const descV + = addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0)); + gmma::mma_async_shmA( + reinterpret_cast(acc(idxInstM, 0)), + descV, descX, true); #elif CACHE_ELEM_ENUM == 2 - gmma::mma_async_regA( - reinterpret_cast(acc(idxInstM, 0)), - reinterpret_cast(fragA[idxInstM]), descX, true); + gmma::mma_async_regA( + reinterpret_cast(acc(idxInstM, 0)), + reinterpret_cast(fragA[idxInstM]), descX, true); #endif + } + gmma::commit_group(); + //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of + // gmma. + gmma::wait_group<0>(); } - gmma::commit_group(); - //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of - // gmma. - gmma::wait_group<0>(); - } #else - auto const descVTBase = gmma::makeMatDesc( - nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode(SharedMem::VTBuffer{})) - .raw(); - vtBar.produced.arrive_and_wait(); + auto const descVTBase = gmma::makeMatDesc( + nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode(SharedMem::VTBuffer{})) + .raw(); + vtBar.produced.arrive_and_wait(); // if (idxIter == 1 && threadIdx.x == 0) { // printf("vtBuf:\n"); // dbg::printArray2D<__nv_fp8_e4m3, true>(vtBuf); // } #pragma unroll - for (uint32_t m = 0; m < Gemm1Acc::rows; m++) - { -#pragma unroll - for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++) + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { - BoundedVal const kOffsetInGrains{grainsPerInstK * k}; - auto const descX = addAddr(descXBase, - &xBuf[kOffsetInGrains.template divBy().get()]( - gmma::instM * m, kOffsetInGrains.template mod().get())); - auto const descVT = addAddr( - descVTBase, &vtBuf(0, kOffsetInGrains.template mod().get())); - gmma::mma_async_shmA( - reinterpret_cast(acc(m, 0)), descX, - descVT, true); +#pragma unroll + for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++) + { + BoundedVal const kOffsetInGrains{grainsPerInstK * k}; + auto const descX = addAddr(descXBase, + &xBuf[kOffsetInGrains.template divBy().get()]( + gmma::instM * m, kOffsetInGrains.template mod().get())); + auto const descVT = addAddr( + descVTBase, &vtBuf(0, kOffsetInGrains.template mod().get())); + gmma::mma_async_shmA( + reinterpret_cast(acc(m, 0)), descX, + descVT, true); + } } - } - gmma::commit_group(); - //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of gmma. - gmma::wait_group<0>(); + gmma::commit_group(); + //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of + // gmma. + gmma::wait_group<0>(); #endif + } + if (idxIter == nbIters - 1) { // gmma::wait_group should have already synchronized threads, so this may be unnecessary. @@ -1471,8 +1578,24 @@ CUBIN_EXPORT __global__ tensorMap #endif }; +#if SKIP_SOFTMAX_ATTN + for (auto& b : smem.skipSoftmaxXBar) + { + unused(b.consumed.arrive()); + } +#endif for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { + uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf; + auto& vBar = smem.vBar[idxVBuf]; +#if SKIP_SOFTMAX_ATTN + uint32_t idxXBuf = idxIter % SharedMem::nbXBuf; + auto& skipSoftmaxXBar = smem.skipSoftmaxXBar[idxXBuf]; + skipSoftmaxXBar.produced.arrive_and_wait(); + bool shouldSkipSoftmaxAttn = smem.skipSoftmaxVotesGemm0ToV[idxXBuf]; + skipSoftmaxXBar.consumed.arrive(); +#endif + uint32_t const idxVTile = idxVTileInit + idxIter * nbSubSeq; vTileLoader.loadPages(idxVTile); #if USE_INPUT_KV || ENABLE_PDL == 2 @@ -1506,8 +1629,20 @@ CUBIN_EXPORT __global__ } #endif - uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf; - auto& vBar = smem.vBar[idxVBuf]; +#if SKIP_SOFTMAX_ATTN + if (shouldSkipSoftmaxAttn) + { + vBar.consumed.arrive_and_wait(); + // compared to non-skip softmax attn, we need to increase vBar.produced count to avoid race + // condition where vBar.consumed is arrived again without wait without skip softmax attn, XVGemm + // will wait for tx_count, so its progress won't go ahead of vload warp with skip softmax attn, + // XVGemm WG may go ahead of vload warp, as previous vBar only have XVGemm WG threads and a tx_count + // (now = 0). Then it may arrive vBar.consumed before it is arrive_and_wait-ed + vBar.produced.arrive(); + continue; + } +#endif + vBar.consumed.arrive_and_wait(); if (warpElectSync()) { @@ -1517,6 +1652,9 @@ CUBIN_EXPORT __global__ vTileLoader.loadData(smem.vBuf(idxVBuf)[idxPart], idxVTile, idxPart, vBar.produced); } } +#if SKIP_SOFTMAX_ATTN + vBar.produced.arrive(); +#endif __syncwarp(); } } @@ -1992,9 +2130,23 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, #endif // SPEC_DEC // smemColMax is persistent across multiple iterations +#if SKIP_SOFTMAX_ATTN +__device__ inline RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, + Gemm0Acc const& src, float skipSoftmaxThreshold, uint32_t* smemSkipVote, bool maybeSkip) +#else __device__ inline RegColWiseVec computeWarpGrpColMax_sync( CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src) +#endif { +#if SKIP_SOFTMAX_ATTN + if (threadIdx.x == 0) + { + *smemSkipVote = maybeSkip ? 1U : 0U; // will sync before vote + } + float const lnThreshold + = log(skipSoftmaxThreshold); // this can be -inf, but should be safe as we only use it for comparison +#endif + auto colMax = RegColWiseVec::filled(Vec::filled(safeInitRowMax)); #pragma unroll for (uint32_t n = 0; n < src.cols; n++) @@ -2029,6 +2181,9 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync( } uint32_t const lane = laneId(); +#if SKIP_SOFTMAX_ATTN + auto prevOrCurrentMax = RegColWiseVec(); +#if SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE if (lane < 4) { #pragma unroll @@ -2037,12 +2192,43 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync( #pragma unroll for (uint32_t j = 0; j < 2; j++) { - atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]); + prevOrCurrentMax[n][j] = smemColMax[8 * n + 2 * lane + j]; } } } warpGrpBar.arrive_and_wait(); +#endif +#endif + + if (lane < 4) + { +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) + { +#if SKIP_SOFTMAX_ATTN && !SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE + // prevOrCurrentMax <= actual smemColMax (after updates from all 4 warps done), but always >= + // smemColMax(Prev), the smemColMax value *before* this tile is computed. + // When determine whether to skip, it is safe to use prevOrCurrentMax: 1) all 4 warps' localmax < + // smemColMax(Prev), then prevOrCurrentMax == smemColMax(Prev), result not affected; 2) if some localmax + // > smemColMax(Prev), prevOrCurrentMax > smemColMax(Prev), some warps may incorrectly vote skip, but + // at least one warp whose localColMax is larger will not skip, then the tile is not skipped. + // This reduces some sync and check, but has issue when threshold > 1. + prevOrCurrentMax[n][j] = atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]); +#else + atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]); +#endif + } + } + } + warpGrpBar.arrive_and_wait(); + uint32_t const idxInQuad = lane % 4; +#if SKIP_SOFTMAX_ATTN + bool localShouldSkip = true; +#endif #pragma unroll for (uint32_t n = 0; n < src.cols; n++) @@ -2050,10 +2236,21 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync( #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { +#if SKIP_SOFTMAX_ATTN + if (lane < 4 && 8 * n + 2 * idxInQuad + j < headGrpSize) + { + localShouldSkip &= (colMax[n][j] - prevOrCurrentMax[n][j]) < lnThreshold; + } +#endif assert(colMax[n][j] <= smemColMax[8 * n + 2 * idxInQuad + j]); colMax[n][j] = smemColMax[8 * n + 2 * idxInQuad + j]; } } + +#if SKIP_SOFTMAX_ATTN + atomicAnd(smemSkipVote, static_cast(localShouldSkip)); // this will be translated to redux and voteu +#endif + warpGrpBar.arrive_and_wait(); return colMax; } @@ -2199,7 +2396,7 @@ __device__ inline void storeGemm0AccToShm( uint32_t const idxOctInsideHalf = idxInHalf / 8; uint32_t const idxRowInsideOct = lane % 8; uint32_t const warpBaseC = 16 * warpRank; - auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> std::pair + auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> mha::pair { uint32_t const accR = idxAccCoreMat / Gemm0Acc::cols; uint32_t const accC = idxAccCoreMat % Gemm0Acc::cols; @@ -3231,6 +3428,24 @@ __device__ inline void storeRotatedPairsForQ(SharedMem::QBuffer& dst, } #ifndef GENERATE_CUBIN +uint32_t computeNbSubSeqPerSeqHopperF8MHA( + cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen) +{ + auto const env = std::getenv("XQA_NB_SUB_SEQ"); + if (env != nullptr) + { + int32_t const val = std::stoi(env); + if (val > 0) + { + return val; + } + } + float const factor = 0.25f; + return mha::min( + mha::max(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)), + divUp(maxSeqLen, gemm0CtaTileNbTokens)); +} + void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #if SLIDING_WINDOW uint32_t slidingWinSize, @@ -3268,6 +3483,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, // int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, +#endif +#if SKIP_SOFTMAX_ATTN + float const skipSoftmaxThresholdScaleFactor, +#if SKIP_SOFTMAX_ATTN_BLOCK_STATS + uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount, +#endif #endif uint32_t* semaphores, void* scratch, cudaStream_t stream) { @@ -3286,22 +3507,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, uint32_t const nbVHeads = nbKHeads; uint32_t const nbQHeads = nbKHeads * headGrpSize; uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads; - uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t - { - auto const env = std::getenv("XQA_NB_SUB_SEQ"); - if (env != nullptr) - { - int32_t const val = std::stoi(env); - if (val > 0) - { - return val; - } - } - float const factor = 0.25f; - return mha::min( - mha::max(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)), - divUp(maxSeqLen, gemm0CtaTileNbTokens)); - }(); + uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqHopperF8MHA(prop, batchSize, nbKHeads, maxSeqLen); #if SPEC_DEC uint32_t const qSeqLen = specDecParams.qSeqLen; #else @@ -3371,6 +3577,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #endif #if SPEC_DEC specDecParams, +#endif +#if SKIP_SOFTMAX_ATTN + skipSoftmaxThresholdScaleFactor, +#if SKIP_SOFTMAX_ATTN_BLOCK_STATS + skippedBlockCount, totalBlockCount, +#endif #endif semaphores, scratch); #else diff --git a/cpp/kernels/xqa/mha_stdheaders.cuh b/cpp/kernels/xqa/mha_stdheaders.cuh index 5d22d2e018..8f4c252c62 100644 --- a/cpp/kernels/xqa/mha_stdheaders.cuh +++ b/cpp/kernels/xqa/mha_stdheaders.cuh @@ -1272,6 +1272,19 @@ using is_void = is_same, void>; template inline constexpr bool is_void_v = is_void::value; #endif + +#ifndef GENERATE_CUBIN +template +using pair = std::pair; +#else +template +struct pair +{ + T1 first; + T2 second; +}; +#endif + } // namespace mha #if GENERATE_CUBIN diff --git a/cpp/kernels/xqa/test/refAttention.cpp b/cpp/kernels/xqa/test/refAttention.cpp index 303678518f..cc218f4cbd 100644 --- a/cpp/kernels/xqa/test/refAttention.cpp +++ b/cpp/kernels/xqa/test/refAttention.cpp @@ -50,7 +50,8 @@ using Vector = Matrix; template Eigen::Matrix refFlashAttention(IOHead const* q, CacheSeq const& k, CacheSeq const& v, uint32_t seqLen, float qScale, - float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks) + float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor, + uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum) { uint32_t const nbTiles = divUp(seqLen, tileSize); auto gemm1Acc = Eigen::Matrix::Zero().eval(); @@ -61,6 +62,16 @@ Eigen::Matrix refFlashAt float const qkScale = qScale * kvScale / sqrtf(validElemsPerHead); uint32_t const seqBeg = (seqLen < slidingWinSize ? 0 : seqLen - slidingWinSize); uint32_t const idxTileBeg = seqBeg / tileSize; + + uint32_t const nbSubSeq = (multiBlockNum > 0 && nbTiles >= 2) ? mha::min(nbTiles, multiBlockNum) : 1; + std::vector> skipRowMaxs(nbSubSeq); + for (uint32_t i = 0; i < nbSubSeq; i++) + { + skipRowMaxs[i].fill(-INFINITY); + } + bool const disableSkipForShortSeq = (seqLen < skipSoftmaxThresholdScaleFactor); + float const skipSoftmaxThreshold = disableSkipForShortSeq ? 0.0f : skipSoftmaxThresholdScaleFactor / seqLen; + for (uint32_t idxTile = idxTileBeg; idxTile < nbTiles; idxTile++) { Eigen::Matrix gemm0Acc; @@ -88,7 +99,22 @@ Eigen::Matrix refFlashAt } } - Eigen::Vector const tileRowMax = gemm0Acc.rowwise().maxCoeff().cwiseMax(rowMax).eval(); + Eigen::Vector const localRowMax = gemm0Acc.rowwise().maxCoeff().eval(); + Eigen::Vector const tileRowMax = localRowMax.cwiseMax(rowMax).eval(); + auto const prevSkipRowMax = skipRowMaxs[idxTile % nbSubSeq]; + skipRowMaxs[idxTile % nbSubSeq] = localRowMax.cwiseMax(skipRowMaxs[idxTile % nbSubSeq]).eval(); + + if (!disableSkipForShortSeq && skipSoftmaxThreshold > 0) + { + *totalBlockCount += 1; + auto const skipSoftmaxMask = ((localRowMax - prevSkipRowMax).array() < std::log(skipSoftmaxThreshold)); + bool const skipBlock = skipSoftmaxMask.all() && ((idxTile - idxTileBeg) >= nbSubSeq); + if (skipBlock) + { + *skippedBlockCount += 1; + continue; + } + } Eigen::Matrix tileX = (gemm0Acc.colwise() - tileRowMax).array().exp().eval(); @@ -138,7 +164,8 @@ Eigen::Matrix refFlashAt template Eigen::Matrix \ refFlashAttention(IOHead const* q, \ CacheSeq const& k, CacheSeq const& v, uint32_t seqLen, \ - float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks) + float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, \ + float skipSoftmaxThreshold, uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum) INSTANTIATE_refFlashAttention(CacheElem, 64, false, false); INSTANTIATE_refFlashAttention(CacheElem, 64, false, true); diff --git a/cpp/kernels/xqa/test/refAttention.h b/cpp/kernels/xqa/test/refAttention.h index 4f1e673ada..a8dd32bab6 100644 --- a/cpp/kernels/xqa/test/refAttention.h +++ b/cpp/kernels/xqa/test/refAttention.h @@ -88,7 +88,8 @@ struct CacheSeq template Eigen::Matrix refFlashAttention(IOHead const* q, CacheSeq const& k, CacheSeq const& v, uint32_t seqLen, float qScale, - float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks); + float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor, + uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum); template #if SPEC_DEC diff --git a/cpp/kernels/xqa/test/test.cpp b/cpp/kernels/xqa/test/test.cpp index 76e94616ce..9702d4bf61 100644 --- a/cpp/kernels/xqa/test/test.cpp +++ b/cpp/kernels/xqa/test/test.cpp @@ -150,7 +150,8 @@ template #endif #endif void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, bool verbose = false, - bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30) + bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30, + float skipSoftmaxThresholdScaleFactor = 0.0f) { #if IS_MLA if (nbKHeads != 1) @@ -224,6 +225,12 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, seqLen = (16U << 20) / gmemCacheHeadBytes; // 32MB per K+V head. } ctxLen = std::min(ctxLen, seqLen); + uint32_t skippedBlockCount = 0; + uint32_t totalBlockCount = 0; + if (skipSoftmaxThresholdScaleFactor > 0) + { + assert(useQGMMA); + } float const kScale = cacheElemSize == 2 ? 1.f : 1 / 4.f; float const vScale = kScale; float const qScale = 1.f; @@ -329,6 +336,17 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, auto const rcpOutScale = ManagedMemBuf(1); auto const seqLenList = ManagedMemBuf(batchSize); auto const ctxLenList = ManagedMemBuf(batchSize); +#if SKIP_SOFTMAX_ATTN +#ifdef SKIP_SOFTMAX_ATTN_BLOCK_STATS + auto const kernelSkippedBlockCount = ManagedMemBuf(1); + auto const kernelTotalBlockCount = ManagedMemBuf(1); + kernelSkippedBlockCount[0] = 0; + kernelTotalBlockCount[0] = 0; +#endif +#else + EXPECT_EQ(skipSoftmaxThresholdScaleFactor, 0.0f) + << "Got non-zero skipSoftmaxThresholdScaleFactor while SKIP_SOFTMAX_ATTN is not enabled."; +#endif #if USE_PAGED_KV_CACHE auto const pageListBuf = ManagedMemBuf(pageListBytes); #if PAGED_KV_CACHE_LAYOUT == 1 @@ -726,6 +744,11 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, maxSeqLen, &seqLenList[0][0], batchSize, kvCacheScale.get(), semaphores.get(), scratch, stream); }; #else + auto multiBlockNum = [&]() + { + auto const calcFunc = useQGMMA ? &computeNbSubSeqPerSeqHopperF8MHA : &computeNbSubSeqPerSeqMHA; + return calcFunc(prop, batchSize, nbKHeads, maxSeqLen); + }(); auto runKernel = [&]() { auto const launchFunc = useQGMMA ? &launchHopperF8MHA : &launchMHA; @@ -776,6 +799,12 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, batchSize, kvCacheScale.get(), #if SPEC_DEC specDecParams, +#endif +#if SKIP_SOFTMAX_ATTN + skipSoftmaxThresholdScaleFactor, +#if SKIP_SOFTMAX_ATTN_BLOCK_STATS + kernelSkippedBlockCount.get(), kernelTotalBlockCount.get(), +#endif #endif semaphores.get(), scratch, stream); checkCuda(cudaGetLastError()); @@ -813,6 +842,10 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, checkCuda(cudaEventRecord(toc, stream)); prefetchToDevice(cudaCpuDeviceId); checkCuda(cudaStreamSynchronize(stream)); +#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS + kernelSkippedBlockCount[0] /= nbIters; + kernelTotalBlockCount[0] /= nbIters; +#endif if (testPerf) { float ms; @@ -849,6 +882,15 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, = totalNbCacheLoadBytes + inputBytes + outputBytes; // we ignore page indices and beam search indices. float const dramSolTime = totalTraffic / bandwidth * 1E3f; float const dramSolRatio = dramSolTime / ms; +#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS + size_t const totalNbCacheLoadWithSkip = gmemCacheHeadBytes + * (nbKHeads + nbVHeads * (1 - 1.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0])) + * nbLoadedCacheTokens; + float const totalTrafficWithSkip + = totalNbCacheLoadWithSkip + inputBytes + outputBytes; // we ignore page indices and beam search indices. + float const dramSolTimeWithSkip = totalTrafficWithSkip / bandwidth * 1E3f; + float const dramSolRatioWithSkip = dramSolTimeWithSkip / ms; +#endif if (verbose) { printf("done\n"); @@ -863,7 +905,13 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, } float const tops = headGrpSize * qSeqLen * float(seqLen) * (validElemsPerKHead + validElemsPerVHead) * 2 * nbKHeads * batchSize / (ms * 1E-3F) * 1E-12F; +#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS + printf("kernel skippedBlockCount: %d/%d (%.2f%%)\n", kernelSkippedBlockCount[0], kernelTotalBlockCount[0], + kernelTotalBlockCount[0] == 0 ? 0.0f : 100.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]); + printf("dramSolRatioWithSkip: %f%% (%f ms, TOPS = %f)\n", dramSolRatioWithSkip * 100, ms, tops); +#else printf("dramSolRatio: %f%% (%f ms, TOPS = %f)\n", dramSolRatio * 100, ms, tops); +#endif } if (refCheck) { @@ -1084,8 +1132,8 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, if (useQGMMA) { refOutput = refFlashAttention(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, - vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, - refAttentionSinks); + vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, refAttentionSinks, + skipSoftmaxThresholdScaleFactor, &skippedBlockCount, &totalBlockCount, multiBlockNum); // refOutput = refAttention(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, // vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize); } @@ -1132,6 +1180,14 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, #endif } } +#if SKIP_SOFTMAX_ATTN + printf("host skippedBlockCount: %d/%d (%.2f%%)\n", skippedBlockCount, totalBlockCount, + totalBlockCount == 0 ? 0.0f : 100.0f * skippedBlockCount / totalBlockCount); +#if SKIP_SOFTMAX_ATTN_BLOCK_STATS + printf("kernel skippedBlockCount: %d/%d (%.2f%%)\n", kernelSkippedBlockCount[0], kernelTotalBlockCount[0], + kernelTotalBlockCount[0] == 0 ? 0.0f : 100.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]); +#endif +#endif if (saveData) { fout_refOutput.close(); @@ -1253,6 +1309,14 @@ TEST(RefCheck, llama_V2_70b) #if SLIDING_WINDOW runTest<2>(2, 4096, false, true, false, false, false, ~0, 256); runTest<2>(2, 400, false, true, false, false, false, ~0U, 256); +#endif +#if SKIP_SOFTMAX_ATTN + runTest<1>(32, 2048, false, true, false, false, false, ~0U, 1U << 30, 0.f); + runTest<4>(32, 1538, false, true, false, false, false, ~0U, 1U << 30, 1280.f); + runTest<2>(32, 4096, false, true, false, false, false, ~0U, 1U << 30, 125.f); + runTest<4>(32, 300, false, true, false, false, false, ~0U, 1U << 30, 80.f); + runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 501.0f); + runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 500.f); #endif runTest<8>(120, 367, false, true); runTest<8>(1792, 2048, false, true); diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index 339da7c527..32a9332a01 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -298,6 +298,11 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams& xqaParams.use_sparse_attention = useTllmGenSparseAttention(); // Skip softmax threshold. xqaParams.skip_softmax_threshold_scale_factor = mSkipSoftmaxThresholdScaleFactorDecode; +#ifdef SKIP_SOFTMAX_STAT + // Statistics of skip-softmax, pointers of device memory for output + xqaParams.skip_softmax_total_blocks = mSkipSoftmaxTotalBlocks; + xqaParams.skip_softmax_skipped_blocks = mSkipSoftmaxSkippedBlocks; +#endif // Cross attention parameters. xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths; diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/compileEngine.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/compileEngine.cpp index 33587d7961..9571737f04 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/compileEngine.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/compileEngine.cpp @@ -105,7 +105,8 @@ CubinObj CompileEngine::compile() const // scratch in this case. /*use_input_kv=*/applyRoPEInXqaKernel, /*rope_style=*/ropeStyle, - /*is_spec_dec_tree=*/mXqaParams.is_spec_dec_tree}; + /*is_spec_dec_tree=*/mXqaParams.is_spec_dec_tree, + /*use_skip_softmax_attn=*/mXqaParams.skip_softmax_threshold_scale_factor != 0}; if (context.kernel_type == TLLM_XQA_JIT_MLA) { auto const& c = context; diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp index 90dda051a0..877a780072 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp @@ -232,6 +232,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const& jit::CubinObj const* const cubinObj = mResource->getCubinObjRegistry()->getCubin(key); TLLM_CHECK(cubinObj != nullptr && cubinObj->isInitialized()); bool const isSpecDec = xqaParams.multi_query_tokens; + bool const isSkipSoftmax = xqaParams.skip_softmax_threshold_scale_factor != 0; bool const isHMMAKernel = (cubinObj->getKernelType() == XQAKernelType::kAMPERE_WARP_SPECIALIZED); bool const isGMMAKernel = (cubinObj->getKernelType() == XQAKernelType::kHOPPER_WARP_SPECIALIZED); bool const isMLAKernel = (cubinObj->getKernelType() == XQAKernelType::kSM120_MLA); @@ -378,7 +379,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const& .mask = reinterpret_cast(xqaParams.spec_decoding_packed_mask)}; }; - constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 16; + constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 19; uint32_t idxNextParam = 0; void* kernelParams[kMAX_NB_KERNEL_PARAMS]; auto appendParam = [&](auto* p) mutable @@ -514,6 +515,16 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const& appendParam(&specDecParams); specDecBlocks = divUp(specDecParams.qSeqLen, 64 / num_q_heads_over_kv); } + if (isSkipSoftmax) + { + TLLM_CHECK_WITH_INFO(isGMMAKernel, "skip softmax is only supported for GMMA kernel for now."); + TLLM_CHECK_WITH_INFO(!isSpecDec, "skip softmax is not supported with spec dec for now."); + appendParam(&xqaParams.skip_softmax_threshold_scale_factor); +#ifdef SKIP_SOFTMAX_STAT + appendParam(&xqaParams.skip_softmax_total_blocks); + appendParam(&xqaParams.skip_softmax_skipped_blocks); +#endif + } appendParam(&launchParams.semaphores); appendParam(&launchParams.scratch); kernelParams[idxNextParam] = nullptr; // one extra nullptr at end as guard. diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp index 26fadd21cc..f6f73dab2e 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp @@ -96,10 +96,16 @@ bool supportConfigQGMMA(XQAParams const& xqaParams, int SM, bool forConfigurePlu { return false; } - if (xqaParams.kv_cache_data_type != DATA_TYPE_E4M3) + if (!contains({DATA_TYPE_FP16, DATA_TYPE_BF16, DATA_TYPE_E4M3}, xqaParams.kv_cache_data_type)) { return false; } + bool const is_skip_softmax = xqaParams.skip_softmax_threshold_scale_factor != 0; + if (!is_skip_softmax && xqaParams.kv_cache_data_type != DATA_TYPE_E4M3) + { + // Only use hopper kernel with fp16/bf16 kv cache data type when skip softmax is enabled + return false; + } if (xqaParams.beam_width != 1) { return false; @@ -168,6 +174,11 @@ bool supportConfigHMMA(XQAParams const& xqaParams, int SM, bool forConfigurePlug { return false; } + bool const is_skip_softmax = xqaParams.skip_softmax_threshold_scale_factor != 0; + if (is_skip_softmax) + { + return false; + } return true; } @@ -201,6 +212,11 @@ bool supportConfigMLA(XQAParams const& xqaParams, int SM, bool forConfigurePlugi { return false; } + bool const is_skip_softmax = xqaParams.skip_softmax_threshold_scale_factor != 0; + if (is_skip_softmax) + { + return false; + } return true; } diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/include/nvrtcWrapper.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/include/nvrtcWrapper.h index ab9e93f0d4..b132e76918 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/include/nvrtcWrapper.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/include/nvrtcWrapper.h @@ -66,6 +66,7 @@ extern "C" bool is_spec_dec_tree = true; // useful only when multi_query_tokens, should be true unless using linear tree in spec-dec. + bool use_skip_softmax_attn; } tllmXqaJitContext; // tllmXqaJitProgram is an opaque handle for a program. diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp index 96481d8474..384b29e313 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp @@ -215,6 +215,10 @@ tllmXqaJitStatus getMacroFlags(tllmXqaJitContext const* context, std::vectoruse_input_kv ? "1" : "0"; macros["ROPE_STYLE"] = std::to_string(int(context->rope_style)); macros["IS_SPEC_DEC_TREE"] = context->is_spec_dec_tree ? "1" : "0"; + macros["SKIP_SOFTMAX_ATTN"] = context->use_skip_softmax_attn ? "1" : "0"; +#ifdef SKIP_SOFTMAX_STAT + macros["SKIP_SOFTMAX_ATTN_BLOCK_STATS"] = context->use_skip_softmax_attn ? "1" : "0"; +#endif // Without these macros, NVRTC uses precompiled headers for cuda_fp16.h etc. // Linking might fail due to ABI incompatibility. diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp index 7bd7c32e5e..822483560c 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp @@ -493,6 +493,10 @@ bool DecoderXQAImplPrecompiled::shouldUse(XQAParams const& xqaParams, bool forCo { SUPPORT_RETURN_FALSE("streaming-llm"); } + if (xqaParams.skip_softmax_threshold_scale_factor != 0) + { + SUPPORT_RETURN_FALSE("skip_softmax_threshold_scale_factor"); + } // OPTIMIZE: For the standard generation-phase MHA, there are still extra limitations. // NOTE: Medusa mode = Multi_query_tokens > 1. diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/tensorMapUtils.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/tensorMapUtils.cpp index e4b642a11e..250e850a3a 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/tensorMapUtils.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/tensorMapUtils.cpp @@ -64,6 +64,21 @@ CUtensorMapSwizzle getSwizzleMode(uint32_t partBytes) } }; +CUtensorMapDataType_enum getDataTypeFromXqaParams(XQAParams const& xqaParams) +{ + if (xqaParams.kv_cache_data_type == DATA_TYPE_BF16) + { + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } + else if (xqaParams.kv_cache_data_type == DATA_TYPE_FP16) + { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } + TLLM_CHECK(xqaParams.kv_cache_data_type == DATA_TYPE_E4M3 || xqaParams.kv_cache_data_type == DATA_TYPE_E5M2 + || xqaParams.kv_cache_data_type == DATA_TYPE_INT8); + return CU_TENSOR_MAP_DATA_TYPE_UINT8; +} + CUtensorMap makeTensorMapForQ(std::shared_ptr const& driver, void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, uint32_t totalNbHeads, uint32_t partElems, uint32_t boxHeads) { @@ -131,24 +146,26 @@ CUtensorMap makeTensorMapForHopperXqaKVCache( if constexpr (std::is_same_v) { uint32_t const headElems = xqaParams.head_size; - uint32_t const elemBytes = getElemBytes(CU_TENSOR_MAP_DATA_TYPE_UINT8); + CUtensorMapDataType_enum const dataType = getDataTypeFromXqaParams(xqaParams); + uint32_t const elemBytes = getElemBytes(dataType); TLLM_CHECK(headElems <= 256); uint32_t const paddedHeadElems = headElems <= 64 ? 64 : (headElems <= 128 ? 128 : 256); uint32_t const partElems = std::min(elemBytes * paddedHeadElems, 128U) / elemBytes; - return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, CU_TENSOR_MAP_DATA_TYPE_UINT8, - xqaParams.head_size, xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems); + return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, dataType, xqaParams.head_size, + xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems); } else { static_assert(std::is_same_v); uint32_t const headElems = xqaParams.head_size; - uint32_t const elemBytes = getElemBytes(CU_TENSOR_MAP_DATA_TYPE_UINT8); + CUtensorMapDataType_enum const dataType = getDataTypeFromXqaParams(xqaParams); + uint32_t const elemBytes = getElemBytes(dataType); TLLM_CHECK(headElems <= 256); uint32_t const paddedHeadElems = headElems <= 64 ? 64 : (headElems <= 128 ? 128 : 256); uint32_t const partElems = std::min(elemBytes * paddedHeadElems, 128U) / elemBytes; - return makeTensorMapForContiguousKVCache(driver, kv_cache_buffer.data, CU_TENSOR_MAP_DATA_TYPE_UINT8, - xqaParams.head_size, xqaParams.num_kv_heads, xqaParams.max_attention_window_size, xqaParams.beam_width, - xqaParams.batch_size, partElems); + return makeTensorMapForContiguousKVCache(driver, kv_cache_buffer.data, dataType, xqaParams.head_size, + xqaParams.num_kv_heads, xqaParams.max_attention_window_size, xqaParams.beam_width, xqaParams.batch_size, + partElems); } } @@ -161,11 +178,12 @@ template CUtensorMap makeTensorMapForXqaMlaKVCache(std::shared_ptr const& driver, XQAParams const& xqaParams, KVCacheBuffer const& kv_cache_buffer, bool forK) { + CUtensorMapDataType_enum const dataType = getDataTypeFromXqaParams(xqaParams); uint32_t const partElems = (forK ? 64 : 128); if constexpr (std::is_same_v) { - return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, CU_TENSOR_MAP_DATA_TYPE_UINT8, - xqaParams.head_size, xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems); + return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, dataType, xqaParams.head_size, + xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems); } else { @@ -183,7 +201,7 @@ CUtensorMap makeTensorMapForXqaMlaQ( std::shared_ptr const& driver, XQAParams const& xqaParams, void const* q) { uint32_t const partElems = 64; - return makeTensorMapForQ(driver, q, CU_TENSOR_MAP_DATA_TYPE_UINT8, xqaParams.head_size, + return makeTensorMapForQ(driver, q, getDataTypeFromXqaParams(xqaParams), xqaParams.head_size, xqaParams.num_q_heads * xqaParams.total_num_input_tokens, partElems, xqaParams.num_q_heads); } } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h index 406bf54b1f..ce2b77aa92 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h @@ -119,7 +119,12 @@ struct XQAParams bool use_sparse_attention = false; // Skip softmax threshold. - float skip_softmax_threshold_scale_factor = 0.0f; + float skip_softmax_threshold_scale_factor = 0; + +#ifdef SKIP_SOFTMAX_STAT + uint32_t* skip_softmax_total_blocks = nullptr; + uint32_t* skip_softmax_skipped_blocks = nullptr; +#endif cudaStream_t stream = 0; // layer index @@ -199,6 +204,10 @@ struct XQAParams << "sparse_params: " << sparse_params.toString() << std::endl << "use_sparse_attention :" << (use_sparse_attention ? "true" : "false") << std ::endl << "skip_softmax_threshold_scale_factor :" << skip_softmax_threshold_scale_factor << std ::endl +#ifdef SKIP_SOFTMAX_STAT + << "skip_softmax_total_blocks :" << skip_softmax_total_blocks << std ::endl + << "skip_softmax_skipped_blocks :" << skip_softmax_skipped_blocks << std ::endl +#endif << "stream :" << stream; return ss.str();