TensorRT-LLMs/cpp/kernels/xqa/mla_sm120.cu
Kanghwan 41e5870a70
[#8476][chore] Update license (#8807)
Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
2025-11-19 15:05:25 -08:00

2015 lines
76 KiB
Plaintext

/*
* SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "defines.h"
#include "mha.h"
#if IS_MLA
#include "barriers.cuh"
#include "mhaUtils.cuh"
#include "mha_components.cuh"
#include "mha_stdheaders.cuh"
#include "mla_sm120.cuh"
#include "mma.cuh"
#include "tma.h"
#include "utils.cuh"
#include "utils.h"
#ifndef GENERATE_CUBIN
#include "hostUtils.h"
#include "tensorMap.h"
#include <cuda_runtime.h>
#endif
#define USE_REG_Q 1
__constant__ constexpr XQAKernelType kernelType = XQAKernelType::kSM120_MLA;
inline constexpr bool allowMultipleInputTokens = true;
inline constexpr uint32_t partElemsK = 64; // @fixme: change this to 128 to save L2 traffic
inline constexpr uint32_t nbKParts = exactDiv(validElemsPerKHead, partElemsK);
inline constexpr uint32_t nbQParts = nbKParts;
inline constexpr uint32_t tokensPerTile = 64;
inline constexpr uint32_t partElemsV = 128;
inline constexpr uint32_t nbVSplit = 2;
inline constexpr uint32_t gemm1V = exactDiv(validElemsPerVHead, nbVSplit);
inline constexpr uint32_t nbProducerCtasPerCga = nbVSplit;
inline constexpr uint32_t multiBlockMinNbTilesPerCta = 2;
inline constexpr uint32_t multiBlockMinNbTiles = multiBlockMinNbTilesPerCta * 2;
using MathElem = CacheElem;
inline constexpr uint32_t mathElemBytes = sizeof(MathElem);
inline constexpr uint32_t grainsPerPartK = exactDiv(partElemsK * mathElemBytes, grainBytes);
inline constexpr uint32_t grainElems = exactDiv(grainBytes, mathElemBytes);
inline constexpr float xScale = 1.f / kE4M3_MAX;
__constant__ constexpr float rcpXScale = kE4M3_MAX;
inline constexpr uint32_t nbRegsForIOWarps = 32;
inline constexpr uint32_t nbRegsForMathWarps = 232;
inline constexpr bool computeRowSumFromF8 = true;
struct KVTilePartLoader
{
#if USE_PAGED_KV_CACHE
static_assert(tokensPerPage % tokensPerTile == 0 || tokensPerTile % tokensPerPage == 0);
static inline constexpr uint32_t nbPagesPerTile
= tokensPerTile >= tokensPerPage ? exactDiv(tokensPerTile, tokensPerPage) : 1;
#endif
static inline constexpr uint32_t const nbKHeads = 1;
KVCacheList<usePagedKVCache> const& cacheList;
uint32_t const idxReq;
static inline constexpr uint32_t const idxHeadGrp = 0;
CUtensorMap const& tensorMap;
// if greater than 1, then we need unrolling for the loading loop. Seems 1 is fine for latency.
static inline constexpr uint32_t nbPageBuffers = 1;
#if USE_PAGED_KV_CACHE
uint32_t const nbPages; // for bound check
Vec<KVCachePageIndex, nbPagesPerTile> pageBuffers[nbPageBuffers];
uint32_t idxTileRef = ~0U; // idxTile used to load the pages
#endif
uint32_t const baseOffset;
__device__ KVTilePartLoader(
KVCacheList<usePagedKVCache> const& cacheList, uint32_t idxReq, CUtensorMap const& tensorMap
#if USE_PAGED_KV_CACHE
,
uint32_t nbPages
#endif
);
// tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache
template <uint32_t nbTokens, uint32_t grainsPerPart, bool alignedForSwizzle>
__device__ void loadData(Array2D<LdGrain, nbTokens, grainsPerPart, alignedForSwizzle>& dst, uint32_t idxTile,
uint32_t idxElemBeg, CtaBarrier& bar, uint32_t idxPageBuf);
__device__ void loadPages(uint32_t idxTile, uint32_t idxPageBuf);
};
__device__ inline KVTilePartLoader::KVTilePartLoader(
KVCacheList<usePagedKVCache> const& cacheList, uint32_t idxReq, CUtensorMap const& tensorMap
#if USE_PAGED_KV_CACHE
,
uint32_t nbPages
#endif
)
: cacheList{cacheList}
, idxReq{idxReq}
, tensorMap{tensorMap}
#if USE_PAGED_KV_CACHE
, nbPages{nbPages}
#if PAGED_KV_CACHE_LAYOUT == 1
, baseOffset{idxReq * cacheList.maxNbPagesPerSeq}
#else
, baseOffset{((idxReq * beamWidth) * 2) * cacheList.maxNbPagesPerSeq}
#endif
#else
, baseOffset{(idxReq * beamWidth) * 2}
#endif
{
#pragma unroll
for (auto& pageBuffer : pageBuffers)
{
pageBuffer.fill(kBAD_PAGE_INDEX);
}
}
// tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache
template <uint32_t nbTokens, uint32_t grainsPerPart, bool alignedForSwizzle>
__device__ inline void KVTilePartLoader::loadData(Array2D<LdGrain, nbTokens, grainsPerPart, alignedForSwizzle>& dst,
uint32_t idxTile, uint32_t idxElemBeg, CtaBarrier& bar, uint32_t idxPageBuf)
{
static_assert(nbTokens == tokensPerTile);
#if USE_PAGED_KV_CACHE
assert(idxTile == idxTileRef);
auto const& pages = pageBuffers[idxPageBuf];
if constexpr (nbTokens < tokensPerPage)
{
assert(nbPagesPerTile == 1);
uint32_t const offset = nbTokens * (idxTile % exactDiv(tokensPerPage, nbTokens));
if (warpElectSync())
{
#if PAGED_KV_CACHE_LAYOUT == 1
tma::loadAsync(&dst, tensorMap, DimsLE<4>{idxElemBeg, idxHeadGrp, offset, (uint32_t) pages[0]}, bar);
#else
tma::loadAsync(&dst, tensorMap, DimsLE<4>{idxElemBeg, offset, idxHeadGrp, (uint32_t) pages[0]}, bar);
#endif
}
}
else
{
#pragma unroll
for (uint32_t i = 0; i < nbPagesPerTile; i++)
{
if (warpElectSync())
{
#if PAGED_KV_CACHE_LAYOUT == 1
tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap,
DimsLE<4>{idxElemBeg, idxHeadGrp, 0, (uint32_t) pages[i]}, bar);
#else
tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap,
DimsLE<4>{idxElemBeg, 0, idxHeadGrp, (uint32_t) pages[i]}, bar);
#endif
}
}
}
#else
if (warpElectSync())
{
tma::loadAsync(&dst, tensorMap, DimsLE<4>{idxElemBeg, nbTokens * idxTile, idxHeadGrp, baseOffset}, bar);
}
#endif
}
__device__ inline void KVTilePartLoader::loadPages(uint32_t idxTile, uint32_t idxPageBuf)
{
#if USE_PAGED_KV_CACHE
uint32_t const idxPageBeg
= tokensPerTile >= tokensPerPage ? nbPagesPerTile * idxTile : idxTile / exactDiv(tokensPerPage, tokensPerTile);
auto& pages = pageBuffers[idxPageBuf];
#pragma unroll
for (uint32_t i = 0; i < nbPagesPerTile; i++)
{
uint32_t const idxPage = idxPageBeg + i;
pages[i] = idxPage < nbPages ? cacheList.kvCachePageList[baseOffset + idxPage] : kBAD_PAGE_INDEX;
}
idxTileRef = idxTile;
#endif
}
using Mat16x32 = Vec<uint32_t, 4>;
template <uint32_t srcRows, uint32_t srcCols>
class Mat16x32Loader
{
public:
using Src = Array2D<LdGrain, srcRows, srcCols>;
// default r and c are for mat A.
__device__ inline Mat16x32Loader(
Src const& src, uint32_t baseRow, uint32_t idxInstK, uint32_t r = laneId() % 16, uint32_t c = laneId() / 16)
: src{src}
, baseRow{baseRow}
, idxInstK{idxInstK}
, r{r}
, c{c}
, basePtr{getPtrRef(0)}
{
static_assert((grainBytes * srcCols * qmmaShape.m) % 1024 == 0);
}
__device__ inline Mat16x32 load(uint32_t idxInstM) const
{
return ldmatrix<false, 4>(getPtr(idxInstM));
}
template <uint32_t tileM>
__device__ inline Vec<Mat16x32, exactDiv(tileM, qmmaShape.m)> loadWholeCol() const
{
uint32_t const nbInstM = exactDiv(tileM, qmmaShape.m);
Vec<Mat16x32, nbInstM> ret;
#pragma unroll
for (uint32_t i = 0; i < nbInstM; i++)
{
ret[i] = load(i);
}
return ret;
}
__device__ inline LdGrain const* getPtr(uint32_t idxInstM) const
{
return checkedVal(basePtr + idxInstM * qmmaShape.m * srcCols, getPtrRef(idxInstM));
}
private:
__device__ inline LdGrain const* getPtrRef(uint32_t idxInstM) const
{
return &src.template at<true>(
baseRow + idxInstM * qmmaShape.m + r, idxInstK * exactDiv(qmmaShape.k, grainElems) + c);
}
Src const& src;
uint32_t const baseRow;
uint32_t const idxInstK;
uint32_t const r;
uint32_t const c;
LdGrain const* const basePtr;
};
using InstAcc = Array2D<float, 2, 2>;
using XBuffer = Array2D<LdGrain, headGrpSize, exactDiv(mathElemBytes* tokensPerTile, grainBytes)>;
struct CgaXBuffer
{
XBuffer x;
Vec<float, headGrpSize> rowSum;
Vec<float, headGrpSize> rowMaxLog2e;
};
struct PingPongMutex
{
using ShmStorage = CtaBarrier[2];
ShmStorage& barriers;
uint32_t const idxGrp;
bool skipWait = false;
static __device__ inline void initStorage(ShmStorage& barriers, uint32_t thrdsPerGrp)
{
new (&barriers[0]) CtaBarrier(thrdsPerGrp);
new (&barriers[1]) CtaBarrier(thrdsPerGrp);
barriers[0].arrive(thrdsPerGrp);
}
__device__ inline PingPongMutex(ShmStorage& shmStorage, uint32_t idxGrp)
: barriers{shmStorage}
, idxGrp{idxGrp}
{
}
__device__ inline void test_lock(uint32_t iter)
{
skipWait = barriers[idxGrp].test_wait_parity(toParity<1>(iter));
}
__device__ inline void lock(uint32_t iter)
{
if (!skipWait)
{
barriers[idxGrp].wait_parity(toParity<1>(iter));
}
}
__device__ inline void unlock()
{
barriers[idxGrp ^ 1U].arrive();
skipWait = false;
}
};
struct PartialResult
{
static constexpr uint32_t nbChunks = 4;
static constexpr uint32_t nbRowsPerChunk = exactDiv(headGrpSize, nbChunks);
struct Chunk
{
Vec<OutputHead, nbRowsPerChunk> data;
Vec<float, nbRowsPerChunk> rowSum;
Vec<float, nbRowsPerChunk> rowMaxLog2e;
};
Chunk chunks[nbChunks];
};
constexpr uint32_t nbMathWarpsA = 8;
constexpr uint32_t nbComputeWarpsB = 8;
constexpr uint32_t nbMathGrpsA = 2;
constexpr uint32_t nbMathWarpsB = 8;
constexpr uint32_t nbMultiBlockBufs = 2;
constexpr uint32_t multiBlockMathWarps = 8;
constexpr bool useRegQ = USE_REG_Q;
struct SharedMemA
{
static inline constexpr uint32_t nbKBufs = 12;
static inline constexpr uint32_t regQParts = (useRegQ ? 4 : 0);
static inline constexpr uint32_t shmQParts = nbQParts - regQParts;
using ShmQPart = Array2D<LdGrain, headGrpSize, grainsPerPartK>;
using ShmKPart = Array2D<LdGrain, tokensPerTile, grainsPerPartK>;
Vec<ShmQPart, shmQParts> q;
ShmKPart k[nbKBufs];
// single buffer reused by two groups. sendX() warp will arbitrate the order of x buffer access via two xBars.
CgaXBuffer x;
// scaled by log2e. Write by last CGA iteration (from the other producer CTA) and read by current producer CTA.
Vec<float, headGrpSize> rowMaxLog2e;
// sync rowMaxLog2e between two producer CTAs and .consumed means the buffer for next iteration (in next producer)
// is ready. The 4 groups from 2 producers CTAs form a ring
CgaBarrier rowMaxLog2eBar[nbMathGrpsA];
PingPongMutex::ShmStorage tensorCoreMutex;
CtaBarrierPair kBars[nbKBufs];
static inline constexpr uint32_t nbXBars = nbMathGrpsA;
CtaBarrierPair xBars[nbXBars];
#if USE_REG_Q
CtaBarrierPair regQBar;
#endif
CtaBarrier shmQBar;
CgaBarrier cgaXBufConsumed; // for X
CtaBarrierPair multiBlockBars[nbMultiBlockBufs];
__device__ inline void invalidateBarriers(uint32_t thrdIdx)
{
constexpr uint32_t nbBars = (useRegQ ? 12 : 10) + 2 * (nbKBufs + nbXBars);
#ifndef __CUDACC_RTC__
constexpr uint32_t nbBarsRef
= exactDiv(offsetof(SharedMemA, qkScaleLog2e) - offsetof(SharedMemA, rowMaxLog2eBar), 8);
static_assert(nbBars == nbBarsRef);
#endif
if (thrdIdx < nbBars)
{
reinterpret_cast<CtaBarrier*>(&rowMaxLog2eBar[0])[thrdIdx].~CtaBarrier();
}
}
__device__ inline Vec<PartialResult::Chunk, nbMultiBlockBufs>& getMultiBlockBufs()
{
#ifndef __CUDACC_RTC__
assert(sizeof(Vec<PartialResult::Chunk, nbMultiBlockBufs>) < offsetof(SharedMemA, rowMaxLog2eBar));
#endif
return *reinterpret_cast<Vec<PartialResult::Chunk, nbMultiBlockBufs>*>(this);
}
float qkScaleLog2e;
bool isLastSubSeq;
};
struct SharedMemB
{
static inline constexpr uint32_t nbXVBufs = 2;
static inline constexpr uint32_t nbXBufs = nbXVBufs;
static inline constexpr uint32_t nbVBufs = nbXVBufs;
using VBuffer
= Vec<Array2D<LdGrain, tokensPerTile, exactDiv(partElemsV, grainElems)>, exactDiv(gemm1V, partElemsV)>;
// x and v are using gemmK=128 per iteration. If we see high pressure on shared memory capacity, we can change to 64
// in the future.
struct XVBuffer
{
VBuffer v;
CgaXBuffer x;
uint8_t pad[headGrpSize * 128 * 2 - sizeof(VBuffer) - sizeof(CgaXBuffer)]; // for output swizzling
};
XVBuffer xv[nbXVBufs];
__device__ inline XBuffer& x(uint32_t idx)
{
return xv[idx].x.x;
}
__device__ inline VBuffer& v(uint32_t idx)
{
return xv[idx].v;
}
__device__ inline Vec<float, headGrpSize>& xRowSum(uint32_t idx)
{
return xv[idx].x.rowSum;
}
__device__ inline Vec<float, headGrpSize>& xRowMaxLog2e(uint32_t idx)
{
return xv[idx].x.rowMaxLog2e;
}
static inline constexpr uint32_t nbAccRowMaxSumCopies = 2;
Vec<float, headGrpSize> accRowMaxLog2e[nbAccRowMaxSumCopies];
Vec<float, headGrpSize> accRowSum[nbAccRowMaxSumCopies];
CtaBarrierPair xBars[nbXBufs];
CtaBarrierPair vBars[nbVBufs];
CgaBarrier cgaXBufProduced[nbProducerCtasPerCga];
CtaBarrier mathWarpsBar;
CtaBarrierPair multiBlockBars[nbMultiBlockBufs];
__device__ inline void invalidateBarriers(uint32_t thrdIdx)
{
constexpr uint32_t nbBars = 15;
#ifndef __CUDACC_RTC__
constexpr uint32_t nbBarsRef = exactDiv(offsetof(SharedMemB, isLastSubSeq) - offsetof(SharedMemB, xBars), 8);
static_assert(nbBars == nbBarsRef);
#endif
if (thrdIdx < nbBars)
{
reinterpret_cast<CtaBarrier*>(&xBars[0])[thrdIdx].~CtaBarrier();
}
}
__device__ inline Vec<PartialResult::Chunk, nbMultiBlockBufs>& getMultiBlockBufs()
{
#ifndef __CUDACC_RTC__
static_assert(sizeof(Vec<PartialResult::Chunk, nbMultiBlockBufs>) < offsetof(SharedMemB, xBars));
#endif
return *reinterpret_cast<Vec<PartialResult::Chunk, nbMultiBlockBufs>*>(this);
}
bool isLastSubSeq;
};
__device__ void mergePartialOutputs(uint32_t& semaphore, Vec<OutputHead, PartialResult::nbRowsPerChunk>& dst,
PartialResult const* reqPartialResults, uint32_t nbSubSeq, uint32_t ctaRank, uint32_t warpRank, uint2 warpIdx,
void* sharedMem);
struct KernelArgs
{
CUtensorMap const& tensorMapQ; // MhaIOHead[nbQHeads * totalNbInputTokens]
CUtensorMap const& tensorMapK;
CUtensorMap const& tensorMapV;
float const& qScale;
OutputHead* __restrict__ const& output; // [totalNbIntputTokens][nbQHeads]
KVCacheList<usePagedKVCache> const& cacheList;
uint32_t const& batchSize;
float const* __restrict__ const&
kvCacheScale; // Device memory scalar. Same scale for K and V cache. Used only for int8/fp8 KV cache.
Vec<CgaXBuffer, nbProducerCtasPerCga>* __restrict__ const& cgaXBuf; // [totalNbInputTokens][maxNbSubSeq]
uint32_t* __restrict__ const& semaphores; // [totalNbInputTokens]
PartialResult* __restrict__ const& partialResults; // [totalNbInputTokens][maxNbSubSeq]
};
struct Producer
{
static inline constexpr uint32_t nbMathGrps = nbMathGrpsA;
static inline constexpr uint32_t nbMathWarps = nbMathWarpsA;
static inline constexpr uint32_t nbMathThrds = nbMathWarps * warp_size;
static inline constexpr uint32_t warpsPerGrp = exactDiv(nbMathWarps, nbMathGrps);
static inline constexpr uint32_t thrdsPerGrp = warpsPerGrp * warp_size;
static inline constexpr uint2 warpTile = {tokensPerTile, exactDiv(headGrpSize, warpsPerGrp)};
using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
using ThrdRegRowMax = ThrdRegRowMaxT<warpTile.y>;
using QuadRegRowMax = QuadRegRowMaxT<warpTile.y>;
KernelArgs const& args;
SharedMemA& smem;
uint32_t const maxNbSubSeq;
uint32_t const idxReq;
uint32_t const idxInputTokenGlobal;
uint32_t const nbSubSeq;
uint32_t const idxSubSeq;
uint32_t const seqLen;
uint32_t const ctaRank;
uint32_t const warpRank;
uint2 const warpIdx;
__device__ inline Producer(KernelArgs const& args, SharedMemA& smem, uint32_t const maxNbSubSeq,
uint32_t const idxReq, uint32_t idxInputTokenGlobal, uint32_t const seqLen, uint32_t const nbSubSeq,
uint32_t const idxSubSeq, uint32_t ctaRank, uint32_t const warpRank, uint2 const warpIdx)
: args(args)
, smem(smem)
, maxNbSubSeq(maxNbSubSeq)
, idxReq(idxReq)
, idxInputTokenGlobal(idxInputTokenGlobal)
, seqLen(seqLen)
, nbSubSeq(nbSubSeq)
, idxSubSeq(idxSubSeq)
, ctaRank(ctaRank)
, warpRank(warpRank)
, warpIdx(warpIdx)
{
#ifndef NDEBUG
if (threadIdx.x == 0)
{
asm("st.bulk.weak [%0], %1, 0;\n" ::"l"(&smem), "n"(sizeof(SharedMemA)) : "memory");
}
__syncthreads();
#endif
if (threadIdx.x == 0)
{
smem.qkScaleLog2e = args.qScale * args.kvCacheScale[0] * log2e;
}
if (threadIdx.x < headGrpSize)
{
smem.rowMaxLog2e[threadIdx.x] = safeInitRowMax;
}
if (warpElectSync())
{
if (warpRank < SharedMemA::nbKBufs)
{
auto& b = smem.kBars[warpRank];
b.initialize(1, thrdsPerGrp);
b.consumed.arrive<Scope::CTA, ArriveOrder::RELAXED>(thrdsPerGrp);
}
if (warpRank < SharedMemA::nbXBars)
{
auto& b = smem.xBars[warpRank];
b.initialize(thrdsPerGrp, 1);
}
#if USE_REG_Q
if (warpRank == 0)
{
smem.regQBar.initialize(1, nbMathThrds);
smem.regQBar.consumed.arrive<Scope::CTA, ArriveOrder::RELAXED>(nbMathThrds);
}
#endif
if (warpRank < nbMathGrpsA)
{
auto& b = smem.rowMaxLog2eBar[warpRank];
init(&b, thrdsPerGrp);
}
if (ctaRank == 0 && warpRank == 0)
{
smem.rowMaxLog2eBar[0].arrive<Scope::CTA, ArriveOrder::RELAXED>(thrdsPerGrp);
}
if (warpRank == 0)
{
init(&smem.shmQBar, 1);
init(&smem.cgaXBufConsumed, 1 * nbVSplit);
smem.cgaXBufConsumed.arrive<Scope::CTA, ArriveOrder::RELAXED>(1 * nbVSplit);
PingPongMutex::initStorage(smem.tensorCoreMutex, thrdsPerGrp);
}
if (nbSubSeq > 1 && warpRank < nbMultiBlockBufs)
{
auto& b = smem.multiBlockBars[warpRank];
b.initialize(1, warp_size * multiBlockMathWarps);
b.consumed.arrive<Scope::CTA, ArriveOrder::RELAXED>(warp_size * multiBlockMathWarps);
}
}
clusterBarArrive();
clusterBarWait();
}
__device__ inline ~Producer()
{
clusterBarArrive();
clusterBarWait();
smem.invalidateBarriers(threadIdx.x);
}
__device__ inline void run()
{
if (warpIdx.y == 2)
{ // IO warps
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" ::"n"(nbRegsForIOWarps));
if (warpIdx.x == 0)
{ // q
loadQ();
}
else if (warpIdx.x == 1)
{ // k
loadK();
}
else if (warpIdx.x == 2)
{ // x
sendX();
}
}
else
{ // Compute warps
asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" ::"n"(nbRegsForMathWarps));
compute();
}
if (nbSubSeq > 1)
{
mergePartialOutputs(args.semaphores[idxInputTokenGlobal],
reinterpret_cast<Vec<OutputHead, PartialResult::nbRowsPerChunk>&>(
args.output[headGrpSize * idxInputTokenGlobal + PartialResult::nbRowsPerChunk * ctaRank]),
args.partialResults + maxNbSubSeq * idxInputTokenGlobal, nbSubSeq, ctaRank, warpRank, warpIdx, &smem);
}
}
private:
__device__ inline uint32_t iterStride() const
{
return nbSubSeq * nbProducerCtasPerCga;
}
__device__ inline uint32_t idxTileBeg() const
{
return nbProducerCtasPerCga * idxSubSeq + ctaRank;
}
__device__ inline uint32_t nbTiles() const
{
return divUp(seqLen, tokensPerTile);
}
__device__ inline SharedMemB& getConsumerShm(uint32_t const idxConsumer)
{
return *mapa(reinterpret_cast<SharedMemB*>(&smem), nbProducerCtasPerCga + idxConsumer);
};
static constexpr uint32_t regQPartShmBeg = SharedMemA::shmQParts - SharedMemA::regQParts;
__device__ inline void loadQ()
{
#if USE_REG_Q
static_assert(SharedMemA::regQParts <= SharedMemA::shmQParts);
smem.regQBar.consumed.wait_parity(toParity<1>(0));
#pragma unroll 1
for (uint32_t i = 0; i < SharedMemA::regQParts; i++)
{
if (warpElectSync())
{
tma::loadAsync(&smem.q[regQPartShmBeg + i], args.tensorMapQ,
DimsLE<2>{partElemsK * i, headGrpSize * idxInputTokenGlobal}, smem.regQBar.produced);
}
}
if (warpElectSync())
{
smem.regQBar.produced.arrive_tx(sizeof(SharedMemA::ShmQPart) * SharedMemA::regQParts);
}
#endif
#pragma unroll 1
for (uint32_t i = 0; i < SharedMemA::shmQParts; i++)
{
uint32_t const idxPart = SharedMemA::regQParts + i;
#if USE_REG_Q
if (i == regQPartShmBeg)
{
smem.regQBar.consumed.wait_parity(toParity<1>(1));
}
#endif
if (warpElectSync())
{
tma::loadAsync(&smem.q[i], args.tensorMapQ,
DimsLE<2>{partElemsK * idxPart, headGrpSize * idxInputTokenGlobal}, smem.shmQBar);
}
}
if (warpElectSync())
{
smem.shmQBar.arrive_tx(sizeof(SharedMemA::ShmQPart) * SharedMemA::shmQParts);
}
}
__device__ inline void loadK();
__device__ inline void sendX();
__device__ inline void compute()
{
uint32_t const grpIdx = warpIdx.y;
uint32_t const tileBaseRow = warpTile.y * warpIdx.x;
PingPongMutex tensorCoreMutex{smem.tensorCoreMutex, grpIdx};
constexpr uint32_t partNbInstK = exactDiv(partElemsK, qmmaShape.k);
using AtomA = Vec<uint32_t, 4>; // for 16x32 data, working as mat A of QMMA.16832
using RegQPartCol = Vec<AtomA, exactDiv(warpTile.y, qmmaShape.m)>;
using RegQPart = Vec<RegQPartCol, partNbInstK>;
using RegQ = Vec<RegQPart, SharedMemA::regQParts>;
constexpr uint32_t tileNbAtomBx2 = exactDiv(tokensPerTile, qmmaShape.n * 2);
using AtomBx2 = Vec<uint32_t, 4>; // one AtomB is 8x32 and AtomBx2 is 16x32
using RegKPartCol = Vec<AtomBx2, tileNbAtomBx2>;
using RegKPart = Vec<RegKPartCol, partNbInstK>;
uint32_t const lane = laneId();
uint32_t const rA = lane % 16;
uint32_t const cA = lane / 16;
uint32_t const rB = (lane / 16) * 8 + lane % 8;
uint32_t const cB = (lane % 16) / 8;
auto loadRegQCol = [&](SharedMemA::ShmQPart const& q, uint32_t idxInstK) -> RegQPartCol
{
Mat16x32Loader const loaderQ(q, tileBaseRow, idxInstK, rA, cA);
return loaderQ.loadWholeCol<warpTile.y>();
};
auto loadRegKCol = [&](SharedMemA::ShmKPart const& k, uint32_t idxInstK) -> RegKPartCol
{
Mat16x32Loader const loaderK(k, 0, idxInstK, rB, cB);
return loaderK.loadWholeCol<warpTile.x>();
};
auto loadPart = [&](auto const& loadCol, auto const& shmPart)
{
mha::conditional_t<mha::is_same_v<SharedMemA::ShmQPart, mha::decay_t<decltype(shmPart)>>, RegQPart,
RegKPart>
regPart;
#pragma unroll
for (uint32_t idxInstK = 0; idxInstK < partNbInstK; idxInstK++)
{
regPart[idxInstK] = loadCol(shmPart, idxInstK);
}
return regPart;
};
#if USE_REG_Q
// load regQ
smem.regQBar.produced.wait_parity(toParity<1>(0));
RegQ regQ;
#pragma unroll
for (uint32_t idxPart = 0; idxPart < SharedMemA::regQParts; idxPart++)
{
uint32_t const idxBuf = regQPartShmBeg + idxPart;
regQ[idxPart] = loadPart(loadRegQCol, smem.q[idxBuf]);
}
smem.regQBar.consumed.arrive();
#endif
// main loop
#pragma unroll 1
for (uint32_t grpIter = 0; true; grpIter++)
{
uint32_t const ctaIter = grpIdx + grpIter * nbMathGrps;
uint32_t const idxTile = idxTileBeg() + iterStride() * ctaIter;
if (idxTile >= nbTiles())
{
break;
}
WarpAcc acc{};
// wait until it's our turn
tensorCoreMutex.lock(grpIter);
BarWaiter kBarWaiter(smem.kBars, ctaIter * nbKParts + 0);
kBarWaiter.testWait();
RegQPart regQBuf;
#if USE_REG_Q
static_assert(SharedMemA::regQParts > 0);
regQBuf[0] = regQ[0][0];
#else
regQBuf[0] = loadRegQCol(smem.q[0], 0);
#endif
kBarWaiter.wait();
RegKPart regKBuf;
regKBuf[0] = loadRegKCol(smem.k[kBarWaiter.idxBuf], 0);
auto shouldTestWait = [](uint32_t idxInstK, uint32_t idxAtomBx2)
{ return idxInstK == partNbInstK - 1 && idxAtomBx2 == tileNbAtomBx2 - 2; };
BarWaiter kBarWaiterNext = kBarWaiter.next();
#if USE_REG_Q
#pragma unroll
for (uint32_t idxPart = 0; idxPart < SharedMemA::regQParts; idxPart++)
{
#pragma unroll
for (uint32_t idxInstK = 0; idxInstK < partNbInstK; idxInstK++)
{
bool const prefetchNextPart = (idxInstK == partNbInstK - 1);
uint32_t const idxPartPrefetch = prefetchNextPart ? idxPart + 1 : idxPart;
uint32_t const idxInstKPrefetch = prefetchNextPart ? 0 : idxInstK + 1;
bool const prefetch = (!prefetchNextPart || (idxPart < nbKParts - 1));
if (prefetchNextPart)
{
kBarWaiter = kBarWaiterNext;
kBarWaiterNext = kBarWaiter.next();
if (prefetch)
{
kBarWaiter.wait();
}
}
Mat16x32Loader const loaderK(smem.k[kBarWaiter.idxBuf], 0, idxInstKPrefetch, rB, cB);
#pragma unroll
for (uint32_t idxAtomBx2 = 0; idxAtomBx2 < tileNbAtomBx2; idxAtomBx2++)
{
if (idxAtomBx2 == 2 && prefetch)
{
if (idxPartPrefetch < SharedMemA::regQParts)
{
regQBuf[idxInstKPrefetch] = regQ[idxPartPrefetch][idxInstKPrefetch];
}
else
{
regQBuf[idxInstKPrefetch]
= loadRegQCol(smem.q[idxPartPrefetch - SharedMemA::regQParts], idxInstKPrefetch);
}
}
AtomBx2 const& atomBx2 = regKBuf[idxInstK][idxAtomBx2];
regKBuf[idxInstKPrefetch][idxAtomBx2] = loaderK.load(idxAtomBx2);
if (shouldTestWait(idxInstKPrefetch, idxAtomBx2) && prefetch)
{
kBarWaiterNext.testWait();
}
#pragma unroll
for (uint32_t i = 0; i < WarpAcc::rows; i++)
{
#pragma unroll
for (uint32_t j = 0; j < 2; j++)
{
mma<__nv_fp8_e4m3>(reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)),
reinterpret_cast<uint32_t const(&)[2][2]>(regQBuf[idxInstK][i]),
reinterpret_cast<uint32_t const(&)[2][1]>(atomBx2[2 * j]));
}
}
if (prefetch)
{
regKBuf[idxInstKPrefetch][idxAtomBx2] = loaderK.load(idxAtomBx2);
}
}
if (idxInstKPrefetch == partNbInstK - 1)
{
assert(prefetch);
kBarWaiter.consumed();
}
}
}
#endif
if (ctaIter == 0)
{
smem.shmQBar.wait_parity(false);
}
#pragma unroll
for (uint32_t idxPart = SharedMemA::regQParts; idxPart < nbQParts; idxPart++)
{
#pragma unroll
for (uint32_t idxInstK = 0; idxInstK < partNbInstK; idxInstK++)
{
bool const prefetchNextPart = (idxInstK == partNbInstK - 1);
uint32_t const idxPartPrefetch = prefetchNextPart ? idxPart + 1 : idxPart;
uint32_t const idxInstKPrefetch = prefetchNextPart ? 0 : idxInstK + 1;
bool const prefetch = (!prefetchNextPart || (idxPart < nbKParts - 1));
if (prefetchNextPart)
{
kBarWaiter = kBarWaiterNext;
kBarWaiterNext = kBarWaiter.next();
if (prefetch)
{
kBarWaiter.wait();
}
}
Mat16x32Loader const loaderK(smem.k[kBarWaiter.idxBuf], 0, idxInstKPrefetch, rB, cB);
#pragma unroll
for (uint32_t idxAtomBx2 = 0; idxAtomBx2 < tileNbAtomBx2; idxAtomBx2++)
{
if (idxAtomBx2 == 2 && prefetch)
{
regQBuf[idxInstKPrefetch]
= loadRegQCol(smem.q[idxPartPrefetch - SharedMemA::regQParts], idxInstKPrefetch);
}
AtomBx2 const& atomBx2 = regKBuf[idxInstK][idxAtomBx2];
if (shouldTestWait(idxInstKPrefetch, idxAtomBx2) && prefetch)
{
kBarWaiterNext.testWait();
}
#pragma unroll
for (uint32_t i = 0; i < WarpAcc::rows; i++)
{
#pragma unroll
for (uint32_t j = 0; j < 2; j++)
{
mma<__nv_fp8_e4m3>(reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)),
reinterpret_cast<uint32_t const(&)[2][2]>(regQBuf[idxInstK][i]),
reinterpret_cast<uint32_t const(&)[2][1]>(atomBx2[2 * j]));
}
}
if (prefetch)
{
regKBuf[idxInstKPrefetch][idxAtomBx2] = loaderK.load(idxAtomBx2);
}
}
if (idxInstKPrefetch == partNbInstK - 1)
{
assert(prefetch);
kBarWaiter.consumed();
if (idxPartPrefetch == nbKParts - 1)
{
tensorCoreMutex.unlock(); // let the other group to use tensor cores
}
}
}
}
uint32_t const validTokens = seqLen - tokensPerTile * idxTile;
if (validTokens < tokensPerTile)
{
applyMask(this_warp(), acc, 0, validTokens);
}
ThrdRegRowMax rowMaxLog2e;
WarpAcc const xF32 = scaleAndSoftmax(rowMaxLog2e, acc, grpIdx, grpIter, tileBaseRow);
auto& xBar = smem.xBars[grpIdx];
bool const skipXBarWait = xBar.consumed.test_wait_parity(toParity<1>(grpIter));
// convert to fp8
WarpAcc const xF32Quant = xF32 * rcpXScale;
// 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15
Array2D<Array2D<uint32_t, 2, 1>, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> xF8;
#pragma unroll
for (uint32_t i = 0; i < WarpAcc::rows; i++)
{
#pragma unroll
for (uint32_t m = 0; m < exactDiv(qmmaShape.m, 8); m++)
{
#pragma unroll
for (uint32_t j = 0; j < WarpAcc::cols; j += 2)
{
auto& dst = reinterpret_cast<__nv_fp8x2_e4m3(&)[2]>(xF8(i, j / 2)(m, 0));
dst[0] = __nv_fp8x2_e4m3(float2{xF32Quant(i, j)(m, 0), xF32Quant(i, j)(m, 1)});
dst[1] = __nv_fp8x2_e4m3(float2{xF32Quant(i, j + 1)(m, 0), xF32Quant(i, j + 1)(m, 1)});
}
}
}
// use tensor core to compute rowSum
ThrdRegRowMax const rowSum = computeRowSumFromF8
? computeRowSumF8<warpTile.y, warpTile.x>(this_warp(), xF8)
: computeRowSumF32<warpTile.y, warpTile.x>(this_warp(), xF32);
// store xF8 and rowSum into L2 scratch buffer
if (!skipXBarWait)
{
xBar.consumed.wait_parity(toParity<1>(grpIter));
}
storeRowMax<warpTile.y>(smem.x.rowMaxLog2e, rowMaxLog2e, tileBaseRow, lane);
storeRowMax<warpTile.y>(smem.x.rowSum, rowSum, tileBaseRow, lane);
storeOrderedXToShm(smem.x.x, xF8, tileBaseRow, lane);
xBar.produced.arrive();
}
}
__device__ inline WarpAcc scaleAndSoftmax(
ThrdRegRowMax& rowMaxLog2e, WarpAcc const& acc, uint32_t grpIdx, uint32_t grpIter, uint32_t tileBaseRow);
__device__ inline void storeOrderedXToShm(XBuffer& dst,
Array2D<Array2D<uint32_t, 2, 1>, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> const& src,
uint32_t const tileBaseRow, uint32_t const lane = laneId());
};
__device__ inline void Producer::loadK()
{
KVTilePartLoader loader
{
args.cacheList, idxReq, args.tensorMapK
#if USE_PAGED_KV_CACHE
,
divUp(seqLen, tokensPerPage)
#endif
};
#pragma unroll 1
for (uint32_t iter = 0; true; iter++)
{
uint32_t const idxTile = idxTileBeg() + iterStride() * iter;
if (idxTile >= nbTiles())
{
break;
}
uint32_t const idxPageBuf = iter % KVTilePartLoader::nbPageBuffers;
loader.loadPages(idxTile, idxPageBuf);
#pragma unroll 1
for (uint32_t idxPart = 0; idxPart < nbKParts; idxPart++)
{
uint32_t const idxPartGlobal = iter * nbKParts + idxPart;
uint32_t const idxBuf = idxPartGlobal % SharedMemA::nbKBufs;
auto& bar = smem.kBars[idxBuf];
bar.consumed.wait_parity(toParity<SharedMemA::nbKBufs>(idxPartGlobal));
loader.loadData(smem.k[idxBuf], idxTile, partElemsK * idxPart, bar.produced, idxPageBuf);
if (warpElectSync())
{
bar.produced.arrive_tx(sizeof(SharedMemA::ShmKPart));
}
}
}
}
__device__ inline void Producer::sendX()
{
// let group 0 to produce first.
if (warpElectSync())
{
smem.xBars[0].consumed.arrive();
}
for (uint32_t iter = 0; true; iter++)
{
uint32_t const idxTile = idxTileBeg() + iterStride() * iter;
if (idxTile >= nbTiles())
{
break;
}
uint32_t const idxBar = iter % SharedMemA::nbXBars;
auto& xBar = smem.xBars[idxBar];
xBar.produced.wait_parity(toParity<SharedMemA::nbXBars>(iter));
smem.cgaXBufConsumed.wait_parity(toParity<1>(iter));
if (warpElectSync())
{
auto& dst = args.cgaXBuf[nbSubSeq * idxInputTokenGlobal + idxSubSeq][ctaRank];
tma::store1DAsync(&dst, &smem.x, sizeof(CgaXBuffer));
tma::commitGroup();
tma::waitGroup<0>();
// it's turn for the other math group to produce.
uint32_t const idxBarNext = (iter + 1) % SharedMemA::nbXBars;
auto& xBarNext = smem.xBars[idxBarNext];
xBarNext.consumed.arrive();
asm volatile("fence.release.cluster;\n");
#pragma unroll
for (uint32_t i = 0; i < nbVSplit; i++)
{
auto& producedBar = getConsumerShm(i).cgaXBufProduced[ctaRank];
producedBar.arrive<Scope::CGA, ArriveOrder::RELAXED>();
}
}
}
}
__device__ inline Producer::WarpAcc Producer::scaleAndSoftmax(
ThrdRegRowMax& rowMaxLog2e, WarpAcc const& acc, uint32_t grpIdx, uint32_t grpIter, uint32_t tileBaseRow)
{
uint32_t const ctaIter = grpIdx + grpIter * nbMathGrps;
uint32_t const cgaIter = ctaRank + ctaIter * nbProducerCtasPerCga;
auto const warp = this_warp();
uint32_t const lane = laneId();
uint32_t const idxProducer = ctaRank;
assert(ctaRank < nbProducerCtasPerCga);
float const qkScaleLog2e = smem.qkScaleLog2e;
bool const skipWaitLastShmRowMax = smem.rowMaxLog2eBar[grpIdx].test_wait_parity(toParity<1>(grpIter));
QuadRegRowMax const tileRowMaxLog2e = computeRowMax<warpTile.y, warpTile.x>(acc) * qkScaleLog2e;
// get max with previous CTA's rowMax
if (!skipWaitLastShmRowMax)
{
smem.rowMaxLog2eBar[grpIdx].wait_parity(toParity<1>(grpIter));
}
auto const lastRowMaxLog2e = loadShmRowMax<warpTile.y>(smem.rowMaxLog2e, tileBaseRow, lane);
auto const quadRowMaxLog2e = fmaxf(tileRowMaxLog2e, replicateForQuad(warp, lastRowMaxLog2e));
// transfer new row max to the other producer CTA for next iteration
SharedMemA& smemNext = mapa(smem, ctaRank ^ 1U);
CgaBarrier& nextRowMaxLog2eBar
= smemNext.rowMaxLog2eBar[(cgaIter + 1) % (nbMathGrps * nbProducerCtasPerCga) / nbMathGrps];
rowMaxLog2e = dedupFromQuad(warp, quadRowMaxLog2e);
storeRowMaxAsync<warpTile.y>(nextRowMaxLog2eBar, smemNext.rowMaxLog2e, rowMaxLog2e, tileBaseRow, lane);
nextRowMaxLog2eBar.arrive_tx_relaxed(sizeof(rowMaxLog2e)); // notify that the next CTA can read rowMax now.
WarpAcc x;
// apply softmax
#pragma unroll
for (uint32_t m = 0; m < acc.rows; m++)
{
#pragma unroll
for (uint32_t i = 0; i < InstAcc::rows; i++)
{
float const maxVal = quadRowMaxLog2e[m * InstAcc::rows + i];
#pragma unroll
for (uint32_t n = 0; n < acc.cols; n++)
{
#pragma unroll
for (uint32_t j = 0; j < InstAcc::cols; j++)
{
float elem = acc(m, n)(i, j);
assert(maxVal >= elem * qkScaleLog2e);
x(m, n)(i, j) = exp2f(elem * qkScaleLog2e - maxVal);
}
}
}
}
return x;
}
__device__ inline void Producer::storeOrderedXToShm(XBuffer& dst,
Array2D<Array2D<uint32_t, 2, 1>, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> const& src, uint32_t const tileBaseRow,
uint32_t const lane)
{
uint32_t const r = lane % 16;
uint32_t const c = lane / 16;
using Src = mha::decay_t<decltype(src)>;
LdGrain* ptrs[exactDiv(Src::cols, 2)][Src::rows];
#pragma unroll
for (uint32_t idxInstK = 0; idxInstK < exactDiv(Src::cols, 2); idxInstK++)
{
Mat16x32Loader const loader(dst, tileBaseRow, idxInstK, r, c);
#pragma unroll
for (uint32_t idxInstM = 0; idxInstM < Src::rows; idxInstM++)
{
auto const p = const_cast<LdGrain*>(loader.getPtr(idxInstM));
stmatrix<false, 4>(p, reinterpret_cast<LdGrain const&>(src(idxInstM, idxInstK * 2)));
ptrs[idxInstK][idxInstM] = p;
}
}
// reorder from 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15
// to 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
__syncwarp();
#pragma unroll
for (uint32_t idxInstK = 0; idxInstK < exactDiv(Src::cols, 2); idxInstK++)
{
#pragma unroll
for (uint32_t idxInstM = 0; idxInstM < Src::rows; idxInstM++)
{
auto const p = ptrs[idxInstK][idxInstM];
auto const i = *p;
LdGrain const o = {prmt(i[0], i[1], PermuteOrder{0, 1, 4, 5}), prmt(i[2], i[3], PermuteOrder{0, 1, 4, 5}),
prmt(i[0], i[1], PermuteOrder{2, 3, 6, 7}), prmt(i[2], i[3], PermuteOrder{2, 3, 6, 7})};
*p = o;
}
}
}
struct Consumer
{
static inline constexpr uint32_t nbMathWarps = nbMathWarpsB;
static inline constexpr uint32_t nbMathThrds = warp_size * nbMathWarps;
static inline constexpr uint2 ctaShape = {2, 4};
static_assert(SharedMemB::nbAccRowMaxSumCopies == ctaShape.x);
static_assert(ctaShape.x * ctaShape.y == nbMathWarps);
static inline constexpr uint2 warpTile = {exactDiv(gemm1V, ctaShape.x), exactDiv(headGrpSize, ctaShape.y)};
static inline constexpr uint32_t nbWarpOutSwizzleBuf = nbMathWarps;
using WarpOutSwizzleBuf = Array2D<LdGrain,
exactDiv(exactDiv(sizeof(SharedMemB::XVBuffer), sizeof(OutputElem) * warpTile.x), nbMathWarps),
exactDiv(sizeof(OutputElem) * warpTile.x, grainBytes)>;
static_assert(WarpOutSwizzleBuf::rows % 8 == 0);
using WarpAcc = WarpAccT<warpTile.y, warpTile.x>;
using ThrdRegRowMax = ThrdRegRowMaxT<warpTile.y>;
using UniformNeedRescaleMask = Vec<uint32_t, divUp(warpTile.y, warp_size)>;
KernelArgs const& args;
SharedMemB& smem;
uint32_t const maxNbSubSeq;
uint32_t const idxReq;
uint32_t const idxInputTokenGlobal;
uint32_t const nbSubSeq;
uint32_t const idxSubSeq;
uint32_t const seqLen;
uint32_t const ctaRank;
uint32_t const warpRank;
uint2 const warpIdx;
__device__ inline uint32_t iterStride() const
{
return nbSubSeq * nbProducerCtasPerCga;
}
__device__ inline uint32_t idxTileBeg() const
{
return nbProducerCtasPerCga * idxSubSeq;
}
__device__ inline uint32_t nbTiles() const
{
return divUp(seqLen, tokensPerTile);
}
__device__ inline uint32_t idxConsumer() const
{
return ctaRank - 2;
}
__device__ inline Consumer(KernelArgs const& args, SharedMemB& smem, uint32_t const maxNbSubSeq,
uint32_t const idxReq, uint32_t const idxInputTokenGlobal, uint32_t const seqLen, uint32_t const nbSubSeq,
uint32_t const idxSubSeq, uint32_t ctaRank, uint32_t const warpRank, uint2 const warpIdx)
: args(args)
, smem(smem)
, maxNbSubSeq(maxNbSubSeq)
, idxReq(idxReq)
, idxInputTokenGlobal(idxInputTokenGlobal)
, seqLen(seqLen)
, nbSubSeq(nbSubSeq)
, idxSubSeq(idxSubSeq)
, ctaRank(ctaRank)
, warpRank(warpRank)
, warpIdx(warpIdx)
{
#ifndef NDEBUG
if (threadIdx.x == 0)
{
asm("st.bulk.weak [%0], %1, 0;\n" ::"l"(&smem), "n"(sizeof(SharedMemB)) : "memory");
}
__syncthreads();
#endif
if (threadIdx.x < headGrpSize)
{
for (uint32_t i = 0; i < SharedMemB::nbAccRowMaxSumCopies; i++)
{
smem.accRowMaxLog2e[i][threadIdx.x] = safeInitRowMax;
smem.accRowSum[i][threadIdx.x] = 0;
}
}
if (warpElectSync())
{
if (warpRank < nbProducerCtasPerCga)
{
init(&smem.cgaXBufProduced[warpRank], 1);
}
if (warpRank < SharedMemB::nbXBufs)
{
auto& bar = smem.xBars[warpRank];
bar.initialize(1, nbMathThrds);
bar.consumed.arrive<Scope::CTA, ArriveOrder::RELAXED>(nbMathThrds);
}
if (warpRank < SharedMemB::nbVBufs)
{
auto& bar = smem.vBars[warpRank];
bar.initialize(1, nbMathThrds);
bar.consumed.arrive<Scope::CTA, ArriveOrder::RELAXED>(nbMathThrds);
}
if (warpRank == 0)
{
init(&smem.mathWarpsBar, warp_size * nbMathWarps);
}
if (nbSubSeq > 1 && warpRank < nbMultiBlockBufs)
{
auto& b = smem.multiBlockBars[warpRank];
b.initialize(1, warp_size * multiBlockMathWarps);
b.consumed.arrive<Scope::CTA, ArriveOrder::RELAXED>(warp_size * multiBlockMathWarps);
}
}
clusterBarArrive();
clusterBarWait();
}
__device__ inline ~Consumer()
{
clusterBarArrive();
clusterBarWait();
smem.invalidateBarriers(threadIdx.x);
}
__device__ inline void run()
{
if (warpIdx.y == 2)
{
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" ::"n"(nbRegsForIOWarps));
if (warpIdx.x == 0)
{
loadX();
}
else if (warpIdx.x == 1)
{
loadV();
}
}
else
{
asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" ::"n"(nbRegsForMathWarps));
compute();
}
if (nbSubSeq > 1)
{
mergePartialOutputs(args.semaphores[idxInputTokenGlobal],
reinterpret_cast<Vec<OutputHead, PartialResult::nbRowsPerChunk>&>(
args.output[headGrpSize * idxInputTokenGlobal + PartialResult::nbRowsPerChunk * ctaRank]),
args.partialResults + maxNbSubSeq * idxInputTokenGlobal, nbSubSeq, ctaRank, warpRank, warpIdx, &smem);
}
}
__device__ inline void loadX();
__device__ inline void loadV();
__device__ inline void compute();
__device__ inline uint32_t iterToTile(uint32_t iter) const
{
return idxTileBeg() + iterStride() * (iter / 2) + iter % 2;
}
__device__ inline SharedMemA& getProducerShm(uint32_t idxProducer) const
{
return mapa(reinterpret_cast<SharedMemA&>(smem), idxProducer);
}
using WarpOutputTile = Array2D<uint32_t, InstAcc::rows * Consumer::WarpAcc::rows, Consumer::WarpAcc::cols>;
__device__ inline WarpOutputTile finalize(
WarpAcc const& acc, ThrdRegRowMax const& accRowSum, float xvScale, uint32_t lane = laneId());
__device__ inline void storeOutput(Vec<OutputHead, warpTile.y>& dst, uint32_t dstBaseCol,
WarpOutputTile const& regTile, WarpOutSwizzleBuf& swizzleBuf, uint32_t lane = laneId());
};
__device__ inline void Consumer::compute()
{
uint2 const tileIdx = {warpIdx.y, warpIdx.x};
uint2 const tileBase = {tileIdx.x * warpTile.x, tileIdx.y * warpTile.y};
constexpr uint32_t tileNbInstK = exactDiv(tokensPerTile, qmmaShape.k);
constexpr uint32_t warpTileNbAtomBx2 = exactDiv(warpTile.x, qmmaShape.n * 2);
uint32_t const lane = laneId();
uint32_t const idxHalf = lane / 16;
uint32_t const laneInHalf = lane % 16;
uint32_t const rA = laneInHalf;
uint32_t const cA = idxHalf;
uint32_t const rB = lane;
uint32_t const cB = 0;
WarpAcc acc{};
uint32_t idxXVBufLast{};
for (uint32_t iter = 0; true; iter++)
{
uint32_t const idxTile = iterToTile(iter);
if (idxTile >= nbTiles())
{
break;
}
ThrdRegRowMax accRowMaxLog2e = loadShmRowMax<warpTile.y>(smem.accRowMaxLog2e[tileIdx.x], tileBase.y, lane);
ThrdRegRowMax accRowSum = loadShmRowMax<warpTile.y>(smem.accRowSum[tileIdx.x], tileBase.y, lane);
uint32_t const idxXBuf = iter % SharedMemB::nbXBufs;
uint32_t const idxVBuf = iter % SharedMemB::nbVBufs;
auto& xBar = smem.xBars[idxXBuf];
auto& vBar = smem.vBars[idxVBuf];
// @fixme: merge these two barriers and use test_wait_parity() early to avoid latency.
bool const skipVBarWait = vBar.produced.test_wait_parity(toParity<SharedMemB::nbVBufs>(iter));
xBar.produced.wait_parity(toParity<SharedMemB::nbXBufs>(iter));
ThrdRegRowMax const xRowMaxLog2e = loadShmRowMax<warpTile.y>(smem.xRowMaxLog2e(idxXBuf), tileBase.y, lane);
assert(all(accRowMaxLog2e <= xRowMaxLog2e));
auto const needRescaleVec = (xRowMaxLog2e > accRowMaxLog2e);
UniformNeedRescaleMask rescaleMask{};
#pragma unroll
for (uint32_t i = 0; i < rescaleMask.size; i++)
{
rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]);
}
bool const anyNeedRescale = any(rescaleMask != UniformNeedRescaleMask::filled(0));
if (anyNeedRescale)
{
auto const scaleVec = exp2f(accRowMaxLog2e - xRowMaxLog2e);
#pragma unroll
for (uint32_t m = 0; m < WarpAcc::rows; m++)
{
#pragma unroll
for (uint32_t i = 0; i < InstAcc::rows; i++)
{
uint8_t const mask = reinterpret_cast<uint8_t const(&)[2][2]>(rescaleMask[m / 2])[m % 2][i];
bool const needRescale = (mask != 0);
if (needRescale)
{ // this branch is warp-uniform
float const scale = __shfl_sync(~0U, scaleVec[m / 2], 16 * (m % 2) + 8 * i + lane / 4);
#pragma unroll
for (uint32_t n = 0; n < WarpAcc::cols; n++)
{
#pragma unroll
for (uint32_t j = 0; j < InstAcc::cols; j++)
{
acc(m, n)(i, j) *= scale;
}
}
}
}
}
accRowSum = accRowSum * scaleVec;
}
accRowMaxLog2e = xRowMaxLog2e;
storeRowMax<warpTile.y>(smem.accRowMaxLog2e[tileIdx.x], accRowMaxLog2e, tileBase.y, lane);
if (!skipVBarWait)
{
vBar.produced.wait_parity(toParity<SharedMemB::nbVBufs>(iter));
}
auto const& xBuf = smem.x(idxXBuf);
auto const& vBuf = smem.v(idxVBuf)[tileIdx.x];
auto const xRowSum = loadShmRowMax<warpTile.y>(smem.xRowSum(idxXBuf), tileBase.y, lane);
accRowSum = accRowSum + xRowSum;
storeRowMax<warpTile.y>(smem.accRowSum[tileIdx.x], accRowSum, tileBase.y, lane);
#pragma unroll
for (uint32_t idxInstK = 0; idxInstK < tileNbInstK; idxInstK++)
{
Mat16x32Loader const loaderX(xBuf, tileBase.y, idxInstK, rA, cA);
Vec<Mat16x32, exactDiv(warpTile.y, qmmaShape.m)> const x = loaderX.loadWholeCol<warpTile.y>();
using AtomB = Vec<uint32_t, 2>;
#pragma unroll
for (uint32_t idxAtomBx2 = 0; idxAtomBx2 < warpTileNbAtomBx2; idxAtomBx2++)
{
auto const data
= ldmatrix_16x16_trans<2>(&vBuf.template at<true>(qmmaShape.k * idxInstK + rB, idxAtomBx2 + cB));
AtomB const v[2] = {data[0], data[2], data[1], data[3]};
#pragma unroll
for (uint32_t i = 0; i < WarpAcc::rows; i++)
{
#pragma unroll
for (uint32_t j = 0; j < 2; j++)
{
#if 1
mma<__nv_fp8_e4m3>(
#else
mmaF8_k32_2inst(
#endif
reinterpret_cast<float(&)[2][2]>(acc(i, 2 * idxAtomBx2 + j)),
reinterpret_cast<uint32_t const(&)[2][2]>(x[i]),
reinterpret_cast<uint32_t const(&)[2][1]>(v[j]));
}
}
}
}
bool const isLastIter = (iterToTile(iter + 1) >= nbTiles());
if (isLastIter)
{
idxXVBufLast = idxXBuf;
assert(idxXBuf == idxVBuf);
}
else
{
xBar.consumed.arrive();
vBar.consumed.arrive();
}
}
smem.mathWarpsBar.arrive();
ThrdRegRowMax const accRowSum = loadShmRowMax<warpTile.y>(smem.accRowSum[tileIdx.x], tileBase.y, lane);
float const xvScale = computeRowSumFromF8 ? args.kvCacheScale[0] : args.kvCacheScale[0] * xScale;
WarpOutputTile const output = finalize(acc, accRowSum, xvScale, lane);
bool const isMultiBlockMode = (nbSubSeq != 1);
static_assert(PartialResult::nbRowsPerChunk == warpTile.y);
auto& dst = isMultiBlockMode
? args.partialResults[maxNbSubSeq * idxInputTokenGlobal + idxSubSeq].chunks[tileIdx.y].data
: reinterpret_cast<Vec<OutputHead, warpTile.y>&>(args.output[headGrpSize * idxInputTokenGlobal + tileBase.y]);
assert(warpRank < nbMathWarps);
WarpOutSwizzleBuf& swizzleBuf
= reinterpret_cast<Vec<WarpOutSwizzleBuf, nbWarpOutSwizzleBuf>&>(smem.xv[idxXVBufLast])[warpRank];
// make sure all math warps have finished using XVBuffer.
smem.mathWarpsBar.wait_parity(false);
storeOutput(dst, gemm1V * idxConsumer() + tileBase.x, output, swizzleBuf, lane);
if (isMultiBlockMode && tileIdx.x == 0)
{
ThrdRegRowMax const accRowMaxLog2e
= loadShmRowMax<warpTile.y>(smem.accRowMaxLog2e[tileIdx.x], tileBase.y, lane);
auto& chunk = args.partialResults[maxNbSubSeq * idxInputTokenGlobal + idxSubSeq].chunks[tileIdx.y];
#pragma unroll
for (uint32_t i = 0; i < ThrdRegRowMax::size; i++)
{
chunk.rowMaxLog2e[warp_size * i + lane] = accRowMaxLog2e[i];
chunk.rowSum[warp_size * i + lane] = accRowSum[i];
}
}
smem.xBars[idxXVBufLast].consumed.arrive();
smem.vBars[idxXVBufLast].consumed.arrive();
}
__device__ inline void Consumer::loadX()
{
#pragma unroll 1
for (uint32_t iter = 0; true; iter++)
{
uint32_t const idxTile = iterToTile(iter);
if (idxTile >= nbTiles())
{
break;
}
// @todo: merge these two barriers.
uint32_t const idxScratchXBuf = iter % nbProducerCtasPerCga;
auto& srcProducedBar = smem.cgaXBufProduced[idxScratchXBuf];
srcProducedBar.wait_parity(toParity<nbProducerCtasPerCga>(iter));
uint32_t const idxXBuf = iter % SharedMemB::nbXBufs;
auto& xBar = smem.xBars[idxXBuf];
xBar.consumed.wait_parity(toParity<SharedMemB::nbXBufs>(iter));
if (warpElectSync())
{
auto& src = args.cgaXBuf[nbSubSeq * idxInputTokenGlobal + idxSubSeq][idxScratchXBuf];
auto& dst = smem.xv[idxXBuf].x;
tma::loadLinearAsync(&dst, &src.x, sizeof(CgaXBuffer), xBar.produced);
xBar.produced.arrive_tx(sizeof(CgaXBuffer));
xBar.produced.wait_parity(toParity<SharedMemB::nbXBufs>(iter));
uint32_t const idxProducer = idxScratchXBuf;
// @fixme: check if this works. If it doesn't, randomly pick some data from dstX and dstRowSum and use
// STAS + arrive_tx to avoid fence.
getProducerShm(idxProducer).cgaXBufConsumed.arrive<Scope::CGA, ArriveOrder::RELAXED>();
}
}
}
__device__ inline void Consumer::loadV()
{
KVTilePartLoader loader(args.cacheList, idxReq, args.tensorMapV
#if USE_PAGED_KV_CACHE
,
divUp(seqLen, tokensPerPage)
#endif
);
for (uint32_t iter = 0; true; iter++)
{
uint32_t const idxTile = iterToTile(iter);
if (idxTile >= nbTiles())
{
break;
}
uint32_t const idxPageBuf = iter % KVTilePartLoader::nbPageBuffers;
loader.loadPages(idxTile, idxPageBuf);
uint32_t const idxVBuf = iter % SharedMemB::nbVBufs;
auto& vBar = smem.vBars[idxVBuf];
vBar.consumed.wait_parity(toParity<SharedMemB::nbVBufs>(iter));
#pragma unroll
for (uint32_t idxPart = 0; idxPart < SharedMemB::VBuffer::size; idxPart++)
{
loader.loadData(smem.v(idxVBuf)[idxPart], idxTile,
gemm1V * idxConsumer() + exactDiv(gemm1V, SharedMemB::VBuffer::size) * idxPart, vBar.produced,
idxPageBuf);
}
if (warpElectSync())
{
vBar.produced.arrive_tx(sizeof(SharedMemB::VBuffer));
}
}
}
__device__ inline Array2D<uint32_t, InstAcc::rows * Consumer::WarpAcc::rows, Consumer::WarpAcc::cols>
Consumer::finalize(WarpAcc const& acc, ThrdRegRowMax const& accRowSum, float const xvScale, uint32_t const lane)
{
ThrdRegRowMax const scaleVec = 1.F / (accRowSum) *xvScale;
WarpOutputTile ret;
#pragma unroll
for (uint32_t m = 0; m < WarpAcc::rows; m++)
{
#pragma unroll
for (uint32_t i = 0; i < InstAcc::rows; i++)
{
uint32_t retRow = m * InstAcc::rows + i;
float const scale = __shfl_sync(~0U, scaleVec[m / 2], 16 * (m % 2) + 8 * i + lane / 4);
#pragma unroll
for (uint32_t n = 0; n < WarpAcc::cols; n++)
{
float data[InstAcc::cols];
#pragma unroll
for (uint32_t j = 0; j < InstAcc::cols; j++)
{
data[j] = acc(m, n)(i, j) * scale;
}
assert(InstAcc::cols == 2);
reinterpret_cast<__nv_bfloat162&>(ret(retRow, n)) = __float22bfloat162_rn(float2{data[0], data[1]});
}
}
}
return ret;
}
__device__ inline void Consumer::storeOutput(Vec<OutputHead, warpTile.y>& dst, uint32_t dstBaseCol,
WarpOutputTile const& src, WarpOutSwizzleBuf& swizzleBuf, uint32_t lane)
{
using Dst = mha::decay_t<decltype(dst)>;
static_assert(Dst::size == WarpOutputTile::rows * 8 && Dst::size % WarpOutSwizzleBuf::rows == 0);
uint32_t const nbIters = exactDiv(Dst::size, WarpOutSwizzleBuf::rows);
uint32_t const rS = lane % 8;
uint32_t const cS = lane / 8;
uint32_t const thrdsPerRow = exactDiv(sizeof(WarpOutSwizzleBuf::Elem) * WarpOutSwizzleBuf::cols, grainBytes);
static_assert(thrdsPerRow <= 32);
uint32_t const rL = lane / thrdsPerRow;
uint32_t const cL = lane % thrdsPerRow;
#pragma unroll
for (uint32_t iter = 0; iter < nbIters; iter++)
{
#pragma unroll
for (uint32_t j = 0; j < WarpOutputTile::cols; j += 4)
{
auto const baseSwzPtr = &swizzleBuf.template at<true>(rS, j + cS);
constexpr uint32_t srcRowsPerIter = exactDiv(WarpOutputTile::rows, nbIters);
#pragma unroll
for (uint32_t i = 0; i < srcRowsPerIter; i++)
{
static_assert(sizeof(WarpOutSwizzleBuf::Elem) * WarpOutSwizzleBuf::cols * 8 % 1024 == 0);
auto const swzPtr = checkedVal(
baseSwzPtr + WarpOutputTile::cols * 8 * i, &swizzleBuf.template at<true>(8 * i + rS, j + cS));
stmatrix<false, 4>(
swzPtr, reinterpret_cast<Vec<uint32_t, 4> const&>(src(srcRowsPerIter * iter + i, j)));
}
}
__syncwarp();
uint32_t const dstRowsPerIter = WarpOutSwizzleBuf::rows;
uint32_t const rowsPerOp = exactDiv(warp_size, thrdsPerRow);
LdGrain* const baseDstPtr = reinterpret_cast<LdGrain*>(
&dst[dstRowsPerIter * iter + rL][dstBaseCol + exactDiv(grainBytes, sizeof(OutputElem)) * cL]);
#pragma unroll
for (uint32_t i = 0; i < dstRowsPerIter; i += rowsPerOp)
{
LdGrain* const dstPtr = checkedVal(baseDstPtr + i * exactDiv(sizeof(OutputHead), grainBytes),
reinterpret_cast<LdGrain*>(
&dst[dstRowsPerIter * iter + i + rL][dstBaseCol + exactDiv(grainBytes, sizeof(OutputElem)) * cL]));
LdGrain* const srcPtr = &swizzleBuf.template at<true>(i + rL, cL);
*dstPtr = *srcPtr;
}
__syncwarp();
}
}
__device__ inline void mergePartialOutputs(uint32_t& semaphore, Vec<OutputHead, PartialResult::nbRowsPerChunk>& dst,
PartialResult const* reqPartialResults, uint32_t nbSubSeq, uint32_t ctaRank, uint32_t warpRank, uint2 warpIdx,
void* sharedMem)
{
assert(nbSubSeq > 1);
clusterBarArrive();
clusterBarWait();
bool const isProducer = (ctaRank < nbProducerCtasPerCga);
bool& shmIsLastSubSeq = isProducer ? static_cast<SharedMemA*>(sharedMem)->isLastSubSeq
: static_cast<SharedMemB*>(sharedMem)->isLastSubSeq;
if (ctaRank == 3 && threadIdx.x == 0)
{
uint32_t old;
uint32_t const lastOld = nbSubSeq - 1;
asm volatile("atom.relaxed.gpu.global.inc.u32 %0, [%1], %2;\n" : "=r"(old) : "l"(&semaphore), "r"(lastOld));
bool const isLastSubSeq = (old == lastOld);
#pragma unroll
for (uint32_t i = 0; i < nbProducerCtasPerCga; i++)
{
static_cast<SharedMemA*>(mapa(sharedMem, i))->isLastSubSeq = isLastSubSeq;
}
mapa(shmIsLastSubSeq, 2) = isLastSubSeq;
shmIsLastSubSeq = isLastSubSeq;
}
clusterBarArrive();
clusterBarWait();
bool const isLastCga = shmIsLastSubSeq;
if (!isLastCga)
{
return;
}
CtaBarrierPair(&bars)[nbMultiBlockBufs] = isProducer ? static_cast<SharedMemA*>(sharedMem)->multiBlockBars
: static_cast<SharedMemB*>(sharedMem)->multiBlockBars;
Vec<PartialResult::Chunk, nbMultiBlockBufs>& shmBufs = isProducer
? static_cast<SharedMemA*>(sharedMem)->getMultiBlockBufs()
: static_cast<SharedMemB*>(sharedMem)->getMultiBlockBufs();
constexpr uint32_t nbShmBufs = nbMultiBlockBufs;
if (warpIdx.y == 2)
{
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" ::"n"(nbRegsForIOWarps));
if (warpIdx.x == 0)
{
#pragma unroll 1
for (uint32_t idxSubSeq = 0; idxSubSeq < nbSubSeq; idxSubSeq++)
{
uint32_t const idxBuf = idxSubSeq % nbShmBufs;
auto& bar = bars[idxBuf];
bar.consumed.wait_parity(toParity<nbShmBufs>(idxSubSeq));
if (warpElectSync())
{
tma::loadLinearAsync(&shmBufs[idxBuf], &reqPartialResults[idxSubSeq].chunks[ctaRank],
sizeof(PartialResult::Chunk), bar.produced);
bar.produced.arrive_tx(sizeof(PartialResult::Chunk));
}
}
}
}
else
{
asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" ::"n"(nbRegsForMathWarps));
constexpr uint32_t nbMathWarps = 8;
constexpr uint32_t rowsPerWarp = exactDiv(PartialResult::nbRowsPerChunk, nbMathWarps);
constexpr uint32_t regGrainsPerRow = exactDiv(sizeof(OutputHead), grainBytes * warp_size);
constexpr uint32_t grainOutElems = exactDiv(grainBytes, sizeof(OutputElem));
uint32_t const lane = laneId();
uint32_t const tileRowBase = rowsPerWarp * warpRank;
using RowWise = Vec<float, rowsPerWarp>;
using RegChunk = Array2D<Vec<OutputElem, grainOutElems>, rowsPerWarp, regGrainsPerRow>;
auto loadBuf = [&](RowWise& rowMaxLog2e, RowWise& rowSum, RegChunk& regChunk, PartialResult::Chunk const& chunk)
{
auto loadRowWise = [&](Vec<float, PartialResult::nbRowsPerChunk> const& src)
{ return reinterpret_cast<RowWise const&>(src[tileRowBase]); };
rowMaxLog2e = loadRowWise(chunk.rowMaxLog2e);
rowSum = loadRowWise(chunk.rowSum);
regChunk;
#pragma unroll
for (uint32_t i = 0; i < rowsPerWarp; i++)
{
#pragma unroll
for (uint32_t j = 0; j < regGrainsPerRow; j++)
{
regChunk(i, j) = reinterpret_cast<Vec<OutputElem, grainOutElems> const&>(
chunk.data[tileRowBase + i][grainOutElems * (warp_size * j + lane)]);
}
}
};
uint32_t const idxSubSeqInit = 0;
uint32_t const idxBufInit = idxSubSeqInit % nbShmBufs;
bars[idxBufInit].produced.wait_parity(toParity<nbShmBufs>(idxSubSeqInit));
RowWise accRowMaxLog2e;
RowWise accRowSum;
RegChunk chunk;
loadBuf(accRowMaxLog2e, accRowSum, chunk, shmBufs[idxBufInit]);
bars[idxBufInit].consumed.arrive();
using Acc = Array2D<Vec<float, grainOutElems>, rowsPerWarp, regGrainsPerRow>;
Acc acc;
#pragma unroll
for (uint32_t i = 0; i < rowsPerWarp; i++)
{
#pragma unroll
for (uint32_t j = 0; j < regGrainsPerRow; j++)
{
acc(i, j) = convert<float>(chunk(i, j)) * accRowSum[i];
}
}
#pragma unroll 1
for (uint32_t idxSubSeq = idxSubSeqInit + 1; idxSubSeq < nbSubSeq; idxSubSeq++)
{
uint32_t const idxBuf = idxSubSeq % nbShmBufs;
auto& bar = bars[idxBuf];
bar.produced.wait_parity(toParity<nbShmBufs>(idxSubSeq));
RowWise chunkRowMaxLog2e;
RowWise chunkRowSum;
loadBuf(chunkRowMaxLog2e, chunkRowSum, chunk, shmBufs[idxBuf]);
bar.consumed.arrive();
#pragma unroll
for (uint32_t i = 0; i < rowsPerWarp; i++)
{
bool const newChunkGreater = (chunkRowMaxLog2e[i] > accRowMaxLog2e[i]);
if (newChunkGreater)
{
float const scale = exp2f(accRowMaxLog2e[i] - chunkRowMaxLog2e[i]);
#pragma unroll
for (uint32_t j = 0; j < regGrainsPerRow; j++)
{
acc(i, j) = acc(i, j) * scale + convert<float>(chunk(i, j)) * chunkRowSum[i];
}
accRowSum[i] = accRowSum[i] * scale + chunkRowSum[i];
accRowMaxLog2e[i] = chunkRowMaxLog2e[i];
}
else
{
float const scale = exp2f(chunkRowMaxLog2e[i] - accRowMaxLog2e[i]);
float const fusedScale = scale * chunkRowSum[i];
#pragma unroll
for (uint32_t j = 0; j < regGrainsPerRow; j++)
{
acc(i, j) = acc(i, j) + convert<float>(chunk(i, j)) * fusedScale;
}
accRowSum[i] = accRowSum[i] + chunkRowSum[i] * scale;
}
}
}
#pragma unroll
for (uint32_t i = 0; i < rowsPerWarp; i++)
{
float const scale = 1.F / accRowSum[i];
auto const dstHead = reinterpret_cast<Vec<OutputElem, grainOutElems>*>(&dst[tileRowBase + i]);
#pragma unroll
for (uint32_t j = 0; j < regGrainsPerRow; j++)
{
dstHead[warp_size * j + lane] = convert<OutputElem>(acc(i, j) * scale);
}
}
}
}
inline constexpr uint32_t cgaSize = nbProducerCtasPerCga + nbVSplit;
CUBIN_EXPORT __global__ __launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha(
__grid_constant__ CUtensorMap const tensorMapQ, // MhaIOHead[nbQHeads * totalNbInputTokens],
__grid_constant__ CUtensorMap const tensorMapK, // with box=64 for the least significant dim
__grid_constant__ CUtensorMap const tensorMapV, // with box=128 for the least significant dim
float const qScale,
OutputHead* __restrict__ const output, // [totalNbIntputTokens][nbQHeads]
KVCacheList<usePagedKVCache> const cacheList, uint32_t const batchSize,
float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for
// int8/fp8 KV cache.
Vec<CgaXBuffer, nbProducerCtasPerCga>* __restrict__ const cgaXBuf, // [totalNbInputTokens][maxNbSubSeq]
uint32_t* __restrict__ const semaphores = nullptr, // [totalNbInputTokens]
PartialResult* __restrict__ const partialResults = nullptr) // [totalNbInputTokens][maxNbSubSeq]
{
assert(blockDim.x == 32 * 12 && blockDim.y == 1 && blockDim.z == 1);
extern __shared__ char smemBuf[];
uint32_t const warpRank = makeWarpUniform(this_warp(), threadIdx.x / warp_size);
uint2 const warpIdx = {warpRank % 4, warpRank / 4};
uint3 const& cgaId = clusterId();
uint32_t const& idxReq = cgaId.z;
uint32_t const& maxNbSubSeq = nbClusters().y;
uint32_t const& idxSubSeq = cgaId.y;
uint32_t const inputSeqLen
= (allowMultipleInputTokens ? exactDiv(gridDim.x, cgaSize) : checkedVal(1U, exactDiv(gridDim.x, cgaSize)));
uint32_t const reqIdxInputToken
= (allowMultipleInputTokens ? blockIdx.x / cgaSize : checkedVal(0U, blockIdx.x / cgaSize));
uint32_t const idxInputTokenGlobal = inputSeqLen * idxReq + reqIdxInputToken;
uint32_t const cacheSeqLen = cacheList.seqLenList[idxReq] - (inputSeqLen - 1) + reqIdxInputToken;
assert(beamWidth == 1);
uint32_t const nbTiles = useKVCache ? divUp(cacheSeqLen, tokensPerTile) : 0;
bool const isMultiBlockMode = (maxNbSubSeq > 1 && nbTiles >= multiBlockMinNbTiles);
uint32_t const nbSubSeq = isMultiBlockMode ? mha::min(nbTiles / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1;
static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2);
assert(isMultiBlockMode == (nbSubSeq > 1));
if (idxSubSeq >= nbSubSeq)
{
return;
}
uint32_t const ctaRank = clusterCtaRank();
bool const isProducer = (ctaRank < nbProducerCtasPerCga);
KernelArgs const args{tensorMapQ, tensorMapK, tensorMapV, qScale, output, cacheList, batchSize, kvCacheScale,
cgaXBuf, semaphores, partialResults};
if (isProducer)
{
Producer{args, *reinterpret_cast<SharedMemA*>(smemBuf), maxNbSubSeq, idxReq, idxInputTokenGlobal, cacheSeqLen,
nbSubSeq, idxSubSeq, ctaRank, warpRank, warpIdx}
.run();
}
else
{
Consumer{args, *reinterpret_cast<SharedMemB*>(smemBuf), maxNbSubSeq, idxReq, idxInputTokenGlobal, cacheSeqLen,
nbSubSeq, idxSubSeq, ctaRank, warpRank, warpIdx}
.run();
}
}
__constant__ constexpr uint32_t smemSize = mha::max(sizeof(SharedMemA), sizeof(SharedMemB));
static_assert(smemSize <= 99 * 1024, "Shared memory size exceeded");
#endif // is_MLA
#ifndef GENERATE_CUBIN
#if IS_MLA
CUtensorMap makeTensorMapForQ(
void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, uint32_t totalNbHeads, uint32_t partElems)
{
CUtensorMap tensorMap{};
uint64_t const globalDims[] = {headElems, totalNbHeads};
uint32_t elemBytes = getElemBytes(dataType);
uint32_t const headBytes = elemBytes * headElems;
uint64_t const globalStrides[] = {headBytes};
uint32_t const boxDims[] = {partElems, headGrpSize};
uint32_t const elemStrides[] = {1, 1};
auto const swizzle = CU_TENSOR_MAP_SWIZZLE_64B;
checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 2, const_cast<void*>(addr), globalDims, globalStrides, boxDims,
elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
return tensorMap;
}
#endif // IS_MLA
void launchMLA(cudaDeviceProp const& prop,
uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed
float qScale, OutputHead* output, InputHead const* q,
#if USE_PAGED_KV_CACHE
#if PAGED_KV_CACHE_LAYOUT == 1
GMemCacheHead* kCacheVLLM, // K cache pool for VLLM layout
GMemCacheHead* vCacheVLLM, // V cache pool for VLLM layout
#else
GMemCacheHead* pool, // global pool of pages
#endif
KVCachePageIndex const*
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or
// [batchSize][maxNbPagesPerSeq] (Layout 1)
#else
GMemKVCacheHead* kvCacheData,
#endif
uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize,
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for
// int8/fp8 KV cache.
uint32_t* semaphores, void* scratch, cudaStream_t stream)
{
#if IS_MLA
static_assert(
SLIDING_WINDOW == 0 && LOW_PREC_OUTPUT == 0 && USE_INPUT_KV == 0 && USE_BEAM_SEARCH == 0, "not implemented");
if (beamWidth != 1)
{
throw std::runtime_error("not implemented");
}
static uint32_t const hostSmemSize = [&]()
{
// printf("smemSize = %u\n", smemSize);
uint32_t size;
checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize)));
checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size));
return size;
}();
uint32_t const nbKHeads = 1;
uint32_t const nbVHeads = nbKHeads;
uint32_t const nbQHeads = nbKHeads * headGrpSize;
uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads;
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t
{
auto const env = std::getenv("XQA_NB_SUB_SEQ");
if (env != nullptr)
{
int32_t const val = std::stoi(env);
if (val > 0)
{
return val;
}
}
float const factor = 4.f;
return mha::min<uint32_t>(
mha::max<uint32_t>(1U, (uint32_t) round(prop.multiProcessorCount / 4 / (batchSize * nbKHeads) * factor)),
divUp(maxSeqLen, tokensPerTile * 2));
}();
// printf("nbSubSeqPerSeq = %u\n", nbSubSeqPerSeq);
// gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x == nbInputSeqSplit
dim3 const dimGrid{4 * inputSeqLen, nbSubSeqPerSeq, nbKHeads * batchSize};
dim3 const dimCta{warp_size * 4 * 3, 1, 1};
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
#if USE_PAGED_KV_CACHE
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
#if PAGED_KV_CACHE_LAYOUT == 1
KVCacheList<true> const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq};
#else
KVCacheList<true> const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq};
#endif
auto const dtype = []
{
if (std::is_same_v<CacheElem, half>)
{
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
}
else if (std::is_same_v<CacheElem, __nv_bfloat16>)
{
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
}
else if (std::is_same_v<CacheElem, __nv_fp8_e4m3>)
{
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
}
throw std::runtime_error("unsupported cache element type");
}();
auto const tensorMapQ
= makeTensorMapForQ(q, dtype, validElemsPerHead, headGrpSize * inputSeqLen * batchSize, partElemsK);
#if PAGED_KV_CACHE_LAYOUT == 1
auto const tensorMapK = makeTensorMapForPagedKVCache(
kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile);
auto const tensorMapV = makeTensorMapForPagedKVCache(
vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile);
#else
auto const tensorMapK = makeTensorMapForPagedKVCache(
pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsK, tokensPerTile);
auto const tensorMapV = makeTensorMapForPagedKVCache(
pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, partElemsV, tokensPerTile);
#endif
uint32_t const nbCgas = exactDiv(dimGrid.x, 4) * dimGrid.y * dimGrid.z;
auto const cgaXBuf = static_cast<Vec<CgaXBuffer, nbProducerCtasPerCga>*>(scratch);
auto const partialResults = reinterpret_cast<PartialResult*>(cgaXBuf + nbCgas);
cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, tensorMapQ, tensorMapK, tensorMapV, qScale,
output, cacheList, batchSize, kvCacheScale, cgaXBuf, semaphores, partialResults);
#else
KVCacheList<false> const cacheList{kvCacheData, seqLen, maxSeqLen};
static_assert(!usePagedKVCache);
assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens);
auto const tensorMap = makeTensorMapForContiguousKVCache(kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8,
validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, batchSize, gemm0CtaTileNbTokens);
cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads,
#if SLIDING_WINDOW
slidingWinSize,
#endif
qScale, output,
#if LOW_PREC_OUTPUT
rcpOutScale,
#endif
#if USE_INPUT_KV
qkv,
#if ROPE_STYLE != 0
ropeCosSin,
#endif
#else
q,
#endif
cacheList,
#if USE_BEAM_SEARCH
beamSearchParams,
#endif
batchSize, kvCacheScale, tensorMap, semaphores, scratch);
#endif
checkCuda(err);
#endif
}
#endif