mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-10022][feat] Add hopper xqa decode support for skip softmax attention (#10264)
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
This commit is contained in:
parent
c5d5af9e7f
commit
c0e25e5418
@ -129,6 +129,18 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
|
||||
#define SLIDING_WINDOW 0
|
||||
#endif
|
||||
|
||||
#ifndef SKIP_SOFTMAX_ATTN
|
||||
#define SKIP_SOFTMAX_ATTN 0
|
||||
#endif
|
||||
|
||||
#ifndef SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
#define SKIP_SOFTMAX_ATTN_BLOCK_STATS 0
|
||||
#endif
|
||||
|
||||
#ifndef SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
|
||||
#define SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE 1
|
||||
#endif
|
||||
|
||||
// 0 - no PDL
|
||||
// 1 - naive PDL
|
||||
// 2 - aggressive PDL (implemented only in mha_sm90.cu for now)
|
||||
|
||||
@ -106,6 +106,7 @@ __device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset,
|
||||
asm volatile("trap;\n");
|
||||
return 0;
|
||||
}();
|
||||
assert(__cvta_generic_to_shared(data) % baseAlign == 0);
|
||||
uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7));
|
||||
return MatDesc{
|
||||
/*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)),
|
||||
|
||||
@ -2734,6 +2734,25 @@ static constexpr auto kernel_mha = kernel_mha_impl;
|
||||
#endif
|
||||
|
||||
#ifndef GENERATE_CUBIN
|
||||
uint32_t computeNbSubSeqPerSeqMHA(cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen)
|
||||
{
|
||||
if (!allowMultiBlockMode)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
auto const env = std::getenv("XQA_NB_SUB_SEQ");
|
||||
if (env != nullptr)
|
||||
{
|
||||
int32_t const val = std::stoi(env);
|
||||
if (val > 0)
|
||||
{
|
||||
return val;
|
||||
}
|
||||
}
|
||||
return std::min<uint32_t>(
|
||||
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
|
||||
}
|
||||
|
||||
void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#if SLIDING_WINDOW
|
||||
uint32_t slidingWinSize,
|
||||
@ -2771,6 +2790,13 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
// int8/fp8 KV cache.
|
||||
#if SPEC_DEC
|
||||
SpecDecParams const& specDecParams,
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
float const skipSoftmaxThresholdScaleFactor, // for compatibility with mha_sm90.cu only
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
uint32_t* __restrict__ skippedBlockCount, // for compatibility with mha_sm90.cu only
|
||||
uint32_t* __restrict__ totalBlockCount, // for compatibility with mha_sm90.cu only
|
||||
#endif
|
||||
#endif
|
||||
uint32_t* semaphores, void* scratch, cudaStream_t stream)
|
||||
{
|
||||
@ -2793,24 +2819,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
uint32_t const nbQHeads = nbKHeads * headGrpSize;
|
||||
|
||||
// const uint32_t nbSubSeqPerSeq = allowMultiBlockMode ? DBG_NB_CTAS_PER_SEQ : 1;
|
||||
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t
|
||||
{
|
||||
if (!allowMultiBlockMode)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
auto const env = std::getenv("XQA_NB_SUB_SEQ");
|
||||
if (env != nullptr)
|
||||
{
|
||||
int32_t const val = std::stoi(env);
|
||||
if (val > 0)
|
||||
{
|
||||
return val;
|
||||
}
|
||||
}
|
||||
return std::min<uint32_t>(
|
||||
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
|
||||
}();
|
||||
uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqMHA(prop, batchSize, nbKHeads, maxSeqLen);
|
||||
// gridDim.z == batchSize && gridDim.y == nbKHeads && gridDim.x == nbSubSeqPerSeq
|
||||
#if SPEC_DEC
|
||||
const uint32_t nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, rowsPerBlock);
|
||||
|
||||
@ -90,6 +90,9 @@ struct BeamSearchParams
|
||||
// match trt-llm API.
|
||||
};
|
||||
|
||||
uint32_t computeNbSubSeqPerSeqMHA(
|
||||
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);
|
||||
|
||||
void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
|
||||
#if SLIDING_WINDOW
|
||||
uint32_t slidingWinSize,
|
||||
@ -127,9 +130,18 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
|
||||
// int8/fp8 KV cache.
|
||||
#if SPEC_DEC
|
||||
SpecDecParams const& specDecParams,
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
float const skipSoftmaxThresholdScaleFactor,
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
|
||||
#endif
|
||||
#endif
|
||||
uint32_t* semaphores, void* scratch, cudaStream_t stream);
|
||||
|
||||
uint32_t computeNbSubSeqPerSeqHopperF8MHA(
|
||||
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);
|
||||
|
||||
void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#if SLIDING_WINDOW
|
||||
uint32_t slidingWinSize,
|
||||
@ -167,6 +179,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
// int8/fp8 KV cache.
|
||||
#if SPEC_DEC
|
||||
SpecDecParams const& specDecParams,
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
float const skipSoftmaxThresholdScaleFactor,
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
|
||||
#endif
|
||||
#endif
|
||||
uint32_t* semaphores, void* scratch, cudaStream_t stream);
|
||||
|
||||
|
||||
@ -49,6 +49,10 @@ static_assert(specDecQLen * headGrpSize <= 32, "SPEC_Q_SEQ_LEN macro value is to
|
||||
#define SWAP_AB (!SPEC_DEC)
|
||||
#endif
|
||||
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
static_assert(SWAP_AB && USE_PAGED_KV_CACHE && !SPEC_DEC && BEAM_WIDTH == 1, "SKIP_SOFTMAX_ATTN is not supported.");
|
||||
#endif
|
||||
|
||||
#define IS_SUPPORTED_F16_CASE (CACHE_ELEM_ENUM == 0 && !SPEC_DEC && SWAP_AB && !USE_INPUT_KV && !LOW_PREC_OUTPUT)
|
||||
|
||||
inline constexpr bool swapAB = SWAP_AB;
|
||||
@ -138,26 +142,38 @@ using PaddedOutHead = PaddedInputHead;
|
||||
|
||||
struct alignas(128) SharedMem
|
||||
{
|
||||
using QBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerQPart>, nbQParts>;
|
||||
using KBuffer = Array2D<LdGrain, gemm0CtaTileNbTokens, exactDiv(cacheHeadPartBytes, grainBytes)>;
|
||||
static constexpr uint32_t nbKBuf = 2;
|
||||
KBuffer k[nbKBuf]; // as is loaded from global mem.
|
||||
using XBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerXPart>, nbXParts>;
|
||||
static constexpr uint32_t nbXBuf
|
||||
= 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
|
||||
using VBuffer = Vec<Array2D<LdGrain, gemm1CtaTileNbTokens, exactDiv(cacheHeadPartBytes, grainBytes),
|
||||
sizeof(XBuffer) % (cacheHeadPartBytes * 8) == 0>,
|
||||
cacheHeadNbParts>;
|
||||
#if !SWAP_AB
|
||||
using VTBuffer = Array2D<LdGrain, headElems, exactDiv(gemm1CtaTileNbTokens, cacheElemsPerGrain), true>;
|
||||
#endif
|
||||
static constexpr uint32_t nbVBuf = 2;
|
||||
#if CACHE_ELEM_ENUM == 0
|
||||
using OutSwizzleBuf = Array2D<LdGrain, ctaNbQHeads, grainsPerPaddedInputHead>;
|
||||
#elif CACHE_ELEM_ENUM == 2
|
||||
using OutSwizzleBuf = Array2D<Vec<Vec<InputElem, 4>, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>;
|
||||
#endif
|
||||
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
static constexpr uint32_t nbKBuf = 2;
|
||||
static constexpr uint32_t nbVBuf = 3; // @fixme: skip_softmax_attn: for skip softmax attn, an extra VBuffer is used
|
||||
static constexpr uint32_t nbXBuf
|
||||
= 3 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
|
||||
#else
|
||||
static constexpr uint32_t nbKBuf = 2;
|
||||
static constexpr uint32_t nbVBuf = 2;
|
||||
static constexpr uint32_t nbXBuf
|
||||
= 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens));
|
||||
#endif
|
||||
static_assert(nbXBuf == nbVBuf);
|
||||
|
||||
// note: buffers used for GMMA may have additional alignment requirements
|
||||
KBuffer k[nbKBuf]; // as is loaded from global mem.
|
||||
QBuffer q; // For gmma math. Conversion done if needed.
|
||||
|
||||
union ReusedXVOutSwizzleBuf
|
||||
{
|
||||
struct XV
|
||||
@ -196,9 +212,6 @@ struct alignas(128) SharedMem
|
||||
return reusedXVOutSwizzleBuf[i].outSwizzle;
|
||||
}
|
||||
|
||||
using QBuffer = Vec<Array2D<LdGrain, ctaNbQHeads, grainsPerQPart>, nbQParts>;
|
||||
QBuffer q; // For gmma math. Conversion done if needed.
|
||||
|
||||
// @fixme: move these into reusedXVOutSwizzleBuf
|
||||
#if SWAP_AB
|
||||
ShmQWiseVec xColMax[nbXBuf];
|
||||
@ -220,6 +233,11 @@ struct alignas(128) SharedMem
|
||||
Vec<KVCachePageIndex, nbPagesPerTile> pages[2]; // one for K and one for V
|
||||
#endif
|
||||
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
uint32_t skipSoftmaxVotesGemm0ToV[nbXBuf]; // guarded by skipSoftmaxXBar
|
||||
uint32_t skipSoftmaxVotesGemm0ToGemm1[nbXBuf]; // guarded by xBar
|
||||
#endif
|
||||
|
||||
// mem barriers
|
||||
|
||||
CtaBarrierPair qBar;
|
||||
@ -229,6 +247,9 @@ struct alignas(128) SharedMem
|
||||
CtaBarrierPair vtBar[nbVBuf];
|
||||
#endif
|
||||
CtaBarrierPair xBar[nbXBuf];
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
CtaBarrierPair skipSoftmaxXBar[nbXBuf]; // for V to wait for X to be ready
|
||||
#endif
|
||||
|
||||
// used internally in the gemm0 warp group
|
||||
// @fixme: use separate arrive and wait for all usage
|
||||
@ -425,8 +446,13 @@ __device__ void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
|
||||
#endif
|
||||
|
||||
#if SWAP_AB
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
__device__ RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src,
|
||||
float skipSoftmaxThreshold, uint32_t* smemSkipVote, bool maybeSkip);
|
||||
#else
|
||||
__device__ RegColWiseVec computeWarpGrpColMax_sync(
|
||||
CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src);
|
||||
#endif
|
||||
__device__ void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd);
|
||||
__device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax);
|
||||
__device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src);
|
||||
@ -675,6 +701,12 @@ CUBIN_EXPORT __global__
|
||||
#endif
|
||||
#if SPEC_DEC
|
||||
SpecDecParams const specDecParams,
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
float const skipSoftmaxThresholdScaleFactor,
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
|
||||
#endif
|
||||
#endif
|
||||
uint32_t* __restrict__ const semaphores
|
||||
= nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)]
|
||||
@ -753,6 +785,10 @@ CUBIN_EXPORT __global__
|
||||
uint32_t const nbSubSeq = isMultiBlockMode ? mha::min(nbTilesInUse / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1;
|
||||
static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2);
|
||||
assert(isMultiBlockMode == (nbSubSeq > 1));
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
bool const disableSkipForShortSeq = (cacheSeqLen < skipSoftmaxThresholdScaleFactor);
|
||||
float const skipSoftmaxThreshold = disableSkipForShortSeq ? 0.0f : skipSoftmaxThresholdScaleFactor / cacheSeqLen;
|
||||
#endif
|
||||
if (idxSubSeq >= nbSubSeq)
|
||||
{
|
||||
return;
|
||||
@ -776,21 +812,34 @@ CUBIN_EXPORT __global__
|
||||
assert(dynamicSmemSize() >= sizeof(SharedMem));
|
||||
SharedMem& smem = *reinterpret_cast<SharedMem*>(&smemByteBuf[0]);
|
||||
|
||||
constexpr uint32_t nbBuffers = 2;
|
||||
static_assert(nbBuffers == SharedMem::nbKBuf && nbBuffers == SharedMem::nbVBuf && nbBuffers == SharedMem::nbXBuf);
|
||||
if (wid < nbBuffers)
|
||||
constexpr uint32_t maxNbBuffers = (SharedMem::nbXBuf > SharedMem::nbVBuf) ? SharedMem::nbXBuf : SharedMem::nbVBuf;
|
||||
static_assert(
|
||||
maxNbBuffers >= SharedMem::nbKBuf && maxNbBuffers >= SharedMem::nbVBuf && maxNbBuffers >= SharedMem::nbXBuf);
|
||||
if (wid < maxNbBuffers)
|
||||
{
|
||||
if (warpElectSync())
|
||||
{
|
||||
smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size);
|
||||
smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size);
|
||||
#if !SWAP_AB
|
||||
smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2);
|
||||
if (wid < SharedMem::nbKBuf)
|
||||
{
|
||||
smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size);
|
||||
}
|
||||
if (wid < SharedMem::nbXBuf)
|
||||
{
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
smem.skipSoftmaxXBar[wid].initialize(gemm0NbThrds + warp_size, gemm0NbThrds + warp_size);
|
||||
smem.vBar[wid].initialize(gemm1NbThrds + warp_size, gemm1NbThrds + warp_size);
|
||||
#else
|
||||
smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size);
|
||||
#endif
|
||||
smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds);
|
||||
|
||||
#if !SWAP_AB
|
||||
smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2);
|
||||
#endif
|
||||
smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (wid == nbBuffers)
|
||||
else if (wid == maxNbBuffers)
|
||||
{
|
||||
if (warpElectSync())
|
||||
{
|
||||
@ -819,6 +868,10 @@ CUBIN_EXPORT __global__
|
||||
SpecDec const specDec{specDecParams, idxReq, idxInputSubSeq, cacheSeqLen};
|
||||
#endif
|
||||
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
uint32_t localSkippedBlockCount = 0;
|
||||
#endif
|
||||
|
||||
// QK gemm
|
||||
constexpr uint32_t nbGmmaInstM = exactDiv(gemm0CtaTileNbTokens, gmma::instM);
|
||||
using Acc = GmmaAcc<gemm0CtaTileNbTokens, ctaNbQHeads>;
|
||||
@ -940,10 +993,39 @@ CUBIN_EXPORT __global__
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf;
|
||||
auto& xBar = smem.xBar[idxXBuf];
|
||||
// update colMax in shared mem and get a register copy
|
||||
#if SWAP_AB
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
auto& skipSoftmaxXBar = smem.skipSoftmaxXBar[idxXBuf];
|
||||
skipSoftmaxXBar.consumed.arrive_and_wait();
|
||||
|
||||
bool const maybeSkip = !disableSkipForShortSeq && idxIter != 0;
|
||||
RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc,
|
||||
skipSoftmaxThreshold, &smem.skipSoftmaxVotesGemm0ToV[idxXBuf], maybeSkip);
|
||||
bool const shouldSkipSoftmaxAttn = static_cast<bool>(smem.skipSoftmaxVotesGemm0ToV[idxXBuf]);
|
||||
unused(skipSoftmaxXBar.produced.arrive());
|
||||
warpGrpOnlineSoftmax(acc, colMax);
|
||||
if (shouldSkipSoftmaxAttn)
|
||||
{
|
||||
xBar.consumed.arrive_and_wait();
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf] = 1U;
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
localSkippedBlockCount++;
|
||||
#endif
|
||||
}
|
||||
asm volatile("fence.proxy.async.shared::cta;\n"); // maybe not used
|
||||
unused(xBar.produced.arrive());
|
||||
continue;
|
||||
}
|
||||
#else
|
||||
RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc);
|
||||
warpGrpOnlineSoftmax(acc, colMax);
|
||||
#endif
|
||||
#else
|
||||
RegRowWiseVec const rowMax = computeWarpGrpRowMax_sync(warpRank, smem.gemm0CurrentSeqMax, acc);
|
||||
warpGrpOnlineSoftmax(acc, rowMax);
|
||||
@ -959,8 +1041,6 @@ CUBIN_EXPORT __global__
|
||||
// map 1 to fp8_max before conversion to fp8
|
||||
acc = acc * kE4M3_MAX;
|
||||
|
||||
uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf;
|
||||
auto& xBar = smem.xBar[idxXBuf];
|
||||
// @fixme: for fp16/bf16, try not to transpose acc here, and leave it to the next GEMM.
|
||||
#if SWAP_AB
|
||||
storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc);
|
||||
@ -989,13 +1069,25 @@ CUBIN_EXPORT __global__
|
||||
storeShmRowWiseVec(warpRank, smem.xRowMax[idxXBuf], rowMax);
|
||||
storeShmRowWiseVec(warpRank, smem.xRowSum[idxXBuf], rowSum);
|
||||
#endif
|
||||
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf] = 0;
|
||||
}
|
||||
#endif
|
||||
__syncwarp();
|
||||
// the release semantics of arrive does not work for async consumers like gmma. additional fence is
|
||||
// needed.
|
||||
asm volatile("fence.proxy.async.shared::cta;\n");
|
||||
unused(xBar.produced.arrive());
|
||||
}
|
||||
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
if (threadIdx.x == 0 && skippedBlockCount != nullptr && totalBlockCount != nullptr)
|
||||
{
|
||||
atomicAdd(skippedBlockCount, localSkippedBlockCount);
|
||||
atomicAdd(totalBlockCount, nbIters);
|
||||
}
|
||||
#endif
|
||||
unused(smem.qBar.consumed.arrive());
|
||||
}
|
||||
else if (warpIdx.z == 1)
|
||||
@ -1043,216 +1135,231 @@ CUBIN_EXPORT __global__
|
||||
uint32_t idxVTile = idxVTileInit + idxIter * nbSubSeq;
|
||||
auto const idxVBuf = idxIter % SharedMem::nbVBuf;
|
||||
auto const idxXBuf = idxVBuf;
|
||||
auto& vBar = smem.vBar[idxVBuf];
|
||||
arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds));
|
||||
auto const& vBuf = smem.vBuf(idxVBuf);
|
||||
#if !SWAP_AB
|
||||
CtaBarrierPair& vtBar = smem.vtBar[idxVBuf];
|
||||
auto& vtBuf = smem.vtBuf(idxVBuf);
|
||||
vtBar.consumed.arrive_and_wait();
|
||||
transposeVTile(warpRank, laneId(), vtBuf, vBuf);
|
||||
vBar.consumed.arrive();
|
||||
vtBar.produced.arrive();
|
||||
#endif
|
||||
auto& xBar = smem.xBar[idxXBuf];
|
||||
auto& vBar = smem.vBar[idxVBuf];
|
||||
auto const& vBuf = smem.vBuf(idxVBuf);
|
||||
xBar.produced.arrive_and_wait();
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
bool shouldSkipSoftmaxAttn = smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf]; // guarded by xBar
|
||||
if (shouldSkipSoftmaxAttn)
|
||||
{
|
||||
vBar.produced.arrive_and_wait();
|
||||
}
|
||||
#endif
|
||||
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
if (!shouldSkipSoftmaxAttn) // skip XVGemm
|
||||
#endif
|
||||
{
|
||||
arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds));
|
||||
#if !SWAP_AB
|
||||
CtaBarrierPair& vtBar = smem.vtBar[idxVBuf];
|
||||
auto& vtBuf = smem.vtBuf(idxVBuf);
|
||||
vtBar.consumed.arrive_and_wait();
|
||||
transposeVTile(warpRank, laneId(), vtBuf, vBuf);
|
||||
vBar.consumed.arrive();
|
||||
vtBar.produced.arrive();
|
||||
#endif
|
||||
#if !defined(NDEBUG) && DBG_PRINT
|
||||
#if SWAP_AB
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
printf("colMax:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
printf("%f, ", smem.xColMax[idxXBuf][i]);
|
||||
}
|
||||
printf("\n");
|
||||
printf("colSum:\n");
|
||||
for (int n = 0; n < 4; n++)
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
printf("colMax:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
printf("%f, ", smem.xColSum[idxXBuf][n][i]);
|
||||
printf("%f, ", smem.xColMax[idxXBuf][i]);
|
||||
}
|
||||
printf("\n");
|
||||
printf("colSum:\n");
|
||||
for (int n = 0; n < 4; n++)
|
||||
{
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
printf("%f, ", smem.xColSum[idxXBuf][n][i]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
printf("X:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
for (int j = 0; j < gemm0CtaTileNbTokens; j++)
|
||||
{
|
||||
auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart);
|
||||
auto const e = reinterpret_cast<Vec<__nv_fp8_e4m3, 16>&>(
|
||||
smem.xBuf(idxXBuf)[j / elemsPerXPart].template at<true>(
|
||||
i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain];
|
||||
printf("%.2f, ", float(e));
|
||||
if (j % 16 == 15)
|
||||
{
|
||||
printf("| ");
|
||||
}
|
||||
}
|
||||
printf("\n\n");
|
||||
}
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
#else
|
||||
if (blockIdx.y == 1 && threadIdx.x == 0)
|
||||
{
|
||||
printf("rowMax:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
printf("%f, ", smem.xRowMax[idxXBuf][i]);
|
||||
}
|
||||
printf("\n");
|
||||
printf("rowSum:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
printf("%f, ", smem.xRowSum[idxXBuf][i]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
printf("X:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
for (int j = 0; j < gemm0CtaTileNbTokens; j++)
|
||||
{
|
||||
auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart);
|
||||
auto const e = reinterpret_cast<Vec<__nv_fp8_e4m3, 16>&>(
|
||||
smem.xBuf(idxXBuf)[j / elemsPerXPart].template at<true>(
|
||||
i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain];
|
||||
printf("%.2f, ", float(e));
|
||||
if (j % 16 == 15)
|
||||
{
|
||||
printf("| ");
|
||||
}
|
||||
}
|
||||
printf("\n\n");
|
||||
}
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
#else
|
||||
if (blockIdx.y == 1 && threadIdx.x == 0)
|
||||
{
|
||||
printf("rowMax:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
printf("%f, ", smem.xRowMax[idxXBuf][i]);
|
||||
}
|
||||
printf("\n");
|
||||
printf("rowSum:\n");
|
||||
for (int i = 0; i < ctaNbQHeads; i++)
|
||||
{
|
||||
printf("%f, ", smem.xRowSum[idxXBuf][i]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if SWAP_AB
|
||||
// @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead.
|
||||
rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf],
|
||||
smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar);
|
||||
// @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead.
|
||||
rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf],
|
||||
smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar);
|
||||
#else
|
||||
rescaleGemm1AccForNewRowMax_sync(
|
||||
warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf], smem.gemm1AccColMax, acc, smem.gemm1AccColSum);
|
||||
rescaleGemm1AccForNewRowMax_sync(warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf],
|
||||
smem.gemm1AccColMax, acc, smem.gemm1AccColSum);
|
||||
#endif
|
||||
auto& xBuf = smem.xBuf(idxXBuf);
|
||||
auto& xBuf = smem.xBuf(idxXBuf);
|
||||
|
||||
auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8,
|
||||
gmma::getSwizzleMode<true>(SharedMem::XBuffer::Elem{}))
|
||||
.raw();
|
||||
auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8,
|
||||
gmma::getSwizzleMode<true>(SharedMem::XBuffer::Elem{}))
|
||||
.raw();
|
||||
#if CACHE_ELEM_ENUM == 0
|
||||
auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8,
|
||||
gmma::getSwizzleMode<true>(SharedMem::VBuffer::Elem{}))
|
||||
.raw();
|
||||
auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8,
|
||||
gmma::getSwizzleMode<true>(SharedMem::VBuffer::Elem{}))
|
||||
.raw();
|
||||
#endif
|
||||
#if SWAP_AB
|
||||
//@fixme: to reduce code size, we can disable unroll and use double-buffer for LDSM in loadVTileTransposed.
|
||||
#pragma unroll
|
||||
for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++)
|
||||
{
|
||||
for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++)
|
||||
{
|
||||
#if CACHE_ELEM_ENUM == 2
|
||||
Vec<RegMatAFrag, gemm1NbGmmaInstM> const fragA
|
||||
= loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK);
|
||||
Vec<RegMatAFrag, gemm1NbGmmaInstM> const fragA
|
||||
= loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK);
|
||||
#if !defined(NDEBUG) && DBG_PRINT
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
printf("fragA:\nidxInstK == %u\n", idxInstK);
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
for (int m = 0; m < 2; m++)
|
||||
{
|
||||
for (int w = 0; w < 4; w++)
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
if (warpRank == w)
|
||||
printf("fragA:\nidxInstK == %u\n", idxInstK);
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
for (int m = 0; m < 2; m++)
|
||||
{
|
||||
for (int w = 0; w < 4; w++)
|
||||
{
|
||||
if (laneId() == 0)
|
||||
if (warpRank == w)
|
||||
{
|
||||
printf(" warpRank = %u\n", warpRank);
|
||||
}
|
||||
__syncwarp();
|
||||
for (int a = 0; a < 2; a++)
|
||||
{
|
||||
for (int b = 0; b < 8; b++)
|
||||
if (laneId() == 0)
|
||||
{
|
||||
for (int c = 0; c < 2; c++)
|
||||
printf(" warpRank = %u\n", warpRank);
|
||||
}
|
||||
__syncwarp();
|
||||
for (int a = 0; a < 2; a++)
|
||||
{
|
||||
for (int b = 0; b < 8; b++)
|
||||
{
|
||||
for (int d = 0; d < 4; d++)
|
||||
for (int c = 0; c < 2; c++)
|
||||
{
|
||||
if (laneId() == b * 4 + d)
|
||||
for (int d = 0; d < 4; d++)
|
||||
{
|
||||
for (int e = 0; e < 4; e++)
|
||||
if (laneId() == b * 4 + d)
|
||||
{
|
||||
auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>(
|
||||
fragA[m](0, c)(a, 0));
|
||||
printf("%.2f, ", float(elem4[e]));
|
||||
for (int e = 0; e < 4; e++)
|
||||
{
|
||||
auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>(
|
||||
fragA[m](0, c)(a, 0));
|
||||
printf("%.2f, ", float(elem4[e]));
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
if (laneId() == 0)
|
||||
{
|
||||
printf("\n");
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
if (laneId() == 0)
|
||||
if (laneId() == 0 && a == 0)
|
||||
{
|
||||
printf("\n");
|
||||
printf("----------------------\n");
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
if (laneId() == 0 && a == 0)
|
||||
{
|
||||
printf("----------------------\n");
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * idxInstK};
|
||||
auto const descX = addAddr(descXBase,
|
||||
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
|
||||
0, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
|
||||
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * idxInstK};
|
||||
auto const descX = addAddr(descXBase,
|
||||
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
|
||||
0, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
|
||||
#if CACHE_ELEM_ENUM == 2
|
||||
gmma::fence();
|
||||
gmma::fence();
|
||||
#endif
|
||||
#pragma unroll
|
||||
for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++)
|
||||
{
|
||||
for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++)
|
||||
{
|
||||
#if CACHE_ELEM_ENUM == 0
|
||||
auto const descV
|
||||
= addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0));
|
||||
gmma::mma_async_shmA<MathElem, ctaNbQHeads, true, false>(
|
||||
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
|
||||
descV, descX, true);
|
||||
auto const descV
|
||||
= addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0));
|
||||
gmma::mma_async_shmA<MathElem, ctaNbQHeads, true, false>(
|
||||
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
|
||||
descV, descX, true);
|
||||
#elif CACHE_ELEM_ENUM == 2
|
||||
gmma::mma_async_regA<MathElem, ctaNbQHeads>(
|
||||
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
|
||||
reinterpret_cast<uint32_t const(&)[2][2][1]>(fragA[idxInstM]), descX, true);
|
||||
gmma::mma_async_regA<MathElem, ctaNbQHeads>(
|
||||
reinterpret_cast<float(&)[exactDiv(ctaNbQHeads, gmma::instNBase)][2][2]>(acc(idxInstM, 0)),
|
||||
reinterpret_cast<uint32_t const(&)[2][2][1]>(fragA[idxInstM]), descX, true);
|
||||
#endif
|
||||
}
|
||||
gmma::commit_group();
|
||||
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
|
||||
// gmma.
|
||||
gmma::wait_group<0>();
|
||||
}
|
||||
gmma::commit_group();
|
||||
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
|
||||
// gmma.
|
||||
gmma::wait_group<0>();
|
||||
}
|
||||
#else
|
||||
auto const descVTBase = gmma::makeMatDesc(
|
||||
nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode<true>(SharedMem::VTBuffer{}))
|
||||
.raw();
|
||||
vtBar.produced.arrive_and_wait();
|
||||
auto const descVTBase = gmma::makeMatDesc(
|
||||
nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode<true>(SharedMem::VTBuffer{}))
|
||||
.raw();
|
||||
vtBar.produced.arrive_and_wait();
|
||||
// if (idxIter == 1 && threadIdx.x == 0) {
|
||||
// printf("vtBuf:\n");
|
||||
// dbg::printArray2D<__nv_fp8_e4m3, true>(vtBuf);
|
||||
// }
|
||||
#pragma unroll
|
||||
for (uint32_t m = 0; m < Gemm1Acc::rows; m++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++)
|
||||
for (uint32_t m = 0; m < Gemm1Acc::rows; m++)
|
||||
{
|
||||
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * k};
|
||||
auto const descX = addAddr(descXBase,
|
||||
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
|
||||
gmma::instM * m, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
|
||||
auto const descVT = addAddr(
|
||||
descVTBase, &vtBuf(0, kOffsetInGrains.template mod<SharedMem::VTBuffer::cols>().get()));
|
||||
gmma::mma_async_shmA<MathElem, headElems>(
|
||||
reinterpret_cast<float(&)[exactDiv(headElems, gmma::instNBase)][2][2]>(acc(m, 0)), descX,
|
||||
descVT, true);
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++)
|
||||
{
|
||||
BoundedVal<grainsPerInstK * gemm1NbGmmaInstK> const kOffsetInGrains{grainsPerInstK * k};
|
||||
auto const descX = addAddr(descXBase,
|
||||
&xBuf[kOffsetInGrains.template divBy<SharedMem::XBuffer::Elem::cols>().get()](
|
||||
gmma::instM * m, kOffsetInGrains.template mod<SharedMem::XBuffer::Elem::cols>().get()));
|
||||
auto const descVT = addAddr(
|
||||
descVTBase, &vtBuf(0, kOffsetInGrains.template mod<SharedMem::VTBuffer::cols>().get()));
|
||||
gmma::mma_async_shmA<MathElem, headElems>(
|
||||
reinterpret_cast<float(&)[exactDiv(headElems, gmma::instNBase)][2][2]>(acc(m, 0)), descX,
|
||||
descVT, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
gmma::commit_group();
|
||||
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of gmma.
|
||||
gmma::wait_group<0>();
|
||||
gmma::commit_group();
|
||||
//@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of
|
||||
// gmma.
|
||||
gmma::wait_group<0>();
|
||||
#endif
|
||||
}
|
||||
|
||||
if (idxIter == nbIters - 1)
|
||||
{
|
||||
// gmma::wait_group should have already synchronized threads, so this may be unnecessary.
|
||||
@ -1471,8 +1578,24 @@ CUBIN_EXPORT __global__
|
||||
tensorMap
|
||||
#endif
|
||||
};
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
for (auto& b : smem.skipSoftmaxXBar)
|
||||
{
|
||||
unused(b.consumed.arrive());
|
||||
}
|
||||
#endif
|
||||
for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++)
|
||||
{
|
||||
uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf;
|
||||
auto& vBar = smem.vBar[idxVBuf];
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
uint32_t idxXBuf = idxIter % SharedMem::nbXBuf;
|
||||
auto& skipSoftmaxXBar = smem.skipSoftmaxXBar[idxXBuf];
|
||||
skipSoftmaxXBar.produced.arrive_and_wait();
|
||||
bool shouldSkipSoftmaxAttn = smem.skipSoftmaxVotesGemm0ToV[idxXBuf];
|
||||
skipSoftmaxXBar.consumed.arrive();
|
||||
#endif
|
||||
|
||||
uint32_t const idxVTile = idxVTileInit + idxIter * nbSubSeq;
|
||||
vTileLoader.loadPages(idxVTile);
|
||||
#if USE_INPUT_KV || ENABLE_PDL == 2
|
||||
@ -1506,8 +1629,20 @@ CUBIN_EXPORT __global__
|
||||
}
|
||||
#endif
|
||||
|
||||
uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf;
|
||||
auto& vBar = smem.vBar[idxVBuf];
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
if (shouldSkipSoftmaxAttn)
|
||||
{
|
||||
vBar.consumed.arrive_and_wait();
|
||||
// compared to non-skip softmax attn, we need to increase vBar.produced count to avoid race
|
||||
// condition where vBar.consumed is arrived again without wait without skip softmax attn, XVGemm
|
||||
// will wait for tx_count, so its progress won't go ahead of vload warp with skip softmax attn,
|
||||
// XVGemm WG may go ahead of vload warp, as previous vBar only have XVGemm WG threads and a tx_count
|
||||
// (now = 0). Then it may arrive vBar.consumed before it is arrive_and_wait-ed
|
||||
vBar.produced.arrive();
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
|
||||
vBar.consumed.arrive_and_wait();
|
||||
if (warpElectSync())
|
||||
{
|
||||
@ -1517,6 +1652,9 @@ CUBIN_EXPORT __global__
|
||||
vTileLoader.loadData(smem.vBuf(idxVBuf)[idxPart], idxVTile, idxPart, vBar.produced);
|
||||
}
|
||||
}
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
vBar.produced.arrive();
|
||||
#endif
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
@ -1992,9 +2130,23 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
|
||||
#endif // SPEC_DEC
|
||||
|
||||
// smemColMax is persistent across multiple iterations
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
__device__ inline RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax,
|
||||
Gemm0Acc const& src, float skipSoftmaxThreshold, uint32_t* smemSkipVote, bool maybeSkip)
|
||||
#else
|
||||
__device__ inline RegColWiseVec computeWarpGrpColMax_sync(
|
||||
CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src)
|
||||
#endif
|
||||
{
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
*smemSkipVote = maybeSkip ? 1U : 0U; // will sync before vote
|
||||
}
|
||||
float const lnThreshold
|
||||
= log(skipSoftmaxThreshold); // this can be -inf, but should be safe as we only use it for comparison
|
||||
#endif
|
||||
|
||||
auto colMax = RegColWiseVec::filled(Vec<float, 2>::filled(safeInitRowMax));
|
||||
#pragma unroll
|
||||
for (uint32_t n = 0; n < src.cols; n++)
|
||||
@ -2029,6 +2181,9 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
|
||||
}
|
||||
|
||||
uint32_t const lane = laneId();
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
auto prevOrCurrentMax = RegColWiseVec();
|
||||
#if SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
|
||||
if (lane < 4)
|
||||
{
|
||||
#pragma unroll
|
||||
@ -2037,12 +2192,43 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++)
|
||||
{
|
||||
atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
|
||||
prevOrCurrentMax[n][j] = smemColMax[8 * n + 2 * lane + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
warpGrpBar.arrive_and_wait();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
if (lane < 4)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t n = 0; n < src.cols; n++)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++)
|
||||
{
|
||||
#if SKIP_SOFTMAX_ATTN && !SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
|
||||
// prevOrCurrentMax <= actual smemColMax (after updates from all 4 warps done), but always >=
|
||||
// smemColMax(Prev), the smemColMax value *before* this tile is computed.
|
||||
// When determine whether to skip, it is safe to use prevOrCurrentMax: 1) all 4 warps' localmax <
|
||||
// smemColMax(Prev), then prevOrCurrentMax == smemColMax(Prev), result not affected; 2) if some localmax
|
||||
// > smemColMax(Prev), prevOrCurrentMax > smemColMax(Prev), some warps may incorrectly vote skip, but
|
||||
// at least one warp whose localColMax is larger will not skip, then the tile is not skipped.
|
||||
// This reduces some sync and check, but has issue when threshold > 1.
|
||||
prevOrCurrentMax[n][j] = atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
|
||||
#else
|
||||
atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
warpGrpBar.arrive_and_wait();
|
||||
|
||||
uint32_t const idxInQuad = lane % 4;
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
bool localShouldSkip = true;
|
||||
#endif
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t n = 0; n < src.cols; n++)
|
||||
@ -2050,10 +2236,21 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++)
|
||||
{
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
if (lane < 4 && 8 * n + 2 * idxInQuad + j < headGrpSize)
|
||||
{
|
||||
localShouldSkip &= (colMax[n][j] - prevOrCurrentMax[n][j]) < lnThreshold;
|
||||
}
|
||||
#endif
|
||||
assert(colMax[n][j] <= smemColMax[8 * n + 2 * idxInQuad + j]);
|
||||
colMax[n][j] = smemColMax[8 * n + 2 * idxInQuad + j];
|
||||
}
|
||||
}
|
||||
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
atomicAnd(smemSkipVote, static_cast<uint32_t>(localShouldSkip)); // this will be translated to redux and voteu
|
||||
#endif
|
||||
|
||||
warpGrpBar.arrive_and_wait();
|
||||
return colMax;
|
||||
}
|
||||
@ -2199,7 +2396,7 @@ __device__ inline void storeGemm0AccToShm(
|
||||
uint32_t const idxOctInsideHalf = idxInHalf / 8;
|
||||
uint32_t const idxRowInsideOct = lane % 8;
|
||||
uint32_t const warpBaseC = 16 * warpRank;
|
||||
auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> std::pair<uint32_t, uint32_t>
|
||||
auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> mha::pair<uint32_t, uint32_t>
|
||||
{
|
||||
uint32_t const accR = idxAccCoreMat / Gemm0Acc::cols;
|
||||
uint32_t const accC = idxAccCoreMat % Gemm0Acc::cols;
|
||||
@ -3231,6 +3428,24 @@ __device__ inline void storeRotatedPairsForQ(SharedMem::QBuffer& dst,
|
||||
}
|
||||
|
||||
#ifndef GENERATE_CUBIN
|
||||
uint32_t computeNbSubSeqPerSeqHopperF8MHA(
|
||||
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen)
|
||||
{
|
||||
auto const env = std::getenv("XQA_NB_SUB_SEQ");
|
||||
if (env != nullptr)
|
||||
{
|
||||
int32_t const val = std::stoi(env);
|
||||
if (val > 0)
|
||||
{
|
||||
return val;
|
||||
}
|
||||
}
|
||||
float const factor = 0.25f;
|
||||
return mha::min<uint32_t>(
|
||||
mha::max<uint32_t>(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)),
|
||||
divUp(maxSeqLen, gemm0CtaTileNbTokens));
|
||||
}
|
||||
|
||||
void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#if SLIDING_WINDOW
|
||||
uint32_t slidingWinSize,
|
||||
@ -3268,6 +3483,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
// int8/fp8 KV cache.
|
||||
#if SPEC_DEC
|
||||
SpecDecParams const& specDecParams,
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
float const skipSoftmaxThresholdScaleFactor,
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
|
||||
#endif
|
||||
#endif
|
||||
uint32_t* semaphores, void* scratch, cudaStream_t stream)
|
||||
{
|
||||
@ -3286,22 +3507,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
uint32_t const nbVHeads = nbKHeads;
|
||||
uint32_t const nbQHeads = nbKHeads * headGrpSize;
|
||||
uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads;
|
||||
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t
|
||||
{
|
||||
auto const env = std::getenv("XQA_NB_SUB_SEQ");
|
||||
if (env != nullptr)
|
||||
{
|
||||
int32_t const val = std::stoi(env);
|
||||
if (val > 0)
|
||||
{
|
||||
return val;
|
||||
}
|
||||
}
|
||||
float const factor = 0.25f;
|
||||
return mha::min<uint32_t>(
|
||||
mha::max<uint32_t>(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)),
|
||||
divUp(maxSeqLen, gemm0CtaTileNbTokens));
|
||||
}();
|
||||
uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqHopperF8MHA(prop, batchSize, nbKHeads, maxSeqLen);
|
||||
#if SPEC_DEC
|
||||
uint32_t const qSeqLen = specDecParams.qSeqLen;
|
||||
#else
|
||||
@ -3371,6 +3577,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
|
||||
#endif
|
||||
#if SPEC_DEC
|
||||
specDecParams,
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
skipSoftmaxThresholdScaleFactor,
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
skippedBlockCount, totalBlockCount,
|
||||
#endif
|
||||
#endif
|
||||
semaphores, scratch);
|
||||
#else
|
||||
|
||||
@ -1272,6 +1272,19 @@ using is_void = is_same<remove_cv_t<T>, void>;
|
||||
template <typename T>
|
||||
inline constexpr bool is_void_v = is_void<T>::value;
|
||||
#endif
|
||||
|
||||
#ifndef GENERATE_CUBIN
|
||||
template <typename T1, typename T2>
|
||||
using pair = std::pair<T1, T2>;
|
||||
#else
|
||||
template <typename T1, typename T2>
|
||||
struct pair
|
||||
{
|
||||
T1 first;
|
||||
T2 second;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace mha
|
||||
|
||||
#if GENERATE_CUBIN
|
||||
|
||||
@ -50,7 +50,8 @@ using Vector = Matrix<Type, Size, 1>;
|
||||
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor,
|
||||
uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum)
|
||||
{
|
||||
uint32_t const nbTiles = divUp(seqLen, tileSize);
|
||||
auto gemm1Acc = Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor>::Zero().eval();
|
||||
@ -61,6 +62,16 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
|
||||
float const qkScale = qScale * kvScale / sqrtf(validElemsPerHead);
|
||||
uint32_t const seqBeg = (seqLen < slidingWinSize ? 0 : seqLen - slidingWinSize);
|
||||
uint32_t const idxTileBeg = seqBeg / tileSize;
|
||||
|
||||
uint32_t const nbSubSeq = (multiBlockNum > 0 && nbTiles >= 2) ? mha::min(nbTiles, multiBlockNum) : 1;
|
||||
std::vector<Eigen::Vector<float, headGrpSize>> skipRowMaxs(nbSubSeq);
|
||||
for (uint32_t i = 0; i < nbSubSeq; i++)
|
||||
{
|
||||
skipRowMaxs[i].fill(-INFINITY);
|
||||
}
|
||||
bool const disableSkipForShortSeq = (seqLen < skipSoftmaxThresholdScaleFactor);
|
||||
float const skipSoftmaxThreshold = disableSkipForShortSeq ? 0.0f : skipSoftmaxThresholdScaleFactor / seqLen;
|
||||
|
||||
for (uint32_t idxTile = idxTileBeg; idxTile < nbTiles; idxTile++)
|
||||
{
|
||||
Eigen::Matrix<float, headGrpSize, tileSize, Eigen::RowMajor> gemm0Acc;
|
||||
@ -88,7 +99,22 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
|
||||
}
|
||||
}
|
||||
|
||||
Eigen::Vector<float, headGrpSize> const tileRowMax = gemm0Acc.rowwise().maxCoeff().cwiseMax(rowMax).eval();
|
||||
Eigen::Vector<float, headGrpSize> const localRowMax = gemm0Acc.rowwise().maxCoeff().eval();
|
||||
Eigen::Vector<float, headGrpSize> const tileRowMax = localRowMax.cwiseMax(rowMax).eval();
|
||||
auto const prevSkipRowMax = skipRowMaxs[idxTile % nbSubSeq];
|
||||
skipRowMaxs[idxTile % nbSubSeq] = localRowMax.cwiseMax(skipRowMaxs[idxTile % nbSubSeq]).eval();
|
||||
|
||||
if (!disableSkipForShortSeq && skipSoftmaxThreshold > 0)
|
||||
{
|
||||
*totalBlockCount += 1;
|
||||
auto const skipSoftmaxMask = ((localRowMax - prevSkipRowMax).array() < std::log(skipSoftmaxThreshold));
|
||||
bool const skipBlock = skipSoftmaxMask.all() && ((idxTile - idxTileBeg) >= nbSubSeq);
|
||||
if (skipBlock)
|
||||
{
|
||||
*skippedBlockCount += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
Eigen::Matrix<float, headGrpSize, tileSize, Eigen::RowMajor> tileX
|
||||
= (gemm0Acc.colwise() - tileRowMax).array().exp().eval();
|
||||
@ -138,7 +164,8 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
|
||||
template Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> \
|
||||
refFlashAttention<prec, tileSize, isPaged, useBeamSearch>(IOHead const* q, \
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, \
|
||||
float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks)
|
||||
float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, \
|
||||
float skipSoftmaxThreshold, uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum)
|
||||
|
||||
INSTANTIATE_refFlashAttention(CacheElem, 64, false, false);
|
||||
INSTANTIATE_refFlashAttention(CacheElem, 64, false, true);
|
||||
|
||||
@ -88,7 +88,8 @@ struct CacheSeq<true, true>
|
||||
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
|
||||
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
|
||||
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks);
|
||||
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor,
|
||||
uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum);
|
||||
|
||||
template <typename MathElem, bool isPaged, bool useBeamSearch>
|
||||
#if SPEC_DEC
|
||||
|
||||
@ -150,7 +150,8 @@ template <uint32_t nbKHeads>
|
||||
#endif
|
||||
#endif
|
||||
void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, bool verbose = false,
|
||||
bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30)
|
||||
bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30,
|
||||
float skipSoftmaxThresholdScaleFactor = 0.0f)
|
||||
{
|
||||
#if IS_MLA
|
||||
if (nbKHeads != 1)
|
||||
@ -224,6 +225,12 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
seqLen = (16U << 20) / gmemCacheHeadBytes; // 32MB per K+V head.
|
||||
}
|
||||
ctxLen = std::min(ctxLen, seqLen);
|
||||
uint32_t skippedBlockCount = 0;
|
||||
uint32_t totalBlockCount = 0;
|
||||
if (skipSoftmaxThresholdScaleFactor > 0)
|
||||
{
|
||||
assert(useQGMMA);
|
||||
}
|
||||
float const kScale = cacheElemSize == 2 ? 1.f : 1 / 4.f;
|
||||
float const vScale = kScale;
|
||||
float const qScale = 1.f;
|
||||
@ -329,6 +336,17 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
auto const rcpOutScale = ManagedMemBuf<float>(1);
|
||||
auto const seqLenList = ManagedMemBuf<uint32_t[beamWidth]>(batchSize);
|
||||
auto const ctxLenList = ManagedMemBuf<uint32_t[beamWidth]>(batchSize);
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
#ifdef SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
auto const kernelSkippedBlockCount = ManagedMemBuf<uint32_t>(1);
|
||||
auto const kernelTotalBlockCount = ManagedMemBuf<uint32_t>(1);
|
||||
kernelSkippedBlockCount[0] = 0;
|
||||
kernelTotalBlockCount[0] = 0;
|
||||
#endif
|
||||
#else
|
||||
EXPECT_EQ(skipSoftmaxThresholdScaleFactor, 0.0f)
|
||||
<< "Got non-zero skipSoftmaxThresholdScaleFactor while SKIP_SOFTMAX_ATTN is not enabled.";
|
||||
#endif
|
||||
#if USE_PAGED_KV_CACHE
|
||||
auto const pageListBuf = ManagedMemBuf<std::byte>(pageListBytes);
|
||||
#if PAGED_KV_CACHE_LAYOUT == 1
|
||||
@ -726,6 +744,11 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
maxSeqLen, &seqLenList[0][0], batchSize, kvCacheScale.get(), semaphores.get(), scratch, stream);
|
||||
};
|
||||
#else
|
||||
auto multiBlockNum = [&]()
|
||||
{
|
||||
auto const calcFunc = useQGMMA ? &computeNbSubSeqPerSeqHopperF8MHA : &computeNbSubSeqPerSeqMHA;
|
||||
return calcFunc(prop, batchSize, nbKHeads, maxSeqLen);
|
||||
}();
|
||||
auto runKernel = [&]()
|
||||
{
|
||||
auto const launchFunc = useQGMMA ? &launchHopperF8MHA : &launchMHA;
|
||||
@ -776,6 +799,12 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
batchSize, kvCacheScale.get(),
|
||||
#if SPEC_DEC
|
||||
specDecParams,
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
skipSoftmaxThresholdScaleFactor,
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
kernelSkippedBlockCount.get(), kernelTotalBlockCount.get(),
|
||||
#endif
|
||||
#endif
|
||||
semaphores.get(), scratch, stream);
|
||||
checkCuda(cudaGetLastError());
|
||||
@ -813,6 +842,10 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
checkCuda(cudaEventRecord(toc, stream));
|
||||
prefetchToDevice(cudaCpuDeviceId);
|
||||
checkCuda(cudaStreamSynchronize(stream));
|
||||
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
kernelSkippedBlockCount[0] /= nbIters;
|
||||
kernelTotalBlockCount[0] /= nbIters;
|
||||
#endif
|
||||
if (testPerf)
|
||||
{
|
||||
float ms;
|
||||
@ -849,6 +882,15 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
= totalNbCacheLoadBytes + inputBytes + outputBytes; // we ignore page indices and beam search indices.
|
||||
float const dramSolTime = totalTraffic / bandwidth * 1E3f;
|
||||
float const dramSolRatio = dramSolTime / ms;
|
||||
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
size_t const totalNbCacheLoadWithSkip = gmemCacheHeadBytes
|
||||
* (nbKHeads + nbVHeads * (1 - 1.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]))
|
||||
* nbLoadedCacheTokens;
|
||||
float const totalTrafficWithSkip
|
||||
= totalNbCacheLoadWithSkip + inputBytes + outputBytes; // we ignore page indices and beam search indices.
|
||||
float const dramSolTimeWithSkip = totalTrafficWithSkip / bandwidth * 1E3f;
|
||||
float const dramSolRatioWithSkip = dramSolTimeWithSkip / ms;
|
||||
#endif
|
||||
if (verbose)
|
||||
{
|
||||
printf("done\n");
|
||||
@ -863,7 +905,13 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
}
|
||||
float const tops = headGrpSize * qSeqLen * float(seqLen) * (validElemsPerKHead + validElemsPerVHead) * 2
|
||||
* nbKHeads * batchSize / (ms * 1E-3F) * 1E-12F;
|
||||
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
printf("kernel skippedBlockCount: %d/%d (%.2f%%)\n", kernelSkippedBlockCount[0], kernelTotalBlockCount[0],
|
||||
kernelTotalBlockCount[0] == 0 ? 0.0f : 100.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]);
|
||||
printf("dramSolRatioWithSkip: %f%% (%f ms, TOPS = %f)\n", dramSolRatioWithSkip * 100, ms, tops);
|
||||
#else
|
||||
printf("dramSolRatio: %f%% (%f ms, TOPS = %f)\n", dramSolRatio * 100, ms, tops);
|
||||
#endif
|
||||
}
|
||||
if (refCheck)
|
||||
{
|
||||
@ -1084,8 +1132,8 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
if (useQGMMA)
|
||||
{
|
||||
refOutput = refFlashAttention<CacheElem, 64>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
|
||||
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize,
|
||||
refAttentionSinks);
|
||||
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, refAttentionSinks,
|
||||
skipSoftmaxThresholdScaleFactor, &skippedBlockCount, &totalBlockCount, multiBlockNum);
|
||||
// refOutput = refAttention<CacheElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
|
||||
// vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
|
||||
}
|
||||
@ -1132,6 +1180,14 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
printf("host skippedBlockCount: %d/%d (%.2f%%)\n", skippedBlockCount, totalBlockCount,
|
||||
totalBlockCount == 0 ? 0.0f : 100.0f * skippedBlockCount / totalBlockCount);
|
||||
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
|
||||
printf("kernel skippedBlockCount: %d/%d (%.2f%%)\n", kernelSkippedBlockCount[0], kernelTotalBlockCount[0],
|
||||
kernelTotalBlockCount[0] == 0 ? 0.0f : 100.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]);
|
||||
#endif
|
||||
#endif
|
||||
if (saveData)
|
||||
{
|
||||
fout_refOutput.close();
|
||||
@ -1253,6 +1309,14 @@ TEST(RefCheck, llama_V2_70b)
|
||||
#if SLIDING_WINDOW
|
||||
runTest<2>(2, 4096, false, true, false, false, false, ~0, 256);
|
||||
runTest<2>(2, 400, false, true, false, false, false, ~0U, 256);
|
||||
#endif
|
||||
#if SKIP_SOFTMAX_ATTN
|
||||
runTest<1>(32, 2048, false, true, false, false, false, ~0U, 1U << 30, 0.f);
|
||||
runTest<4>(32, 1538, false, true, false, false, false, ~0U, 1U << 30, 1280.f);
|
||||
runTest<2>(32, 4096, false, true, false, false, false, ~0U, 1U << 30, 125.f);
|
||||
runTest<4>(32, 300, false, true, false, false, false, ~0U, 1U << 30, 80.f);
|
||||
runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 501.0f);
|
||||
runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 500.f);
|
||||
#endif
|
||||
runTest<8>(120, 367, false, true);
|
||||
runTest<8>(1792, 2048, false, true);
|
||||
|
||||
@ -298,6 +298,11 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
|
||||
xqaParams.use_sparse_attention = useTllmGenSparseAttention();
|
||||
// Skip softmax threshold.
|
||||
xqaParams.skip_softmax_threshold_scale_factor = mSkipSoftmaxThresholdScaleFactorDecode;
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
// Statistics of skip-softmax, pointers of device memory for output
|
||||
xqaParams.skip_softmax_total_blocks = mSkipSoftmaxTotalBlocks;
|
||||
xqaParams.skip_softmax_skipped_blocks = mSkipSoftmaxSkippedBlocks;
|
||||
#endif
|
||||
// Cross attention parameters.
|
||||
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;
|
||||
|
||||
|
||||
@ -105,7 +105,8 @@ CubinObj CompileEngine::compile() const
|
||||
// scratch in this case.
|
||||
/*use_input_kv=*/applyRoPEInXqaKernel,
|
||||
/*rope_style=*/ropeStyle,
|
||||
/*is_spec_dec_tree=*/mXqaParams.is_spec_dec_tree};
|
||||
/*is_spec_dec_tree=*/mXqaParams.is_spec_dec_tree,
|
||||
/*use_skip_softmax_attn=*/mXqaParams.skip_softmax_threshold_scale_factor != 0};
|
||||
if (context.kernel_type == TLLM_XQA_JIT_MLA)
|
||||
{
|
||||
auto const& c = context;
|
||||
|
||||
@ -232,6 +232,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
|
||||
jit::CubinObj const* const cubinObj = mResource->getCubinObjRegistry()->getCubin(key);
|
||||
TLLM_CHECK(cubinObj != nullptr && cubinObj->isInitialized());
|
||||
bool const isSpecDec = xqaParams.multi_query_tokens;
|
||||
bool const isSkipSoftmax = xqaParams.skip_softmax_threshold_scale_factor != 0;
|
||||
bool const isHMMAKernel = (cubinObj->getKernelType() == XQAKernelType::kAMPERE_WARP_SPECIALIZED);
|
||||
bool const isGMMAKernel = (cubinObj->getKernelType() == XQAKernelType::kHOPPER_WARP_SPECIALIZED);
|
||||
bool const isMLAKernel = (cubinObj->getKernelType() == XQAKernelType::kSM120_MLA);
|
||||
@ -378,7 +379,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
|
||||
.mask = reinterpret_cast<SpecDecParams::MaskType const*>(xqaParams.spec_decoding_packed_mask)};
|
||||
};
|
||||
|
||||
constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 16;
|
||||
constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 19;
|
||||
uint32_t idxNextParam = 0;
|
||||
void* kernelParams[kMAX_NB_KERNEL_PARAMS];
|
||||
auto appendParam = [&](auto* p) mutable
|
||||
@ -514,6 +515,16 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
|
||||
appendParam(&specDecParams);
|
||||
specDecBlocks = divUp(specDecParams.qSeqLen, 64 / num_q_heads_over_kv);
|
||||
}
|
||||
if (isSkipSoftmax)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(isGMMAKernel, "skip softmax is only supported for GMMA kernel for now.");
|
||||
TLLM_CHECK_WITH_INFO(!isSpecDec, "skip softmax is not supported with spec dec for now.");
|
||||
appendParam(&xqaParams.skip_softmax_threshold_scale_factor);
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
appendParam(&xqaParams.skip_softmax_total_blocks);
|
||||
appendParam(&xqaParams.skip_softmax_skipped_blocks);
|
||||
#endif
|
||||
}
|
||||
appendParam(&launchParams.semaphores);
|
||||
appendParam(&launchParams.scratch);
|
||||
kernelParams[idxNextParam] = nullptr; // one extra nullptr at end as guard.
|
||||
|
||||
@ -96,10 +96,16 @@ bool supportConfigQGMMA(XQAParams const& xqaParams, int SM, bool forConfigurePlu
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if (xqaParams.kv_cache_data_type != DATA_TYPE_E4M3)
|
||||
if (!contains({DATA_TYPE_FP16, DATA_TYPE_BF16, DATA_TYPE_E4M3}, xqaParams.kv_cache_data_type))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
bool const is_skip_softmax = xqaParams.skip_softmax_threshold_scale_factor != 0;
|
||||
if (!is_skip_softmax && xqaParams.kv_cache_data_type != DATA_TYPE_E4M3)
|
||||
{
|
||||
// Only use hopper kernel with fp16/bf16 kv cache data type when skip softmax is enabled
|
||||
return false;
|
||||
}
|
||||
if (xqaParams.beam_width != 1)
|
||||
{
|
||||
return false;
|
||||
@ -168,6 +174,11 @@ bool supportConfigHMMA(XQAParams const& xqaParams, int SM, bool forConfigurePlug
|
||||
{
|
||||
return false;
|
||||
}
|
||||
bool const is_skip_softmax = xqaParams.skip_softmax_threshold_scale_factor != 0;
|
||||
if (is_skip_softmax)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -201,6 +212,11 @@ bool supportConfigMLA(XQAParams const& xqaParams, int SM, bool forConfigurePlugi
|
||||
{
|
||||
return false;
|
||||
}
|
||||
bool const is_skip_softmax = xqaParams.skip_softmax_threshold_scale_factor != 0;
|
||||
if (is_skip_softmax)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -66,6 +66,7 @@ extern "C"
|
||||
|
||||
bool is_spec_dec_tree
|
||||
= true; // useful only when multi_query_tokens, should be true unless using linear tree in spec-dec.
|
||||
bool use_skip_softmax_attn;
|
||||
} tllmXqaJitContext;
|
||||
|
||||
// tllmXqaJitProgram is an opaque handle for a program.
|
||||
|
||||
@ -215,6 +215,10 @@ tllmXqaJitStatus getMacroFlags(tllmXqaJitContext const* context, std::vector<std
|
||||
macros["USE_INPUT_KV"] = context->use_input_kv ? "1" : "0";
|
||||
macros["ROPE_STYLE"] = std::to_string(int(context->rope_style));
|
||||
macros["IS_SPEC_DEC_TREE"] = context->is_spec_dec_tree ? "1" : "0";
|
||||
macros["SKIP_SOFTMAX_ATTN"] = context->use_skip_softmax_attn ? "1" : "0";
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
macros["SKIP_SOFTMAX_ATTN_BLOCK_STATS"] = context->use_skip_softmax_attn ? "1" : "0";
|
||||
#endif
|
||||
|
||||
// Without these macros, NVRTC uses precompiled headers for cuda_fp16.h etc.
|
||||
// Linking might fail due to ABI incompatibility.
|
||||
|
||||
@ -493,6 +493,10 @@ bool DecoderXQAImplPrecompiled::shouldUse(XQAParams const& xqaParams, bool forCo
|
||||
{
|
||||
SUPPORT_RETURN_FALSE("streaming-llm");
|
||||
}
|
||||
if (xqaParams.skip_softmax_threshold_scale_factor != 0)
|
||||
{
|
||||
SUPPORT_RETURN_FALSE("skip_softmax_threshold_scale_factor");
|
||||
}
|
||||
|
||||
// OPTIMIZE: For the standard generation-phase MHA, there are still extra limitations.
|
||||
// NOTE: Medusa mode = Multi_query_tokens > 1.
|
||||
|
||||
@ -64,6 +64,21 @@ CUtensorMapSwizzle getSwizzleMode(uint32_t partBytes)
|
||||
}
|
||||
};
|
||||
|
||||
CUtensorMapDataType_enum getDataTypeFromXqaParams(XQAParams const& xqaParams)
|
||||
{
|
||||
if (xqaParams.kv_cache_data_type == DATA_TYPE_BF16)
|
||||
{
|
||||
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
}
|
||||
else if (xqaParams.kv_cache_data_type == DATA_TYPE_FP16)
|
||||
{
|
||||
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
||||
}
|
||||
TLLM_CHECK(xqaParams.kv_cache_data_type == DATA_TYPE_E4M3 || xqaParams.kv_cache_data_type == DATA_TYPE_E5M2
|
||||
|| xqaParams.kv_cache_data_type == DATA_TYPE_INT8);
|
||||
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
}
|
||||
|
||||
CUtensorMap makeTensorMapForQ(std::shared_ptr<CUDADriverWrapper> const& driver, void const* addr,
|
||||
CUtensorMapDataType_enum dataType, uint32_t headElems, uint32_t totalNbHeads, uint32_t partElems, uint32_t boxHeads)
|
||||
{
|
||||
@ -131,24 +146,26 @@ CUtensorMap makeTensorMapForHopperXqaKVCache(
|
||||
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
|
||||
{
|
||||
uint32_t const headElems = xqaParams.head_size;
|
||||
uint32_t const elemBytes = getElemBytes(CU_TENSOR_MAP_DATA_TYPE_UINT8);
|
||||
CUtensorMapDataType_enum const dataType = getDataTypeFromXqaParams(xqaParams);
|
||||
uint32_t const elemBytes = getElemBytes(dataType);
|
||||
TLLM_CHECK(headElems <= 256);
|
||||
uint32_t const paddedHeadElems = headElems <= 64 ? 64 : (headElems <= 128 ? 128 : 256);
|
||||
uint32_t const partElems = std::min(elemBytes * paddedHeadElems, 128U) / elemBytes;
|
||||
return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
xqaParams.head_size, xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems);
|
||||
return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, dataType, xqaParams.head_size,
|
||||
xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(std::is_same_v<KVCacheBuffer, KVLinearBuffer>);
|
||||
uint32_t const headElems = xqaParams.head_size;
|
||||
uint32_t const elemBytes = getElemBytes(CU_TENSOR_MAP_DATA_TYPE_UINT8);
|
||||
CUtensorMapDataType_enum const dataType = getDataTypeFromXqaParams(xqaParams);
|
||||
uint32_t const elemBytes = getElemBytes(dataType);
|
||||
TLLM_CHECK(headElems <= 256);
|
||||
uint32_t const paddedHeadElems = headElems <= 64 ? 64 : (headElems <= 128 ? 128 : 256);
|
||||
uint32_t const partElems = std::min(elemBytes * paddedHeadElems, 128U) / elemBytes;
|
||||
return makeTensorMapForContiguousKVCache(driver, kv_cache_buffer.data, CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
xqaParams.head_size, xqaParams.num_kv_heads, xqaParams.max_attention_window_size, xqaParams.beam_width,
|
||||
xqaParams.batch_size, partElems);
|
||||
return makeTensorMapForContiguousKVCache(driver, kv_cache_buffer.data, dataType, xqaParams.head_size,
|
||||
xqaParams.num_kv_heads, xqaParams.max_attention_window_size, xqaParams.beam_width, xqaParams.batch_size,
|
||||
partElems);
|
||||
}
|
||||
}
|
||||
|
||||
@ -161,11 +178,12 @@ template <typename KVCacheBuffer>
|
||||
CUtensorMap makeTensorMapForXqaMlaKVCache(std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> const& driver,
|
||||
XQAParams const& xqaParams, KVCacheBuffer const& kv_cache_buffer, bool forK)
|
||||
{
|
||||
CUtensorMapDataType_enum const dataType = getDataTypeFromXqaParams(xqaParams);
|
||||
uint32_t const partElems = (forK ? 64 : 128);
|
||||
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
|
||||
{
|
||||
return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
xqaParams.head_size, xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems);
|
||||
return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, dataType, xqaParams.head_size,
|
||||
xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -183,7 +201,7 @@ CUtensorMap makeTensorMapForXqaMlaQ(
|
||||
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> const& driver, XQAParams const& xqaParams, void const* q)
|
||||
{
|
||||
uint32_t const partElems = 64;
|
||||
return makeTensorMapForQ(driver, q, CU_TENSOR_MAP_DATA_TYPE_UINT8, xqaParams.head_size,
|
||||
return makeTensorMapForQ(driver, q, getDataTypeFromXqaParams(xqaParams), xqaParams.head_size,
|
||||
xqaParams.num_q_heads * xqaParams.total_num_input_tokens, partElems, xqaParams.num_q_heads);
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
@ -119,7 +119,12 @@ struct XQAParams
|
||||
bool use_sparse_attention = false;
|
||||
|
||||
// Skip softmax threshold.
|
||||
float skip_softmax_threshold_scale_factor = 0.0f;
|
||||
float skip_softmax_threshold_scale_factor = 0;
|
||||
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
uint32_t* skip_softmax_total_blocks = nullptr;
|
||||
uint32_t* skip_softmax_skipped_blocks = nullptr;
|
||||
#endif
|
||||
|
||||
cudaStream_t stream = 0;
|
||||
// layer index
|
||||
@ -199,6 +204,10 @@ struct XQAParams
|
||||
<< "sparse_params: " << sparse_params.toString() << std::endl
|
||||
<< "use_sparse_attention :" << (use_sparse_attention ? "true" : "false") << std ::endl
|
||||
<< "skip_softmax_threshold_scale_factor :" << skip_softmax_threshold_scale_factor << std ::endl
|
||||
#ifdef SKIP_SOFTMAX_STAT
|
||||
<< "skip_softmax_total_blocks :" << skip_softmax_total_blocks << std ::endl
|
||||
<< "skip_softmax_skipped_blocks :" << skip_softmax_skipped_blocks << std ::endl
|
||||
#endif
|
||||
<< "stream :" << stream;
|
||||
|
||||
return ss.str();
|
||||
|
||||
Loading…
Reference in New Issue
Block a user