[TRTLLM-10022][feat] Add hopper xqa decode support for skip softmax attention (#10264)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
This commit is contained in:
Pengbo Wang 2026-01-12 08:26:10 +08:00 committed by GitHub
parent c5d5af9e7f
commit c0e25e5418
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 654 additions and 228 deletions

View File

@ -129,6 +129,18 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
#define SLIDING_WINDOW 0
#endif
#ifndef SKIP_SOFTMAX_ATTN
#define SKIP_SOFTMAX_ATTN 0
#endif
#ifndef SKIP_SOFTMAX_ATTN_BLOCK_STATS
#define SKIP_SOFTMAX_ATTN_BLOCK_STATS 0
#endif
#ifndef SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
#define SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE 1
#endif
// 0 - no PDL
// 1 - naive PDL
// 2 - aggressive PDL (implemented only in mha_sm90.cu for now)

View File

@ -106,6 +106,7 @@ __device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset,
asm volatile("trap;\n");
return 0;
}();
assert(__cvta_generic_to_shared(data) % baseAlign == 0);
uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7));
return MatDesc{
/*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)),

View File

@ -2734,6 +2734,25 @@ static constexpr auto kernel_mha = kernel_mha_impl;
#endif
#ifndef GENERATE_CUBIN
uint32_t computeNbSubSeqPerSeqMHA(cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen)
{
if (!allowMultiBlockMode)
{
return 1;
}
auto const env = std::getenv("XQA_NB_SUB_SEQ");
if (env != nullptr)
{
int32_t const val = std::stoi(env);
if (val > 0)
{
return val;
}
}
return std::min<uint32_t>(
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
}
void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
@ -2771,6 +2790,13 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor, // for compatibility with mha_sm90.cu only
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, // for compatibility with mha_sm90.cu only
uint32_t* __restrict__ totalBlockCount, // for compatibility with mha_sm90.cu only
#endif
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream)
{
@ -2793,24 +2819,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
uint32_t const nbQHeads = nbKHeads * headGrpSize;
// const uint32_t nbSubSeqPerSeq = allowMultiBlockMode ? DBG_NB_CTAS_PER_SEQ : 1;
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t
{
if (!allowMultiBlockMode)
{
return 1;
}
auto const env = std::getenv("XQA_NB_SUB_SEQ");
if (env != nullptr)
{
int32_t const val = std::stoi(env);
if (val > 0)
{
return val;
}
}
return std::min<uint32_t>(
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
}();
uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqMHA(prop, batchSize, nbKHeads, maxSeqLen);
// gridDim.z == batchSize && gridDim.y == nbKHeads && gridDim.x == nbSubSeqPerSeq
#if SPEC_DEC
const uint32_t nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, rowsPerBlock);

View File

@ -90,6 +90,9 @@ struct BeamSearchParams
// match trt-llm API.
};
uint32_t computeNbSubSeqPerSeqMHA(
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);
void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
@ -127,9 +130,18 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
#endif
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream);
uint32_t computeNbSubSeqPerSeqHopperF8MHA(
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);
void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
@ -167,6 +179,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
#endif
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream);

View File

@ -49,6 +49,10 @@ static_assert(specDecQLen * headGrpSize <= 32, "SPEC_Q_SEQ_LEN macro value is to
#define SWAP_AB (!SPEC_DEC)
#endif
#if SKIP_SOFTMAX_ATTN
static_assert(SWAP_AB && USE_PAGED_KV_CACHE && !SPEC_DEC && BEAM_WIDTH == 1, "SKIP_SOFTMAX_ATTN is not supported.");
#endif
#define IS_SUPPORTED_F16_CASE (CACHE_ELEM_ENUM == 0 && !SPEC_DEC && SWAP_AB && !USE_INPUT_KV && !LOW_PREC_OUTPUT)
inline constexpr bool swapAB = SWAP_AB;
@ -138,26 +142,38 @@ using PaddedOutHead = PaddedInputHead;
struct alignas(128) SharedMem
{
using QBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerQPart>, nbQParts>;
using KBuffer = Array2D<LdGrain, gemm0CtaTileNbTokens, exactDiv(cacheHeadPartBytes, grainBytes)>;
static constexpr uint32_t nbKBuf = 2;
KBuffer k[nbKBuf]; // as is loaded from global mem.
using XBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerXPart>, nbXParts>;
static constexpr uint32_t nbXBuf
= 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
using VBuffer = Vec<Array2D<LdGrain, gemm1CtaTileNbTokens, exactDiv(cacheHeadPartBytes, grainBytes),
sizeof(XBuffer) % (cacheHeadPartBytes * 8) == 0>,
cacheHeadNbParts>;
#if !SWAP_AB
using VTBuffer = Array2D<LdGrain, headElems, exactDiv(gemm1CtaTileNbTokens, cacheElemsPerGrain), true>;
#endif
static constexpr uint32_t nbVBuf = 2;
#if CACHE_ELEM_ENUM == 0
using OutSwizzleBuf = Array2D<LdGrain, ctaNbQHeads, grainsPerPaddedInputHead>;
#elif CACHE_ELEM_ENUM == 2
using OutSwizzleBuf = Array2D<Vec<Vec<InputElem, 4>, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>;
#endif
#if SKIP_SOFTMAX_ATTN
static constexpr uint32_t nbKBuf = 2;
static constexpr uint32_t nbVBuf = 3; // @fixme: skip_softmax_attn: for skip softmax attn, an extra VBuffer is used
static constexpr uint32_t nbXBuf
= 3 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
#else
static constexpr uint32_t nbKBuf = 2;
static constexpr uint32_t nbVBuf = 2;
static constexpr uint32_t nbXBuf
= 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
#endif
static_assert(nbXBuf == nbVBuf);
// note: buffers used for GMMA may have additional alignment requirements
KBuffer k[nbKBuf]; // as is loaded from global mem.
QBuffer q; // For gmma math. Conversion done if needed.
union ReusedXVOutSwizzleBuf
{
struct XV
@ -196,9 +212,6 @@ struct alignas(128) SharedMem
return reusedXVOutSwizzleBuf[i].outSwizzle;
}
using QBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerQPart>, nbQParts>;
QBuffer q; // For gmma math. Conversion done if needed.
// @fixme: move these into reusedXVOutSwizzleBuf
#if SWAP_AB
ShmQWiseVec xColMax[nbXBuf];
@ -220,6 +233,11 @@ struct alignas(128) SharedMem
Vec<KVCachePageIndex, nbPagesPerTile> pages[2]; // one for K and one for V
#endif
#if SKIP_SOFTMAX_ATTN
uint32_t skipSoftmaxVotesGemm0ToV[nbXBuf]; // guarded by skipSoftmaxXBar
uint32_t skipSoftmaxVotesGemm0ToGemm1[nbXBuf]; // guarded by xBar
#endif
// mem barriers
CtaBarrierPair qBar;
@ -229,6 +247,9 @@ struct alignas(128) SharedMem
CtaBarrierPair vtBar[nbVBuf];
#endif
CtaBarrierPair xBar[nbXBuf];
#if SKIP_SOFTMAX_ATTN
CtaBarrierPair skipSoftmaxXBar[nbXBuf]; // for V to wait for X to be ready
#endif
// used internally in the gemm0 warp group
// @fixme: use separate arrive and wait for all usage
@ -425,8 +446,13 @@ __device__ void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
#endif
#if SWAP_AB
#if SKIP_SOFTMAX_ATTN
__device__ RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src,
float skipSoftmaxThreshold, uint32_t* smemSkipVote, bool maybeSkip);
#else
__device__ RegColWiseVec computeWarpGrpColMax_sync(
CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src);
#endif
__device__ void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd);
__device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax);
__device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src);
@ -675,6 +701,12 @@ CUBIN_EXPORT __global__
#endif
#if SPEC_DEC
SpecDecParams const specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
#endif
#endif
uint32_t* __restrict__ const semaphores
= nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)]
@ -753,6 +785,10 @@ CUBIN_EXPORT __global__
uint32_t const nbSubSeq = isMultiBlockMode ? mha::min(nbTilesInUse / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1;
static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2);
assert(isMultiBlockMode == (nbSubSeq > 1));
#if SKIP_SOFTMAX_ATTN
bool const disableSkipForShortSeq = (cacheSeqLen < skipSoftmaxThresholdScaleFactor);
float const skipSoftmaxThreshold = disableSkipForShortSeq ? 0.0f : skipSoftmaxThresholdScaleFactor / cacheSeqLen;
#endif
if (idxSubSeq >= nbSubSeq)
{
return;
@ -776,21 +812,34 @@ CUBIN_EXPORT __global__
assert(dynamicSmemSize() >= sizeof(SharedMem));
SharedMem& smem = *reinterpret_cast<SharedMem*>(&smemByteBuf[0]);
constexpr uint32_t nbBuffers = 2;
static_assert(nbBuffers == SharedMem::nbKBuf && nbBuffers == SharedMem::nbVBuf && nbBuffers == SharedMem::nbXBuf);
if (wid < nbBuffers)
constexpr uint32_t maxNbBuffers = (SharedMem::nbXBuf > SharedMem::nbVBuf) ? SharedMem::nbXBuf : SharedMem::nbVBuf;
static_assert(
maxNbBuffers >= SharedMem::nbKBuf && maxNbBuffers >= SharedMem::nbVBuf && maxNbBuffers >= SharedMem::nbXBuf);
if (wid < maxNbBuffers)
{
if (warpElectSync())
{
smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size);
smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size);
#if !SWAP_AB
smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2);
if (wid < SharedMem::nbKBuf)
{
smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size);
}
if (wid < SharedMem::nbXBuf)
{
#if SKIP_SOFTMAX_ATTN
smem.skipSoftmaxXBar[wid].initialize(gemm0NbThrds + warp_size, gemm0NbThrds + warp_size);
smem.vBar[wid].initialize(gemm1NbThrds + warp_size, gemm1NbThrds + warp_size);
#else
smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size);
#endif
smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds);
#if !SWAP_AB
smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2);
#endif
smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds);
}
}
}
else if (wid == nbBuffers)
else if (wid == maxNbBuffers)
{
if (warpElectSync())
{
@ -819,6 +868,10 @@ CUBIN_EXPORT __global__
SpecDec const specDec{specDecParams, idxReq, idxInputSubSeq, cacheSeqLen};
#endif
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t localSkippedBlockCount = 0;
#endif
// QK gemm
constexpr uint32_t nbGmmaInstM = exactDiv(gemm0CtaTileNbTokens, gmma::instM);
using Acc = GmmaAcc<gemm0CtaTileNbTokens, ctaNbQHeads>;
@ -940,10 +993,39 @@ CUBIN_EXPORT __global__
}
}
#endif
uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf;
auto& xBar = smem.xBar[idxXBuf];
// update colMax in shared mem and get a register copy
#if SWAP_AB
#if SKIP_SOFTMAX_ATTN
auto& skipSoftmaxXBar = smem.skipSoftmaxXBar[idxXBuf];
skipSoftmaxXBar.consumed.arrive_and_wait();
bool const maybeSkip = !disableSkipForShortSeq && idxIter != 0;
RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc,
skipSoftmaxThreshold, &smem.skipSoftmaxVotesGemm0ToV[idxXBuf], maybeSkip);
bool const shouldSkipSoftmaxAttn = static_cast<bool>(smem.skipSoftmaxVotesGemm0ToV[idxXBuf]);
unused(skipSoftmaxXBar.produced.arrive());
warpGrpOnlineSoftmax(acc, colMax);
if (shouldSkipSoftmaxAttn)
{
xBar.consumed.arrive_and_wait();
if (threadIdx.x == 0)
{
smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf] = 1U;
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
localSkippedBlockCount++;
#endif
}
asm volatile("fence.proxy.async.shared::cta;\n"); // maybe not used
unused(xBar.produced.arrive());
continue;
}
#else
RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc);
warpGrpOnlineSoftmax(acc, colMax);
#endif
#else
RegRowWiseVec const rowMax = computeWarpGrpRowMax_sync(warpRank, smem.gemm0CurrentSeqMax, acc);
warpGrpOnlineSoftmax(acc, rowMax);
@ -959,8 +1041,6 @@ CUBIN_EXPORT __global__
// map 1 to fp8_max before conversion to fp8
acc = acc * kE4M3_MAX;
uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf;
auto& xBar = smem.xBar[idxXBuf];
// @fixme: for fp16/bf16, try not to transpose acc here, and leave it to the next GEMM.
#if SWAP_AB
storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc);
@ -989,13 +1069,25 @@ CUBIN_EXPORT __global__
storeShmRowWiseVec(warpRank, smem.xRowMax[idxXBuf], rowMax);
storeShmRowWiseVec(warpRank, smem.xRowSum[idxXBuf], rowSum);
#endif
#if SKIP_SOFTMAX_ATTN
if (threadIdx.x == 0)
{
smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf] = 0;
}
#endif
__syncwarp();
// the release semantics of arrive does not work for async consumers like gmma. additional fence is
// needed.
asm volatile("fence.proxy.async.shared::cta;\n");
unused(xBar.produced.arrive());
}
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
if (threadIdx.x == 0 && skippedBlockCount != nullptr && totalBlockCount != nullptr)
{
atomicAdd(skippedBlockCount, localSkippedBlockCount);
atomicAdd(totalBlockCount, nbIters);
}
#endif
unused(smem.qBar.consumed.arrive());
}
else if (warpIdx.z == 1)
@ -1043,216 +1135,231 @@ CUBIN_EXPORT __global__
uint32_t idxVTile = idxVTileInit + idxIter * nbSubSeq;
auto const idxVBuf = idxIter % SharedMem::nbVBuf;
auto const idxXBuf = idxVBuf;
auto& vBar = smem.vBar[idxVBuf];
arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds));
auto const& vBuf = smem.vBuf(idxVBuf);
#if !SWAP_AB
CtaBarrierPair& vtBar = smem.vtBar[idxVBuf];
auto& vtBuf = smem.vtBuf(idxVBuf);
vtBar.consumed.arrive_and_wait();
transposeVTile(warpRank, laneId(), vtBuf, vBuf);
vBar.consumed.arrive();
vtBar.produced.arrive();
#endif
auto& xBar = smem.xBar[idxXBuf];
auto& vBar = smem.vBar[idxVBuf];
auto const& vBuf = smem.vBuf(idxVBuf);
xBar.produced.arrive_and_wait();
#if SKIP_SOFTMAX_ATTN
bool shouldSkipSoftmaxAttn = smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf]; // guarded by xBar
if (shouldSkipSoftmaxAttn)
{
vBar.produced.arrive_and_wait();
}
#endif
#if SKIP_SOFTMAX_ATTN
if (!shouldSkipSoftmaxAttn) // skip XVGemm
#endif
{
arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds));
#if !SWAP_AB
CtaBarrierPair& vtBar = smem.vtBar[idxVBuf];
auto& vtBuf = smem.vtBuf(idxVBuf);
vtBar.consumed.arrive_and_wait();
transposeVTile(warpRank, laneId(), vtBuf, vBuf);
vBar.consumed.arrive();
vtBar.produced.arrive();
#endif
#if !defined(NDEBUG) && DBG_PRINT
#if SWAP_AB
if (threadIdx.x == 0)
{
printf("colMax:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xColMax[idxXBuf][i]);
}
printf("\n");
printf("colSum:\n");
for (int n = 0; n < 4; n++)
if (threadIdx.x == 0)
{
printf("colMax:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xColSum[idxXBuf][n][i]);
printf("%f, ", smem.xColMax[idxXBuf][i]);
}
printf("\n");
printf("colSum:\n");
for (int n = 0; n < 4; n++)
{
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xColSum[idxXBuf][n][i]);
}
printf("\n");
}
printf("\n");
printf("X:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
for (int j = 0; j < gemm0CtaTileNbTokens; j++)
{
auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart);
auto const e = reinterpret_cast<Vec<__nv_fp8_e4m3, 16>&>(
smem.xBuf(idxXBuf)[j / elemsPerXPart].template at<true>(
i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain];
printf("%.2f, ", float(e));
if (j % 16 == 15)
{
printf("| ");
}
}
printf("\n\n");
}
}
smem.gemm1WarpGrpBar.arrive_and_wait();
#else
if (blockIdx.y == 1 && threadIdx.x == 0)
{
printf("rowMax:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xRowMax[idxXBuf][i]);
}
printf("\n");
printf("rowSum:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xRowSum[idxXBuf][i]);
}
printf("\n");
}
printf("\n");
printf("X:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
for (int j = 0; j < gemm0CtaTileNbTokens; j++)
{
auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart);
auto const e = reinterpret_cast<Vec<__nv_fp8_e4m3, 16>&>(
smem.xBuf(idxXBuf)[j / elemsPerXPart].template at<true>(
i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain];
printf("%.2f, ", float(e));
if (j % 16 == 15)
{
printf("| ");
}
}
printf("\n\n");
}
}
smem.gemm1WarpGrpBar.arrive_and_wait();
#else
if (blockIdx.y == 1 && threadIdx.x == 0)
{
printf("rowMax:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xRowMax[idxXBuf][i]);
}
printf("\n");
printf("rowSum:\n");
for (int i = 0; i < ctaNbQHeads; i++)
{
printf("%f, ", smem.xRowSum[idxXBuf][i]);
}
printf("\n");
}
smem.gemm1WarpGrpBar.arrive_and_wait();
smem.gemm1WarpGrpBar.arrive_and_wait();
#endif
#endif
#if SWAP_AB
// @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead.
rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf],
smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar);
// @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead.
rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf],
smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar);
#else
rescaleGemm1AccForNewRowMax_sync(
warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf], smem.gemm1AccColMax, acc, smem.gemm1AccColSum);
rescaleGemm1AccForNewRowMax_sync(warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf],
smem.gemm1AccColMax, acc, smem.gemm1AccColSum);
#endif
auto& xBuf = smem.xBuf(idxXBuf);
auto& xBuf = smem.xBuf(idxXBuf);
auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8,
gmma::getSwizzleMode<true>(SharedMem::XBuffer::Elem{}))
.raw();
auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8,
gmma::getSwizzleMode<true>(SharedMem::XBuffer::Elem{}))
.raw();
#if CACHE_ELEM_ENUM == 0
auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8,
gmma::getSwizzleMode<true>(SharedMem::VBuffer::Elem{}))
.raw();
auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8,
gmma::getSwizzleMode<true>(SharedMem::VBuffer::Elem{}))
.raw();
#endif
#if SWAP_AB
//@fixme: to reduce code size, we can disable unroll and use double-buffer for LDSM in loadVTileTransposed.
#pragma unroll
for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++)
{
for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++)
{
#if CACHE_ELEM_ENUM == 2
Vec<RegMatAFrag, gemm1NbGmmaInstM> const fragA
= loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK);
Vec<RegMatAFrag, gemm1NbGmmaInstM> const fragA
= loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK);
#if !defined(NDEBUG) && DBG_PRINT
if (threadIdx.x == 0)
{
printf("fragA:\nidxInstK == %u\n", idxInstK);
}
smem.gemm1WarpGrpBar.arrive_and_wait();
for (int m = 0; m < 2; m++)
{
for (int w = 0; w < 4; w++)
if (threadIdx.x == 0)
{
if (warpRank == w)
printf("fragA:\nidxInstK == %u\n", idxInstK);
}
smem.gemm1WarpGrpBar.arrive_and_wait();
for (int m = 0; m < 2; m++)
{
for (int w = 0; w < 4; w++)
{
if (laneId() == 0)
if (warpRank == w)
{
printf(" warpRank = %u\n", warpRank);
}
__syncwarp();
for (int a = 0; a < 2; a++)
{
for (int b = 0; b < 8; b++)
if (laneId() == 0)
{
for (int c = 0; c < 2; c++)
printf(" warpRank = %u\n", warpRank);
}
__syncwarp();
for (int a = 0; a < 2; a++)
{
for (int b = 0; b < 8; b++)
{
for (int d = 0; d < 4; d++)
for (int c = 0; c < 2; c++)
{
if (laneId() == b * 4 + d)
for (int d = 0; d < 4; d++)
{
for (int e = 0; e < 4; e++)
if (laneId() == b * 4 + d)
{
auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>(
fragA[m](0, c)(a, 0));
printf("%.2f, ", float(elem4[e]));
for (int e = 0; e < 4; e++)
{
auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>(
fragA[m](0, c)(a, 0));
printf("%.2f, ", float(elem4[e]));
}
}
__syncwarp();
}
__syncwarp();
}
if (laneId() == 0)
{
printf("\n");
}
__syncwarp();
}
if (laneId() == 0)
if (laneId() == 0 && a == 0)
{
printf("\n");
printf("----------------------\n");
}
__syncwarp();
}
if (laneId() == 0 && a == 0)
{
printf("----------------------\n");
}
__syncwarp();
}
smem.gemm1WarpGrpBar.arrive_and_wait();
}
smem.gemm1WarpGrpBar.arrive_and_wait();
}
}
#endif
#endif
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * idxInstK};
auto const descX = addAddr(descXBase,
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
0, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * idxInstK};
auto const descX = addAddr(descXBase,
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
0, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
#if CACHE_ELEM_ENUM == 2
gmma::fence();
gmma::fence();
#endif
#pragma unroll
for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++)
{
for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++)
{
#if CACHE_ELEM_ENUM == 0
auto const descV
= addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0));
gmma::mma_async_shmA<MathElem, ctaNbQHeads, true, false>(
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
descV, descX, true);
auto const descV
= addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0));
gmma::mma_async_shmA<MathElem, ctaNbQHeads, true, false>(
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
descV, descX, true);
#elif CACHE_ELEM_ENUM == 2
gmma::mma_async_regA<MathElem, ctaNbQHeads>(
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
reinterpret_cast<uint32_t const(&)[2][2][1]>(fragA[idxInstM]), descX, true);
gmma::mma_async_regA<MathElem, ctaNbQHeads>(
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
reinterpret_cast<uint32_t const(&)[2][2][1]>(fragA[idxInstM]), descX, true);
#endif
}
gmma::commit_group();
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
// gmma.
gmma::wait_group<0>();
}
gmma::commit_group();
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
// gmma.
gmma::wait_group<0>();
}
#else
auto const descVTBase = gmma::makeMatDesc(
nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode<true>(SharedMem::VTBuffer{}))
.raw();
vtBar.produced.arrive_and_wait();
auto const descVTBase = gmma::makeMatDesc(
nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode<true>(SharedMem::VTBuffer{}))
.raw();
vtBar.produced.arrive_and_wait();
// if (idxIter == 1 && threadIdx.x == 0) {
// printf("vtBuf:\n");
// dbg::printArray2D<__nv_fp8_e4m3, true>(vtBuf);
// }
#pragma unroll
for (uint32_t m = 0; m < Gemm1Acc::rows; m++)
{
#pragma unroll
for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++)
for (uint32_t m = 0; m < Gemm1Acc::rows; m++)
{
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * k};
auto const descX = addAddr(descXBase,
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
gmma::instM * m, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
auto const descVT = addAddr(
descVTBase, &vtBuf(0, kOffsetInGrains.template mod<SharedMem::VTBuffer::cols>().get()));
gmma::mma_async_shmA<MathElem, headElems>(
reinterpret_cast<float(&)[exactDiv(headElems, gmma::instNBase)][2][2]>(acc(m, 0)), descX,
descVT, true);
#pragma unroll
for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++)
{
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * k};
auto const descX = addAddr(descXBase,
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
gmma::instM * m, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
auto const descVT = addAddr(
descVTBase, &vtBuf(0, kOffsetInGrains.template mod<SharedMem::VTBuffer::cols>().get()));
gmma::mma_async_shmA<MathElem, headElems>(
reinterpret_cast<float(&)[exactDiv(headElems, gmma::instNBase)][2][2]>(acc(m, 0)), descX,
descVT, true);
}
}
}
gmma::commit_group();
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of gmma.
gmma::wait_group<0>();
gmma::commit_group();
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
// gmma.
gmma::wait_group<0>();
#endif
}
if (idxIter == nbIters - 1)
{
// gmma::wait_group should have already synchronized threads, so this may be unnecessary.
@ -1471,8 +1578,24 @@ CUBIN_EXPORT __global__
tensorMap
#endif
};
#if SKIP_SOFTMAX_ATTN
for (auto& b : smem.skipSoftmaxXBar)
{
unused(b.consumed.arrive());
}
#endif
for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++)
{
uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf;
auto& vBar = smem.vBar[idxVBuf];
#if SKIP_SOFTMAX_ATTN
uint32_t idxXBuf = idxIter % SharedMem::nbXBuf;
auto& skipSoftmaxXBar = smem.skipSoftmaxXBar[idxXBuf];
skipSoftmaxXBar.produced.arrive_and_wait();
bool shouldSkipSoftmaxAttn = smem.skipSoftmaxVotesGemm0ToV[idxXBuf];
skipSoftmaxXBar.consumed.arrive();
#endif
uint32_t const idxVTile = idxVTileInit + idxIter * nbSubSeq;
vTileLoader.loadPages(idxVTile);
#if USE_INPUT_KV || ENABLE_PDL == 2
@ -1506,8 +1629,20 @@ CUBIN_EXPORT __global__
}
#endif
uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf;
auto& vBar = smem.vBar[idxVBuf];
#if SKIP_SOFTMAX_ATTN
if (shouldSkipSoftmaxAttn)
{
vBar.consumed.arrive_and_wait();
// compared to non-skip softmax attn, we need to increase vBar.produced count to avoid race
// condition where vBar.consumed is arrived again without wait without skip softmax attn, XVGemm
// will wait for tx_count, so its progress won't go ahead of vload warp with skip softmax attn,
// XVGemm WG may go ahead of vload warp, as previous vBar only have XVGemm WG threads and a tx_count
// (now = 0). Then it may arrive vBar.consumed before it is arrive_and_wait-ed
vBar.produced.arrive();
continue;
}
#endif
vBar.consumed.arrive_and_wait();
if (warpElectSync())
{
@ -1517,6 +1652,9 @@ CUBIN_EXPORT __global__
vTileLoader.loadData(smem.vBuf(idxVBuf)[idxPart], idxVTile, idxPart, vBar.produced);
}
}
#if SKIP_SOFTMAX_ATTN
vBar.produced.arrive();
#endif
__syncwarp();
}
}
@ -1992,9 +2130,23 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
#endif // SPEC_DEC
// smemColMax is persistent across multiple iterations
#if SKIP_SOFTMAX_ATTN
__device__ inline RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax,
Gemm0Acc const& src, float skipSoftmaxThreshold, uint32_t* smemSkipVote, bool maybeSkip)
#else
__device__ inline RegColWiseVec computeWarpGrpColMax_sync(
CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src)
#endif
{
#if SKIP_SOFTMAX_ATTN
if (threadIdx.x == 0)
{
*smemSkipVote = maybeSkip ? 1U : 0U; // will sync before vote
}
float const lnThreshold
= log(skipSoftmaxThreshold); // this can be -inf, but should be safe as we only use it for comparison
#endif
auto colMax = RegColWiseVec::filled(Vec<float, 2>::filled(safeInitRowMax));
#pragma unroll
for (uint32_t n = 0; n < src.cols; n++)
@ -2029,6 +2181,9 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
}
uint32_t const lane = laneId();
#if SKIP_SOFTMAX_ATTN
auto prevOrCurrentMax = RegColWiseVec();
#if SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
if (lane < 4)
{
#pragma unroll
@ -2037,12 +2192,43 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
#pragma unroll
for (uint32_t j = 0; j < 2; j++)
{
atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
prevOrCurrentMax[n][j] = smemColMax[8 * n + 2 * lane + j];
}
}
}
warpGrpBar.arrive_and_wait();
#endif
#endif
if (lane < 4)
{
#pragma unroll
for (uint32_t n = 0; n < src.cols; n++)
{
#pragma unroll
for (uint32_t j = 0; j < 2; j++)
{
#if SKIP_SOFTMAX_ATTN && !SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
// prevOrCurrentMax <= actual smemColMax (after updates from all 4 warps done), but always >=
// smemColMax(Prev), the smemColMax value *before* this tile is computed.
// When determine whether to skip, it is safe to use prevOrCurrentMax: 1) all 4 warps' localmax <
// smemColMax(Prev), then prevOrCurrentMax == smemColMax(Prev), result not affected; 2) if some localmax
// > smemColMax(Prev), prevOrCurrentMax > smemColMax(Prev), some warps may incorrectly vote skip, but
// at least one warp whose localColMax is larger will not skip, then the tile is not skipped.
// This reduces some sync and check, but has issue when threshold > 1.
prevOrCurrentMax[n][j] = atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
#else
atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
#endif
}
}
}
warpGrpBar.arrive_and_wait();
uint32_t const idxInQuad = lane % 4;
#if SKIP_SOFTMAX_ATTN
bool localShouldSkip = true;
#endif
#pragma unroll
for (uint32_t n = 0; n < src.cols; n++)
@ -2050,10 +2236,21 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
#pragma unroll
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++)
{
#if SKIP_SOFTMAX_ATTN
if (lane < 4 && 8 * n + 2 * idxInQuad + j < headGrpSize)
{
localShouldSkip &= (colMax[n][j] - prevOrCurrentMax[n][j]) < lnThreshold;
}
#endif
assert(colMax[n][j] <= smemColMax[8 * n + 2 * idxInQuad + j]);
colMax[n][j] = smemColMax[8 * n + 2 * idxInQuad + j];
}
}
#if SKIP_SOFTMAX_ATTN
atomicAnd(smemSkipVote, static_cast<uint32_t>(localShouldSkip)); // this will be translated to redux and voteu
#endif
warpGrpBar.arrive_and_wait();
return colMax;
}
@ -2199,7 +2396,7 @@ __device__ inline void storeGemm0AccToShm(
uint32_t const idxOctInsideHalf = idxInHalf / 8;
uint32_t const idxRowInsideOct = lane % 8;
uint32_t const warpBaseC = 16 * warpRank;
auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> std::pair<uint32_t, uint32_t>
auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> mha::pair<uint32_t, uint32_t>
{
uint32_t const accR = idxAccCoreMat / Gemm0Acc::cols;
uint32_t const accC = idxAccCoreMat % Gemm0Acc::cols;
@ -3231,6 +3428,24 @@ __device__ inline void storeRotatedPairsForQ(SharedMem::QBuffer& dst,
}
#ifndef GENERATE_CUBIN
uint32_t computeNbSubSeqPerSeqHopperF8MHA(
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen)
{
auto const env = std::getenv("XQA_NB_SUB_SEQ");
if (env != nullptr)
{
int32_t const val = std::stoi(env);
if (val > 0)
{
return val;
}
}
float const factor = 0.25f;
return mha::min<uint32_t>(
mha::max<uint32_t>(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)),
divUp(maxSeqLen, gemm0CtaTileNbTokens));
}
void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
@ -3268,6 +3483,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
#endif
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream)
{
@ -3286,22 +3507,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
uint32_t const nbVHeads = nbKHeads;
uint32_t const nbQHeads = nbKHeads * headGrpSize;
uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads;
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t
{
auto const env = std::getenv("XQA_NB_SUB_SEQ");
if (env != nullptr)
{
int32_t const val = std::stoi(env);
if (val > 0)
{
return val;
}
}
float const factor = 0.25f;
return mha::min<uint32_t>(
mha::max<uint32_t>(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)),
divUp(maxSeqLen, gemm0CtaTileNbTokens));
}();
uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqHopperF8MHA(prop, batchSize, nbKHeads, maxSeqLen);
#if SPEC_DEC
uint32_t const qSeqLen = specDecParams.qSeqLen;
#else
@ -3371,6 +3577,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#endif
#if SPEC_DEC
specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
skippedBlockCount, totalBlockCount,
#endif
#endif
semaphores, scratch);
#else

View File

@ -1272,6 +1272,19 @@ using is_void = is_same<remove_cv_t<T>, void>;
template <typename T>
inline constexpr bool is_void_v = is_void<T>::value;
#endif
#ifndef GENERATE_CUBIN
template <typename T1, typename T2>
using pair = std::pair<T1, T2>;
#else
template <typename T1, typename T2>
struct pair
{
T1 first;
T2 second;
};
#endif
} // namespace mha
#if GENERATE_CUBIN

View File

@ -50,7 +50,8 @@ using Vector = Matrix<Type, Size, 1>;
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor,
uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum)
{
uint32_t const nbTiles = divUp(seqLen, tileSize);
auto gemm1Acc = Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor>::Zero().eval();
@ -61,6 +62,16 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
float const qkScale = qScale * kvScale / sqrtf(validElemsPerHead);
uint32_t const seqBeg = (seqLen < slidingWinSize ? 0 : seqLen - slidingWinSize);
uint32_t const idxTileBeg = seqBeg / tileSize;
uint32_t const nbSubSeq = (multiBlockNum > 0 && nbTiles >= 2) ? mha::min(nbTiles, multiBlockNum) : 1;
std::vector<Eigen::Vector<float, headGrpSize>> skipRowMaxs(nbSubSeq);
for (uint32_t i = 0; i < nbSubSeq; i++)
{
skipRowMaxs[i].fill(-INFINITY);
}
bool const disableSkipForShortSeq = (seqLen < skipSoftmaxThresholdScaleFactor);
float const skipSoftmaxThreshold = disableSkipForShortSeq ? 0.0f : skipSoftmaxThresholdScaleFactor / seqLen;
for (uint32_t idxTile = idxTileBeg; idxTile < nbTiles; idxTile++)
{
Eigen::Matrix<float, headGrpSize, tileSize, Eigen::RowMajor> gemm0Acc;
@ -88,7 +99,22 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
}
}
Eigen::Vector<float, headGrpSize> const tileRowMax = gemm0Acc.rowwise().maxCoeff().cwiseMax(rowMax).eval();
Eigen::Vector<float, headGrpSize> const localRowMax = gemm0Acc.rowwise().maxCoeff().eval();
Eigen::Vector<float, headGrpSize> const tileRowMax = localRowMax.cwiseMax(rowMax).eval();
auto const prevSkipRowMax = skipRowMaxs[idxTile % nbSubSeq];
skipRowMaxs[idxTile % nbSubSeq] = localRowMax.cwiseMax(skipRowMaxs[idxTile % nbSubSeq]).eval();
if (!disableSkipForShortSeq && skipSoftmaxThreshold > 0)
{
*totalBlockCount += 1;
auto const skipSoftmaxMask = ((localRowMax - prevSkipRowMax).array() < std::log(skipSoftmaxThreshold));
bool const skipBlock = skipSoftmaxMask.all() && ((idxTile - idxTileBeg) >= nbSubSeq);
if (skipBlock)
{
*skippedBlockCount += 1;
continue;
}
}
Eigen::Matrix<float, headGrpSize, tileSize, Eigen::RowMajor> tileX
= (gemm0Acc.colwise() - tileRowMax).array().exp().eval();
@ -138,7 +164,8 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
template Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> \
refFlashAttention<prec, tileSize, isPaged, useBeamSearch>(IOHead const* q, \
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, \
float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, \
float skipSoftmaxThreshold, uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum)
INSTANTIATE_refFlashAttention(CacheElem, 64, false, false);
INSTANTIATE_refFlashAttention(CacheElem, 64, false, true);

View File

@ -88,7 +88,8 @@ struct CacheSeq<true, true>
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks);
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor,
uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum);
template <typename MathElem, bool isPaged, bool useBeamSearch>
#if SPEC_DEC

View File

@ -150,7 +150,8 @@ template <uint32_t nbKHeads>
#endif
#endif
void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, bool verbose = false,
bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30)
bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30,
float skipSoftmaxThresholdScaleFactor = 0.0f)
{
#if IS_MLA
if (nbKHeads != 1)
@ -224,6 +225,12 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
seqLen = (16U << 20) / gmemCacheHeadBytes; // 32MB per K+V head.
}
ctxLen = std::min(ctxLen, seqLen);
uint32_t skippedBlockCount = 0;
uint32_t totalBlockCount = 0;
if (skipSoftmaxThresholdScaleFactor > 0)
{
assert(useQGMMA);
}
float const kScale = cacheElemSize == 2 ? 1.f : 1 / 4.f;
float const vScale = kScale;
float const qScale = 1.f;
@ -329,6 +336,17 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
auto const rcpOutScale = ManagedMemBuf<float>(1);
auto const seqLenList = ManagedMemBuf<uint32_t[beamWidth]>(batchSize);
auto const ctxLenList = ManagedMemBuf<uint32_t[beamWidth]>(batchSize);
#if SKIP_SOFTMAX_ATTN
#ifdef SKIP_SOFTMAX_ATTN_BLOCK_STATS
auto const kernelSkippedBlockCount = ManagedMemBuf<uint32_t>(1);
auto const kernelTotalBlockCount = ManagedMemBuf<uint32_t>(1);
kernelSkippedBlockCount[0] = 0;
kernelTotalBlockCount[0] = 0;
#endif
#else
EXPECT_EQ(skipSoftmaxThresholdScaleFactor, 0.0f)
<< "Got non-zero skipSoftmaxThresholdScaleFactor while SKIP_SOFTMAX_ATTN is not enabled.";
#endif
#if USE_PAGED_KV_CACHE
auto const pageListBuf = ManagedMemBuf<std::byte>(pageListBytes);
#if PAGED_KV_CACHE_LAYOUT == 1
@ -726,6 +744,11 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
maxSeqLen, &seqLenList[0][0], batchSize, kvCacheScale.get(), semaphores.get(), scratch, stream);
};
#else
auto multiBlockNum = [&]()
{
auto const calcFunc = useQGMMA ? &computeNbSubSeqPerSeqHopperF8MHA : &computeNbSubSeqPerSeqMHA;
return calcFunc(prop, batchSize, nbKHeads, maxSeqLen);
}();
auto runKernel = [&]()
{
auto const launchFunc = useQGMMA ? &launchHopperF8MHA : &launchMHA;
@ -776,6 +799,12 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
batchSize, kvCacheScale.get(),
#if SPEC_DEC
specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
kernelSkippedBlockCount.get(), kernelTotalBlockCount.get(),
#endif
#endif
semaphores.get(), scratch, stream);
checkCuda(cudaGetLastError());
@ -813,6 +842,10 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
checkCuda(cudaEventRecord(toc, stream));
prefetchToDevice(cudaCpuDeviceId);
checkCuda(cudaStreamSynchronize(stream));
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
kernelSkippedBlockCount[0] /= nbIters;
kernelTotalBlockCount[0] /= nbIters;
#endif
if (testPerf)
{
float ms;
@ -849,6 +882,15 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
= totalNbCacheLoadBytes + inputBytes + outputBytes; // we ignore page indices and beam search indices.
float const dramSolTime = totalTraffic / bandwidth * 1E3f;
float const dramSolRatio = dramSolTime / ms;
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
size_t const totalNbCacheLoadWithSkip = gmemCacheHeadBytes
* (nbKHeads + nbVHeads * (1 - 1.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]))
* nbLoadedCacheTokens;
float const totalTrafficWithSkip
= totalNbCacheLoadWithSkip + inputBytes + outputBytes; // we ignore page indices and beam search indices.
float const dramSolTimeWithSkip = totalTrafficWithSkip / bandwidth * 1E3f;
float const dramSolRatioWithSkip = dramSolTimeWithSkip / ms;
#endif
if (verbose)
{
printf("done\n");
@ -863,7 +905,13 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
}
float const tops = headGrpSize * qSeqLen * float(seqLen) * (validElemsPerKHead + validElemsPerVHead) * 2
* nbKHeads * batchSize / (ms * 1E-3F) * 1E-12F;
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
printf("kernel skippedBlockCount: %d/%d (%.2f%%)\n", kernelSkippedBlockCount[0], kernelTotalBlockCount[0],
kernelTotalBlockCount[0] == 0 ? 0.0f : 100.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]);
printf("dramSolRatioWithSkip: %f%% (%f ms, TOPS = %f)\n", dramSolRatioWithSkip * 100, ms, tops);
#else
printf("dramSolRatio: %f%% (%f ms, TOPS = %f)\n", dramSolRatio * 100, ms, tops);
#endif
}
if (refCheck)
{
@ -1084,8 +1132,8 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
if (useQGMMA)
{
refOutput = refFlashAttention<CacheElem, 64>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize,
refAttentionSinks);
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, refAttentionSinks,
skipSoftmaxThresholdScaleFactor, &skippedBlockCount, &totalBlockCount, multiBlockNum);
// refOutput = refAttention<CacheElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
// vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
}
@ -1132,6 +1180,14 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
#endif
}
}
#if SKIP_SOFTMAX_ATTN
printf("host skippedBlockCount: %d/%d (%.2f%%)\n", skippedBlockCount, totalBlockCount,
totalBlockCount == 0 ? 0.0f : 100.0f * skippedBlockCount / totalBlockCount);
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
printf("kernel skippedBlockCount: %d/%d (%.2f%%)\n", kernelSkippedBlockCount[0], kernelTotalBlockCount[0],
kernelTotalBlockCount[0] == 0 ? 0.0f : 100.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]);
#endif
#endif
if (saveData)
{
fout_refOutput.close();
@ -1253,6 +1309,14 @@ TEST(RefCheck, llama_V2_70b)
#if SLIDING_WINDOW
runTest<2>(2, 4096, false, true, false, false, false, ~0, 256);
runTest<2>(2, 400, false, true, false, false, false, ~0U, 256);
#endif
#if SKIP_SOFTMAX_ATTN
runTest<1>(32, 2048, false, true, false, false, false, ~0U, 1U << 30, 0.f);
runTest<4>(32, 1538, false, true, false, false, false, ~0U, 1U << 30, 1280.f);
runTest<2>(32, 4096, false, true, false, false, false, ~0U, 1U << 30, 125.f);
runTest<4>(32, 300, false, true, false, false, false, ~0U, 1U << 30, 80.f);
runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 501.0f);
runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 500.f);
#endif
runTest<8>(120, 367, false, true);
runTest<8>(1792, 2048, false, true);

View File

@ -298,6 +298,11 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
xqaParams.use_sparse_attention = useTllmGenSparseAttention();
// Skip softmax threshold.
xqaParams.skip_softmax_threshold_scale_factor = mSkipSoftmaxThresholdScaleFactorDecode;
#ifdef SKIP_SOFTMAX_STAT
// Statistics of skip-softmax, pointers of device memory for output
xqaParams.skip_softmax_total_blocks = mSkipSoftmaxTotalBlocks;
xqaParams.skip_softmax_skipped_blocks = mSkipSoftmaxSkippedBlocks;
#endif
// Cross attention parameters.
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;

View File

@ -105,7 +105,8 @@ CubinObj CompileEngine::compile() const
// scratch in this case.
/*use_input_kv=*/applyRoPEInXqaKernel,
/*rope_style=*/ropeStyle,
/*is_spec_dec_tree=*/mXqaParams.is_spec_dec_tree};
/*is_spec_dec_tree=*/mXqaParams.is_spec_dec_tree,
/*use_skip_softmax_attn=*/mXqaParams.skip_softmax_threshold_scale_factor != 0};
if (context.kernel_type == TLLM_XQA_JIT_MLA)
{
auto const& c = context;

View File

@ -232,6 +232,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
jit::CubinObj const* const cubinObj = mResource->getCubinObjRegistry()->getCubin(key);
TLLM_CHECK(cubinObj != nullptr && cubinObj->isInitialized());
bool const isSpecDec = xqaParams.multi_query_tokens;
bool const isSkipSoftmax = xqaParams.skip_softmax_threshold_scale_factor != 0;
bool const isHMMAKernel = (cubinObj->getKernelType() == XQAKernelType::kAMPERE_WARP_SPECIALIZED);
bool const isGMMAKernel = (cubinObj->getKernelType() == XQAKernelType::kHOPPER_WARP_SPECIALIZED);
bool const isMLAKernel = (cubinObj->getKernelType() == XQAKernelType::kSM120_MLA);
@ -378,7 +379,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
.mask = reinterpret_cast<SpecDecParams::MaskType const*>(xqaParams.spec_decoding_packed_mask)};
};
constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 16;
constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 19;
uint32_t idxNextParam = 0;
void* kernelParams[kMAX_NB_KERNEL_PARAMS];
auto appendParam = [&](auto* p) mutable
@ -514,6 +515,16 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
appendParam(&specDecParams);
specDecBlocks = divUp(specDecParams.qSeqLen, 64 / num_q_heads_over_kv);
}
if (isSkipSoftmax)
{
TLLM_CHECK_WITH_INFO(isGMMAKernel, "skip softmax is only supported for GMMA kernel for now.");
TLLM_CHECK_WITH_INFO(!isSpecDec, "skip softmax is not supported with spec dec for now.");
appendParam(&xqaParams.skip_softmax_threshold_scale_factor);
#ifdef SKIP_SOFTMAX_STAT
appendParam(&xqaParams.skip_softmax_total_blocks);
appendParam(&xqaParams.skip_softmax_skipped_blocks);
#endif
}
appendParam(&launchParams.semaphores);
appendParam(&launchParams.scratch);
kernelParams[idxNextParam] = nullptr; // one extra nullptr at end as guard.

