/* * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #include "defines.h" #include "mha.h" #if IS_MLA #include "barriers.cuh" #include "mhaUtils.cuh" #include "mha_components.cuh" #include "mha_stdheaders.cuh" #include "mla_sm120.cuh" #include "mma.cuh" #include "tma.h" #include "utils.cuh" #include "utils.h" #ifndef GENERATE_CUBIN #include "hostUtils.h" #include "tensorMap.h" #include #endif __constant__ constexpr XQAKernelType kernelType = XQAKernelType::kSM120_MLA; inline constexpr bool allowMultipleInputTokens = true; inline constexpr uint32_t partElemsK = 64; // @fixme: change this to 128 to save L2 traffic inline constexpr uint32_t nbKParts = exactDiv(validElemsPerKHead, partElemsK); inline constexpr uint32_t nbQParts = nbKParts; inline constexpr uint32_t tokensPerTile = 64; inline constexpr uint32_t partElemsV = 128; inline constexpr uint32_t nbVSplit = 2; inline constexpr uint32_t gemm1V = exactDiv(validElemsPerVHead, nbVSplit); inline constexpr uint32_t nbProducerCtasPerCga = nbVSplit; inline constexpr uint32_t multiBlockMinNbTilesPerCta = 2; inline constexpr uint32_t multiBlockMinNbTiles = multiBlockMinNbTilesPerCta * 2; using MathElem = CacheElem; inline constexpr uint32_t mathElemBytes = sizeof(MathElem); inline constexpr uint32_t grainsPerPartK = exactDiv(partElemsK * mathElemBytes, grainBytes); inline constexpr uint32_t grainElems = exactDiv(grainBytes, mathElemBytes); inline constexpr float xScale = 1.f / kE4M3_MAX; __constant__ constexpr float rcpXScale = kE4M3_MAX; inline constexpr uint32_t nbRegsForIOWarps = 32; inline constexpr uint32_t nbRegsForMathWarps = 232; inline constexpr bool computeRowSumFromF8 = true; struct KVTilePartLoader { #if USE_PAGED_KV_CACHE static_assert(tokensPerPage % tokensPerTile == 0 || tokensPerTile % tokensPerPage == 0); static inline constexpr uint32_t nbPagesPerTile = tokensPerTile >= tokensPerPage ? exactDiv(tokensPerTile, tokensPerPage) : 1; #endif static inline constexpr uint32_t const nbKHeads = 1; KVCacheList const& cacheList; uint32_t const idxReq; static inline constexpr uint32_t const idxHeadGrp = 0; CUtensorMap const& tensorMap; // if greater than 1, then we need unrolling for the loading loop. Seems 1 is fine for latency. static inline constexpr uint32_t nbPageBuffers = 1; #if USE_PAGED_KV_CACHE uint32_t const nbPages; // for bound check Vec pageBuffers[nbPageBuffers]; uint32_t idxTileRef; // idxTile used to load the pages #endif uint32_t const baseOffset; __device__ KVTilePartLoader( KVCacheList const& cacheList, uint32_t idxReq, CUtensorMap const& tensorMap #if USE_PAGED_KV_CACHE , uint32_t nbPages #endif ); // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache template __device__ void loadData(Array2D& dst, uint32_t idxTile, uint32_t idxElemBeg, CtaBarrier& bar, uint32_t idxPageBuf); __device__ void loadPages(uint32_t idxTile, uint32_t idxPageBuf); }; __device__ inline KVTilePartLoader::KVTilePartLoader( KVCacheList const& cacheList, uint32_t idxReq, CUtensorMap const& tensorMap #if USE_PAGED_KV_CACHE , uint32_t nbPages #endif ) : cacheList{cacheList} , idxReq{idxReq} , tensorMap{tensorMap} #if USE_PAGED_KV_CACHE , nbPages{nbPages} , baseOffset{((idxReq * beamWidth) * 2) * cacheList.maxNbPagesPerSeq} #else , baseOffset{(idxReq * beamWidth) * 2} #endif { } // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache template __device__ inline void KVTilePartLoader::loadData(Array2D& dst, uint32_t idxTile, uint32_t idxElemBeg, CtaBarrier& bar, uint32_t idxPageBuf) { static_assert(nbTokens == tokensPerTile); #if USE_PAGED_KV_CACHE assert(idxTile == idxTileRef); auto const& pages = pageBuffers[idxPageBuf]; if constexpr (nbTokens < tokensPerPage) { assert(nbPagesPerTile == 1); uint32_t const offset = nbTokens * (idxTile % exactDiv(tokensPerPage, nbTokens)); if (warpElectSync()) { tma::loadAsync(&dst, tensorMap, DimsLE<4>{idxElemBeg, offset, idxHeadGrp, (uint32_t) pages[0]}, bar); } } else { #pragma unroll for (uint32_t i = 0; i < nbPagesPerTile; i++) { if (warpElectSync()) { tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, DimsLE<4>{idxElemBeg, 0, idxHeadGrp, (uint32_t) pages[i]}, bar); } } } #else if (warpElectSync()) { tma::loadAsync(&dst, tensorMap, DimsLE<4>{idxElemBeg, nbTokens * idxTile, idxHeadGrp, baseOffset}, bar); } #endif } __device__ inline void KVTilePartLoader::loadPages(uint32_t idxTile, uint32_t idxPageBuf) { #if USE_PAGED_KV_CACHE uint32_t const idxPageBeg = tokensPerTile >= tokensPerPage ? nbPagesPerTile * idxTile : idxTile / exactDiv(tokensPerPage, tokensPerTile); auto& pages = pageBuffers[idxPageBuf]; #pragma unroll for (uint32_t i = 0; i < nbPagesPerTile; i++) { uint32_t const idxPage = idxPageBeg + i; pages[i] = idxPage < nbPages ? cacheList.kvCachePageList[baseOffset + idxPage] : kBAD_PAGE_INDEX; } idxTileRef = idxTile; #endif } using Mat16x32 = Vec; template class Mat16x32Loader { public: using Src = Array2D; // default r and c are for mat A. __device__ inline Mat16x32Loader( Src const& src, uint32_t baseRow, uint32_t idxInstK, uint32_t r = laneId() % 16, uint32_t c = laneId() / 16) : src{src} , baseRow{baseRow} , idxInstK{idxInstK} , r{r} , c{c} , basePtr{getPtrRef(0)} { static_assert((grainBytes * srcCols * qmmaShape.m) % 1024 == 0); } __device__ inline Mat16x32 load(uint32_t idxInstM) const { return ldmatrix(getPtr(idxInstM)); } template __device__ inline Vec loadWholeCol() const { uint32_t const nbInstM = exactDiv(tileM, qmmaShape.m); Vec ret; #pragma unroll for (uint32_t i = 0; i < nbInstM; i++) { ret[i] = load(i); } return ret; } __device__ inline LdGrain const* const getPtr(uint32_t idxInstM) const { return checkedVal(basePtr + idxInstM * qmmaShape.m * srcCols, getPtrRef(idxInstM)); } private: __device__ inline LdGrain const* const getPtrRef(uint32_t idxInstM) const { return &src.template at( baseRow + idxInstM * qmmaShape.m + r, idxInstK * exactDiv(qmmaShape.k, grainElems) + c); } Src const& src; uint32_t const baseRow; uint32_t const idxInstK; uint32_t const r; uint32_t const c; LdGrain const* const basePtr; }; using InstAcc = Array2D; using XBuffer = Array2D; struct CgaXBuffer { XBuffer x; Vec rowSum; }; struct PingPongMutex { using ShmStorage = CtaBarrier[2]; ShmStorage& barriers; uint32_t const idxGrp; static __device__ inline void initStorage(ShmStorage& barriers, uint32_t thrdsPerGrp) { new (&barriers[0]) CtaBarrier(thrdsPerGrp); new (&barriers[1]) CtaBarrier(thrdsPerGrp); barriers[0].arrive(thrdsPerGrp); } __device__ inline PingPongMutex(ShmStorage& shmStorage, uint32_t idxGrp) : barriers{shmStorage} , idxGrp{idxGrp} { } __device__ inline void lock(uint32_t iter) { barriers[idxGrp].wait_parity(toParity<1>(iter)); } __device__ inline void unlock() { barriers[idxGrp ^ 1U].arrive(); } }; struct PartialResult { static constexpr uint32_t nbChunks = 4; static constexpr uint32_t nbRowsPerChunk = exactDiv(headGrpSize, nbChunks); struct Chunk { Vec data; Vec rowSum; Vec rowMaxLog2e; }; Chunk chunks[nbChunks]; }; constexpr uint32_t nbMathWarpsA = 8; constexpr uint32_t nbComputeWarpsB = 8; constexpr uint32_t nbMathGrpsA = 2; constexpr uint32_t nbMathWarpsB = 8; constexpr uint32_t nbMultiBlockBufs = 2; constexpr uint32_t multiBlockMathWarps = 8; #define USE_REG_Q 1 struct SharedMemA { static inline constexpr uint32_t nbKBufs = 4; static inline constexpr uint32_t nbXBufs = 2; #if USE_REG_Q static inline constexpr uint32_t regQParts = 2; #else static inline constexpr uint32_t regQParts = 0; #endif static inline constexpr uint32_t shmQParts = nbQParts - regQParts; using ShmQPart = Array2D; using ShmKPart = Array2D; Vec q; ShmKPart k[nbKBufs]; XBuffer x[nbXBufs]; Vec rowSum[nbXBufs]; Vec drain; // data does not matter. Used to help avoid fence. // scaled by log2e. Write by last CGA iteration (from the other producer CTA) and read by current producer CTA. Vec rowMaxLog2e; // sync rowMaxLog2e between two producer CTAs and .consumed means the buffer for next iteration (in next producer) // is ready. The 4 groups from 2 producers CTAs form a ring CgaBarrier rowMaxLog2eBar[nbMathGrpsA]; PingPongMutex::ShmStorage tensorCoreMutex; CtaBarrierPair kBars[nbKBufs]; CtaBarrierPair xBars[nbXBufs]; #if USE_REG_Q static constexpr uint32_t nbRegQBars = 2; CtaBarrierPair regQBars[nbRegQBars]; #endif CtaBarrier shmQBar; CgaBarrier cgaXBufConsumed; // for X PingPongMutex::ShmStorage rowMaxTransferMutex; // protect the order of rowMax transfer to consumers CgaBarrier consumerRowMaxConsumedBar; // arrive by consumer CTAs. CtaBarrierPair multiBlockBars[nbMultiBlockBufs]; __device__ inline void invalidateBarriers(uint32_t thrdIdx) { constexpr uint32_t nbBars = USE_REG_Q ? 25 : 21; #ifndef __CUDACC_RTC__ constexpr uint32_t nbBarsRef = exactDiv(offsetof(SharedMemA, qkScaleLog2e) - offsetof(SharedMemA, rowMaxLog2eBar), 8); assert(nbBars == nbBarsRef); #endif if (thrdIdx < nbBars) { reinterpret_cast(&rowMaxLog2eBar[0])[thrdIdx].~CtaBarrier(); } } __device__ inline Vec& getMultiBlockBufs() { #ifndef __CUDACC_RTC__ assert(sizeof(Vec) < offsetof(SharedMemA, rowMaxLog2eBar)); #endif return *reinterpret_cast*>(this); } float qkScaleLog2e; bool isLastSubSeq; }; struct SharedMemB { static inline constexpr uint32_t nbXVBufs = 2; static inline constexpr uint32_t nbXBufs = nbXVBufs; static inline constexpr uint32_t nbVBufs = nbXVBufs; using VBuffer = Vec, exactDiv(gemm1V, partElemsV)>; // x and v are using gemmK=128 per iteration. If we see high pressure on shared memory capacity, we can change to 64 // in the future. struct XVBuffer { XBuffer x; VBuffer v; XBuffer pad; // for output swizzling }; XVBuffer xv[nbXVBufs]; __device__ inline XBuffer& x(uint32_t idx) { return xv[idx].x; } __device__ inline VBuffer& v(uint32_t idx) { return xv[idx].v; } Vec xRowSum[nbXBufs]; static inline constexpr uint32_t nbAccRowMaxSumCopies = 2; Vec accRowMaxLog2e[nbAccRowMaxSumCopies]; Vec accRowSum[nbAccRowMaxSumCopies]; Vec xRowMaxLog2e[nbProducerCtasPerCga]; CgaBarrier xRowMaxLog2eProducedBar[nbProducerCtasPerCga]; CtaBarrierPair xBars[nbXBufs]; CtaBarrierPair vBars[nbVBufs]; CgaBarrier cgaXBufProduced[nbProducerCtasPerCga]; CtaBarrier mathWarpsBar; CtaBarrierPair multiBlockBars[nbMultiBlockBufs]; __device__ inline void invalidateBarriers(uint32_t thrdIdx) { constexpr uint32_t nbBars = 17; #ifndef __CUDACC_RTC__ constexpr uint32_t nbBarsRef = exactDiv(offsetof(SharedMemB, isLastSubSeq) - offsetof(SharedMemB, xRowMaxLog2eProducedBar), 8); assert(nbBars == nbBarsRef); #endif if (thrdIdx < nbBars) { reinterpret_cast(&xRowMaxLog2eProducedBar[0])[thrdIdx].~CtaBarrier(); } } __device__ inline Vec& getMultiBlockBufs() { #ifndef __CUDACC_RTC__ static_assert( sizeof(Vec) < offsetof(SharedMemB, xRowMaxLog2eProducedBar)); #endif return *reinterpret_cast*>(this); } bool isLastSubSeq; }; __device__ void mergePartialOutputs(uint32_t& semaphore, Vec& dst, PartialResult const* reqPartialResults, uint32_t nbSubSeq, uint32_t ctaRank, uint32_t warpRank, uint2 warpIdx, void* sharedMem); struct KernelArgs { CUtensorMap const& tensorMapQ; // MhaIOHead[nbQHeads * totalNbInputTokens] CUtensorMap const& tensorMapK; CUtensorMap const& tensorMapV; float const& qScale; OutputHead* __restrict__ const& output; // [totalNbIntputTokens][nbQHeads] KVCacheList const& cacheList; uint32_t const& batchSize; float const* __restrict__ const& kvCacheScale; // Device memory scalar. Same scale for K and V cache. Used only for int8/fp8 KV cache. Vec* __restrict__ const& cgaXBuf; // [totalNbInputTokens][maxNbSubSeq] uint32_t* __restrict__ const& semaphores; // [totalNbInputTokens] PartialResult* __restrict__ const& partialResults; // [totalNbInputTokens][maxNbSubSeq] }; struct Producer { static inline constexpr uint32_t nbMathGrps = nbMathGrpsA; static inline constexpr uint32_t nbMathWarps = nbMathWarpsA; static inline constexpr uint32_t nbMathThrds = nbMathWarps * warp_size; static inline constexpr uint32_t warpsPerGrp = exactDiv(nbMathWarps, nbMathGrps); static inline constexpr uint32_t thrdsPerGrp = warpsPerGrp * warp_size; static inline constexpr uint2 warpTile = {tokensPerTile, exactDiv(headGrpSize, warpsPerGrp)}; using WarpAcc = WarpAccT; using ThrdRegRowMax = ThrdRegRowMaxT; using QuadRegRowMax = QuadRegRowMaxT; KernelArgs const& args; SharedMemA& smem; uint32_t const maxNbSubSeq; uint32_t const idxReq; uint32_t const idxInputTokenGlobal; uint32_t const nbSubSeq; uint32_t const idxSubSeq; uint32_t const seqLen; uint32_t const ctaRank; uint32_t const warpRank; uint2 const warpIdx; __device__ inline Producer(KernelArgs const& args, SharedMemA& smem, uint32_t const maxNbSubSeq, uint32_t const idxReq, uint32_t idxInputTokenGlobal, uint32_t const seqLen, uint32_t const nbSubSeq, uint32_t const idxSubSeq, uint32_t ctaRank, uint32_t const warpRank, uint2 const warpIdx) : args(args) , smem(smem) , maxNbSubSeq(maxNbSubSeq) , idxReq(idxReq) , idxInputTokenGlobal(idxInputTokenGlobal) , seqLen(seqLen) , nbSubSeq(nbSubSeq) , idxSubSeq(idxSubSeq) , ctaRank(ctaRank) , warpRank(warpRank) , warpIdx(warpIdx) { #ifndef NDEBUG if (threadIdx.x == 0) { asm("st.bulk.weak [%0], %1, 0;\n" ::"l"(&smem), "n"(sizeof(SharedMemA)) : "memory"); } __syncthreads(); #endif if (threadIdx.x == 0) { smem.qkScaleLog2e = args.qScale * args.kvCacheScale[0] * log2e; } if (threadIdx.x < headGrpSize) { smem.rowMaxLog2e[threadIdx.x] = safeInitRowMax; } if (warpElectSync()) { if (warpRank < SharedMemA::nbKBufs) { auto& b = smem.kBars[warpRank]; b.initialize(1, thrdsPerGrp); b.consumed.arrive(thrdsPerGrp); } if (warpRank < SharedMemA::nbXBufs) { auto& b = smem.xBars[warpRank]; b.initialize(thrdsPerGrp, 1); b.consumed.arrive(1); } #if USE_REG_Q if (warpRank < SharedMemA::nbRegQBars) { auto& b = smem.regQBars[warpRank]; b.initialize(1, nbMathThrds); b.consumed.arrive(nbMathThrds); } #endif if (warpRank < nbMathGrpsA) { auto& b = smem.rowMaxLog2eBar[warpRank]; init(&b, thrdsPerGrp); } if (ctaRank == 0 && warpRank == 0) { smem.rowMaxLog2eBar[0].arrive(thrdsPerGrp); } if (warpRank == 0) { init(&smem.shmQBar, 1); init(&smem.cgaXBufConsumed, 1 * nbVSplit); smem.cgaXBufConsumed.arrive(1 * nbVSplit); PingPongMutex::initStorage(smem.tensorCoreMutex, thrdsPerGrp); PingPongMutex::initStorage(smem.rowMaxTransferMutex, thrdsPerGrp); init(&smem.consumerRowMaxConsumedBar, warp_size * nbComputeWarpsB * nbVSplit); smem.consumerRowMaxConsumedBar.arrive( warp_size * nbComputeWarpsB * nbVSplit); } if (nbSubSeq > 1 && warpRank < nbMultiBlockBufs) { auto& b = smem.multiBlockBars[warpRank]; b.initialize(1, warp_size * multiBlockMathWarps); b.consumed.arrive(warp_size * multiBlockMathWarps); } } clusterBarArrive(); clusterBarWait(); } __device__ inline ~Producer() { clusterBarArrive(); clusterBarWait(); smem.invalidateBarriers(threadIdx.x); } __device__ inline void run() { if (warpIdx.y == 2) { // IO warps asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" ::"n"(nbRegsForIOWarps)); if (warpIdx.x == 0) { // q loadQ(); } else if (warpIdx.x == 1) { // k loadK(); } else if (warpIdx.x == 2) { // x sendX(); } } else { // Compute warps asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" ::"n"(nbRegsForMathWarps)); compute(); } if (nbSubSeq > 1) { mergePartialOutputs(args.semaphores[idxInputTokenGlobal], reinterpret_cast&>( args.output[headGrpSize * idxInputTokenGlobal + PartialResult::nbRowsPerChunk * ctaRank]), args.partialResults + maxNbSubSeq * idxInputTokenGlobal, nbSubSeq, ctaRank, warpRank, warpIdx, &smem); } } private: __device__ inline uint32_t iterStride() const { return nbSubSeq * nbProducerCtasPerCga; } __device__ inline uint32_t idxTileBeg() const { return nbProducerCtasPerCga * idxSubSeq + ctaRank; } __device__ inline uint32_t nbTiles() const { return divUp(seqLen, tokensPerTile); } __device__ inline SharedMemB& getConsumerShm(uint32_t const idxConsumer) { return *mapa(reinterpret_cast(&smem), nbProducerCtasPerCga + idxConsumer); }; __device__ inline void loadQ() { #if USE_REG_Q #pragma unroll 1 for (uint32_t i = 0; i < SharedMemA::regQParts; i++) { uint32_t const idxBuf = i % SharedMemA::nbRegQBars; auto& bar = smem.regQBars[idxBuf]; bar.consumed.wait_parity(toParity(i)); if (warpElectSync()) { tma::loadAsync(&smem.q[idxBuf], args.tensorMapQ, DimsLE<2>{partElemsK * i, headGrpSize * idxInputTokenGlobal}, bar.produced); bar.produced.arrive_tx(sizeof(SharedMemA::ShmQPart)); } } #endif #pragma unroll 1 for (uint32_t i = 0; i < SharedMemA::shmQParts; i++) { uint32_t const idxPart = SharedMemA::regQParts + i; #if USE_REG_Q if (i < SharedMemA::nbRegQBars) { static_assert(SharedMemA::regQParts % SharedMemA::nbRegQBars == 0); uint32_t const idxBuf = idxPart % SharedMemA::nbRegQBars; assert(idxBuf == i); auto& bar = smem.regQBars[idxBuf]; bar.consumed.wait_parity(toParity(idxPart)); } #endif if (warpElectSync()) { tma::loadAsync(&smem.q[i], args.tensorMapQ, DimsLE<2>{partElemsK * idxPart, headGrpSize * idxInputTokenGlobal}, smem.shmQBar); } } if (warpElectSync()) { smem.shmQBar.arrive_tx(sizeof(SharedMemA::ShmQPart) * SharedMemA::shmQParts); } } __device__ inline void loadK(); __device__ inline void sendX(); __device__ inline void compute() { class KBarWaiter { public: __device__ inline KBarWaiter(SharedMemA& smem, uint32_t ctaIter, uint32_t const idxPartInit) : smem{smem} , idxPartGlobalNext{nbKParts * ctaIter + idxPartInit} , idxBufNext{idxPartGlobalNext % SharedMemA::nbKBufs} { testWaitNext(); } __device__ inline void testWaitNext() { #if 0 skipKBarWaitNext = smem.kBars[idxBufNext].produced.test_wait_parity(toParity(idxPartGlobalNext)); #else skipKBarWaitNext = false; #endif } __device__ inline void wait() { if (!skipKBarWait) { getKBar().produced.wait_parity(toParity(idxPartGlobal)); } } __device__ inline bool next() { idxPartGlobal = idxPartGlobalNext; idxBuf = idxBufNext; idxPartGlobalNext = idxPartGlobal + 1; idxBufNext = idxPartGlobalNext % SharedMemA::nbKBufs; skipKBarWait = skipKBarWaitNext; return skipKBarWait; } __device__ inline void arrive() { getKBar().consumed.arrive(); } __device__ inline SharedMemA::ShmKPart& getK() { return smem.k[idxBuf]; } private: __device__ inline CtaBarrierPair& getKBar() { return smem.kBars[idxBuf]; } __device__ inline CtaBarrierPair& getKBarNext() { return smem.kBars[idxBufNext]; } private: SharedMemA& smem; uint32_t idxPartGlobal; uint32_t idxBuf; bool skipKBarWait; uint32_t idxPartGlobalNext; uint32_t idxBufNext; bool skipKBarWaitNext; }; uint32_t const grpIdx = warpIdx.y; uint32_t const tileBaseRow = warpTile.y * warpIdx.x; PingPongMutex tensorCoreMutex{smem.tensorCoreMutex, grpIdx}; PingPongMutex rowMaxTransferMutex{smem.rowMaxTransferMutex, grpIdx}; using AtomA = Vec; // for 16x32 data, working as mat A of QMMA.16832 using RegQPartCol = Vec; using RegQPart = Vec; using RegQ = Vec; uint32_t const lane = laneId(); uint32_t const rA = lane % 16; uint32_t const cA = lane / 16; uint32_t const rB = (lane / 16) * 8 + lane % 8; uint32_t const cB = (lane % 16) / 8; #if USE_REG_Q // load regQ RegQ regQ; #pragma unroll for (uint32_t idxPart = 0; idxPart < SharedMemA::regQParts; idxPart++) { uint32_t const idxBuf = idxPart % SharedMemA::nbRegQBars; auto& bar = smem.regQBars[idxBuf]; bar.produced.wait_parity(toParity(idxPart)); #pragma unroll for (uint32_t j = 0; j < RegQPart::size; j++) { Mat16x32Loader const loader(smem.q[idxBuf], tileBaseRow, j, rA, cA); regQ[idxPart][j] = loader.loadWholeCol(); } bar.consumed.arrive(); } #endif smem.shmQBar.wait_parity(false); // main loop #pragma unroll 1 for (uint32_t grpIter = 0; true; grpIter++) { uint32_t const ctaIter = grpIdx + grpIter * nbMathGrps; uint32_t const idxTile = idxTileBeg() + iterStride() * ctaIter; if (idxTile >= nbTiles()) { break; } WarpAcc acc{}; using AtomBx2 = Vec; // one AtomB is 8x32 and AtomBx2 is 16x32 // wait until it's our turn tensorCoreMutex.lock(grpIter); KBarWaiter kBarWaiter{smem, ctaIter, 0}; #if USE_REG_Q #pragma unroll for (uint32_t idxPart = 0; idxPart < SharedMemA::regQParts; idxPart++) { kBarWaiter.next(); kBarWaiter.wait(); #pragma unroll for (uint32_t idxInstK = 0; idxInstK < exactDiv(partElemsK, qmmaShape.k); idxInstK++) { Mat16x32Loader const loaderK(kBarWaiter.getK(), 0, idxInstK, rB, cB); #pragma unroll for (uint32_t idxAtomBx2 = 0; idxAtomBx2 < exactDiv(tokensPerTile, qmmaShape.n * 2); idxAtomBx2++) { AtomBx2 const atomBx2 = loaderK.load(idxAtomBx2); if (idxInstK == exactDiv(partElemsK, qmmaShape.k) - 1 && idxAtomBx2 == exactDiv(tokensPerTile, qmmaShape.n * 2) - 2 && idxPart != nbQParts - 1) { kBarWaiter.testWaitNext(); } #pragma unroll for (uint32_t i = 0; i < WarpAcc::rows; i++) { #pragma unroll for (uint32_t j = 0; j < 2; j++) { mma<__nv_fp8_e4m3>(reinterpret_cast(acc(i, 2 * idxAtomBx2 + j)), reinterpret_cast(regQ[idxPart][idxInstK][i]), reinterpret_cast(atomBx2[2 * j])); } } } } kBarWaiter.arrive(); } #endif #pragma unroll 1 for (uint32_t idxPart = SharedMemA::regQParts; idxPart < nbQParts; idxPart++) { kBarWaiter.next(); kBarWaiter.wait(); #pragma unroll for (uint32_t idxInstK = 0; idxInstK < exactDiv(partElemsK, qmmaShape.k); idxInstK++) { Mat16x32Loader const loaderQ( smem.q[idxPart - SharedMemA::regQParts], tileBaseRow, idxInstK, rA, cA); auto const qPart = loaderQ.loadWholeCol(); Mat16x32Loader const loaderK(kBarWaiter.getK(), 0, idxInstK, rB, cB); #pragma unroll for (uint32_t idxAtomBx2 = 0; idxAtomBx2 < exactDiv(tokensPerTile, qmmaShape.n * 2); idxAtomBx2++) { AtomBx2 const atomBx2 = loaderK.load(idxAtomBx2); if (idxInstK == exactDiv(partElemsK, qmmaShape.k) - 1 && idxAtomBx2 == exactDiv(tokensPerTile, qmmaShape.n * 2) - 2 && idxPart != nbQParts - 1) { kBarWaiter.testWaitNext(); } #pragma unroll for (uint32_t i = 0; i < WarpAcc::rows; i++) { #pragma unroll for (uint32_t j = 0; j < 2; j++) { mma<__nv_fp8_e4m3>(reinterpret_cast(acc(i, 2 * idxAtomBx2 + j)), reinterpret_cast(qPart[i]), reinterpret_cast(atomBx2[2 * j])); } } } } kBarWaiter.arrive(); } tensorCoreMutex.unlock(); // let the other group to use tensor cores uint32_t const validTokens = seqLen - tokensPerTile * idxTile; if (validTokens < tokensPerTile) { applyMask(this_warp(), acc, 0, validTokens); } WarpAcc const xF32 = scaleAndSoftmax(acc, grpIdx, grpIter, tileBaseRow, rowMaxTransferMutex); // convert to fp8 WarpAcc const xF32Quant = xF32 * rcpXScale; // 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15 Array2D, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> xF8; #pragma unroll for (uint32_t i = 0; i < WarpAcc::rows; i++) { #pragma unroll for (uint32_t m = 0; m < exactDiv(qmmaShape.m, 8); m++) { #pragma unroll for (uint32_t j = 0; j < WarpAcc::cols; j += 2) { auto& dst = reinterpret_cast<__nv_fp8x2_e4m3(&)[2]>(xF8(i, j / 2)(m, 0)); dst[0] = __nv_fp8x2_e4m3(float2{xF32Quant(i, j)(m, 0), xF32Quant(i, j)(m, 1)}); dst[1] = __nv_fp8x2_e4m3(float2{xF32Quant(i, j + 1)(m, 0), xF32Quant(i, j + 1)(m, 1)}); } } } // use tensor core to compute rowSum ThrdRegRowMax const rowSum = computeRowSumFromF8 ? computeRowSumF8(this_warp(), xF8) : computeRowSumF32(this_warp(), xF32); // store xF8 and rowSum into L2 scratch buffer uint32_t const idxXBuf = checkedVal(grpIdx, ctaIter % SharedMemA::nbXBufs); auto& xBar = smem.xBars[idxXBuf]; xBar.consumed.wait_parity(checkedVal(grpIter % 2, toParity(ctaIter))); storeRowMax(smem.rowSum[idxXBuf], rowSum, tileBaseRow, lane); storeXToShm(smem.x[idxXBuf], xF8, tileBaseRow, lane); xBar.produced.arrive(); } } __device__ inline WarpAcc scaleAndSoftmax(WarpAcc const& acc, uint32_t grpIdx, uint32_t grpIter, uint32_t tileBaseRow, PingPongMutex& rowMaxTransferMutex); __device__ inline void storeXToShm(XBuffer& dst, Array2D, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> const& src, uint32_t const tileBaseRow, uint32_t const lane = laneId()); }; __device__ inline void Producer::loadK() { KVTilePartLoader loader { args.cacheList, idxReq, args.tensorMapK #if USE_PAGED_KV_CACHE , divUp(seqLen, tokensPerPage) #endif }; for (uint32_t iter = 0; true; iter++) { uint32_t const idxTile = idxTileBeg() + iterStride() * iter; if (idxTile >= nbTiles()) { break; } uint32_t const idxPageBuf = iter % KVTilePartLoader::nbPageBuffers; loader.loadPages(idxTile, idxPageBuf); for (uint32_t idxPart = 0; idxPart < nbKParts; idxPart++) { uint32_t const idxPartGlobal = iter * nbKParts + idxPart; uint32_t const idxBuf = idxPartGlobal % SharedMemA::nbKBufs; auto& bar = smem.kBars[idxBuf]; bar.consumed.wait_parity(toParity(idxPartGlobal)); loader.loadData(smem.k[idxBuf], idxTile, partElemsK * idxPart, bar.produced, idxPageBuf); if (warpElectSync()) { bar.produced.arrive_tx(sizeof(SharedMemA::ShmKPart)); } } } } __device__ inline void Producer::sendX() { for (uint32_t iter = 0; true; iter++) { uint32_t const idxTile = idxTileBeg() + iterStride() * iter; if (idxTile >= nbTiles()) { break; } uint32_t const idxBuf = iter % SharedMemA::nbXBufs; auto& xBar = smem.xBars[idxBuf]; xBar.produced.wait_parity(toParity(iter)); smem.cgaXBufConsumed.wait_parity(toParity<1>(iter)); if (warpElectSync()) { auto& dst = args.cgaXBuf[nbSubSeq * idxInputTokenGlobal + idxSubSeq][ctaRank]; tma::store1DAsync(&dst.x, &smem.x[idxBuf], sizeof(XBuffer)); tma::store1DAsync(&dst.rowSum, &smem.rowSum[idxBuf], sizeof(smem.rowSum[0])); tma::commitGroup(); tma::waitGroup<0>(); xBar.consumed.arrive(); asm volatile("fence.release.cluster;\n"); #pragma unroll for (uint32_t i = 0; i < nbVSplit; i++) { auto& producedBar = getConsumerShm(i).cgaXBufProduced[ctaRank]; producedBar.arrive(); } } } } __device__ inline Producer::WarpAcc Producer::scaleAndSoftmax( WarpAcc const& acc, uint32_t grpIdx, uint32_t grpIter, uint32_t tileBaseRow, PingPongMutex& rowMaxTransferMutex) { uint32_t const ctaIter = grpIdx + grpIter * nbMathGrps; uint32_t const cgaIter = ctaRank + ctaIter * nbProducerCtasPerCga; auto const warp = this_warp(); uint32_t const lane = laneId(); uint32_t const idxProducer = ctaRank; assert(ctaRank < nbProducerCtasPerCga); auto const accLog2e = acc * smem.qkScaleLog2e; bool const skipWaitLastShmRowMax = smem.rowMaxLog2eBar[grpIdx].test_wait_parity(toParity<1>(grpIter)); QuadRegRowMax const tileRowMaxLog2e = computeRowMax(accLog2e); // get max with previous CTA's rowMax if (!skipWaitLastShmRowMax) { smem.rowMaxLog2eBar[grpIdx].wait_parity(toParity<1>(grpIter)); } auto const lastRowMaxLog2e = loadShmRowMax(smem.rowMaxLog2e, tileBaseRow, lane); auto const quadRowMaxLog2e = fmaxf(tileRowMaxLog2e, replicateForQuad(warp, lastRowMaxLog2e)); // transfer new row max to the other producer CTA for next iteration SharedMemA& smemNext = mapa(smem, ctaRank ^ 1U); CgaBarrier& nextRowMaxLog2eBar = smemNext.rowMaxLog2eBar[(cgaIter + 1) % (nbMathGrps * nbProducerCtasPerCga) / nbMathGrps]; ThrdRegRowMax const rowMaxLog2e = dedupFromQuad(warp, quadRowMaxLog2e); storeRowMaxAsync(nextRowMaxLog2eBar, smemNext.rowMaxLog2e, rowMaxLog2e, tileBaseRow, lane); nextRowMaxLog2eBar.arrive_tx_relaxed(sizeof(rowMaxLog2e)); // notify that the next CTA can read rowMax now. // transfer rowMax to consumers. rowMaxTransferMutex.lock(grpIter); // @fixme: use test_wait_parity() early to avoid latency. smem.consumerRowMaxConsumedBar.wait_parity(checkedVal(grpIdx, toParity<1>(ctaIter))); for (uint32_t idxConsumer = 0; idxConsumer < nbVSplit; idxConsumer++) { auto& smemB = getConsumerShm(idxConsumer); storeRowMaxAsync(smemB.xRowMaxLog2eProducedBar[idxProducer], smemB.xRowMaxLog2e[idxProducer], rowMaxLog2e, tileBaseRow, lane); smemB.xRowMaxLog2eProducedBar[idxProducer].arrive_tx_relaxed(sizeof(rowMaxLog2e)); } rowMaxTransferMutex.unlock(); WarpAcc x; // apply softmax #pragma unroll for (uint32_t m = 0; m < acc.rows; m++) { #pragma unroll for (uint32_t i = 0; i < InstAcc::rows; i++) { float const maxVal = quadRowMaxLog2e[m * InstAcc::rows + i]; #pragma unroll for (uint32_t n = 0; n < acc.cols; n++) { #pragma unroll for (uint32_t j = 0; j < InstAcc::cols; j++) { float elem = accLog2e(m, n)(i, j); assert(maxVal >= elem); x(m, n)(i, j) = exp2f(elem - maxVal); } } } } return x; } __device__ inline void Producer::storeXToShm(XBuffer& dst, Array2D, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> const& src, uint32_t const tileBaseRow, uint32_t const lane) { uint32_t const r = lane % 16; uint32_t const c = lane / 16; #pragma unroll for (uint32_t idxInstK = 0; idxInstK < exactDiv(src.cols, 2); idxInstK++) { Mat16x32Loader const loader(dst, tileBaseRow, idxInstK, r, c); #pragma unroll for (uint32_t idxInstM = 0; idxInstM < src.rows; idxInstM++) { stmatrix(const_cast(loader.getPtr(idxInstM)), reinterpret_cast(src(idxInstM, idxInstK * 2))); } } } struct Consumer { static inline constexpr uint32_t nbMathWarps = nbMathWarpsB; static inline constexpr uint32_t nbMathThrds = warp_size * nbMathWarps; static inline constexpr uint2 ctaShape = {2, 4}; static_assert(SharedMemB::nbAccRowMaxSumCopies == ctaShape.x); static_assert(ctaShape.x * ctaShape.y == nbMathWarps); static inline constexpr uint2 warpTile = {exactDiv(gemm1V, ctaShape.x), exactDiv(headGrpSize, ctaShape.y)}; static inline constexpr uint32_t nbWarpOutSwizzleBuf = nbMathWarps; using WarpOutSwizzleBuf = Array2D; static_assert(WarpOutSwizzleBuf::rows % 8 == 0); using WarpAcc = WarpAccT; using ThrdRegRowMax = ThrdRegRowMaxT; using UniformNeedRescaleMask = Vec; KernelArgs const& args; SharedMemB& smem; uint32_t const maxNbSubSeq; uint32_t const idxReq; uint32_t const idxInputTokenGlobal; uint32_t const nbSubSeq; uint32_t const idxSubSeq; uint32_t const seqLen; uint32_t const ctaRank; uint32_t const warpRank; uint2 const warpIdx; __device__ inline uint32_t iterStride() const { return nbSubSeq * nbProducerCtasPerCga; } __device__ inline uint32_t idxTileBeg() const { return nbProducerCtasPerCga * idxSubSeq; } __device__ inline uint32_t nbTiles() const { return divUp(seqLen, tokensPerTile); } __device__ inline uint32_t idxConsumer() const { return ctaRank - 2; } __device__ inline Consumer(KernelArgs const& args, SharedMemB& smem, uint32_t const maxNbSubSeq, uint32_t const idxReq, uint32_t const idxInputTokenGlobal, uint32_t const seqLen, uint32_t const nbSubSeq, uint32_t const idxSubSeq, uint32_t ctaRank, uint32_t const warpRank, uint2 const warpIdx) : args(args) , smem(smem) , maxNbSubSeq(maxNbSubSeq) , idxReq(idxReq) , idxInputTokenGlobal(idxInputTokenGlobal) , seqLen(seqLen) , nbSubSeq(nbSubSeq) , idxSubSeq(idxSubSeq) , ctaRank(ctaRank) , warpRank(warpRank) , warpIdx(warpIdx) { #ifndef NDEBUG if (threadIdx.x == 0) { asm("st.bulk.weak [%0], %1, 0;\n" ::"l"(&smem), "n"(sizeof(SharedMemB)) : "memory"); } __syncthreads(); #endif if (threadIdx.x < headGrpSize) { for (uint32_t i = 0; i < SharedMemB::nbAccRowMaxSumCopies; i++) { smem.accRowMaxLog2e[i][threadIdx.x] = safeInitRowMax; smem.accRowSum[i][threadIdx.x] = 0; } } if (warpElectSync()) { if (warpRank < nbProducerCtasPerCga) { init(&smem.xRowMaxLog2eProducedBar[warpRank], Producer::thrdsPerGrp); init(&smem.cgaXBufProduced[warpRank], 1); } if (warpRank < SharedMemB::nbXBufs) { auto& bar = smem.xBars[warpRank]; bar.initialize(1, nbMathThrds); bar.consumed.arrive(nbMathThrds); } if (warpRank < SharedMemB::nbVBufs) { auto& bar = smem.vBars[warpRank]; bar.initialize(1, nbMathThrds); bar.consumed.arrive(nbMathThrds); } if (warpRank == 0) { init(&smem.mathWarpsBar, warp_size * nbMathWarps); } if (nbSubSeq > 1 && warpRank < nbMultiBlockBufs) { auto& b = smem.multiBlockBars[warpRank]; b.initialize(1, warp_size * multiBlockMathWarps); b.consumed.arrive(warp_size * multiBlockMathWarps); } } clusterBarArrive(); clusterBarWait(); } __device__ inline ~Consumer() { clusterBarArrive(); clusterBarWait(); smem.invalidateBarriers(threadIdx.x); } __device__ inline void run() { if (warpIdx.y == 2) { asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" ::"n"(nbRegsForIOWarps)); if (warpIdx.x == 0) { loadX(); } else if (warpIdx.x == 1) { loadV(); } } else { asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" ::"n"(nbRegsForMathWarps)); compute(); } if (nbSubSeq > 1) { mergePartialOutputs(args.semaphores[idxInputTokenGlobal], reinterpret_cast&>( args.output[headGrpSize * idxInputTokenGlobal + PartialResult::nbRowsPerChunk * ctaRank]), args.partialResults + maxNbSubSeq * idxInputTokenGlobal, nbSubSeq, ctaRank, warpRank, warpIdx, &smem); } } __device__ inline void loadX(); __device__ inline void loadV(); __device__ inline void compute(); __device__ inline uint32_t iterToTile(uint32_t iter) const { return idxTileBeg() + iterStride() * (iter / 2) + iter % 2; } __device__ inline SharedMemA& getProducerShm(uint32_t idxProducer) const { return mapa(reinterpret_cast(smem), idxProducer); } using WarpOutputTile = Array2D; __device__ inline WarpOutputTile finalize( WarpAcc const& acc, ThrdRegRowMax const& accRowSum, float xvScale, uint32_t lane = laneId()); __device__ inline void storeOutput(Vec& dst, uint32_t dstBaseCol, WarpOutputTile const& regTile, WarpOutSwizzleBuf& swizzleBuf, uint32_t lane = laneId()); }; __device__ inline void Consumer::compute() { uint2 const tileIdx = {warpIdx.y, warpIdx.x}; uint2 const tileBase = {tileIdx.x * warpTile.x, tileIdx.y * warpTile.y}; uint32_t const lane = laneId(); uint32_t const idxHalf = lane / 16; uint32_t const laneInHalf = lane % 16; uint32_t const rA = laneInHalf; uint32_t const cA = idxHalf; uint32_t const rB = idxHalf * 16 + laneInHalf / 4 * 2 + laneInHalf % 4 / 2 * 8 + lane % 2; uint32_t const cB = 0; WarpAcc acc{}; uint32_t idxXVBufLast; for (uint32_t iter = 0; true; iter++) { uint32_t const idxTile = iterToTile(iter); if (idxTile >= nbTiles()) { break; } ThrdRegRowMax accRowMaxLog2e = loadShmRowMax(smem.accRowMaxLog2e[tileIdx.x], tileBase.y, lane); ThrdRegRowMax accRowSum = loadShmRowMax(smem.accRowSum[tileIdx.x], tileBase.y, lane); uint32_t const idxProducer = iter % nbProducerCtasPerCga; smem.xRowMaxLog2eProducedBar[idxProducer].wait_parity(toParity(iter)); ThrdRegRowMax const xRowMaxLog2e = loadShmRowMax(smem.xRowMaxLog2e[idxProducer], tileBase.y, lane); auto& prodSmem = getProducerShm(idxProducer); uint32_t const drainData = hashRegData(xRowMaxLog2e); tma::storeAsync(&prodSmem.drain[lane], drainData, prodSmem.consumerRowMaxConsumedBar); prodSmem.consumerRowMaxConsumedBar.template arrive_tx(sizeof(drainData)); assert(all(accRowMaxLog2e <= xRowMaxLog2e)); auto const needRescaleVec = (xRowMaxLog2e > accRowMaxLog2e); UniformNeedRescaleMask rescaleMask; #pragma unroll for (uint32_t i = 0; i < rescaleMask.size; i++) { rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]); } bool const anyNeedRescale = any(rescaleMask != UniformNeedRescaleMask::filled(0)); if (anyNeedRescale) { auto const scaleVec = exp2f(accRowMaxLog2e - xRowMaxLog2e); #pragma unroll for (uint32_t m = 0; m < WarpAcc::rows; m++) { #pragma unroll for (uint32_t i = 0; i < InstAcc::rows; i++) { uint8_t const mask = reinterpret_cast(rescaleMask[m / 2])[m % 2][i]; bool const needRescale = (mask != 0); if (needRescale) { // this branch is warp-uniform float const scale = __shfl_sync(~0U, scaleVec[m / 2], 16 * (m % 2) + 8 * i + lane / 4); #pragma unroll for (uint32_t n = 0; n < WarpAcc::cols; n++) { #pragma unroll for (uint32_t j = 0; j < InstAcc::cols; j++) { acc(m, n)(i, j) *= scale; } } } } } accRowSum = accRowSum * scaleVec; } accRowMaxLog2e = xRowMaxLog2e; storeRowMax(smem.accRowMaxLog2e[tileIdx.x], accRowMaxLog2e, tileBase.y, lane); uint32_t const idxXBuf = iter % SharedMemB::nbXBufs; uint32_t const idxVBuf = iter % SharedMemB::nbVBufs; auto& xBar = smem.xBars[idxXBuf]; auto& vBar = smem.vBars[idxVBuf]; // @fixme: merge these two barriers and use test_wait_parity() early to avoid latency. vBar.produced.wait_parity(toParity(iter)); xBar.produced.wait_parity(toParity(iter)); auto const& xBuf = smem.x(idxXBuf); auto const& vBuf = smem.v(idxVBuf)[tileIdx.x]; auto const xRowSum = loadShmRowMax(smem.xRowSum[idxXBuf], tileBase.y, lane); accRowSum = accRowSum + xRowSum; storeRowMax(smem.accRowSum[tileIdx.x], accRowSum, tileBase.y, lane); #pragma unroll for (uint32_t idxInstK = 0; idxInstK < exactDiv(tokensPerTile, qmmaShape.k); idxInstK++) { Mat16x32Loader const loaderX(xBuf, tileBase.y, idxInstK, rA, cA); Vec const x = loaderX.loadWholeCol(); using AtomB = Vec; #pragma unroll for (uint32_t idxAtomBx2 = 0; idxAtomBx2 < exactDiv(warpTile.x, qmmaShape.n * 2); idxAtomBx2++) { auto const data = ldmatrix_16x16_trans<2>(&vBuf.template at(qmmaShape.k * idxInstK + rB, idxAtomBx2 + cB)); AtomB const v[2] = {data[0], data[2], data[1], data[3]}; #pragma unroll for (uint32_t i = 0; i < WarpAcc::rows; i++) { #pragma unroll for (uint32_t j = 0; j < 2; j++) { #if 1 mma<__nv_fp8_e4m3>( #else mmaF8_k32_2inst( #endif reinterpret_cast(acc(i, 2 * idxAtomBx2 + j)), reinterpret_cast(x[i]), reinterpret_cast(v[j])); } } } } bool const isLastIter = (iterToTile(iter + 1) >= nbTiles()); if (isLastIter) { idxXVBufLast = idxXBuf; assert(idxXBuf == idxVBuf); } else { xBar.consumed.arrive(); vBar.consumed.arrive(); } } smem.mathWarpsBar.arrive(); ThrdRegRowMax const accRowSum = loadShmRowMax(smem.accRowSum[tileIdx.x], tileBase.y, lane); float const xvScale = computeRowSumFromF8 ? args.kvCacheScale[0] : args.kvCacheScale[0] * xScale; WarpOutputTile const output = finalize(acc, accRowSum, xvScale, lane); bool const isMultiBlockMode = (nbSubSeq != 1); static_assert(PartialResult::nbRowsPerChunk == warpTile.y); auto& dst = isMultiBlockMode ? args.partialResults[maxNbSubSeq * idxInputTokenGlobal + idxSubSeq].chunks[tileIdx.y].data : reinterpret_cast&>(args.output[headGrpSize * idxInputTokenGlobal + tileBase.y]); assert(warpRank < nbMathWarps); WarpOutSwizzleBuf& swizzleBuf = reinterpret_cast&>(smem.xv[idxXVBufLast])[warpRank]; // make sure all math warps have finished using XVBuffer. smem.mathWarpsBar.wait_parity(false); storeOutput(dst, gemm1V * idxConsumer() + tileBase.x, output, swizzleBuf, lane); if (isMultiBlockMode && tileIdx.x == 0) { ThrdRegRowMax const accRowMaxLog2e = loadShmRowMax(smem.accRowMaxLog2e[tileIdx.x], tileBase.y, lane); auto& chunk = args.partialResults[maxNbSubSeq * idxInputTokenGlobal + idxSubSeq].chunks[tileIdx.y]; #pragma unroll for (uint32_t i = 0; i < ThrdRegRowMax::size; i++) { chunk.rowMaxLog2e[warp_size * i + lane] = accRowMaxLog2e[i]; chunk.rowSum[warp_size * i + lane] = accRowSum[i]; } } smem.xBars[idxXVBufLast].consumed.arrive(); smem.vBars[idxXVBufLast].consumed.arrive(); } __device__ inline void Consumer::loadX() { #pragma unroll 1 for (uint32_t iter = 0; true; iter++) { uint32_t const idxTile = iterToTile(iter); if (idxTile >= nbTiles()) { break; } // @todo: merge these two barriers. uint32_t const idxScratchXBuf = iter % nbProducerCtasPerCga; auto& srcProducedBar = smem.cgaXBufProduced[idxScratchXBuf]; srcProducedBar.wait_parity(toParity(iter)); uint32_t const idxXBuf = iter % SharedMemB::nbXBufs; auto& xBar = smem.xBars[idxXBuf]; xBar.consumed.wait_parity(toParity(iter)); if (warpElectSync()) { auto& src = args.cgaXBuf[nbSubSeq * idxInputTokenGlobal + idxSubSeq][idxScratchXBuf]; auto& dstX = smem.x(idxXBuf); auto& dstRowSum = smem.xRowSum[idxXBuf]; tma::load1DAsync(&dstX, &src.x, sizeof(smem.x(0)), xBar.produced); tma::load1DAsync(&dstRowSum, &src.rowSum, sizeof(smem.xRowSum[0]), xBar.produced); xBar.produced.arrive_tx(sizeof(smem.x(0)) + sizeof(smem.xRowSum[0])); xBar.produced.wait_parity(toParity(iter)); uint32_t const idxProducer = idxScratchXBuf; // @fixme: check if this works. If it doesn't, randomly pick some data from dstX and dstRowSum and use // STAS + arrive_tx to avoid fence. getProducerShm(idxProducer).cgaXBufConsumed.arrive(); } } } __device__ inline void Consumer::loadV() { KVTilePartLoader loader(args.cacheList, idxReq, args.tensorMapV #if USE_PAGED_KV_CACHE , divUp(seqLen, tokensPerPage) #endif ); for (uint32_t iter = 0; true; iter++) { uint32_t const idxTile = iterToTile(iter); if (idxTile >= nbTiles()) { break; } uint32_t const idxPageBuf = iter % KVTilePartLoader::nbPageBuffers; loader.loadPages(idxTile, idxPageBuf); uint32_t const idxVBuf = iter % SharedMemB::nbVBufs; auto& vBar = smem.vBars[idxVBuf]; vBar.consumed.wait_parity(toParity(iter)); #pragma unroll for (uint32_t idxPart = 0; idxPart < SharedMemB::VBuffer::size; idxPart++) { loader.loadData(smem.v(idxVBuf)[idxPart], idxTile, gemm1V * idxConsumer() + exactDiv(gemm1V, SharedMemB::VBuffer::size) * idxPart, vBar.produced, idxPageBuf); } if (warpElectSync()) { vBar.produced.arrive_tx(sizeof(SharedMemB::VBuffer)); } } } __device__ inline Array2D Consumer::finalize(WarpAcc const& acc, ThrdRegRowMax const& accRowSum, float const xvScale, uint32_t const lane) { ThrdRegRowMax const scaleVec = 1.F / (accRowSum) *xvScale; WarpOutputTile ret; #pragma unroll for (uint32_t m = 0; m < WarpAcc::rows; m++) { #pragma unroll for (uint32_t i = 0; i < InstAcc::rows; i++) { uint32_t retRow = m * InstAcc::rows + i; float const scale = __shfl_sync(~0U, scaleVec[m / 2], 16 * (m % 2) + 8 * i + lane / 4); #pragma unroll for (uint32_t n = 0; n < WarpAcc::cols; n++) { float data[InstAcc::cols]; #pragma unroll for (uint32_t j = 0; j < InstAcc::cols; j++) { data[j] = acc(m, n)(i, j) * scale; } assert(InstAcc::cols == 2); reinterpret_cast<__nv_bfloat162&>(ret(retRow, n)) = __float22bfloat162_rn(float2{data[0], data[1]}); } } } return ret; } __device__ inline void Consumer::storeOutput(Vec& dst, uint32_t dstBaseCol, WarpOutputTile const& src, WarpOutSwizzleBuf& swizzleBuf, uint32_t lane) { using Dst = mha::decay_t; static_assert(Dst::size == WarpOutputTile::rows * 8 && Dst::size % WarpOutSwizzleBuf::rows == 0); uint32_t const nbIters = exactDiv(Dst::size, WarpOutSwizzleBuf::rows); uint32_t const rS = lane % 8; uint32_t const cS = lane / 8; uint32_t const thrdsPerRow = exactDiv(sizeof(WarpOutSwizzleBuf::Elem) * WarpOutSwizzleBuf::cols, grainBytes); static_assert(thrdsPerRow <= 32); uint32_t const rL = lane / thrdsPerRow; uint32_t const cL = lane % thrdsPerRow; #pragma unroll for (uint32_t iter = 0; iter < nbIters; iter++) { #pragma unroll for (uint32_t j = 0; j < WarpOutputTile::cols; j += 4) { auto const baseSwzPtr = &swizzleBuf.template at(rS, j + cS); constexpr uint32_t srcRowsPerIter = exactDiv(WarpOutputTile::rows, nbIters); #pragma unroll for (uint32_t i = 0; i < srcRowsPerIter; i++) { static_assert(sizeof(WarpOutSwizzleBuf::Elem) * WarpOutSwizzleBuf::cols * 8 % 1024 == 0); auto const swzPtr = checkedVal( baseSwzPtr + WarpOutputTile::cols * 8 * i, &swizzleBuf.template at(8 * i + rS, j + cS)); stmatrix( swzPtr, reinterpret_cast const&>(src(srcRowsPerIter * iter + i, j))); } } __syncwarp(); uint32_t const dstRowsPerIter = WarpOutSwizzleBuf::rows; uint32_t const rowsPerOp = exactDiv(warp_size, thrdsPerRow); LdGrain* const baseDstPtr = reinterpret_cast( &dst[dstRowsPerIter * iter + rL][dstBaseCol + exactDiv(grainBytes, sizeof(OutputElem)) * cL]); #pragma unroll for (uint32_t i = 0; i < dstRowsPerIter; i += rowsPerOp) { LdGrain* const dstPtr = checkedVal(baseDstPtr + i * exactDiv(sizeof(OutputHead), grainBytes), reinterpret_cast( &dst[dstRowsPerIter * iter + i + rL][dstBaseCol + exactDiv(grainBytes, sizeof(OutputElem)) * cL])); LdGrain* const srcPtr = &swizzleBuf.template at(i + rL, cL); *dstPtr = *srcPtr; } __syncwarp(); } } __device__ inline void mergePartialOutputs(uint32_t& semaphore, Vec& dst, PartialResult const* reqPartialResults, uint32_t nbSubSeq, uint32_t ctaRank, uint32_t warpRank, uint2 warpIdx, void* sharedMem) { assert(nbSubSeq > 1); clusterBarArrive(); clusterBarWait(); bool const isProducer = (ctaRank < nbProducerCtasPerCga); bool& shmIsLastSubSeq = isProducer ? static_cast(sharedMem)->isLastSubSeq : static_cast(sharedMem)->isLastSubSeq; if (ctaRank == 3 && threadIdx.x == 0) { uint32_t old; uint32_t const lastOld = nbSubSeq - 1; asm volatile("atom.relaxed.gpu.global.inc.u32 %0, [%1], %2;\n" : "=r"(old) : "l"(&semaphore), "r"(lastOld)); bool const isLastSubSeq = (old == lastOld); #pragma unroll for (uint32_t i = 0; i < nbProducerCtasPerCga; i++) { static_cast(mapa(sharedMem, i))->isLastSubSeq = isLastSubSeq; } mapa(shmIsLastSubSeq, 2) = isLastSubSeq; shmIsLastSubSeq = isLastSubSeq; } clusterBarArrive(); clusterBarWait(); bool const isLastCga = shmIsLastSubSeq; if (!isLastCga) { return; } CtaBarrierPair(&bars)[nbMultiBlockBufs] = isProducer ? static_cast(sharedMem)->multiBlockBars : static_cast(sharedMem)->multiBlockBars; Vec& shmBufs = isProducer ? static_cast(sharedMem)->getMultiBlockBufs() : static_cast(sharedMem)->getMultiBlockBufs(); constexpr uint32_t nbShmBufs = nbMultiBlockBufs; if (warpIdx.y == 2) { asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" ::"n"(nbRegsForIOWarps)); if (warpIdx.x == 0) { #pragma unroll 1 for (uint32_t idxSubSeq = 0; idxSubSeq < nbSubSeq; idxSubSeq++) { uint32_t const idxBuf = idxSubSeq % nbShmBufs; auto& bar = bars[idxBuf]; bar.consumed.wait_parity(toParity(idxSubSeq)); if (warpElectSync()) { tma::load1DAsync(&shmBufs[idxBuf], &reqPartialResults[idxSubSeq].chunks[ctaRank], sizeof(PartialResult::Chunk), bar.produced); bar.produced.arrive_tx(sizeof(PartialResult::Chunk)); } } } } else { asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" ::"n"(nbRegsForMathWarps)); constexpr uint32_t nbMathWarps = 8; constexpr uint32_t rowsPerWarp = exactDiv(PartialResult::nbRowsPerChunk, nbMathWarps); constexpr uint32_t regGrainsPerRow = exactDiv(sizeof(OutputHead), grainBytes * warp_size); constexpr uint32_t grainOutElems = exactDiv(grainBytes, sizeof(OutputElem)); uint32_t const lane = laneId(); uint32_t const tileRowBase = rowsPerWarp * warpRank; using RowWise = Vec; using RegChunk = Array2D, rowsPerWarp, regGrainsPerRow>; auto loadBuf = [&](RowWise& rowMaxLog2e, RowWise& rowSum, RegChunk& regChunk, PartialResult::Chunk const& chunk) { auto loadRowWise = [&](Vec const& src) { return reinterpret_cast(src[tileRowBase]); }; rowMaxLog2e = loadRowWise(chunk.rowMaxLog2e); rowSum = loadRowWise(chunk.rowSum); regChunk; #pragma unroll for (uint32_t i = 0; i < rowsPerWarp; i++) { #pragma unroll for (uint32_t j = 0; j < regGrainsPerRow; j++) { regChunk(i, j) = reinterpret_cast const&>( chunk.data[tileRowBase + i][grainOutElems * (warp_size * j + lane)]); } } }; uint32_t const idxSubSeqInit = 0; uint32_t const idxBufInit = idxSubSeqInit % nbShmBufs; bars[idxBufInit].produced.wait_parity(toParity(idxSubSeqInit)); RowWise accRowMaxLog2e; RowWise accRowSum; RegChunk chunk; loadBuf(accRowMaxLog2e, accRowSum, chunk, shmBufs[idxBufInit]); bars[idxBufInit].consumed.arrive(); using Acc = Array2D, rowsPerWarp, regGrainsPerRow>; Acc acc; #pragma unroll for (uint32_t i = 0; i < rowsPerWarp; i++) { #pragma unroll for (uint32_t j = 0; j < regGrainsPerRow; j++) { acc(i, j) = convert(chunk(i, j)) * accRowSum[i]; } } #pragma unroll 1 for (uint32_t idxSubSeq = idxSubSeqInit + 1; idxSubSeq < nbSubSeq; idxSubSeq++) { uint32_t const idxBuf = idxSubSeq % nbShmBufs; auto& bar = bars[idxBuf]; bar.produced.wait_parity(toParity(idxSubSeq)); RowWise chunkRowMaxLog2e; RowWise chunkRowSum; loadBuf(chunkRowMaxLog2e, chunkRowSum, chunk, shmBufs[idxBuf]); bar.consumed.arrive(); #pragma unroll for (uint32_t i = 0; i < rowsPerWarp; i++) { bool const newChunkGreater = (chunkRowMaxLog2e[i] > accRowMaxLog2e[i]); if (newChunkGreater) { float const scale = exp2f(accRowMaxLog2e[i] - chunkRowMaxLog2e[i]); #pragma unroll for (uint32_t j = 0; j < regGrainsPerRow; j++) { acc(i, j) = acc(i, j) * scale + convert(chunk(i, j)) * chunkRowSum[i]; } accRowSum[i] = accRowSum[i] * scale + chunkRowSum[i]; accRowMaxLog2e[i] = chunkRowMaxLog2e[i]; } else { float const scale = exp2f(chunkRowMaxLog2e[i] - accRowMaxLog2e[i]); float const fusedScale = scale * chunkRowSum[i]; #pragma unroll for (uint32_t j = 0; j < regGrainsPerRow; j++) { acc(i, j) = acc(i, j) + convert(chunk(i, j)) * fusedScale; } accRowSum[i] = accRowSum[i] + chunkRowSum[i] * scale; } } } #pragma unroll for (uint32_t i = 0; i < rowsPerWarp; i++) { float const scale = 1.F / accRowSum[i]; auto const dstHead = reinterpret_cast*>(&dst[tileRowBase + i]); #pragma unroll for (uint32_t j = 0; j < regGrainsPerRow; j++) { dstHead[warp_size * j + lane] = convert(acc(i, j) * scale); } } } } inline constexpr uint32_t cgaSize = nbProducerCtasPerCga + nbVSplit; CUBIN_EXPORT __global__ __launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha( __grid_constant__ CUtensorMap const tensorMapQ, // MhaIOHead[nbQHeads * totalNbInputTokens], __grid_constant__ CUtensorMap const tensorMapK, // with box=64 for the least significant dim __grid_constant__ CUtensorMap const tensorMapV, // with box=128 for the least significant dim float const qScale, OutputHead* __restrict__ const output, // [totalNbIntputTokens][nbQHeads] KVCacheList const cacheList, uint32_t const batchSize, float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for // int8/fp8 KV cache. Vec* __restrict__ const cgaXBuf, // [totalNbInputTokens][maxNbSubSeq] uint32_t* __restrict__ const semaphores = nullptr, // [totalNbInputTokens] PartialResult* __restrict__ const partialResults = nullptr) // [totalNbInputTokens][maxNbSubSeq] { assert(blockDim.x == 32 * 12 && blockDim.y == 1 && blockDim.z == 1); extern __shared__ char smemBuf[]; uint32_t const warpRank = makeWarpUniform(this_warp(), threadIdx.x / warp_size); uint2 const warpIdx = {warpRank % 4, warpRank / 4}; uint3 const& cgaId = clusterId(); uint32_t const& idxReq = cgaId.z; uint32_t const& maxNbSubSeq = nbClusters().y; uint32_t const& idxSubSeq = cgaId.y; uint32_t const inputSeqLen = (allowMultipleInputTokens ? exactDiv(gridDim.x, cgaSize) : checkedVal(1U, exactDiv(gridDim.x, cgaSize))); uint32_t const reqIdxInputToken = (allowMultipleInputTokens ? blockIdx.x / cgaSize : checkedVal(0U, blockIdx.x / cgaSize)); uint32_t const idxInputTokenGlobal = inputSeqLen * idxReq + reqIdxInputToken; uint32_t const cacheSeqLen = cacheList.seqLenList[idxReq] - (inputSeqLen - 1) + reqIdxInputToken; assert(beamWidth == 1); uint32_t const nbTiles = useKVCache ? divUp(cacheSeqLen, tokensPerTile) : 0; bool const isMultiBlockMode = (maxNbSubSeq > 1 && nbTiles >= multiBlockMinNbTiles); uint32_t const nbSubSeq = isMultiBlockMode ? mha::min(nbTiles / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1; static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2); assert(isMultiBlockMode == (nbSubSeq > 1)); if (idxSubSeq >= nbSubSeq) { return; } uint32_t const ctaRank = clusterCtaRank(); bool const isProducer = (ctaRank < nbProducerCtasPerCga); KernelArgs const args{tensorMapQ, tensorMapK, tensorMapV, qScale, output, cacheList, batchSize, kvCacheScale, cgaXBuf, semaphores, partialResults}; if (isProducer) { Producer{args, *reinterpret_cast(smemBuf), maxNbSubSeq, idxReq, idxInputTokenGlobal, cacheSeqLen, nbSubSeq, idxSubSeq, ctaRank, warpRank, warpIdx} .run(); } else { Consumer{args, *reinterpret_cast(smemBuf), maxNbSubSeq, idxReq, idxInputTokenGlobal, cacheSeqLen, nbSubSeq, idxSubSeq, ctaRank, warpRank, warpIdx} .run(); } } __constant__ constexpr uint32_t smemSize = mha::max(sizeof(SharedMemA), sizeof(SharedMemB)); static_assert(smemSize <= 99 * 1024, "Shared memory size exceeded"); #endif // is_MLA #ifndef GENERATE_CUBIN #if IS_MLA CUtensorMap makeTensorMapForQ( void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, uint32_t totalNbHeads, uint32_t partElems) { CUtensorMap tensorMap{}; uint64_t const globalDims[] = {headElems, totalNbHeads}; uint32_t elemBytes = getElemBytes(dataType); uint32_t const headBytes = elemBytes * headElems; uint64_t const globalStrides[] = {headBytes}; uint32_t const boxDims[] = {partElems, headGrpSize}; uint32_t const elemStrides[] = {1, 1}; auto const swizzle = CU_TENSOR_MAP_SWIZZLE_64B; checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 2, const_cast(addr), globalDims, globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); return tensorMap; } #endif // IS_MLA void launchMLA(cudaDeviceProp const& prop, uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed float qScale, OutputHead* output, InputHead const* q, #if USE_PAGED_KV_CACHE GMemCacheHead* pool, // global pool of pages KVCachePageIndex const* kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] #else GMemKVCacheHead* kvCacheData, #endif uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for // int8/fp8 KV cache. uint32_t* semaphores, void* scratch, cudaStream_t stream) { #if IS_MLA static_assert( SLIDING_WINDOW == 0 && LOW_PREC_OUTPUT == 0 && USE_INPUT_KV == 0 && USE_BEAM_SEARCH == 0, "not implemented"); if (beamWidth != 1) { throw std::runtime_error("not implemented"); } static uint32_t const hostSmemSize = [&]() { // printf("smemSize = %u\n", smemSize); uint32_t size; checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); return size; }(); uint32_t const nbKHeads = 1; uint32_t const nbVHeads = nbKHeads; uint32_t const nbQHeads = nbKHeads * headGrpSize; uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads; uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { auto const env = std::getenv("XQA_NB_SUB_SEQ"); if (env != nullptr) { int32_t const val = std::stoi(env); if (val > 0) { return val; } } float const factor = 4.f; return mha::min( mha::max(1U, (uint32_t) round(prop.multiProcessorCount / 4 / (batchSize * nbKHeads) * factor)), divUp(maxSeqLen, tokensPerTile * 2)); }(); // printf("nbSubSeqPerSeq = %u\n", nbSubSeqPerSeq); // gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x == nbInputSeqSplit dim3 const dimGrid{4 * inputSeqLen, nbSubSeqPerSeq, nbKHeads * batchSize}; dim3 const dimCta{warp_size * 4 * 3, 1, 1}; auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); #if USE_PAGED_KV_CACHE uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; auto const dtype = [] { if (std::is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else if (std::is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else if (std::is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } throw std::runtime_error("unsupported cache element type"); }(); auto const tensorMapQ = makeTensorMapForQ(q, dtype, validElemsPerHead, headGrpSize * inputSeqLen * batchSize, partElemsK); auto const tensorMapK = makeTensorMapForPagedKVCache( pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile); auto const tensorMapV = makeTensorMapForPagedKVCache( pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile); uint32_t const nbCgas = exactDiv(dimGrid.x, 4) * dimGrid.y * dimGrid.z; auto const cgaXBuf = static_cast*>(scratch); auto const partialResults = reinterpret_cast(cgaXBuf + nbCgas); cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, tensorMapQ, tensorMapK, tensorMapV, qScale, output, cacheList, batchSize, kvCacheScale, cgaXBuf, semaphores, partialResults); #else KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; static_assert(!usePagedKVCache); assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); auto const tensorMap = makeTensorMapForContiguousKVCache(kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, batchSize, gemm0CtaTileNbTokens); cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, #if SLIDING_WINDOW slidingWinSize, #endif qScale, output, #if LOW_PREC_OUTPUT rcpOutScale, #endif #if USE_INPUT_KV qkv, #if ROPE_STYLE != 0 ropeCosSin, #endif #else q, #endif cacheList, #if USE_BEAM_SEARCH beamSearchParams, #endif batchSize, kvCacheScale, tensorMap, semaphores, scratch); #endif checkCuda(err); #endif } #endif