/* * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #include "barriers.cuh" #include "cuda_hint.cuh" #include "defines.h" #include "utils.cuh" #include "utils.h" #if SPEC_DEC #define Q_HEADS_PER_CTA 64 #include "specDec.h" #endif #define SWAP_AB (!SPEC_DEC) #define IS_SUPPORTED_F16_CASE (CACHE_ELEM_ENUM == 0 && !SPEC_DEC && SWAP_AB && !USE_INPUT_KV && !LOW_PREC_OUTPUT) inline constexpr bool swapAB = SWAP_AB; #ifndef GENERATE_CUBIN #include "hostUtils.h" #include #endif #include "gmma.cuh" #include "mha.h" #include "mhaUtils.cuh" #include "mha_stdheaders.cuh" #include "tma.h" #define DBG_PRINT 0 #pragma region Config static_assert( (inputElemSize == cacheElemSize && mha::is_same_v) || inputElemSize > cacheElemSize); using MathElem = mha::conditional_t<(inputElemSize > cacheElemSize && mha::is_same_v), InputElem, CacheElem>; constexpr uint32_t gmmaWarpsPerGrp = 4; constexpr uint32_t gmmaWarpGrpSize = warp_size * gmmaWarpsPerGrp; constexpr uint32_t gemm0NbGmmaGrps = 1; constexpr uint32_t gemm0NbThrds = gmmaWarpGrpSize * gemm0NbGmmaGrps; constexpr uint32_t gemm0NbWarps = gmmaWarpsPerGrp * gemm0NbGmmaGrps; #if SPEC_DEC inline constexpr uint32_t ctaNbQHeads = Q_HEADS_PER_CTA; inline constexpr uint32_t inputTokensPerCta = exactDiv(ctaNbQHeads, headGrpSize); constexpr uint32_t ctaNbValidQHeads = ctaNbQHeads; #else inline constexpr uint32_t ctaNbValidQHeads = headGrpSize * beamWidth; inline constexpr uint32_t ctaNbQHeads = roundUp(ctaNbValidQHeads, swapAB ? 8U : 64U); inline constexpr uint32_t inputTokensPerCta = 1; #endif constexpr uint32_t gemm0WarpGrpTileNbTokens = 64; inline constexpr uint32_t gemm0CtaTileNbTokens = gemm0WarpGrpTileNbTokens * gemm0NbGmmaGrps; constexpr uint32_t gemm1NbGmmaGrps = 1; constexpr uint32_t gemm1NbThrds = gmmaWarpGrpSize * gemm1NbGmmaGrps; constexpr uint32_t gemm1NbWarps = gmmaWarpsPerGrp * gemm1NbGmmaGrps; constexpr uint32_t gemm1CtaTileNbTokens = gemm0CtaTileNbTokens; constexpr uint32_t mathHeadBytes = sizeof(Vec); constexpr uint32_t nbIOWarps = 4; constexpr uint32_t nbIOThrds = warp_size * nbIOWarps; constexpr uint32_t multiBlockMinNbTilesPerCta = 1; // 3; // @fixme: need tuning constexpr uint32_t multiBlockMinNbTiles = multiBlockMinNbTilesPerCta * 2; constexpr uint32_t nbWarps = gemm0NbWarps + gemm1NbWarps + nbIOWarps; constexpr uint32_t cacheHeadPartBytes = mha::min(paddedCacheHeadBytes, 128U); constexpr uint32_t cacheHeadNbParts = exactDiv(paddedCacheHeadBytes, cacheHeadPartBytes); // @fixme: support divUp in the future constexpr uint32_t cacheHeadPartElems = exactDiv(headElems, cacheHeadNbParts); constexpr uint32_t swizzleBytes = cacheHeadPartBytes; static_assert(swizzleBytes == 128 || swizzleBytes == 64 || swizzleBytes == 32); constexpr bool needInputCvt = inputElemSize > cacheElemSize&& mha::is_same_v; constexpr bool needCacheCvt = inputElemSize > cacheElemSize&& mha::is_same_v; static_assert(needInputCvt || needCacheCvt || mha::is_same_v); using ShmQWiseVec = Vec; constexpr uint32_t qPartBytes = mha::min(mathHeadBytes, 128U); constexpr uint32_t nbQParts = exactDiv(mathHeadBytes, qPartBytes); constexpr uint32_t grainsPerQPart = exactDiv(qPartBytes, grainBytes); constexpr uint32_t xPartBytes = mha::min(cacheElemSize * gemm0CtaTileNbTokens, 128U); constexpr uint32_t nbXParts = exactDiv(cacheElemSize * gemm0CtaTileNbTokens, xPartBytes); constexpr uint32_t grainsPerXPart = exactDiv(xPartBytes, grainBytes); constexpr uint32_t cacheElemsPerGrain = exactDiv(grainBytes, cacheElemSize); constexpr uint32_t grainsPerIOHead = exactDiv(ioHeadBytes, grainBytes); constexpr uint32_t grainsPerPaddedInputHead = exactDiv(paddedInputHeadBytes, grainBytes); #if USE_BEAM_SEARCH constexpr uint32_t beamSearchGemm0CtaTileNbTokens = exactDiv(gemm0CtaTileNbTokens, beamWidth); #endif using PaddedOutHead = PaddedInputHead; #pragma endregion Config struct alignas(128) SharedMem { using KBuffer = Array2D; static constexpr uint32_t nbKBuf = 2; KBuffer k[nbKBuf]; // as is loaded from global mem. using XBuffer = Vec, nbXParts>; static constexpr uint32_t nbXBuf = 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens)); using VBuffer = Vec, cacheHeadNbParts>; #if !SWAP_AB using VTBuffer = Array2D; #endif static constexpr uint32_t nbVBuf = 2; #if CACHE_ELEM_ENUM == 0 using OutSwizzleBuf = Array2D; #elif CACHE_ELEM_ENUM == 2 using OutSwizzleBuf = Array2D, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>; #endif static_assert(nbXBuf == nbVBuf); union ReusedXVOutSwizzleBuf { struct XV { XBuffer x; VBuffer v; #if !SWAP_AB VTBuffer vt; #endif // @fixme: also put xColMax and xColSum here } xv; OutSwizzleBuf outSwizzle; } reusedXVOutSwizzleBuf[nbXBuf]; static_assert(sizeof(OutSwizzleBuf) <= sizeof(SharedMem::ReusedXVOutSwizzleBuf::XV), "need to use split output to avoid excessive shared memory usage"); __device__ inline XBuffer& xBuf(uint32_t i) { return reusedXVOutSwizzleBuf[i].xv.x; } __device__ inline VBuffer& vBuf(uint32_t i) { return reusedXVOutSwizzleBuf[i].xv.v; } #if !SWAP_AB __device__ inline VTBuffer& vtBuf(uint32_t i) { return reusedXVOutSwizzleBuf[i].xv.vt; } #endif __device__ inline OutSwizzleBuf& outSwizzleBuf(uint32_t i) { return reusedXVOutSwizzleBuf[i].outSwizzle; } using QBuffer = Vec, nbQParts>; QBuffer q; // For gmma math. Conversion done if needed. // @fixme: move these into reusedXVOutSwizzleBuf #if SWAP_AB ShmQWiseVec xColMax[nbXBuf]; ShmQWiseVec xColSum[nbXBuf][gemm0NbWarps]; #else ShmQWiseVec xRowMax[nbXBuf]; ShmQWiseVec xRowSum[nbXBuf]; #endif ShmQWiseVec gemm0CurrentSeqMax; // col sum and max for the current gemm1 acc. Use shared memory to save some registers. register storage will be 8x // duplicate for swapAB and 4x duplicate for non-swapAB. ShmQWiseVec gemm1AccColMax; ShmQWiseVec gemm1AccColSum; #if USE_PAGED_KV_CACHE static constexpr uint32_t nbPagesPerTile = gemm0CtaTileNbTokens >= tokensPerPage ? exactDiv(gemm0CtaTileNbTokens, tokensPerPage) : 1; Vec pages[2]; // one for K and one for V #endif // mem barriers CtaBarrierPair qBar; CtaBarrierPair kBar[nbKBuf]; CtaBarrierPair vBar[nbVBuf]; #if !SWAP_AB CtaBarrierPair vtBar[nbVBuf]; #endif CtaBarrierPair xBar[nbXBuf]; // used internally in the gemm0 warp group // @fixme: use separate arrive and wait for all usage CtaBarrier gemm0WarpGrpBar; // used internally in the gemm1 warp group // @fixme: use separate arrive and wait for all usage CtaBarrier gemm1WarpGrpBar; bool isLastCta; }; CUBIN_EXPORT __device__ constexpr uint32_t smemSize = sizeof(SharedMem); #ifdef __CUDA_ARCH__ static_assert(smemSize < kMAX_SMEM_SIZE); #endif constexpr uint32_t nbQLdWarps = needInputCvt ? nbIOWarps - 2 : 1; constexpr uint32_t nbQLdThrds = warp_size * nbQLdWarps; #if CACHE_ELEM_ENUM == 0 || CACHE_ELEM_ENUM == 2 template struct F16QToF8Converter { static_assert(inputElemSize == 2); using F16Vec = Vec; #if CACHE_ELEM_ENUM == 0 using ShmVec = F16Vec; #elif CACHE_ELEM_ENUM == 2 using F8Vec = Vec; using ShmVec = F8Vec; #endif static constexpr uint32_t grainsPerPaddedInputHead = exactDiv(paddedInputHeadBytes, grainBytes); static constexpr uint32_t grainsPerPaddedInputQHeadGrp = grainsPerPaddedInputHead * headGrpSize; #if !(SPEC_DEC) static constexpr uint32_t totalGrains = grainsPerPaddedInputQHeadGrp * beamWidth; #else static_assert(beamWidth == 1); static constexpr uint32_t totalGrains = grainsPerPaddedInputQHeadGrp * inputTokensPerCta; #endif static constexpr uint32_t nbIters = divUp(totalGrains, nbThrds); using RegData = Vec; static __device__ RegData load(uint32_t tid, TinyPtr const& src, uint32_t const nbKHeads /*for beam search and spec dec*/, uint32_t nbTokens); static __device__ void store(uint32_t tid, SharedMem::QBuffer& dst, RegData const& data); }; #endif // CACHE_ELEM_ENUM struct KVTilePartLoader { static constexpr uint32_t nbParts = cacheHeadNbParts; static constexpr uint32_t partElems = exactDiv(headElems, nbParts); #if USE_PAGED_KV_CACHE static_assert(gemm0CtaTileNbTokens % tokensPerPage == 0 || tokensPerPage % gemm0CtaTileNbTokens == 0); static constexpr uint32_t nbPagesPerTile = SharedMem::nbPagesPerTile; #endif uint32_t const nbKHeads; KVCacheList const& cacheList; uint32_t const idxReq; uint32_t const idxHeadGrp; CUtensorMap const& tensorMap; #if USE_PAGED_KV_CACHE uint32_t const nbPages; // for bound check Vec& pages; uint32_t idxTileRef; // idxTile used to load the pages #endif uint32_t const baseOffset; __device__ KVTilePartLoader(bool isK, uint32_t nbKHeads, KVCacheList const& cacheList, uint32_t idxReq, uint32_t idxHeadGrp, CUtensorMap const& tensorMap #if USE_PAGED_KV_CACHE , uint32_t nbPages, Vec& pageBuf #endif ); // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache template __device__ void loadData( Array2D& dst, uint32_t idxTile, uint32_t idxPart, CtaBarrier& bar); __device__ void loadPages(uint32_t idxTile); __device__ GMemKVCacheHead& getHead(uint32_t pos); }; using GmmaAccCoreMat = Array2D; template using GmmaAcc = Array2D; inline constexpr uint32_t gemm0M = (swapAB ? gemm0CtaTileNbTokens : ctaNbQHeads); inline constexpr uint32_t gemm0N = (swapAB ? ctaNbQHeads : gemm0CtaTileNbTokens); using Gemm0Acc = GmmaAcc; #if SWAP_AB using RegColWiseVec = Vec, Gemm0Acc::cols>; using UniformNeedRescaleMask = Vec; using RegSeqWiseVec = RegColWiseVec; #else using RegRowWiseVec = Vec, Gemm0Acc::rows>; using UniformNeedRescaleMask = Vec; using RegSeqWiseVec = RegRowWiseVec; #endif #if SPEC_DEC __device__ inline uint32_t getInputSeqLen(SpecDecParams const& params, uint32_t idxReq) { return (params.qCuSeqLens == nullptr) ? params.qSeqLen : params.qCuSeqLens[idxReq + 1] - params.qCuSeqLens[idxReq]; } __device__ inline uint32_t getInputTokOffset(SpecDecParams const& params, uint32_t idxReq) { return (params.qCuSeqLens == nullptr) ? params.qSeqLen * idxReq : params.qCuSeqLens[idxReq]; } static_assert(!allowSlidingWindow, "SpecDec is not supported for sliding window"); struct SpecDec { static inline constexpr uint32_t tileSize = gemm0CtaTileNbTokens; static inline constexpr uint32_t ctaMaxQSeqLen = exactDiv(ctaNbQHeads, headGrpSize); using TileMaskRow = Vec; __device__ inline SpecDec(SpecDecParams const& params, uint32_t idxReq, uint32_t ctaIdxY, uint32_t seqLen) : params(params) , ctaIdxY(ctaIdxY) , seqLen(seqLen) { inputSeqLen = getInputSeqLen(params, idxReq); baseOffset = divUp(params.qSeqLen, 32U) * (getInputTokOffset(params, idxReq) + ctaMaxQSeqLen * ctaIdxY); } __device__ inline uint32_t unmaskedSeqLen() const { return seqLen - inputSeqLen; } __device__ inline bool needMask(uint32_t idxTile, uint32_t idxQTokInCta) const { return tileSize * (idxTile + 1) > unmaskedSeqLen() && ctaMaxQSeqLen * ctaIdxY + idxQTokInCta < inputSeqLen && params.mask != nullptr; } __device__ inline int32_t maskColBeg(uint32_t idxTile) const { int32_t const convergedSeqLen = int32_t(unmaskedSeqLen()); return static_cast(exactDiv(tileSize, 32) * idxTile) - static_cast(divUp(convergedSeqLen, 32)); } __device__ inline TileMaskRow loadTileMaskRow(uint32_t idxTile, uint32_t idxQTokInCta) const { assert(needMask(idxTile, idxQTokInCta)); constexpr uint32_t nbOrigElems = TileMaskRow::size + 1; Vec orig; int32_t const cols = divUp(params.qSeqLen, 32); uint32_t const rowOffset = baseOffset + idxQTokInCta * cols; int32_t const colBeg = maskColBeg(idxTile); #pragma unroll for (int32_t i = 0; i < int32_t(nbOrigElems); i++) { int32_t const idx = colBeg + i; orig[i] = inRange(idx, 0, cols) ? params.mask[rowOffset + idx] : (idx < 0 ? ~0U : 0U); } TileMaskRow mask; uint32_t const shift = (32 - unmaskedSeqLen() % 32) % 32; #pragma unroll for (uint32_t i = 0; i < TileMaskRow::size; i++) { asm("shf.r.clamp.b32 %0, %1, %2, %3;\n" : "=r"(mask[i]) : "r"(orig[i]), "r"(orig[i + 1]), "r"(shift)); } return mask; } SpecDecParams const& params; uint32_t const ctaIdxY; uint32_t const seqLen; uint32_t inputSeqLen; uint32_t baseOffset; }; __device__ void warpGrpApplyMask( Gemm0Acc& acc, SpecDec const& specDec, uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank); #endif #if SWAP_AB __device__ RegColWiseVec computeWarpGrpColMax_sync( CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src); __device__ void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd); __device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax); __device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src); __device__ void storeGemm0AccToShm( uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc); __device__ RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec); #else __device__ RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, ShmQWiseVec& smemColMax, Gemm0Acc const& src); __device__ void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd); __device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegRowWiseVec const& colMax); __device__ RegRowWiseVec computeWarpRowSum(Gemm0Acc& src); __device__ void storeGemm0AccToShm( uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc); __device__ RegRowWiseVec loadShmRowWiseVecWithDup(uint32_t warpRank, ShmQWiseVec const& smemVec); __device__ void storeShmRowWiseVec(uint32_t warpRank, ShmQWiseVec& smemVec, RegRowWiseVec const& regVec); #endif using RegMatAFrag = Array2D, 1, 2>; constexpr uint32_t gemm1NbGmmaInstK = exactDiv(gemm1CtaTileNbTokens, gmma::instK); #if SWAP_AB constexpr uint32_t gemm1NbGmmaInstM = exactDiv(headElems, gmma::instM); __device__ Vec loadVTileTransposed( uint32_t warpRank, uint32_t lane, SharedMem::VBuffer const& smemV, uint32_t idxGmmaInstK); using Gemm1Acc = GmmaAcc; __device__ void rescaleGemm1AccForNewColMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXColMax, ShmQWiseVec const (&shmXColSum)[gemm0NbWarps], ShmQWiseVec& shmAccColMax, Gemm1Acc& acc, ShmQWiseVec& shmAccColSum, CtaBarrier& gemm1WarpGrpBar); template __device__ void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, ShmQWiseVec const& accColSum); #else __device__ void transposeVTile( uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, SharedMem::VBuffer const& src); using Gemm1Acc = GmmaAcc; __device__ void rescaleGemm1AccForNewRowMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXRowMax, ShmQWiseVec const(&shmXRowSum), ShmQWiseVec& shmAccRowMax, Gemm1Acc& acc, ShmQWiseVec& shmAccRowSum); template __device__ void finalizeAndWriteOut_sync(uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, ShmQWiseVec const& accColSum, uint32_t nbKHeads /* only for final result in spec dec. set to 1 for workspace*/, uint32_t ctaNbValidTokens); #endif inline constexpr uint32_t ropeNbPairsPerThrdImpl(uint32_t nbThrds) { auto const val = divUp(exactDiv(validElemsPerHead, 2), nbThrds); assert(val <= 32); return val <= 2 ? val : (val <= 4 ? 4 : (val <= 8 ? 8 : (val <= 16 ? 16 : 32))); } template inline constexpr uint32_t ropeNbPairsPerThrd = ropeNbPairsPerThrdImpl(nbThrds); template __device__ Vec, ropeNbPairsPerThrd> loadHead( Vec const& head, uint32_t tid); template __device__ mha::conditional_t, 2>, Vec, nbPairsPerThrd>> applyRoPE(Vec, nbPairsPerThrd> const& data, Vec, nbPairsPerThrd> const& ropeCosSin); template __device__ void storeRotatedPairsForKV(GMemCacheHead& dst, mha::conditional_t>, 2>, Vec, ropeNbPairsPerThrd>> const& src, uint32_t tid); template __device__ void storeRotatedPairsForQ(SharedMem::QBuffer& dst, mha::conditional_t>, 2>, Vec, ropeNbPairsPerThrd>> const& src, uint32_t row, uint32_t tid); class ScratchMem { public: struct alignas(8) SumMax { float sum; float max; }; using ColWiseVec = Vec; HOST_DEVICE_FUNC ScratchMem(void* scratch, uint32_t maxTotalNbSubSeq, uint32_t nbInputSeqSplit) : mScratch{static_cast(scratch)} { uint32_t const nbChunks = maxTotalNbSubSeq * nbInputSeqSplit; Segmenter segmenter; constexpr uint32_t alignment = sizeof(Vec); mRowSumMax = segmenter.template newSeg(nbChunks, alignment); mTokens = segmenter.template newSeg>(nbChunks, alignment); } HOST_DEVICE_FUNC TinyPtr rowSumMax() const { return makePtr(mRowSumMax); } HOST_DEVICE_FUNC TinyPtr> tokens() const { return makePtr>(mTokens); } private: template HOST_DEVICE_FUNC TinyPtr makePtr(uint32_t offset) const { return TinyPtr{mScratch, offset}.template cast(); } private: mha::byte* mScratch; // offsets uint32_t mRowSumMax; uint32_t mTokens; }; struct MultiBlockSMem { using ColWiseVec = ScratchMem::ColWiseVec; static constexpr uint32_t nbBuf = useSpecDec ? 2 : 4; static constexpr uint32_t nbIOWarps = nbBuf; using Elem = InputElem; using Head = Vec; Vec, nbBuf> tokens; Vec rowSumMax; Vec barriers; }; #ifndef NDEBUG namespace dbg { template __device__ void printAcc( CtaBarrier& warpGrpBar, uint32_t warpRank, Array2D const& acc) { for (int m = 0; m < nbGmmaInstM; m++) { for (int w = 0; w < 4; w++) { if (warpRank == w) { for (int a = 0; a < 2; a++) { for (int b = 0; b < 8; b++) { for (int n = 0; n < nbGmmaInstNBase; n++) { for (uint32_t i = 0; i < 4; i++) { if (laneId() == b * 4 + i) { printf("%f, %f, ", acc(m, n)(a, 0), acc(m, n)(a, 1)); } __syncwarp(); } } if (laneId() == 0) { printf("\n"); } __syncwarp(); } if (laneId() == 0) { printf("\n"); } __syncwarp(); } } warpGrpBar.arrive_and_wait(); } } } __device__ void printShmColWiseVec(ShmQWiseVec const& vec) { for (uint32_t i = 0; i < vec.size; i++) { printf("%f, ", vec[i]); } printf("\n"); } template __device__ void printArray2D(Array2D const& src) { for (uint32_t i = 0; i < rows; i++) { for (uint32_t j = 0; j < cols; j++) { T const val = src.template at(i, j); for (uint32_t k = 0; k < exactDiv(sizeof(T), sizeof(Elem)); k++) { printf("%f, ", float(reinterpret_cast(&val)[k])); } } printf("\n"); } } } // namespace dbg #endif CUBIN_EXPORT __device__ constexpr XQAKernelType kernelType = XQAKernelType::kHOPPER_WARP_SPECIALIZED; CUBIN_EXPORT __global__ #ifdef NDEBUG #if !OPTIMIZE_FOR_LATENCY __launch_bounds__(128 * 3, headElems* ctaNbQHeads <= 128 * 16 ? 3 : 2) #endif #else __launch_bounds__(128 * 3, 1) #endif void kernel_mha(uint32_t const nbKHeads, #if SLIDING_WINDOW uint32_t const slidingWinSize, #endif float const qScale, OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads] #if LOW_PREC_OUTPUT float const* const rcpOutScale, #endif #if USE_INPUT_KV IOHead const* __restrict__ const qkv, // [nbReq][beamWidth][nbQHeads+nbKHeads+nbVHeads], #if ROPE_STYLE != 0 Vec const* __restrict__ const ropeCosSin, // [maxNbPosEmb] #endif #else IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads], #endif KVCacheList const cacheList, #if USE_BEAM_SEARCH BeamSearchParams const beamSearchParams, #endif uint32_t const batchSize, float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used // only for int8/fp8 KV cache. __grid_constant__ CUtensorMap const tensorMap, #if SPEC_DEC SpecDecParams const specDecParams, #endif uint32_t* __restrict__ const semaphores = nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)] void* __restrict__ const scratch = nullptr) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) \ && (IS_SUPPORTED_F16_CASE || CACHE_ELEM_ENUM == 2) && BEAM_WIDTH == 1 uint32_t const idxReq = blockIdx.z / nbKHeads; #if SPEC_DEC uint32_t const reqInputTokBeg = getInputTokOffset(specDecParams, idxReq); uint32_t const reqInputTokEnd = getInputTokOffset(specDecParams, idxReq + 1); uint32_t const nbInputSeqSplit = gridDim.x; assert(nbInputSeqSplit == divUp(specDecParams.qSeqLen, inputTokensPerCta)); #else uint32_t const reqInputTokBeg = idxReq; uint32_t const reqInputTokEnd = idxReq + 1; constexpr uint32_t nbInputSeqSplit = 1; assert(gridDim.x == nbInputSeqSplit); #endif uint32_t const idxHeadGrp = blockIdx.z % nbKHeads; // inside one request assert(gridDim.z == nbKHeads * batchSize); uint32_t const cacheSeqLen = getCacheSeqLen(cacheList, idxReq); static_assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); constexpr uint32_t tileSize = gemm0CtaTileNbTokens; static_assert(!(allowSlidingWindow && useSpecDec), "Sliding window is not yet supported in spec-dec mode"); #if SLIDING_WINDOW bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize); uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0; #else constexpr bool rtIsReallySliding = false; constexpr uint32_t nbTotalSkipTokens = 0; #endif uint32_t const nbSkipLeadingTiles = nbTotalSkipTokens / tileSize; uint32_t const tile0NbSkipTokens = nbTotalSkipTokens % tileSize; #if USE_BEAM_SEARCH uint32_t const ctxCacheSeqLen = getCtxCacheSeqLen(beamSearchParams, idxReq); uint32_t const nbCtxKTiles = useKVCache ? ctxCacheSeqLen / gemm0CtaTileNbTokens : 0; uint32_t const nbDivergentKTiles = useKVCache ? divUp(cacheSeqLen - gemm0CtaTileNbTokens * nbCtxKTiles, beamSearchGemm0CtaTileNbTokens) : 0; uint32_t const nbKTiles = nbCtxKTiles + nbDivergentKTiles; uint32_t const nbVTiles = nbKTiles; #else uint32_t const nbTiles = useKVCache ? divUp(cacheSeqLen, tileSize) : 0; // uint32_t const nbKTiles = nbTiles; // uint32_t const nbVTiles = nbTiles; uint32_t const nbTilesInUse = nbTiles - nbSkipLeadingTiles; #endif uint32_t const maxNbSubSeq = gridDim.y; uint32_t const idxSubSeq = blockIdx.y; bool const isMultiBlockMode = (maxNbSubSeq > 1 && nbTilesInUse >= multiBlockMinNbTiles); uint32_t const idxKTileInit = nbSkipLeadingTiles + idxSubSeq; uint32_t const idxVTileInit = idxKTileInit; uint32_t const nbSubSeq = isMultiBlockMode ? mha::min(nbTilesInUse / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1; static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2); assert(isMultiBlockMode == (nbSubSeq > 1)); if (idxSubSeq >= nbSubSeq) { return; } #if SPEC_DEC uint32_t const idxInputSubSeq = blockIdx.x; uint32_t const inputSeqLen = reqInputTokEnd - reqInputTokBeg; uint32_t const ctaTokOffset = inputTokensPerCta * idxInputSubSeq; uint32_t const ctaNbValidTokens = mha::min(uint32_t{inputTokensPerCta}, inputSeqLen - ctaTokOffset); if (ctaTokOffset >= inputSeqLen) { return; } #else uint32_t const idxInputSubSeq = 0; uint32_t const inputSeqLen = 1; uint32_t const ctaTokOffset = 0; uint32_t const ctaNbValidTokens = 1; #endif uint32_t const ctaInputTokBeg = reqInputTokBeg + ctaTokOffset; auto const warpIdx = getWarpIdx(uint3{128, 1, 3}); auto const wid = warpIdx.z * 4 + warpIdx.x; if (wid == 0 && warpElectSync()) { tma::prefetchTensorMap(tensorMap); } extern __shared__ char smemByteBuf[]; assert(dynamicSmemSize() >= sizeof(SharedMem)); SharedMem& smem = *reinterpret_cast(&smemByteBuf[0]); constexpr uint32_t nbBuffers = 2; static_assert(nbBuffers == SharedMem::nbKBuf && nbBuffers == SharedMem::nbVBuf && nbBuffers == SharedMem::nbXBuf); if (wid < nbBuffers) { if (warpElectSync()) { smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size); smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size); #if !SWAP_AB smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2); #endif smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds); } } else if (wid == nbBuffers) { if (warpElectSync()) { smem.qBar.initialize(gemm0NbThrds + nbQLdThrds, gemm0NbThrds + nbQLdThrds); init(&smem.gemm0WarpGrpBar, gemm0NbThrds); init(&smem.gemm1WarpGrpBar, gemm1NbThrds); } } __syncthreads(); #if USE_PAGED_KV_CACHE uint32_t const nbPages = divUp(cacheSeqLen, tokensPerPage); #endif constexpr bool isKVCacheQuantized = (cacheElemSize < 2); assert(idxKTileInit < nbTiles); uint32_t const nbIters = divUp(nbTiles - idxKTileInit, nbSubSeq); assert(nbIters >= 1); constexpr uint32_t gmmaInstK = gmma::instK; constexpr uint32_t grainsPerInstK = exactDiv(sizeof(MathElem) * gmmaInstK, grainBytes); if (warpIdx.z == 0) { #if SPEC_DEC SpecDec const specDec{specDecParams, idxReq, idxInputSubSeq, cacheSeqLen}; #endif // QK gemm constexpr uint32_t nbGmmaInstM = exactDiv(gemm0CtaTileNbTokens, gmma::instM); using Acc = GmmaAcc; unused(smem.qBar.consumed.arrive()); for (auto& b : smem.kBar) { unused(b.consumed.arrive()); } float const qkScale = qScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * rsqrtf(validElemsPerHead); // qkScale is applied onto Q*K.T before softmax. uint32_t const warpRank = warpIdx.x; // init once per sequence. It also works as global colMax across iterations. if (threadIdx.x < ctaNbQHeads) { smem.gemm0CurrentSeqMax[threadIdx.x] = safeInitRowMax; } smem.gemm0WarpGrpBar.arrive_and_wait(); smem.qBar.produced.arrive_and_wait(); #if DBG_PRINT if (threadIdx.x == 0) { printf("q:\n"); dbg::printArray2D<__nv_fp8_e4m3, true>(smem.q[0]); } #endif auto const matDescQBase = gmma::makeMatDesc( nullptr, 0, SharedMem::QBuffer::Elem::rowBytes * 8, gmma::getSwizzleMode(SharedMem::QBuffer::Elem{})) .raw(); for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { uint32_t const idxKTile = idxKTileInit + idxIter * nbSubSeq; assert(idxKTile < nbTiles); Acc acc; // no need to initialize. GMMA allows us to ignore acc initial values. gmma::fence(); static_assert(cacheHeadNbParts == nbQParts); #pragma unroll for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) { auto const idxKBuf = (idxIter * cacheHeadNbParts + idxPart) % SharedMem::nbKBuf; auto& kBuf = smem.k[idxKBuf]; auto& kBar = smem.kBar[idxKBuf]; static_assert(SharedMem::KBuffer::rows % 8 == 0); auto const matDescKBase = gmma::makeMatDesc(nullptr, 0, SharedMem::KBuffer::rowBytes * 8, &smem.k[0], gmma::getSwizzleMode(SharedMem::KBuffer{})) .raw(); assert(matDescKBase == gmma::makeMatDesc( nullptr, 0, SharedMem::KBuffer::rowBytes * 8, gmma::getSwizzleMode(SharedMem::KBuffer{})) .raw()); arrive_tx_and_wait(kBar.produced, exactDiv(sizeof(SharedMem::KBuffer), gemm0NbThrds)); // if (threadIdx.x == 0) { // printf("************* part %u *******\n", idxPart); // printf("q:\n"); // dbg::printArray2D<__nv_fp8_e4m3, true>(smem.q[idxPart]); // printf("k:\n"); // dbg::printArray2D<__nv_fp8_e4m3, true>(kBuf); // } constexpr uint32_t nbGmmaInstK = exactDiv(cacheHeadPartElems, gmmaInstK); #pragma unroll for (uint32_t k = 0; k < nbGmmaInstK; k++) { bool const accHasVal = (idxPart != 0 || k != 0); auto const matDescQ = addAddr(matDescQBase, &smem.q[idxPart](0, grainsPerInstK * k)); #pragma unroll for (uint32_t m = 0; m < nbGmmaInstM; m++) { auto const matDescK = addAddr(matDescKBase, &kBuf(64 * m, grainsPerInstK * k)); #if SWAP_AB gmma::mma_async_shmA( reinterpret_cast(acc(m, 0)), matDescK, matDescQ, accHasVal); #else gmma::mma_async_shmA( reinterpret_cast(acc(m, 0)), matDescQ, matDescK, accHasVal); #endif } } gmma::commit_group(); //@fixme: use two sets of acc and let gmma_async overlap with softmax. But this will let tile0_softmax // wait for // k loading of tile1 and may harm perf for short-seq cases. gmma::wait_group<0>(); unused(kBar.consumed.arrive()); } #if !defined(NDEBUG) && DBG_PRINT dbg::printAcc(smem.gemm0WarpGrpBar, warpRank, acc); #endif // apply qkScale acc = acc * qkScale; // apply mask #if SPEC_DEC warpGrpApplyMask(acc, specDec, cacheSeqLen, idxKTile, warpRank); #else bool const isFirstTile = (idxKTile == nbSkipLeadingTiles); bool const needMaskLeading = (rtIsReallySliding && isFirstTile && tile0NbSkipTokens > 0); bool const isLastTile = (idxKTile + 1 == nbTiles); bool const needMaskTrailing = isLastTile && cacheSeqLen % tileSize != 0; if (needMaskLeading || needMaskTrailing) { uint32_t const validTokenBeg = needMaskLeading ? tile0NbSkipTokens : 0; uint32_t const validTokenEnd = (needMaskTrailing ? cacheSeqLen % tileSize : tileSize); if (validTokenBeg > 0 || validTokenEnd < tileSize) { #if SWAP_AB warpGrpApplyMask(warpRank, acc, validTokenBeg, validTokenEnd); #else warpGrpApplyMask(acc, validTokenBeg, validTokenEnd); #endif } } #endif // update colMax in shared mem and get a register copy #if SWAP_AB RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc); warpGrpOnlineSoftmax(acc, colMax); #else RegRowWiseVec const rowMax = computeWarpGrpRowMax_sync(warpRank, smem.gemm0CurrentSeqMax, acc); warpGrpOnlineSoftmax(acc, rowMax); #endif // @fixme: may need fp32->fp8->fp32 before doing sum. #if SWAP_AB RegColWiseVec const warpColSum = computeWarpColSum(acc); #else RegRowWiseVec const rowSum = computeWarpRowSum(acc); #endif // map 1 to fp8_max before conversion to fp8 acc = acc * kE4M3_MAX; uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf; auto& xBar = smem.xBar[idxXBuf]; // @fixme: for fp16/bf16, try not to transpose acc here, and leave it to the next GEMM. #if SWAP_AB storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc); // store colMax and warpColSum auto const lane = laneId(); if (lane < 4) { auto& xColMax = smem.xColMax[idxXBuf]; auto& xColSum = smem.xColSum[idxXBuf][warpRank]; #pragma unroll for (uint32_t n = 0; n < colMax.size; n++) { #pragma unroll for (uint32_t j = 0; j < 2; j++) { if (warpRank == 0) { xColMax[8 * n + 2 * lane + j] = colMax[n][j]; } xColSum[8 * n + 2 * lane + j] = warpColSum[n][j]; } } } #else storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc); storeShmRowWiseVec(warpRank, smem.xRowMax[idxXBuf], rowMax); storeShmRowWiseVec(warpRank, smem.xRowSum[idxXBuf], rowSum); #endif __syncwarp(); // the release semantics of arrive does not work for async consumers like gmma/utcmma. additional fence is // needed. asm volatile("fence.proxy.async.shared::cta;\n"); unused(xBar.produced.arrive()); } unused(smem.qBar.consumed.arrive()); } else if (warpIdx.z == 1) { // XV GEMM for (auto& b : smem.vBar) { unused(b.consumed.arrive()); } #if !SWAP_AB for (auto& b : smem.vtBar) { unused(b.consumed.arrive()); } #endif for (auto& b : smem.xBar) { unused(b.consumed.arrive()); } if (threadIdx.x < smem.gemm1AccColMax.size) { auto const idx = threadIdx.x; smem.gemm1AccColMax[idx] = mha::numeric_limits::lowest(); smem.gemm1AccColSum[idx] = 0; } smem.gemm1WarpGrpBar.arrive_and_wait(); uint32_t const warpRank = warpIdx.x; constexpr float xScale = 1.f / kE4M3_MAX; #if LOW_PREC_OUTPUT float const oScale = rcpOutScale[0]; #else constexpr float oScale = 1.F; #endif float const xvoScale = xScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * oScale; Gemm1Acc acc{}; // init to zeros to avoid runtime checking for first gmma instruction. gmma::fence(); static_assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens, "not implemented"); for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { uint32_t idxVTile = idxVTileInit + idxIter * nbSubSeq; auto const idxVBuf = idxIter % SharedMem::nbVBuf; auto const idxXBuf = idxVBuf; auto& vBar = smem.vBar[idxVBuf]; arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds)); auto const& vBuf = smem.vBuf(idxVBuf); #if !SWAP_AB CtaBarrierPair& vtBar = smem.vtBar[idxVBuf]; auto& vtBuf = smem.vtBuf(idxVBuf); vtBar.consumed.arrive_and_wait(); transposeVTile(warpRank, laneId(), vtBuf, vBuf); vBar.consumed.arrive(); vtBar.produced.arrive(); #endif auto& xBar = smem.xBar[idxXBuf]; xBar.produced.arrive_and_wait(); #if !defined(NDEBUG) && DBG_PRINT #if SWAP_AB if (threadIdx.x == 0) { printf("colMax:\n"); for (int i = 0; i < ctaNbQHeads; i++) { printf("%f, ", smem.xColMax[idxXBuf][i]); } printf("\n"); printf("colSum:\n"); for (int n = 0; n < 4; n++) { for (int i = 0; i < ctaNbQHeads; i++) { printf("%f, ", smem.xColSum[idxXBuf][n][i]); } printf("\n"); } printf("\n"); printf("X:\n"); for (int i = 0; i < ctaNbQHeads; i++) { for (int j = 0; j < gemm0CtaTileNbTokens; j++) { auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart); auto const e = reinterpret_cast&>( smem.xBuf(idxXBuf)[j / elemsPerXPart].template at( i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain]; printf("%.2f, ", float(e)); if (j % 16 == 15) { printf("| "); } } printf("\n\n"); } } smem.gemm1WarpGrpBar.arrive_and_wait(); #endif #endif #if SWAP_AB // @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead. rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf], smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar); #else rescaleGemm1AccForNewRowMax_sync( warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf], smem.gemm1AccColMax, acc, smem.gemm1AccColSum); #endif auto& xBuf = smem.xBuf(idxXBuf); auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8, gmma::getSwizzleMode(SharedMem::XBuffer::Elem{})) .raw(); #if CACHE_ELEM_ENUM == 0 auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8, gmma::getSwizzleMode(SharedMem::VBuffer::Elem{})) .raw(); #endif #if SWAP_AB //@fixme: to reduce code size, we can disable unroll and use double-buffer for LDSM in loadVTileTransposed. #pragma unroll for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++) { #if CACHE_ELEM_ENUM == 2 Vec const fragA = loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK); #if !defined(NDEBUG) && DBG_PRINT if (threadIdx.x == 0) { printf("fragA:\nidxInstK == %u\n", idxInstK); } smem.gemm1WarpGrpBar.arrive_and_wait(); for (int m = 0; m < 2; m++) { for (int w = 0; w < 4; w++) { if (warpRank == w) { if (laneId() == 0) { printf(" warpRank = %u\n", warpRank); } __syncwarp(); for (int a = 0; a < 2; a++) { for (int b = 0; b < 8; b++) { for (int c = 0; c < 2; c++) { for (int d = 0; d < 4; d++) { if (laneId() == b * 4 + d) { for (int e = 0; e < 4; e++) { auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>( fragA[m](0, c)(a, 0)); printf("%.2f, ", float(elem4[e])); } } __syncwarp(); } } if (laneId() == 0) { printf("\n"); } __syncwarp(); } if (laneId() == 0 && a == 0) { printf("----------------------\n"); } __syncwarp(); } } smem.gemm1WarpGrpBar.arrive_and_wait(); } } #endif #endif BoundedVal const kOffsetInGrains{grainsPerInstK * idxInstK}; auto const descX = addAddr(descXBase, &xBuf[kOffsetInGrains.template divBy().get()]( 0, kOffsetInGrains.template mod().get())); #if CACHE_ELEM_ENUM == 2 gmma::fence(); #endif #pragma unroll for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++) { #if CACHE_ELEM_ENUM == 0 auto const descV = addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0)); gmma::mma_async_shmA( reinterpret_cast(acc(idxInstM, 0)), descV, descX, true); #elif CACHE_ELEM_ENUM == 2 gmma::mma_async_regA( reinterpret_cast(acc(idxInstM, 0)), reinterpret_cast(fragA[idxInstM]), descX, true); #endif } gmma::commit_group(); //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of // gmma. gmma::wait_group<0>(); } #else auto const descVTBase = gmma::makeMatDesc( nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode(SharedMem::VTBuffer{})) .raw(); vtBar.produced.arrive_and_wait(); // if (idxIter == 1 && threadIdx.x == 0) { // printf("vtBuf:\n"); // dbg::printArray2D<__nv_fp8_e4m3, true>(vtBuf); // } #pragma unroll for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { #pragma unroll for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++) { BoundedVal const kOffsetInGrains{grainsPerInstK * k}; auto const descX = addAddr(descXBase, &xBuf[kOffsetInGrains.template divBy().get()]( gmma::instM * m, kOffsetInGrains.template mod().get())); auto const descVT = addAddr( descVTBase, &vtBuf(0, kOffsetInGrains.template mod().get())); gmma::mma_async_shmA( reinterpret_cast(acc(m, 0)), descX, descVT, true); } } gmma::commit_group(); //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of gmma. gmma::wait_group<0>(); #endif if (idxIter == nbIters - 1) { // gmma::wait_group should have already synchronized threads, so this may be unnecessary. smem.gemm1WarpGrpBar.arrive_and_wait(); assert(idxXBuf == idxVBuf); if (isMultiBlockMode) { ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; uint32_t const idxAllSubSeq = maxNbSubSeq * idxSeq + idxSubSeq; uint32_t const idxChunk = idxAllSubSeq * nbInputSeqSplit + idxInputSubSeq; // save row max/sum static_assert(ctaNbValidQHeads <= gmmaWarpsPerGrp * warp_size); if (threadIdx.x < ctaNbValidQHeads) { float const colMax = smem.gemm1AccColMax[threadIdx.x]; float const colSum = smem.gemm1AccColSum[threadIdx.x]; ScratchMem::SumMax sumMax; sumMax.sum = colSum; sumMax.max = colMax; (scratchMem.rowSumMax() + idxChunk).template cast()[threadIdx.x] = sumMax; } // compute scratch ptr for output writing IOHead* const dst = (scratchMem.tokens() + idxChunk).template cast(); #if SWAP_AB finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum); #else finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, smem.gemm1AccColSum, 1, ctaNbValidTokens); #endif } else { uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); OutputHead* const dst = &output[outOffset]; #if SWAP_AB finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum); #else finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, smem.gemm1AccColSum, nbKHeads, ctaNbValidTokens); #endif } } unused(xBar.consumed.arrive()); #if SWAP_AB unused(vBar.consumed.arrive()); #else unused(vtBar.consumed.arrive()); #endif } } else { // IO warps static_assert(beamWidth == 1); #if ENABLE_FDL preExit(); #endif #if ENABLE_FDL == 1 acqBulk(); #endif assert(warpIdx.z == 2); uint32_t const newTokenPos = cacheSeqLen - 1; if (warpIdx.x < nbQLdWarps) { // load Q. Use register to load fp16 data and store fp8 to shared mem. // @fixme: If register pressure is high and shared mem pressure is low, switch to TMA instead. using QCvt = F16QToF8Converter; static_assert(beamWidth == 1); #if USE_INPUT_KV TinyPtr const qData{qkv, headGrpSize * idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq}; constexpr bool isNeox = (ROPE_STYLE == 1); constexpr uint32_t thrdsPerHead = mha::min(warp_size, divUp(headElems, 4U)); uint32_t const lane = laneId(); uint32_t const idxThrd = warpIdx.x * warp_size + lane; uint32_t const idxThrdGrp = (thrdsPerHead % 32 == 0 ? makeWarpUniform(this_warp(), idxThrd / thrdsPerHead) : idxThrd / thrdsPerHead); constexpr uint32_t nbThrdGrps = exactDiv(warp_size * nbQLdWarps, thrdsPerHead); uint32_t const tid = idxThrd % thrdsPerHead; smem.qBar.consumed.arrive_and_wait(); #if ROPE_STYLE != 0 auto const& ropeCosSinHead = reinterpret_cast const&>(ropeCosSin[cacheSeqLen - 1]); auto const cosSinPairs = loadHead(ropeCosSinHead, tid); #endif #if ENABLE_FDL == 2 acqBulk(); #endif #pragma unroll for (uint32_t iter = 0; iter < divUp(headGrpSize, nbThrdGrps); iter++) { uint32_t const idxHead = nbThrdGrps * iter + idxThrdGrp; if (idxHead >= headGrpSize) { break; } #if ROPE_STYLE == 0 auto const rotatedPairs = loadHead(qData[idxHead], tid); #else auto const pairs = loadHead(qData[idxHead], tid); auto const rotatedPairs = applyRoPE(pairs, cosSinPairs); #endif storeRotatedPairsForQ(smem.q, rotatedPairs, idxHead, tid); } #else TinyPtr const qData{q, headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp)}; #if ENABLE_FDL == 2 acqBulk(); #endif auto const f16QData = QCvt::load(threadIdx.x, qData, nbKHeads, ctaNbValidTokens); smem.qBar.consumed.arrive_and_wait(); QCvt::store(threadIdx.x, smem.q, f16QData); #endif // the release semantics of arrive does not work for async consumers like gmma/utcmma. additional fence is // needed. asm volatile("fence.proxy.async.shared::cta;\n"); unused(smem.qBar.produced.arrive()); } else if (warpIdx.x == nbQLdWarps) { // load k KVTilePartLoader kTilePartLoader { true, nbKHeads, cacheList, idxReq, idxHeadGrp, tensorMap #if USE_PAGED_KV_CACHE , nbPages, smem.pages[0] #endif }; for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { uint32_t const idxKTile = idxKTileInit + idxIter * nbSubSeq; kTilePartLoader.loadPages(idxKTile); #if USE_INPUT_KV || ENABLE_FDL == 2 #if SPEC_DEC static_assert(SLIDING_WINDOW == 0); bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxKTile + 1) > cacheSeqLen - inputSeqLen); #else bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxKTile + 1) >= cacheSeqLen); #endif if (anyNewTokens) { #if ENABLE_FDL == 2 acqBulk(); #endif #if USE_INPUT_KV static_assert(beamWidth == 1); uint32_t const inputKHeadOffset = headGrpSize * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; IOHead const& inKHead = qkv[inputKHeadOffset]; uint32_t const lane = laneId(); float const rcpKScale = 1.F / kvCacheScale[0]; #if ROPE_STYLE == 0 constexpr bool isNeox = false; auto const pairs = loadHead(inKHead, lane) * rcpKScale; Vec, decltype(pairs)::size> convertedPairs; constexpr uint32_t nbElems = decltype(pairs)::Elem::size * decltype(pairs)::size; reinterpret_cast&>(convertedPairs) = convert(reinterpret_cast const&>(pairs)); storeRotatedPairsForKV( kTilePartLoader.getHead(newTokenPos), convertedPairs, lane); #else constexpr bool isNeox = (ROPE_STYLE == 1); auto const pairs = loadHead(inKHead, lane) * rcpKScale; auto const& ropeCosSinHead = reinterpret_cast const&>(ropeCosSin[cacheSeqLen - 1]); auto const cosSinPairs = loadHead(ropeCosSinHead, lane); auto const rotatedPairs = applyRoPE(pairs, cosSinPairs); storeRotatedPairsForKV(kTilePartLoader.getHead(newTokenPos), rotatedPairs, lane); #endif static_assert(inputSeqLen == 1); __syncwarp(); #endif } #endif for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) { auto const idxKBuf = (idxIter * cacheHeadNbParts + idxPart) % SharedMem::nbKBuf; auto& kBar = smem.kBar[idxKBuf]; kBar.consumed.arrive_and_wait(); if (warpElectSync()) { kTilePartLoader.loadData(smem.k[idxKBuf], idxKTile, idxPart, kBar.produced); } __syncwarp(); } } } else if (warpIdx.x == nbQLdWarps + 1) { // load v KVTilePartLoader vTileLoader { false, nbKHeads, cacheList, idxReq, idxHeadGrp, tensorMap #if USE_PAGED_KV_CACHE , nbPages, smem.pages[1] #endif }; for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { uint32_t const idxVTile = idxVTileInit + idxIter * nbSubSeq; vTileLoader.loadPages(idxVTile); #if USE_INPUT_KV || ENABLE_FDL == 2 #if SPEC_DEC static_assert(SLIDING_WINDOW == 0); bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxVTile + 1) > cacheSeqLen - inputSeqLen); #else bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxVTile + 1) >= cacheSeqLen); #endif if (anyNewTokens) { #if ENABLE_FDL == 2 acqBulk(); #endif #if USE_INPUT_KV static_assert(beamWidth == 1); uint32_t const inputVHeadOffset = (headGrpSize + 1) * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; IOHead const& inVHead = qkv[inputVHeadOffset]; uint32_t const lane = laneId(); float const rcpVScale = 1.F / kvCacheScale[0]; constexpr bool isNeox = false; auto const pairs = loadHead(inVHead, lane) * rcpVScale; Vec, decltype(pairs)::size> convertedPairs; constexpr uint32_t nbElems = decltype(pairs)::Elem::size * decltype(pairs)::size; reinterpret_cast&>(convertedPairs) = convert(reinterpret_cast const&>(pairs)); static_assert(SPEC_DEC == 0); storeRotatedPairsForKV(vTileLoader.getHead(newTokenPos), convertedPairs, lane); __syncwarp(); #endif } #endif uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf; auto& vBar = smem.vBar[idxVBuf]; vBar.consumed.arrive_and_wait(); if (warpElectSync()) { #pragma unroll for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) { vTileLoader.loadData(smem.vBuf(idxVBuf)[idxPart], idxVTile, idxPart, vBar.produced); } } __syncwarp(); } } } __syncthreads(); uint32_t const nbBarriers = &smem.gemm1WarpGrpBar - &smem.qBar.produced + 1; uint32_t const tid = threadIdx.x + blockDim.x * threadIdx.y + blockDim.x * blockDim.y * threadIdx.z; assert(nbBarriers <= blockDim.x * blockDim.y * blockDim.z); if (tid < nbBarriers) { (&smem.qBar.produced)[tid].~CtaBarrier(); } if (!isMultiBlockMode) { return; } bool& smemIsLastCta = smem.isLastCta; if (threadIdx.x == gemm1NbThrds - 1U && threadIdx.z == 0) { uint32_t const lastOld = nbSubSeq - 1; ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; uint32_t old; uint32_t const idxSemaphore = idxSeq * nbInputSeqSplit + idxInputSubSeq; auto const pSemaphore = &semaphores[idxSemaphore]; asm volatile("atom.acq_rel.gpu.global.inc.u32 %0, [%1], %2;\n" : "=r"(old) : "l"(pSemaphore), "r"(lastOld)); smemIsLastCta = (old == lastOld); } { assert(dynamicSmemSize() >= sizeof(MultiBlockSMem)); #ifndef __CUDACC_RTC__ assert(sizeof(MultiBlockSMem) < offsetof(SharedMem, isLastCta)); #endif auto& smem = *reinterpret_cast(&smemByteBuf[0]); assert(blockDim.x >= MultiBlockSMem::nbBuf); constexpr uint32_t nbMathWarps = gemm0NbWarps + gemm1NbWarps; static_assert(nbWarps >= MultiBlockSMem::nbBuf); if (wid < MultiBlockSMem::nbBuf) { if (warpElectSync()) { smem.barriers[wid].initialize(isHeadPadded ? warp_size : 1U, nbMathWarps * warp_size); smem.barriers[wid].consumed.arrive(nbMathWarps * warp_size); } } __syncthreads(); if (!smemIsLastCta) { return; } if (wid < nbMathWarps) { constexpr uint32_t headsPerWarp = divUp(ctaNbValidQHeads, nbMathWarps); using Acc = Vec; struct HeadState { Acc acc; float sum; float max; }; Vec states{}; for (auto& s : states.data) { s.max = safeInitRowMax; } uint32_t const lane = laneId(); for (uint32_t idxBlock = 0; idxBlock < nbSubSeq; idxBlock++) { uint32_t const idxBuf = idxBlock % MultiBlockSMem::nbBuf; auto& bar = smem.barriers[idxBuf]; bar.produced.wait_parity(idxBlock / MultiBlockSMem::nbBuf % 2 != 0); for (uint32_t i = 0; i < headsPerWarp; i++) { uint32_t const idxHead = wid + nbMathWarps * i; if ((ctaNbValidQHeads % nbMathWarps != 0) && (idxHead >= ctaNbValidQHeads)) { break; } HeadState& state = states[i]; auto const sumMax = smem.rowSumMax[idxBuf][idxHead]; auto const data = convert( reinterpret_cast&>(smem.tokens[idxBuf][idxHead][Acc::size * lane])); if (sumMax.max > state.max) { float const scale = expf(state.max - sumMax.max); state.max = sumMax.max; state.sum = state.sum * scale + sumMax.sum; state.acc = state.acc * scale + data * sumMax.sum; } else { float const scale = expf(sumMax.max - state.max); state.sum = state.sum + sumMax.sum * scale; state.acc = state.acc + data * (sumMax.sum * scale); } } unused(bar.consumed.arrive()); } __syncthreads(); uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); auto const dst = &output[outOffset]; for (uint32_t i = 0; i < headsPerWarp; i++) { uint32_t const idxHead = wid + nbMathWarps * i; if ((ctaNbValidQHeads % nbMathWarps != 0) && (idxHead >= ctaNbValidQHeads)) { break; } #if SPEC_DEC uint32_t const idxToken = idxHead / headGrpSize; if (idxToken >= ctaNbValidTokens) { break; } uint32_t const tokenPad = headGrpSize * (nbKHeads - 1); uint32_t const idxDstHead = idxHead + idxToken * tokenPad; #else uint32_t const idxDstHead = idxHead; #endif auto const& s = states[i]; auto const outData = convert(s.acc * (1.f / s.sum)); if (Acc::size * lane < validElemsPerHead) { reinterpret_cast&>(dst[idxDstHead][Acc::size * lane]) = outData; } } } else if (wid < nbMathWarps + MultiBlockSMem::nbIOWarps) { static_assert(MultiBlockSMem::nbIOWarps <= MultiBlockSMem::nbBuf); ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; uint32_t const initIdxBlock = wid - nbMathWarps; // each warp loads data for a block for (uint32_t idxBlock = initIdxBlock; idxBlock < nbSubSeq; idxBlock += MultiBlockSMem::nbIOWarps) { uint32_t const idxAllSubSeq = maxNbSubSeq * idxSeq + idxBlock; uint32_t const idxChunk = idxAllSubSeq * nbInputSeqSplit + idxInputSubSeq; uint32_t const idxBuf = idxBlock % MultiBlockSMem::nbBuf; auto& bar = smem.barriers[idxBuf]; bar.consumed.wait_parity(idxBlock / MultiBlockSMem::nbBuf % 2 != 0); auto const lane = laneId(); #pragma unroll for (uint32_t iter = 0; iter < divUp(ctaNbValidQHeads, warp_size); iter++) { uint32_t const i = iter * warp_size + lane; if (ctaNbValidQHeads % warp_size != 0 && i >= ctaNbValidQHeads) { break; } ldgsts::copyAsync( &smem.rowSumMax[idxBuf][i], &scratchMem.rowSumMax()[idxChunk][i]); } ldgsts::barArrive(bar.produced, false); if constexpr (isHeadPadded) { static_assert(grainsPerPaddedInputHead <= warp_size); constexpr uint32_t headsPerIter = exactDiv(warp_size, grainsPerPaddedInputHead); constexpr uint32_t nbIters = divUp(ctaNbValidQHeads, headsPerIter); constexpr uint32_t nbWholeIters = ctaNbValidQHeads / headsPerIter; #pragma unroll for (uint32_t i = 0; i < nbIters; i++) { uint32_t const idxHead = headsPerIter * i + BoundedVal{lane}.template divBy().get(); uint32_t const idxGrain = BoundedVal{lane}.template mod().get(); if (i < nbWholeIters || idxHead < ctaNbValidQHeads) { constexpr uint32_t nbElemsPerGrain = exactDiv(grainBytes, sizeof(MultiBlockSMem::Elem)); auto const dst = &smem.tokens[idxBuf][idxHead][nbElemsPerGrain * idxGrain]; auto const src = idxGrain < grainsPerIOHead ? &scratchMem.tokens()[idxChunk][idxHead][nbElemsPerGrain * idxGrain] : nullptr; ldgsts::copyAsync(dst, src, idxGrain < grainsPerIOHead ? grainBytes : 0U); } } ldgsts::barArrive(bar.produced, true); } else { if (warpElectSync()) { tma::load1DAsync(&smem.tokens[idxBuf], &scratchMem.tokens()[idxChunk], sizeof(smem.tokens[idxBuf]), bar.produced); arrive_tx(bar.produced, sizeof(smem.tokens[idxBuf]), 1); } } } __syncthreads(); uint32_t const idxBar = tid - warp_size * nbMathWarps; if (idxBar < MultiBlockSMem::nbBuf * 2) { reinterpret_cast(&smem.barriers[0])[idxBar].~CtaBarrier(); } } } #else #if GENERATE_CUBIN static_assert("This kernel is for Hopper only"); #else asm volatile("trap;\n"); #endif #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && BEAM_WIDTH == 1 } #if CACHE_ELEM_ENUM == 0 || CACHE_ELEM_ENUM == 2 template __device__ inline typename F16QToF8Converter::RegData F16QToF8Converter::load( uint32_t tid, TinyPtr const& src, uint32_t const nbKHeads /*for beam search only*/, uint32_t nbTokens) { #if !(SPEC_DEC) assert(nbTokens == 1); nbTokens = 1; #endif typename F16QToF8Converter::RegData dst; #pragma unroll for (uint32_t iter = 0; iter < nbIters; iter++) { uint32_t const idxGrain = nbThrds * iter + tid; if (idxGrain >= totalGrains) { break; } #if SPEC_DEC uint32_t const idxToken = idxGrain / grainsPerPaddedInputQHeadGrp; uint32_t const tokenPad = grainsPerPaddedInputQHeadGrp * (nbKHeads - 1); uint32_t offsetInGrains = idxGrain + tokenPad * idxToken; static_assert(beamWidth == 1); #else uint32_t const idxBeam = beamWidth == 1 ? 0 : idxGrain / grainsPerPaddedInputQHeadGrp; uint32_t const beamPad = grainsPerPaddedInputQHeadGrp * (nbKHeads - 1); uint32_t offsetInGrains = idxGrain + beamPad * idxBeam; #endif bool isGrainInBound = true; if constexpr (isHeadPadded) { uint32_t const idxGrainInsideHead = offsetInGrains % grainsPerPaddedInputHead; offsetInGrains = offsetInGrains / grainsPerPaddedInputHead * grainsPerIOHead + idxGrainInsideHead; isGrainInBound = (idxGrainInsideHead < grainsPerIOHead); } #if SPEC_DEC isGrainInBound = isGrainInBound && (idxToken < nbTokens); #endif LdGrain const srcGrain = isGrainInBound ? src.template cast()[offsetInGrains] : LdGrain{}; static_assert(inputElemSize == 2); auto const& fp16Data = reinterpret_cast const&>(srcGrain); dst[iter] = idxGrain % grainsPerPaddedInputHead < grainsPerIOHead ? fp16Data : mha::decay_t{}; } return dst; } template __device__ inline void F16QToF8Converter::store( uint32_t tid, SharedMem::QBuffer& dst, F16QToF8Converter::RegData const& data) { #pragma unroll for (uint32_t iter = 0; iter < nbIters; iter++) { uint32_t const idxGrain = nbThrds * iter + tid; if (idxGrain >= totalGrains) { break; } #if CACHE_ELEM_ENUM == 0 static_assert(inputElemSize == cacheElemSize); ShmVec const& shmData = data[iter]; uint32_t const r = idxGrain / grainsPerPaddedInputHead; BoundedVal const c = {idxGrain % grainsPerPaddedInputHead}; dst[c.template divBy().get()].template at(r, c.template mod().get()) = reinterpret_cast(shmData); #else auto const& fp16Data = data[iter]; ShmVec shmData; #pragma unroll for (uint32_t i = 0; i < fp16Data.size; i++) { shmData[i] = CacheElem{fp16Data[i]}; } uint32_t const dstIdxGrain = idxGrain / 2; uint32_t const dstIdxHalfGrain = idxGrain % 2; constexpr uint32_t grainsPerCacheHead = exactDiv(paddedCacheHeadBytes, grainBytes); uint32_t const r = dstIdxGrain / grainsPerCacheHead; BoundedVal const c = {dstIdxGrain % grainsPerCacheHead}; reinterpret_cast&>(dst[c.template divBy().get()].template at( r, c.template mod().get()))[dstIdxHalfGrain] = shmData; #endif } } #endif __device__ inline KVTilePartLoader::KVTilePartLoader(bool isK, uint32_t nbKHeads, KVCacheList const& cacheList, uint32_t idxReq, uint32_t idxHeadGrp, CUtensorMap const& tensorMap #if USE_PAGED_KV_CACHE , uint32_t nbPages, Vec& pageBuf #endif ) : nbKHeads{nbKHeads} , cacheList{cacheList} , idxReq{idxReq} , idxHeadGrp{idxHeadGrp} , tensorMap{tensorMap} #if USE_PAGED_KV_CACHE , nbPages{nbPages} , pages{pageBuf} , baseOffset{((idxReq * beamWidth) * 2 + (isK ? 0 : 1)) * cacheList.maxNbPagesPerSeq} #else , baseOffset{(idxReq * beamWidth) * 2 + (isK ? 0 : 1)} #endif { } // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache template __device__ inline void KVTilePartLoader::loadData( Array2D& dst, uint32_t idxTile, uint32_t idxPart, CtaBarrier& bar) { static_assert(nbTokens == gemm0CtaTileNbTokens); #if USE_PAGED_KV_CACHE assert(idxTile == idxTileRef); if constexpr (nbTokens < tokensPerPage) { assert(nbPagesPerTile == 1); uint32_t const offset = nbTokens * (idxTile % exactDiv(tokensPerPage, nbTokens)); tma::loadAsync(&dst, tensorMap, DimsLE<4>{partElems * idxPart, offset, idxHeadGrp, (uint32_t) pages[0]}, bar); } else { #pragma unroll for (uint32_t i = 0; i < nbPagesPerTile; i++) { tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, DimsLE<4>{partElems * idxPart, 0, idxHeadGrp, (uint32_t) pages[i]}, bar); } } #else tma::loadAsync(&dst, tensorMap, DimsLE<4>{partElems * idxPart, nbTokens * idxTile, idxHeadGrp, baseOffset}, bar); #endif } __device__ inline void KVTilePartLoader::loadPages(uint32_t idxTile) { #if USE_PAGED_KV_CACHE uint32_t const idxPageBeg = gemm0CtaTileNbTokens >= tokensPerPage ? nbPagesPerTile * idxTile : idxTile / exactDiv(tokensPerPage, gemm0CtaTileNbTokens); #pragma unroll for (uint32_t i = 0; i < nbPagesPerTile; i++) { uint32_t const idxPage = idxPageBeg + i; auto const page = idxPage < nbPages ? cacheList.kvCachePageList[baseOffset + idxPage] : kBAD_PAGE_INDEX; if (warpElectSync()) { pages[i] = page; } } idxTileRef = idxTile; __syncwarp(); #endif } __device__ inline GMemKVCacheHead& KVTilePartLoader::getHead(uint32_t pos) { constexpr uint32_t nbTokens = gemm0CtaTileNbTokens; #if USE_PAGED_KV_CACHE uint32_t const idxTile = pos / nbTokens; assert(idxTile == idxTileRef); uint32_t const offset = pos % tokensPerPage; return cacheList.pool[tokensPerPage * (nbKHeads * pages[pos % nbTokens / tokensPerPage] + idxHeadGrp) + offset]; #else // shape: KVCacheHead[batchSize][beamWidth][2][nbKHeads][capacity] return cacheList.data[cacheList.capacity * (baseOffset * nbKHeads + idxHeadGrp) + pos]; #endif } #if SWAP_AB // smemColMax is persistent across multiple iterations __device__ inline RegColWiseVec computeWarpGrpColMax_sync( CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src) { auto colMax = RegColWiseVec::filled(Vec::filled(safeInitRowMax)); #pragma unroll for (uint32_t n = 0; n < src.cols; n++) { for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { #pragma unroll for (uint32_t m = 0; m < src.rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { colMax[n][j] = (m == 0 && i == 0) ? src(m, n)(i, j) : fmax(colMax[n][j], src(m, n)(i, j)); } } } } #pragma unroll for (uint32_t xorMask = 16; xorMask > 2; xorMask /= 2) { #pragma unroll for (uint32_t n = 0; n < src.cols; n++) { #pragma unroll for (uint32_t j = 0; j < 2; j++) { auto& x = colMax[n][j]; x = fmax(x, __shfl_xor_sync(~0U, x, xorMask)); } } } uint32_t const lane = laneId(); if (lane < 4) { #pragma unroll for (uint32_t n = 0; n < src.cols; n++) { #pragma unroll for (uint32_t j = 0; j < 2; j++) { atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]); } } } warpGrpBar.arrive_and_wait(); uint32_t const idxInQuad = lane % 4; #pragma unroll for (uint32_t n = 0; n < src.cols; n++) { #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { assert(colMax[n][j] <= smemColMax[8 * n + 2 * idxInQuad + j]); colMax[n][j] = smemColMax[8 * n + 2 * idxInQuad + j]; } } warpGrpBar.arrive_and_wait(); return colMax; } __device__ inline RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec) { RegColWiseVec ret; constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); auto const idx = laneId() % nbThrdsPerInstNBase; #pragma unroll for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) { static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); ret[i] = reinterpret_cast< Vec, exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( smemVec)[i * nbThrdsPerInstNBase + idx]; } return ret; } __device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd) { uint32_t const idxInQuad = laneId() % 4; uint32_t const idxQuad = laneId() / 4; #pragma unroll for (uint32_t m = 0; m < acc.rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { uint32_t const row = 64 * m + 16 * warpRank + 8 * i + idxQuad; if (row >= validRowBeg && row < validRowEnd) { continue; } #pragma unroll for (uint32_t n = 0; n < acc.cols; n++) { #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { acc(m, n)(i, j) = mha::numeric_limits::lowest(); } } } } } __device__ inline void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax) { #pragma unroll for (uint32_t n = 0; n < acc.cols; n++) { #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { float const maxVal = colMax[n][j]; float const bias = maxVal * log2e; #pragma unroll for (uint32_t m = 0; m < acc.rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { float& elem = acc(m, n)(i, j); assert(maxVal >= elem); elem = exp2f(elem * log2e - bias); } } } } } __device__ inline RegColWiseVec computeWarpColSum(Gemm0Acc& src) { auto colSum = RegColWiseVec::filled(Vec::filled(0)); #pragma unroll for (uint32_t n = 0; n < src.cols; n++) { #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { #pragma unroll for (uint32_t m = 0; m < src.rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { colSum[n][j] = (m == 0 && i == 0) ? src(m, n)(i, j) : colSum[n][j] + src(m, n)(i, j); } } } } #pragma unroll for (uint32_t xorMask = 16; xorMask > 2; xorMask /= 2) { #pragma unroll for (uint32_t n = 0; n < src.cols; n++) { #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { auto& x = colSum[n][j]; x += __shfl_xor_sync(~0U, x, xorMask); } } } return colSum; } __device__ inline void storeGemm0AccToShm( uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc) { #if CACHE_ELEM_ENUM == 0 using F16Acc = Array2D, Gemm0Acc::rows, Gemm0Acc::cols>; F16Acc f16Acc; reinterpret_cast&>(f16Acc) = convert(reinterpret_cast const&>(acc)); static_assert(Gemm0Acc::size == 1 || Gemm0Acc::size % 2 == 0); uint32_t const idxHalf = lane / 16; uint32_t const idxInHalf = lane % 16; uint32_t const idxOctInsideHalf = idxInHalf / 8; uint32_t const idxRowInsideOct = lane % 8; uint32_t const warpBaseC = 16 * warpRank; auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> std::pair { uint32_t const accR = idxAccCoreMat / Gemm0Acc::cols; uint32_t const accC = idxAccCoreMat % Gemm0Acc::cols; return {accR, accC}; }; auto const getDstAddr = [&](uint32_t idxAccCoreMat) -> LdGrain* { auto const [accR, accC] = toAccCoords(idxAccCoreMat); static_assert(sizeof(MathElem) * gemm0CtaTileNbTokens == xPartBytes); uint32_t const idxPart = 0; uint32_t const dstR = accC * 8 + idxRowInsideOct; uint32_t const dstC = exactDiv(gmma::instM * accR + warpBaseC + 8 * idxOctInsideHalf, cacheElemsPerGrain); assert(dstC / exactDiv(xPartBytes, grainBytes) == idxPart); return &smemX[idxPart].template at(dstR, dstC); }; auto const getAccData = [&](uint32_t idxAccCoreMat) { auto const [accR, accC] = toAccCoords(idxAccCoreMat); return f16Acc(accR, accC); }; barConsumed.arrive_and_wait(); #pragma unroll for (uint32_t iter = 0; iter < Gemm0Acc::size / 2; iter++) { auto const dstAddr = getDstAddr(iter * 2 + idxHalf); Vec const data[2] = {getAccData(iter * 2), getAccData(iter * 2 + 1)}; stmatrix(dstAddr, reinterpret_cast(data)); } if constexpr (Gemm0Acc::size % 2 != 0) { auto const dstAddr = lane < 16 ? getDstAddr(Gemm0Acc::size - 1) : nullptr; stmatrix(dstAddr, getAccData(Gemm0Acc::size - 1)); } #elif CACHE_ELEM_ENUM == 2 using F8Acc = Array2D; F8Acc f8Acc; #pragma unroll for (uint32_t i = 0; i < acc.rows; i++) { #pragma unroll for (uint32_t j = 0; j < acc.cols; j++) { auto const& core = acc(i, j); static_assert(mha::is_same_v); Vec const f8Data = {__nv_cvt_float2_to_fp8x2(float2{core(0, 0), core(1, 0)}, __NV_SATFINITE, __NV_E4M3), __nv_cvt_float2_to_fp8x2(float2{core(0, 1), core(1, 1)}, __NV_SATFINITE, __NV_E4M3)}; f8Acc(i, j) = reinterpret_cast(f8Data); } } if constexpr (F8Acc::size == 4 || F8Acc::size == 2 || F8Acc::size == 1) { LdGrain* dst = nullptr; if (F8Acc::size == 4 || lane < 8 * F8Acc::size) { uint32_t const idxCore = lane / 8; uint32_t const srcRow = idxCore / F8Acc::cols; uint32_t const srcCol = idxCore % F8Acc::cols; uint32_t const dstCoreRow = lane % 8; uint32_t const dstRow = srcCol * 8 + dstCoreRow; BoundedVal const dstCol{srcRow * 4 + warpRank}; dst = &smemX[dstCol.template divBy().get()].template at( dstRow, dstCol.template mod().get()); } barConsumed.arrive_and_wait(); stmatrix(dst, reinterpret_cast const&>(f8Acc)); } else { // we need to use loops assert(false); trap(); } #endif } #else __device__ inline RegRowWiseVec warpRowWiseReduce( RegRowWiseVec const& init, Gemm0Acc const& src, float (*op)(float, float)) { RegRowWiseVec vec = init; #pragma unroll for (uint32_t m = 0; m < src.rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { #pragma unroll for (uint32_t n = 0; n < src.cols; n++) { #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { // @fixme: check if compiler is reordering these op to hide latency. vec[m][i] = op(vec[m][i], src(m, n)(i, j)); } } } } #pragma unroll for (uint32_t xorMask = 2; xorMask != 0; xorMask /= 2) { #pragma unroll for (uint32_t m = 0; m < src.rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { auto& x = vec[m][i]; x = op(x, __shfl_xor_sync(~0U, x, xorMask)); } } } return vec; } __device__ inline RegRowWiseVec computeWarpGrpRowMax_sync( uint32_t warpRank, ShmQWiseVec& smemRowMax, Gemm0Acc const& src) { assert(warpRank < 4); RegRowWiseVec const init = loadShmRowWiseVecWithDup(warpRank, smemRowMax); RegRowWiseVec rowMax = warpRowWiseReduce(init, src, fmax); storeShmRowWiseVec(warpRank, smemRowMax, rowMax); __syncwarp(); return rowMax; } #if SPEC_DEC __device__ inline void warpGrpApplyMask( Gemm0Acc& acc, SpecDec const& specDec, uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank) { static_assert(!SLIDING_WINDOW, "SpecDec is not supported for sliding window"); constexpr uint32_t tileSize = gemm0CtaTileNbTokens; uint32_t const offset = tileSize * idxTile; uint32_t const nbValidCols = mha::min(offset < cacheSeqLen ? cacheSeqLen - offset : 0U, tileSize); bool const ctaNeedEndMask = (nbValidCols < tileSize); bool const ctaNeedSpecDecMask = specDec.needMask(idxTile, 0); bool const needMask = ctaNeedEndMask || ctaNeedSpecDecMask; if (!needMask) { return; } static_assert(tileSize == 64, "not implemented"); auto const endMask = (~uint64_t{0} >> (tileSize - nbValidCols)); uint32_t const idxInQuad = laneId() % 4; uint32_t const idxQuad = laneId() / 4; #pragma unroll for (uint32_t m = 0; m < acc.rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { uint32_t const row = gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad; auto const specDecMask = specDec.needMask(idxTile, row / headGrpSize) ? specDec.loadTileMaskRow(idxTile, row / headGrpSize) : SpecDec::TileMaskRow{~0U, ~0U}; auto const mask = endMask & reinterpret_cast(specDecMask); if (mask == ~uint64_t{0}) { continue; } #if DBG_PRINT if (idxInQuad == 0) { printf("mask at row %d: %lx\n", row, mask); } #endif #pragma unroll for (uint32_t n = 0; n < acc.cols; n++) { #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; assert((col < nbValidCols) == bool(endMask & (1ULL << col))); if (((mask >> col) & 1) == 0) { acc(m, n)(i, j) = mha::numeric_limits::lowest(); } } } } } } #else __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd) { uint32_t const idxInQuad = laneId() % 4; #pragma unroll for (uint32_t n = 0; n < acc.cols; n++) { #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; if (col >= validColBeg && col < validColEnd) { continue; } #pragma unroll for (uint32_t m = 0; m < acc.rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { acc(m, n)(i, j) = mha::numeric_limits::lowest(); } } } } } #endif __device__ inline void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegRowWiseVec const& rowMax) { #pragma unroll for (uint32_t m = 0; m < acc.rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { float const maxVal = rowMax[m][i]; float const bias = maxVal * log2e; #pragma unroll for (uint32_t n = 0; n < acc.cols; n++) { #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { float& elem = acc(m, n)(i, j); assert(maxVal >= elem); elem = exp2f(elem * log2e - bias); } } } } } __device__ inline RegRowWiseVec computeWarpRowSum(Gemm0Acc& src) { return warpRowWiseReduce(RegRowWiseVec{}, src, [](float a, float b) { return a + b; }); } __device__ inline RegRowWiseVec loadShmRowWiseVecWithDup(uint32_t warpRank, ShmQWiseVec const& smemVec) { RegRowWiseVec vec; uint32_t const idxQuad = laneId() / 4; #pragma unroll for (uint32_t m = 0; m < RegRowWiseVec::size; m++) { #pragma unroll for (uint32_t i = 0; i < RegRowWiseVec::Elem::size; i++) { vec[m][i] = smemVec[gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad]; } } return vec; } __device__ void storeShmRowWiseVec(uint32_t warpRank, ShmQWiseVec& smemVec, RegRowWiseVec const& regVec) { uint32_t const lane = laneId(); uint32_t const idxQuad = lane / 4; uint32_t const idxInQuad = lane % 4; bool const enable = (idxInQuad == 0); #pragma unroll for (uint32_t m = 0; m < RegRowWiseVec::size; m++) { #pragma unroll for (uint32_t i = 0; i < RegRowWiseVec::Elem::size; i++) { assert(__shfl_sync(~0U, regVec[m][i], idxQuad * 4) == regVec[m][i]); if (enable) { smemVec[gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad] = regVec[m][i]; } } } } // for X // order: 0,8,1,9, 2,10,3,11, 4,12,5,13, 6,14,7,15, ... __device__ inline void storeGemm0AccToShm( uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc) { uint32_t const idxMat = lane / 8; uint32_t const idxRow = lane % 8; barConsumed.arrive_and_wait(); #pragma unroll for (uint32_t m = 0; m < Gemm0Acc::rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { Vec fp8Data; #pragma unroll for (uint32_t n = 0; n < exactDiv(Gemm0Acc::cols, 2); n++) { reinterpret_cast&>(fp8Data[n]) = {__nv_fp8x2_e4m3(float2{acc(m, n * 2)(i, 0), acc(m, n * 2 + 1)(i, 0)}), __nv_fp8x2_e4m3(float2{acc(m, n * 2)(i, 1), acc(m, n * 2 + 1)(i, 1)})}; } static_assert(decltype(fp8Data)::size == 4); stmatrix_4x( this_warp(), &smemX[m].template at(16 * warpRank + 8 * i + idxRow, idxMat), fp8Data); } } } #endif #if SWAP_AB __device__ inline Vec loadVTileTransposed( uint32_t warpRank, uint32_t lane, SharedMem::VBuffer const& smemV, uint32_t idxGmmaInstK) { Vec fragA; constexpr uint32_t instK = gmma::instK; #pragma unroll for (uint32_t i = 0; i < gemm1NbGmmaInstM; i++) { static_assert(exactDiv(gmma::instM, gmmaWarpsPerGrp) == grainBytes); constexpr uint32_t grainsPerPart = exactDiv(cacheHeadPartBytes, grainBytes); #if CACHE_ELEM_ENUM == 0 uint32_t idxRow = lane % 8; uint32_t idxMat = lane / 8; uint32_t c = idxMat % 2; uint32_t r = idxMat / 2; auto const col = BoundedVal<2 * gmmaWarpsPerGrp * gemm1NbGmmaInstM>{2 * (gmmaWarpsPerGrp * i + warpRank) + c}; auto const src = &smemV[col.template divBy().get()].template at( instK * idxGmmaInstK + 8 * r + idxRow, col.template mod().get()); auto const data = ldmatrix(src); fragA[i] = reinterpret_cast(data); #elif CACHE_ELEM_ENUM == 2 auto const col = BoundedVal{gmmaWarpsPerGrp * i + warpRank}; LdGrain const* src = &smemV[col.template divBy().get()].template at( instK * idxGmmaInstK + lane, col.template mod().get()); auto const data = ldmatrix(src); fragA[i](0, 0)(0, 0) = prmt(data[0], data[1], {0, 4, 2, 6}); fragA[i](0, 0)(1, 0) = prmt(data[0], data[1], {1, 5, 3, 7}); fragA[i](0, 1)(0, 0) = prmt(data[2], data[3], {0, 4, 2, 6}); fragA[i](0, 1)(1, 0) = prmt(data[2], data[3], {1, 5, 3, 7}); #endif } return fragA; } #else __device__ inline void transposeVTile( uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, SharedMem::VBuffer const& src) { uint32_t const idxMat = lane / 8; uint32_t const idxRow = lane % 8; #pragma unroll for (uint32_t m = 0; m < exactDiv(SharedMem::VTBuffer::rows, gmma::instM); m++) { static_assert(cacheHeadPartElems >= gmma::instM); uint32_t const idxPart = gmma::instM * m / cacheHeadPartElems; constexpr uint32_t grainsPerCacheHeadPart = exactDiv(cacheHeadPartElems, cacheElemsPerGrain); #pragma unroll for (uint32_t n = 0; n < exactDiv(SharedMem::VTBuffer::cols, 2); n++) { LdGrain const a = ldmatrix_4x(this_warp(), &src[idxPart].template at(32 * n + lane, exactDiv(gmma::instM, cacheElemsPerGrain) * m - grainsPerCacheHeadPart * idxPart + warpRank)); LdGrain const b = {prmt(a[0], a[1], {0, 4, 2, 6}), prmt(a[0], a[1], {1, 5, 3, 7}), prmt(a[2], a[3], {0, 4, 2, 6}), prmt(a[2], a[3], {1, 5, 3, 7})}; uint32_t const i = idxMat % 2; uint32_t const j = idxMat / 2; stmatrix_4x( this_warp(), &dst.template at(gmma::instM * m + 16 * warpRank + 8 * i + idxRow, 2 * n + j), b); } } } #endif #if SWAP_AB __device__ inline Vec loadShmColWiseVecNoDup(ShmQWiseVec const& shmVec) { Vec ret; #pragma unroll for (uint32_t i = 0; i < divUp(ShmQWiseVec::size, warp_size); i++) { uint32_t const idx = i * warp_size + laneId(); bool const inBound = ((ShmQWiseVec::size % warp_size == 0) || (idx < ShmQWiseVec::size)); ret[i] = (inBound ? shmVec[idx] : 0); } return ret; } __device__ inline void storeShmColWiseVecNoDup( ShmQWiseVec& shmVec, Vec const& src) { #pragma unroll for (uint32_t i = 0; i < divUp(ShmQWiseVec::size, warp_size); i++) { uint32_t const idx = i * warp_size + laneId(); bool const inBound = ((ShmQWiseVec::size % warp_size == 0) || (idx < ShmQWiseVec::size)); if (inBound) { shmVec[idx] = src[i]; } } } #else __device__ inline Vec loadShmRowWiseVecNoDup(uint32_t warpRank, ShmQWiseVec const& shmVec) { constexpr uint32_t const nbElems = exactDiv(ShmQWiseVec::size, gmma::instM) * (gmma::instM / 4); Vec ret; uint32_t const lane = laneId(); uint32_t const idxHalf = lane / (gmma::instM / 4); uint32_t const idxInHalf = lane % (gmma::instM / 4); #pragma unroll for (uint32_t i = 0; i < divUp(nbElems, warp_size); i++) { uint32_t const idx = gmma::instM * 2 * i + gmma::instM * idxHalf + (gmma::instM / 4) * warpRank + idxInHalf; bool const inBound = ((nbElems % warp_size == 0) || (i + 1 < divUp(nbElems, warp_size)) || (idx < ShmQWiseVec::size)); ret[i] = (inBound ? shmVec[idx] : 0); } return ret; } __device__ inline void storeShmRowWiseVecNoDup(uint32_t warpRank, ShmQWiseVec& shmVec, Vec const& src) { constexpr uint32_t const nbElems = exactDiv(ShmQWiseVec::size, gmma::instM) * (gmma::instM / 4); Vec ret; uint32_t const lane = laneId(); uint32_t const idxHalf = lane / (gmma::instM / 4); uint32_t const idxInHalf = lane % (gmma::instM / 4); #pragma unroll for (uint32_t i = 0; i < divUp(nbElems, warp_size); i++) { uint32_t const idx = gmma::instM * 2 * i + gmma::instM * idxHalf + (gmma::instM / 4) * warpRank + idxInHalf; bool const inBound = ((nbElems % warp_size == 0) || (i + 1 < divUp(nbElems, warp_size)) || (idx < ShmQWiseVec::size)); if (inBound) { shmVec[idx] = src[i]; } } } #endif #if SWAP_AB __device__ inline void rescaleGemm1AccForNewColMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXColMax, ShmQWiseVec const (&shmXColSum)[gemm0NbWarps], ShmQWiseVec& shmAccColMax, Gemm1Acc& acc, ShmQWiseVec& shmAccColSum, CtaBarrier& gemm1WarpGrpBar) { auto accColSum = loadShmColWiseVecNoDup(shmAccColSum); auto const xColMax = loadShmColWiseVecNoDup(shmXColMax); auto const accColMax = loadShmColWiseVecNoDup(shmAccColMax); auto token = gemm1WarpGrpBar.arrive(); auto const needRescaleVec = (accColMax < xColMax); UniformNeedRescaleMask rescaleMask; bool anyNeedRescale = false; #pragma unroll for (uint32_t i = 0; i < rescaleMask.size; i++) { assert(accColMax[i] <= xColMax[i]); rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]); anyNeedRescale = anyNeedRescale || (rescaleMask[i] != 0); } if (anyNeedRescale) { auto const scaleVec = expf(accColMax - xColMax); auto const lane = laneId(); #pragma unroll for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { uint32_t const vecIdx = gmma::instNBase * n / warp_size; uint32_t const offset = gmma::instNBase * n % warp_size; constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { auto const mask = ((rescaleMask[vecIdx] >> (offset + j)) & 0b01010101U); auto getScale = [&] { return __shfl_sync( ~0U, scaleVec[vecIdx], offset + lane % nbThrdsPerInstNBase * GmmaAccCoreMat::cols + j); }; assert((getScale() != 1) == ((mask >> lane % nbThrdsPerInstNBase * GmmaAccCoreMat::cols) & 0x1U)); bool const needRescale = (mask != 0); if (!needRescale) { // this branch is warp-uniform continue; } float const scale = getScale(); #pragma unroll for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { acc(m, n)(i, j) *= scale; } } } } accColSum = accColSum * scaleVec; } gemm1WarpGrpBar.wait(mha::move(token)); // @fixme: with atomic, we can let the first warp reaching here to do the update, instead of always warp 3. uint32_t const warpRankForUpdate = gmmaWarpsPerGrp - 1; if (warpRank == warpRankForUpdate) { if (anyNeedRescale) { storeShmColWiseVecNoDup(shmAccColMax, xColMax); } #pragma unroll for (uint32_t i = 0; i < gemm0NbWarps; i++) { accColSum = accColSum + loadShmColWiseVecNoDup(shmXColSum[i]); } storeShmColWiseVecNoDup(shmAccColSum, accColSum); } gemm1WarpGrpBar.arrive_and_wait(); } #else __device__ inline void rescaleGemm1AccForNewRowMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXRowMax, ShmQWiseVec const& shmXRowSum, ShmQWiseVec& shmAccRowMax, Gemm1Acc& acc, ShmQWiseVec& shmAccRowSum) { auto accRowSum = loadShmRowWiseVecNoDup(warpRank, shmAccRowSum); auto const xRowMax = loadShmRowWiseVecNoDup(warpRank, shmXRowMax); auto const accRowMax = loadShmRowWiseVecNoDup(warpRank, shmAccRowMax); assert(all(xRowMax >= accRowMax)); auto const needRescaleVec = (accRowMax < xRowMax); UniformNeedRescaleMask rescaleMask; bool anyNeedRescale = false; #pragma unroll for (uint32_t i = 0; i < rescaleMask.size; i++) { assert(accRowMax[i] <= xRowMax[i]); rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]); anyNeedRescale = anyNeedRescale || (rescaleMask[i] != 0); } if (anyNeedRescale) { auto const scaleVec = expf(accRowMax - xRowMax); auto const lane = laneId(); #pragma unroll for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { uint8_t const mask = reinterpret_cast(rescaleMask[m / 2])[m % 2][i]; bool const needRescale = (mask != 0); if (needRescale) { // this branch is warp-uniform float const scale = __shfl_sync(~0U, scaleVec[m / 2], 16 * (m % 2) + 8 * i + lane / 4); #pragma unroll for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { acc(m, n)(i, j) *= scale; } } } } } accRowSum = accRowSum * scaleVec; } __syncwarp(); auto const xRowSum = loadShmRowWiseVecNoDup(warpRank, shmXRowSum); storeShmRowWiseVecNoDup(warpRank, shmAccRowSum, accRowSum + xRowSum); storeShmRowWiseVecNoDup(warpRank, shmAccRowMax, xRowMax); __syncwarp(); } #endif #if SWAP_AB __device__ inline void rescaleAcc(Gemm1Acc& acc, RegColWiseVec const& scale) { #pragma unroll for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { #pragma unroll for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { acc(m, n)(i, j) *= scale[n][j]; } } } } } #else __device__ inline void rescaleAcc(Gemm1Acc& acc, RegRowWiseVec const& scale) { #pragma unroll for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { #pragma unroll for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { #pragma unroll for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { acc(m, n)(i, j) *= scale[m][i]; } } } } } #endif #if SWAP_AB // @fixme: consider make this noinline template __device__ inline void saveTransposedOutput(uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc const& acc, CtaBarrier& warpGrpBar) { uint32_t const lane = laneId(); #if CACHE_ELEM_ENUM == 0 uint32_t const idxMat = lane / 8; uint32_t const idxRow = lane % 8; #elif CACHE_ELEM_ENUM == 2 uint32_t const idxQuad = lane / 4; uint32_t const idxInQuad = lane % 4; #endif #pragma unroll for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { #pragma unroll for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { auto const& core = acc(m, n); #if CACHE_ELEM_ENUM == 0 Vec f16Core; reinterpret_cast&>(f16Core) = convert(reinterpret_cast const&>(core)); auto const dst = idxMat < 2 ? &swizzleBuf.template at(idxRow, 2 * (gmmaWarpsPerGrp * m + warpRank) + idxMat) : nullptr; stmatrix(dst, f16Core); #elif CACHE_ELEM_ENUM == 2 // each row is part of a b16 8x8 matrix and is transposed Array2D coreTrans; for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { static_assert(GmmaAccCoreMat::cols == 2 && sizeof(InputElem) == 2); InputElem2 const coreRow = float2ToInputElem2({core(i, 0), core(i, 1)}); auto const coreRowTrans = movmatrix(reinterpret_cast(coreRow)); reinterpret_cast(coreTrans(i, 0)) = coreRowTrans; } // expect compiler to generate two PRMT instructions Vec const data = {coreTrans(0, 0), coreTrans(1, 0), coreTrans(0, 1), coreTrans(1, 1)}; swizzleBuf.template at(gmma::instNBase * n + idxQuad, (gmma::instM * m + exactDiv(gmma::instM, gmmaWarpsPerGrp) * warpRank) / 16)[idxInQuad] = data; #endif } } warpGrpBar.arrive_and_wait(); constexpr uint32_t headsPerIter = exactDiv(grainBytes * gemm1NbThrds, paddedInputHeadBytes); constexpr uint32_t nbIters = divUp(ctaNbValidQHeads, headsPerIter); constexpr uint32_t nbWholeIters = ctaNbValidQHeads / headsPerIter; constexpr uint32_t nbGrainsPerHead = exactDiv(paddedInputHeadBytes, grainBytes); uint32_t const idxHeadBase = threadRank / nbGrainsPerHead; uint32_t const idxGrain = threadRank % nbGrainsPerHead; #pragma unroll for (uint32_t iter = 0; iter < nbIters; iter++) { uint32_t const idxHead = idxHeadBase + iter * headsPerIter; if ((iter < nbWholeIters || idxHead < ctaNbValidQHeads) && (!isHeadPadded || idxGrain < grainsPerIOHead)) { #if CACHE_ELEM_ENUM == 0 auto const data = swizzleBuf.template at(idxHead, idxGrain); #elif CACHE_ELEM_ENUM == 2 auto const data = reinterpret_cast&>(swizzleBuf.template at(idxHead, idxGrain / 2))[idxGrain % 2]; #endif constexpr uint32_t inputElemsPerGrain = exactDiv(grainBytes, inputElemSize); auto const outVec = convert(reinterpret_cast const&>(data)); reinterpret_cast, nbGrainsPerHead>&>(dst[idxHead])[idxGrain] = outVec; } } } template __device__ inline void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, ShmQWiseVec const& accColSum) { // @fixme: if ctaNbQHeads is large, use loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of mufu.rcp // static_assert(ctaNbQHeads <= 8, "Warning: consider using loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of // mufu.rcp"); auto const regColSum = loadShmColWiseVecWithDup(accColSum); auto const regOutScale = __frcp_rn(regColSum) * xvoScale; rescaleAcc(acc, regOutScale); saveTransposedOutput(threadRank, warpRank, dst, swizzleBuf, acc, warpGrpBar); warpGrpBar.arrive_and_wait(); } #else template __device__ inline void finalizeAndWriteOut_sync(uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, ShmQWiseVec const& accRowSum, uint32_t nbKHeads /* for spec dec. set to 1 for workspace*/, uint32_t ctaNbValidTokens) { auto const regRowSum = loadShmRowWiseVecWithDup(warpRank, accRowSum); auto const regOutScale = __frcp_rn(regRowSum) * xvoScale; rescaleAcc(acc, regOutScale); using DstElem = typename DstHead::Elem; auto const lane = laneId(); uint32_t const idxQuad = lane / 4; uint32_t const idxInQuad = lane % 4; using Atom = Vec, 4>; using SwizzleBuf = Array2D, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>; static_assert(sizeof(SwizzleBuf) <= sizeof(swizzleBuf)); auto& buf = reinterpret_cast(swizzleBuf); #pragma unroll for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { #pragma unroll for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { uint32_t const r = gmma::instM * m + 16 * warpRank + 8 * i + idxQuad; static_assert(SwizzleBuf::cols == exactDiv(Gemm1Acc::cols, 2)); #pragma unroll for (uint32_t n = 0; n < exactDiv(Gemm1Acc::cols, 2); n++) { Vec const v = convert(Vec{ acc(m, n * 2)(i, 0), acc(m, n * 2 + 1)(i, 0), acc(m, n * 2)(i, 1), acc(m, n * 2 + 1)(i, 1)}); //@fixme: without reinterpret_cast to V, the compiler generates wrong code, and require a __syncwarp() // after rescaleAcc() to work around. Likely a bug of the compiler. //@todo: report a compiler bug. using V = Vec; reinterpret_cast(buf.template at(r, n)[idxInQuad]) = reinterpret_cast(v); // buf.template at(r, n)[idxInQuad] = v; } } } __syncwarp(); #pragma unroll for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { constexpr uint32_t srcHeadBytes = sizeof(DstElem) * headElems; constexpr uint32_t grpSize = exactDiv(srcHeadBytes, grainBytes); constexpr uint32_t nbGrps = exactDiv(warp_size, grpSize); uint32_t const idxGrp = lane / grpSize; constexpr uint32_t grainsPerAtom = exactDiv(sizeof(Atom), grainBytes); uint32_t const rowBase = gmma::instM * m + 16 * warpRank; constexpr uint32_t totalNbGrains = grainsPerAtom * SwizzleBuf::cols * 16; uint32_t const nbIters = divUp(totalNbGrains, nbGrps); constexpr bool wholeIters = (totalNbGrains % nbGrps == 0); constexpr bool wholeHeads = (validElemsPerHead == headElems); #pragma unroll for (uint32_t iter = 0; iter < nbIters; iter++) { uint32_t const idxGrain = nbGrps * iter + idxGrp; constexpr uint32_t grainsPerSrcHead = exactDiv(srcHeadBytes, grainBytes); uint32_t const r = idxGrain / grainsPerSrcHead; if (!wholeIters && r >= 16) { break; } uint32_t const cGrain = idxGrain % grainsPerSrcHead; uint32_t const cAtom = cGrain / grainsPerAtom; constexpr uint32_t grainsPerDstHead = exactDiv(sizeof(DstHead), grainBytes); uint32_t const glbRow = gmma::instM * m + 16 * warpRank + r; if (ctaNbValidQHeads != ctaNbQHeads && glbRow >= ctaNbValidQHeads) { break; } if (wholeHeads || cGrain < grainsPerDstHead) { uint32_t const srcRow = rowBase + r; auto const data = reinterpret_cast( buf.template at(srcRow, cAtom))[cGrain % grainsPerAtom]; #if SPEC_DEC static_assert(beamWidth == 1); uint32_t const idxToken = srcRow / headGrpSize; // inside CTA if (idxToken >= ctaNbValidTokens) { break; } uint32_t const tokenPad = headGrpSize * (nbKHeads - 1); uint32_t const dstRow = srcRow + idxToken * tokenPad; #else uint32_t const dstRow = srcRow; #endif reinterpret_cast(dst[dstRow])[cGrain] = data; } } } } #endif template __device__ inline Vec, ropeNbPairsPerThrd> loadHead( Vec const& head, uint32_t tid) { constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); static_assert(nbPairs % nbPairsPerThrd == 0); Vec, nbPairsPerThrd> ret; if constexpr (forNeox) { auto const& pairs = reinterpret_cast, nbWorkingThrds>, 2> const&>(head); auto const data = isWorkingThrd ? Vec, 2>{pairs[0][tid], pairs[1][tid]} : Vec, 2>{}; Vec, 2> const tmp = {convert(data[0]), convert(data[1])}; #pragma unroll for (uint32_t i = 0; i < nbPairsPerThrd; i++) { ret[i][0] = tmp[0][i]; ret[i][1] = tmp[1][i]; } } else { auto const data = isWorkingThrd ? reinterpret_cast, nbPairsPerThrd> const*>(&head)[tid] : Vec, nbPairsPerThrd>{}; #pragma unroll for (uint32_t i = 0; i < nbPairsPerThrd; i++) { ret[i] = convert(data[i]); } } return ret; } template __device__ inline mha::conditional_t, 2>, Vec, nbPairsPerThrd>> applyRoPE(Vec, nbPairsPerThrd> const& data, Vec, nbPairsPerThrd> const& ropeCosSin) { Vec, nbPairsPerThrd> r; #pragma unroll for (uint32_t i = 0; i < nbPairsPerThrd; i++) { float const x = data[i][0]; float const y = data[i][1]; float const c = ropeCosSin[i][0]; float const s = ropeCosSin[i][1]; r[i] = Vec{c * x - s * y, s * x + c * y}; } if constexpr (forNeox) { Vec, 2> tmp; #pragma unroll for (uint32_t i = 0; i < nbPairsPerThrd; i++) { tmp[0][i] = r[i][0]; tmp[1][i] = r[i][1]; } return Vec, 2>{convert(tmp[0]), convert(tmp[1])}; } else { Vec, nbPairsPerThrd> ret; #pragma unroll for (uint32_t i = 0; i < nbPairsPerThrd; i++) { ret[i] = convert(r[i]); } return ret; } } template __device__ inline void storeRotatedPairsForKV(GMemCacheHead& dst, mha::conditional_t>, 2>, Vec, ropeNbPairsPerThrd>> const& src, uint32_t tid) { constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); static_assert(nbPairs % nbPairsPerThrd == 0); if (!isWorkingThrd) { return; } if constexpr (forNeox) { auto& pairs = reinterpret_cast, nbWorkingThrds>, 2>&>(dst); pairs[0][tid] = src[0]; pairs[1][tid] = src[1]; } else { reinterpret_cast, nbPairsPerThrd>*>(&dst)[tid] = src; } } template __device__ inline void storeRotatedPairsForQ(SharedMem::QBuffer& dst, mha::conditional_t>, 2>, Vec, ropeNbPairsPerThrd>> const& src, uint32_t row, uint32_t tid) { constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); static_assert(nbPairs % nbPairsPerThrd == 0); if (isWorkingThrd) { if constexpr (forNeox) { #pragma unroll for (uint32_t i = 0; i < 2; i++) { auto const byteOffset = BoundedVal{cacheElemSize * nbPairsPerThrd * (nbWorkingThrds * i + tid)}; uint32_t const idxPart = byteOffset.template divBy().get(); auto const byteOffsetInsidePart = byteOffset.template mod(); uint32_t const idxGrain = byteOffsetInsidePart.template divBy().get(); LdGrain& grain = dst[idxPart].template at(row, idxGrain); uint32_t const byteOffsetInsideGrain = byteOffsetInsidePart.template mod().get(); static_assert( cacheElemSize * nbPairsPerThrd <= grainBytes && grainBytes % (cacheElemSize * nbPairsPerThrd) == 0); reinterpret_cast&>( reinterpret_cast(&grain)[byteOffsetInsideGrain]) = src[i]; } } else { auto const byteOffset = BoundedVal{cacheElemSize * 2 * nbPairsPerThrd * tid}; uint32_t const idxPart = byteOffset.template divBy().get(); auto const byteOffsetInsidePart = byteOffset.template mod(); uint32_t const idxGrain = byteOffsetInsidePart.template divBy().get(); LdGrain& grain = dst[idxPart].template at(row, idxGrain); uint32_t const byteOffsetInsideGrain = byteOffsetInsidePart.template mod().get(); static_assert(cacheElemSize * 2 * nbPairsPerThrd <= grainBytes && grainBytes % (cacheElemSize * 2 * nbPairsPerThrd) == 0); reinterpret_cast, nbPairsPerThrd>&>( reinterpret_cast(&grain)[byteOffsetInsideGrain]) = src; } } static_assert(validElemsPerHead % 16 == 0); __syncwarp(); if constexpr (validElemsPerHead < headElems) { static_assert(validElemsPerHead >= headElems - exactDiv(headElems, nbQParts)); constexpr uint32_t nbPadGrainsPerHead = exactDiv(headElems - validElemsPerHead, cacheElemsPerGrain); constexpr uint32_t nbPadGrains = nbPadGrainsPerHead * ctaNbQHeads; uint32_t const nbIters = divUp(nbPadGrains, nbThrds); #pragma unroll for (uint32_t iter = 0; iter < nbIters; iter++) { uint32_t idx = tid + nbThrds * iter; if (idx >= nbPadGrains) { break; } uint32_t const r = idx / nbPadGrainsPerHead; uint32_t const c = grainsPerQPart - nbPadGrainsPerHead + idx % nbPadGrainsPerHead; dst[dst.size - 1].template at(r, c) = LdGrain{}; } } } #ifndef GENERATE_CUBIN constexpr uint32_t getElemBytes(CUtensorMapDataType_enum dataType) { switch (dataType) { case CU_TENSOR_MAP_DATA_TYPE_UINT8: return 1; case CU_TENSOR_MAP_DATA_TYPE_UINT16: return 2; case CU_TENSOR_MAP_DATA_TYPE_UINT32: return 4; case CU_TENSOR_MAP_DATA_TYPE_INT32: return 4; case CU_TENSOR_MAP_DATA_TYPE_UINT64: return 8; case CU_TENSOR_MAP_DATA_TYPE_INT64: return 8; case CU_TENSOR_MAP_DATA_TYPE_FLOAT16: return 2; case CU_TENSOR_MAP_DATA_TYPE_FLOAT32: return 4; case CU_TENSOR_MAP_DATA_TYPE_FLOAT64: return 8; case CU_TENSOR_MAP_DATA_TYPE_BFLOAT16: return 2; case CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ: return 4; case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32: return 4; case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ: return 4; } throw std::runtime_error("unsupported data type"); } static CUtensorMap makeTensorMapForContiguousKVCache(void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, uint32_t nbKHeads, uint32_t maxCacheLen, uint32_t beamWidth, uint32_t batchSize, uint32_t nbTokensPerTile) { CUtensorMap tensorMap{}; uint64_t const globalDims[] = {headElems, maxCacheLen, nbKHeads, 2 * beamWidth * batchSize}; uint32_t elemBytes = getElemBytes(dataType); uint32_t const headBytes = elemBytes * headElems; uint64_t const globalStrides[] = {headBytes, headBytes * maxCacheLen, headBytes * maxCacheLen * nbKHeads}; assert(headElems <= 256); uint32_t const paddedHeadElems = headElems <= 64 ? 64 : (headElems <= 128 ? 128 : 256); uint32_t const partElems = mha::min(elemBytes * paddedHeadElems, 128U) / elemBytes; uint32_t const boxDims[] = {partElems, nbTokensPerTile, 1, 1}; uint32_t const elemStrides[] = {1, 1, 1, 1}; auto const swizzle = [&] { switch (partElems) { case 128: return CU_TENSOR_MAP_SWIZZLE_128B; case 64: return CU_TENSOR_MAP_SWIZZLE_64B; default: throw std::runtime_error("unsupported cache head size"); } }(); checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast(addr), globalDims, globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); return tensorMap; } static CUtensorMap makeTensorMapForPagedKVCache(void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, uint32_t nbKHeads, uint32_t tokensPerPage, uint32_t nbTokensPerTile) { CUtensorMap tensorMap{}; uint32_t elemBytes = getElemBytes(dataType); uint64_t const globalDims[] = {headElems, tokensPerPage, nbKHeads, 1U << 31}; uint32_t const headBytes = elemBytes * headElems; uint64_t const globalStrides[] = {headBytes, headBytes * tokensPerPage, headBytes * tokensPerPage * nbKHeads}; assert(headElems <= 256); uint32_t const paddedHeadElems = headElems <= 64 ? 64 : (headElems <= 128 ? 128 : 256); uint32_t const partBytes = mha::min(elemBytes * paddedHeadElems, 128U); uint32_t const partElems = partBytes / elemBytes; uint32_t const boxDims[] = {partElems, mha::min(tokensPerPage, nbTokensPerTile), 1, 1}; uint32_t const elemStrides[] = {1, 1, 1, 1}; auto const swizzle = [&] { switch (partBytes) { case 128: return CU_TENSOR_MAP_SWIZZLE_128B; case 64: return CU_TENSOR_MAP_SWIZZLE_64B; default: throw std::runtime_error("unsupported cache head size"); } }(); checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast(addr), globalDims, globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); return tensorMap; } void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #if SLIDING_WINDOW uint32_t slidingWinSize, #endif float qScale, OutputHead* output, #if LOW_PREC_OUTPUT float const* rcpOutScale, #endif #if USE_INPUT_KV InputHead const* qkv, #if ROPE_STYLE != 0 Vec const* ropeCosSin, #endif #else InputHead const* q, #endif #if USE_PAGED_KV_CACHE GMemCacheHead* pool, // global pool of pages KVCachePageIndex const* kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] #else GMemKVCacheHead* kvCacheData, #endif uint32_t maxSeqLen, uint32_t const* seqLen, #if USE_BEAM_SEARCH BeamSearchParams const& beamSearchParams, #endif uint32_t batchSize, float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for // int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, #endif uint32_t* semaphores, void* scratch, cudaStream_t stream) { if (beamWidth != 1) { throw std::runtime_error("not implemented"); } static uint32_t const hostSmemSize = [&]() { uint32_t size; checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); return size; }(); // printf("smemSize = %u\n", hostSmemSize); uint32_t const nbVHeads = nbKHeads; uint32_t const nbQHeads = nbKHeads * headGrpSize; uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads; uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { auto const env = std::getenv("XQA_NB_SUB_SEQ"); if (env != nullptr) { int32_t const val = std::stoi(env); if (val > 0) { return val; } } float const factor = 0.25f; return mha::min( mha::max(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)), divUp(maxSeqLen, gemm0CtaTileNbTokens)); }(); #if SPEC_DEC uint32_t const qSeqLen = specDecParams.qSeqLen; #else uint32_t const qSeqLen = 1; #endif // gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x == nbInputSeqSplit dim3 const dimGrid{divUp(qSeqLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_FDL != 0); #if USE_PAGED_KV_CACHE uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; auto const dtype = [] { if (std::is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else if (std::is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else if (std::is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } throw std::runtime_error("unsupported cache element type"); }(); auto const tensorMap = makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, gemm0CtaTileNbTokens); cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, nbKHeads, #if SLIDING_WINDOW slidingWinSize, #endif qScale, output, #if LOW_PREC_OUTPUT rcpOutScale, #endif #if USE_INPUT_KV qkv, #if ROPE_STYLE != 0 ropeCosSin, #endif #else q, #endif cacheList, #if USE_BEAM_SEARCH beamSearchParams, #endif batchSize, kvCacheScale, tensorMap, #if SPEC_DEC specDecParams, #endif semaphores, scratch); #else KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; static_assert(!usePagedKVCache); assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); auto const tensorMap = makeTensorMapForContiguousKVCache(kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, batchSize, gemm0CtaTileNbTokens); cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, #if SLIDING_WINDOW slidingWinSize, #endif qScale, output, #if LOW_PREC_OUTPUT rcpOutScale, #endif #if USE_INPUT_KV qkv, #if ROPE_STYLE != 0 ropeCosSin, #endif #else q, #endif cacheList, #if USE_BEAM_SEARCH beamSearchParams, #endif batchSize, kvCacheScale, tensorMap, semaphores, scratch); #endif checkCuda(err); } #endif