View File

@ -96,10 +96,16 @@ bool supportConfigQGMMA(XQAParams const& xqaParams, int SM, bool forConfigurePlu
{
return false;
}
if (xqaParams.kv_cache_data_type != DATA_TYPE_E4M3)
if (!contains({DATA_TYPE_FP16, DATA_TYPE_BF16, DATA_TYPE_E4M3}, xqaParams.kv_cache_data_type))
{
return false;
}
bool const is_skip_softmax = xqaParams.skip_softmax_threshold_scale_factor != 0;
if (!is_skip_softmax && xqaParams.kv_cache_data_type != DATA_TYPE_E4M3)
{
// Only use hopper kernel with fp16/bf16 kv cache data type when skip softmax is enabled
return false;
}
if (xqaParams.beam_width != 1)
{
return false;
@ -168,6 +174,11 @@ bool supportConfigHMMA(XQAParams const& xqaParams, int SM, bool forConfigurePlug
{
return false;
}
bool const is_skip_softmax = xqaParams.skip_softmax_threshold_scale_factor != 0;
if (is_skip_softmax)
{
return false;
}
return true;
}
@ -201,6 +212,11 @@ bool supportConfigMLA(XQAParams const& xqaParams, int SM, bool forConfigurePlugi
{
return false;
}
bool const is_skip_softmax = xqaParams.skip_softmax_threshold_scale_factor != 0;
if (is_skip_softmax)
{
return false;
}
return true;
}

View File

@ -66,6 +66,7 @@ extern "C"
bool is_spec_dec_tree
= true; // useful only when multi_query_tokens, should be true unless using linear tree in spec-dec.
bool use_skip_softmax_attn;
} tllmXqaJitContext;
// tllmXqaJitProgram is an opaque handle for a program.

View File

@ -215,6 +215,10 @@ tllmXqaJitStatus getMacroFlags(tllmXqaJitContext const* context, std::vector<std
macros["USE_INPUT_KV"] = context->use_input_kv ? "1" : "0";
macros["ROPE_STYLE"] = std::to_string(int(context->rope_style));
macros["IS_SPEC_DEC_TREE"] = context->is_spec_dec_tree ? "1" : "0";
macros["SKIP_SOFTMAX_ATTN"] = context->use_skip_softmax_attn ? "1" : "0";
#ifdef SKIP_SOFTMAX_STAT
macros["SKIP_SOFTMAX_ATTN_BLOCK_STATS"] = context->use_skip_softmax_attn ? "1" : "0";
#endif
// Without these macros, NVRTC uses precompiled headers for cuda_fp16.h etc.
// Linking might fail due to ABI incompatibility.

View File

@ -493,6 +493,10 @@ bool DecoderXQAImplPrecompiled::shouldUse(XQAParams const& xqaParams, bool forCo
{
SUPPORT_RETURN_FALSE("streaming-llm");
}
if (xqaParams.skip_softmax_threshold_scale_factor != 0)
{
SUPPORT_RETURN_FALSE("skip_softmax_threshold_scale_factor");
}
// OPTIMIZE: For the standard generation-phase MHA, there are still extra limitations.
// NOTE: Medusa mode = Multi_query_tokens > 1.

View File

@ -64,6 +64,21 @@ CUtensorMapSwizzle getSwizzleMode(uint32_t partBytes)
}
};
CUtensorMapDataType_enum getDataTypeFromXqaParams(XQAParams const& xqaParams)
{
if (xqaParams.kv_cache_data_type == DATA_TYPE_BF16)
{
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
}
else if (xqaParams.kv_cache_data_type == DATA_TYPE_FP16)
{
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
}
TLLM_CHECK(xqaParams.kv_cache_data_type == DATA_TYPE_E4M3 || xqaParams.kv_cache_data_type == DATA_TYPE_E5M2
|| xqaParams.kv_cache_data_type == DATA_TYPE_INT8);
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
}
CUtensorMap makeTensorMapForQ(std::shared_ptr<CUDADriverWrapper> const& driver, void const* addr,
CUtensorMapDataType_enum dataType, uint32_t headElems, uint32_t totalNbHeads, uint32_t partElems, uint32_t boxHeads)
{
@ -131,24 +146,26 @@ CUtensorMap makeTensorMapForHopperXqaKVCache(
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
{
uint32_t const headElems = xqaParams.head_size;
uint32_t const elemBytes = getElemBytes(CU_TENSOR_MAP_DATA_TYPE_UINT8);
CUtensorMapDataType_enum const dataType = getDataTypeFromXqaParams(xqaParams);
uint32_t const elemBytes = getElemBytes(dataType);
TLLM_CHECK(headElems <= 256);
uint32_t const paddedHeadElems = headElems <= 64 ? 64 : (headElems <= 128 ? 128 : 256);
uint32_t const partElems = std::min(elemBytes * paddedHeadElems, 128U) / elemBytes;
return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, CU_TENSOR_MAP_DATA_TYPE_UINT8,
xqaParams.head_size, xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems);
return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, dataType, xqaParams.head_size,
xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems);
}
else
{
static_assert(std::is_same_v<KVCacheBuffer, KVLinearBuffer>);
uint32_t const headElems = xqaParams.head_size;
uint32_t const elemBytes = getElemBytes(CU_TENSOR_MAP_DATA_TYPE_UINT8);
CUtensorMapDataType_enum const dataType = getDataTypeFromXqaParams(xqaParams);
uint32_t const elemBytes = getElemBytes(dataType);
TLLM_CHECK(headElems <= 256);
uint32_t const paddedHeadElems = headElems <= 64 ? 64 : (headElems <= 128 ? 128 : 256);
uint32_t const partElems = std::min(elemBytes * paddedHeadElems, 128U) / elemBytes;
return makeTensorMapForContiguousKVCache(driver, kv_cache_buffer.data, CU_TENSOR_MAP_DATA_TYPE_UINT8,
xqaParams.head_size, xqaParams.num_kv_heads, xqaParams.max_attention_window_size, xqaParams.beam_width,
xqaParams.batch_size, partElems);
return makeTensorMapForContiguousKVCache(driver, kv_cache_buffer.data, dataType, xqaParams.head_size,
xqaParams.num_kv_heads, xqaParams.max_attention_window_size, xqaParams.beam_width, xqaParams.batch_size,
partElems);
}
}
@ -161,11 +178,12 @@ template <typename KVCacheBuffer>
CUtensorMap makeTensorMapForXqaMlaKVCache(std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> const& driver,
XQAParams const& xqaParams, KVCacheBuffer const& kv_cache_buffer, bool forK)
{
CUtensorMapDataType_enum const dataType = getDataTypeFromXqaParams(xqaParams);
uint32_t const partElems = (forK ? 64 : 128);
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
{
return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, CU_TENSOR_MAP_DATA_TYPE_UINT8,
xqaParams.head_size, xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems);
return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, dataType, xqaParams.head_size,
xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems);
}
else
{
@ -183,7 +201,7 @@ CUtensorMap makeTensorMapForXqaMlaQ(
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> const& driver, XQAParams const& xqaParams, void const* q)
{
uint32_t const partElems = 64;
return makeTensorMapForQ(driver, q, CU_TENSOR_MAP_DATA_TYPE_UINT8, xqaParams.head_size,
return makeTensorMapForQ(driver, q, getDataTypeFromXqaParams(xqaParams), xqaParams.head_size,
xqaParams.num_q_heads * xqaParams.total_num_input_tokens, partElems, xqaParams.num_q_heads);
}
} // namespace kernels

View File

@ -119,7 +119,12 @@ struct XQAParams
bool use_sparse_attention = false;
// Skip softmax threshold.
float skip_softmax_threshold_scale_factor = 0.0f;
float skip_softmax_threshold_scale_factor = 0;
#ifdef SKIP_SOFTMAX_STAT
uint32_t* skip_softmax_total_blocks = nullptr;
uint32_t* skip_softmax_skipped_blocks = nullptr;
#endif
cudaStream_t stream = 0;
// layer index
@ -199,6 +204,10 @@ struct XQAParams
<< "sparse_params: " << sparse_params.toString() << std::endl
<< "use_sparse_attention :" << (use_sparse_attention ? "true" : "false") << std ::endl
<< "skip_softmax_threshold_scale_factor :" << skip_softmax_threshold_scale_factor << std ::endl
#ifdef SKIP_SOFTMAX_STAT
<< "skip_softmax_total_blocks :" << skip_softmax_total_blocks << std ::endl
<< "skip_softmax_skipped_blocks :" << skip_softmax_skipped_blocks << std ::endl
#endif
<< "stream :" << stream;
return ss.str();