TensorRT-LLMs/cpp/kernels/xqa/mhaUtils.cuh
Jinyang Yuan 20d0649f19
[feat] Support XQA-based MLA on SM120 (#4858)
Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com>
Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
Co-authored-by: Yao Yao <lowsfer@users.noreply.github.com>
Co-authored-by: peaceh-nv <103117813+peaceh-nv@users.noreply.github.com>
2025-06-06 22:32:49 +08:00

485 lines
18 KiB
Plaintext

/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include "ldgsts.cuh"
#include "mha.h"
#include "utils.cuh"
// for beam search
template <typename Head, uint32_t tokensPerPage, uint32_t nbPages>
struct IndexedHeadPtrImpl
{
static_assert(tokensPerPage != 0 && nbPages != 0);
uint32_t const* indices; // values are in range [0, beamWidth)
Head* pool;
Vec<KVCachePageIndex, nbPages> const* pageIndices;
uint32_t nbKHeads;
uint32_t offset; // applied onto pool + pointers
__device__ inline Head& operator[](uint32_t i) const
{
return *(*this + i);
}
__device__ inline Head* operator+(uint32_t i) const
{
assert(indices[i] < beamWidth);
assert(nbPages == 1 || offset % tokensPerPage == 0);
auto const pageIdx = pageIndices[indices[i]][nbPages == 1 ? 0U : i / tokensPerPage];
return pool + (tokensPerPage * nbKHeads * pageIdx + offset + i % tokensPerPage);
}
};
template <typename Head>
struct IndexedHeadPtrImpl<Head, 0, 0>
{
uint32_t const* indices; // values are in range [0, beamWidth)
Head* pointer;
uint32_t offset;
uint32_t beamStride;
__device__ inline Head& operator[](uint32_t i) const
{
return *(*this + i);
}
__device__ inline Head* operator+(uint32_t i) const
{
assert(indices[i] < beamWidth);
return pointer + (beamStride * indices[i] + offset + i);
}
};
template <typename Head, uint32_t tokensPerPage, uint32_t nbPages = 0>
using IndexedHeadPtr = IndexedHeadPtrImpl<Head, tokensPerPage, nbPages>;
// for beamWidth = 1
template <typename Head, uint32_t tokensPerPage, uint32_t nbPages>
struct HeadPtr
{
static_assert(tokensPerPage != 0 && nbPages != 0);
Head* pool;
Vec<KVCachePageIndex, nbPages> pageIndices;
uint32_t nbKHeads;
uint32_t offset; // offset inside the first page.
__device__ inline Head& operator[](uint32_t i) const
{
return *(*this + i);
}
__device__ inline Head* operator+(uint32_t i) const
{
assert(nbPages == 1 || offset % tokensPerPage == 0);
auto const pageIdx = pageIndices[nbPages == 1 ? 0U : i / tokensPerPage];
return (pageIdx & (1U << 31)) ? nullptr
: pool + (tokensPerPage * nbKHeads * pageIdx + offset + i % tokensPerPage);
}
};
template <typename Head>
struct HeadPtr<Head, 0, 0> : TinyPtr<Head>
{
};
// template <typename Head>
// #if BEAM_WIDTH == 1
// using SrcHeadPtr = TinyPtr<Head const>;
// #else
// using SrcHeadPtr = IndexedHeadPtr<Head>;
// #endif
// @fixme: give evict first hint for last part.
template <typename Head, uint32_t maxNbCopiedHeads, uint32_t nbPartsPerHead, bool swizzle, bool isFull,
uint32_t dstNbHeads, typename SrcHeadPtr, typename LocalHeadIdxMap = uint32_t (*)(uint32_t)>
__device__ inline void copyPartialHeadsAsync(
Warp const& warp, Array2D<LdGrain, dstNbHeads, exactDiv(exactDiv(sizeof(Head), nbPartsPerHead), grainBytes)>& dst,
uint32_t dstHeadOffset, SrcHeadPtr const& src, uint32_t idxPart, uint32_t nbAvailHeads = maxNbCopiedHeads,
LocalHeadIdxMap&& localHeadIdxMap = [](uint32_t x) { return x; })
{
static_assert(maxNbCopiedHeads <= dstNbHeads);
assert(idxPart < nbPartsPerHead);
assert(dstHeadOffset + maxNbCopiedHeads <= dstNbHeads);
assert(sizeof(Head) * (src.offset + maxNbCopiedHeads) <= (1UL << 32));
assert(!isFull || nbAvailHeads >= maxNbCopiedHeads);
constexpr uint32_t headBytes = sizeof(Head);
constexpr uint32_t partBytes = exactDiv(headBytes, nbPartsPerHead);
constexpr uint32_t warpLdBytes = partBytes * maxNbCopiedHeads;
constexpr uint32_t thrdLdBytes = exactDiv(warpLdBytes, warp_size);
assertIsPowerOf2<thrdLdBytes>();
static_assert(thrdLdBytes >= grainBytes);
// a segment is responsible for loading one partial head collaboratively
constexpr uint32_t thrdsPerSeg = exactDiv(partBytes, grainBytes);
static_assert(thrdsPerSeg > 0 && thrdsPerSeg <= warp_size);
assertIsPowerOf2<thrdsPerSeg>();
assert(__shfl_sync(0xFU << (laneId() / 4 * 4), src.offset, 0, 4) == src.offset);
auto const warpLane = laneId();
uint32_t const segIdx = warpLane / thrdsPerSeg;
uint32_t const segLane = warpLane % thrdsPerSeg;
constexpr uint32_t partsPerWarpInst = exactDiv(grainBytes * warp_size, partBytes);
#pragma unroll
for (uint32_t i = 0; i < thrdLdBytes / grainBytes; i++)
{
uint32_t const idxHeadLocal = partsPerWarpInst * i + segIdx;
assert(idxHeadLocal < maxNbCopiedHeads);
bool const isHeadInBound = isFull || (idxHeadLocal < nbAvailHeads);
constexpr uint32_t grainsPerPart = exactDiv(partBytes, grainBytes);
using SrcHead = mha::decay_t<decltype(src[0])>;
constexpr uint32_t nbValidGrains = exactDiv(sizeof(SrcHead), grainBytes);
uint32_t const idxGrainInsideHead = grainsPerPart * idxPart + segLane;
bool const isGrainInBound = (!isHeadPadded || idxGrainInsideHead < nbValidGrains);
SrcHead const* const pSrcHead = src + localHeadIdxMap(idxHeadLocal);
bool const isValidPage = (pSrcHead != nullptr);
LdGrain const* const pSrc = reinterpret_cast<LdGrain const*>(pSrcHead) + idxGrainInsideHead;
LdGrain* const pDst = &dst.template at<swizzle>(dstHeadOffset + idxHeadLocal, segLane);
assert(!hasBankConflict(pDst));
ldgsts::copyAsync<grainBytes>(pDst, pSrc, isValidPage && isHeadInBound && isGrainInBound ? grainBytes : 0u);
}
}
template <typename Head, uint32_t maxNbCopiedHeads, uint32_t nbWarps, bool swizzle, bool isFull, uint32_t dstNbHeads,
typename SrcHeadPtr, typename LocalHeadIdxMap = uint32_t (*)(uint32_t)>
__device__ inline void copyHeadsAsync(
uint32_t idxWarp, Array2D<LdGrain, dstNbHeads, exactDiv(sizeof(Head), grainBytes)>& dst, SrcHeadPtr const& src,
uint32_t nbAvailHeads = maxNbCopiedHeads, LocalHeadIdxMap&& localHeadIdxMap = [](uint32_t x) { return x; })
{
assert(idxWarp < nbWarps);
Warp const& warp = this_warp();
constexpr uint32_t maxNbHeadsPerWarp = exactDiv(maxNbCopiedHeads, nbWarps);
uint32_t const dstHeadOffset = maxNbHeadsPerWarp * idxWarp;
uint32_t const warpNbAvailHeads = (dstHeadOffset < nbAvailHeads ? nbAvailHeads - dstHeadOffset : 0);
constexpr uint32_t idxPart = 0;
copyPartialHeadsAsync<Head, maxNbHeadsPerWarp, 1, swizzle, isFull, dstNbHeads>(warp, dst, dstHeadOffset, src,
idxPart, warpNbAvailHeads, [&](uint32_t x) { return localHeadIdxMap(dstHeadOffset + x); });
}
template <bool isAsync, uint32_t maxTotalNbGrains, uint32_t nbWarps, bool isFull = true>
__device__ inline void copyGrains(
uint32_t idxWarp, LdGrain* dst, LdGrain const* src, uint32_t totalNbGrains = maxTotalNbGrains)
{
assert((isFull && totalNbGrains == maxTotalNbGrains) || (!isFull && totalNbGrains <= maxTotalNbGrains));
constexpr uint32_t nbThrds = warp_size * nbWarps;
uint32_t const tid = warp_size * idxWarp + laneId();
// copy output to scratch
#pragma unroll
for (uint32_t i = 0; i < divUp(maxTotalNbGrains, nbThrds); i++)
{
uint32_t const idx = nbThrds * i + tid;
if (!(isFull && maxTotalNbGrains % nbThrds == 0) && idx >= totalNbGrains)
{
break;
}
if constexpr (isAsync)
{
ldgsts::copyAsync<grainBytes>(&dst[idx], &src[idx], grainBytes);
}
else
{
dst[idx] = src[idx];
}
}
}
// with ldmatrix, what we load for fp8 cache is T0:{e0,e1,e2,e3}; T1:{e4, e5, e6, e7}; T2:{e8,e9,e10,e11}; T3:{e12, e13,
// e14, e15}; When casted to fp16, it will be T0:{e0, e1}; T1{e4, e5};... | T0:{e2, e3}; T1{e6, e7}; ... We need to
// reorder Q to match that order. isFwd=false to revert the reorder.
template <uint32_t nbWarps, bool swizzled, bool isFwd, uint32_t cols, uint32_t rows>
__device__ inline void reorder16bQHeadsToMatch8bKCache(uint32_t idxWarp, Array2D<LdGrain, rows, cols>& qHeads)
{
assert(idxWarp < nbWarps);
constexpr uint32_t nbWarpIters = exactDiv(exactDiv(cols, 2) * rows, warp_size); // warps * iters
constexpr uint32_t nbWorkingWarps = mha::min(nbWarps, nbWarpIters);
if (idxWarp >= nbWorkingWarps)
{
return;
}
static_assert(cols % 2 == 0);
uint32_t const tid = warp_size * idxWarp + laneId();
constexpr uint32_t iterCols = exactDiv(warp_size * nbWorkingWarps, rows) * 2;
static_assert(cols % iterCols == 0, "fix this by reducing nbWorkingWarps, or use divUp and add runtime check");
constexpr uint32_t nbIters = exactDiv(cols, iterCols);
static_assert(nbIters == exactDiv(nbWarpIters, nbWorkingWarps));
uint32_t const r = tid % rows;
uint32_t const cInit = tid / rows * 2;
#pragma unroll
for (uint32_t n = 0; n < nbIters; n++)
{
uint32_t const c = cInit + iterCols * n;
LdGrain const src[2] = {
qHeads.template at<swizzled>(r, c),
qHeads.template at<swizzled>(r, c + 1),
};
auto const& s = reinterpret_cast<Vec<uint32_t, LdGrain::size * 2> const&>(src);
if constexpr (isFwd)
{
qHeads.template at<swizzled>(r, c) = LdGrain{s[0], s[2], s[4], s[6]};
qHeads.template at<swizzled>(r, c + 1) = LdGrain{s[1], s[3], s[5], s[7]};
}
else
{
qHeads.template at<swizzled>(r, c) = LdGrain{s[0], s[4], s[1], s[5]};
qHeads.template at<swizzled>(r, c + 1) = LdGrain{s[2], s[6], s[3], s[7]};
}
}
}
template <bool usePagedKVCache>
struct KVCacheList;
template <>
struct KVCacheList<true>
{
GMemKVCacheHead* pool;
KVCachePageIndex const* kvCachePageList; // shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
SeqLenDataType const* seqLenList; // shape: [batchSize][beamWidth] (for compatibility)
uint32_t maxNbPagesPerSeq;
};
template <>
struct KVCacheList<false>
{
GMemKVCacheHead* data; // shape: KVCacheHead[batchSize][beamWidth][2][nbKHeads][capacity]
SeqLenDataType const* seqLenList; // shape: [batchSize][beamWidth] (for compatibility)
uint32_t capacity;
};
__device__ inline uint32_t getSeqLen(uint32_t const* seqLenList, uint32_t idxReq)
{
uint64_t cachePolicy;
asm("createpolicy.fractional.L2::evict_last.b64 %0;\n" : "=l"(cachePolicy));
uint32_t len;
asm("ld.global.nc.L1::evict_last.L2::cache_hint.L2::256B.b32 %0, [%1], %2;\n"
: "=r"(len)
: "l"(&seqLenList[idxReq * beamWidth]), "l"(cachePolicy));
for (uint32_t i = 0; i < beamWidth; i++)
{
assert(len == seqLenList[idxReq * beamWidth + i]);
}
return len;
}
template <bool isPaged>
__device__ inline uint32_t getCacheSeqLen(KVCacheList<isPaged> const& cacheList, uint32_t idxReq)
{
return getSeqLen(cacheList.seqLenList, idxReq);
}
__device__ inline uint32_t getCtxCacheSeqLen(BeamSearchParams const& beamSearchParams, uint32_t idxReq)
{
return getSeqLen(beamSearchParams.ctxLenList, idxReq);
}
template <uint32_t nbLoadedPages>
__device__ inline Vec<KVCachePageIndex, nbLoadedPages> getPage(KVCacheList<true> const& cacheList, bool isK,
uint32_t idxReq, uint32_t idxBeam, uint32_t idxPageBeg, uint32_t nbPages)
{
auto const maxNbPagesPerSeq = cacheList.maxNbPagesPerSeq;
Vec<KVCachePageIndex, nbLoadedPages> ret;
#pragma unroll
for (uint32_t i = 0; i < nbLoadedPages; i++)
{
uint32_t const idxPage = idxPageBeg + i;
ret[i] = (idxPage < nbPages ? cacheList.kvCachePageList[beamWidth * 2 * maxNbPagesPerSeq * idxReq
+ 2 * maxNbPagesPerSeq * idxBeam + maxNbPagesPerSeq * (isK ? 0U : 1U) + idxPage]
: kBAD_PAGE_INDEX);
}
return ret;
}
template <uint32_t nbWarps, uint32_t nbLoadedPages>
__device__ inline void loadPagesForBeamSearchAsync(uint32_t idxWarp,
Vec<Vec<KVCachePageIndex, nbLoadedPages>, beamWidth>& dst, KVCacheList<true> const& cacheList, bool isK,
uint32_t idxReq, uint32_t idxPageBeg, uint32_t nbPages)
{
assert(idxWarp < nbWarps);
auto const maxNbPagesPerSeq = cacheList.maxNbPagesPerSeq;
static_assert(beamWidth < warp_size);
auto const tid = warp_size * idxWarp + laneId();
auto const idxBeam = tid / nbLoadedPages;
auto const idxLoadedPage = tid % nbLoadedPages;
static_assert(warp_size * nbWarps >= beamWidth * nbLoadedPages);
if (idxBeam < beamWidth)
{
constexpr uint32_t nbBytes = sizeof(KVCachePageIndex);
uint32_t const idxPage = idxPageBeg + idxLoadedPage;
ldgsts::copyAsync<nbBytes>(&dst[idxBeam][idxLoadedPage],
&cacheList.kvCachePageList[beamWidth * 2 * maxNbPagesPerSeq * idxReq + 2 * maxNbPagesPerSeq * idxBeam
+ (isK ? 0U : maxNbPagesPerSeq) + idxPage],
idxPage < nbPages ? nbBytes : 0U);
}
}
template <uint32_t nbWarps, uint32_t length, bool isFullTile = false>
__device__ inline void loadIndicesForBeamSearchAsync(uint32_t idxWarp, Vec<uint32_t, length>& dst,
BeamSearchParams const& params, uint32_t idxReq, uint32_t idxBeam, uint32_t uniformSeqOffset, uint32_t seqLen)
{
constexpr uint32_t nbThreads = warp_size * nbWarps;
// constexpr uint32_t indicesPerInst = mha::min(exactDiv(grainBytes, sizeof(uint32_t)), divUp(length, nbThreads));
// // @fixme: std::bit_ceil on length
constexpr uint32_t indicesPerInst = 1U; // to handle unaligned case.
constexpr uint32_t bytesPerInst = sizeof(uint32_t) * indicesPerInst;
assertIsPowerOf2<indicesPerInst>();
uint32_t const capacity = params.capacity;
uint32_t const srcOffset = (idxReq * beamWidth + idxBeam) * capacity + uniformSeqOffset;
uint32_t const tid = warp_size * idxWarp + laneId();
constexpr uint32_t indicesPerIter = indicesPerInst * nbThreads;
#pragma unroll
for (uint32_t i = 0; i < length / indicesPerIter; i++)
{
uint32_t const idx = indicesPerIter * i + indicesPerInst * tid;
ldgsts::copyAsync<bytesPerInst>(&dst[idx], &params.indices[srcOffset + idx],
(isFullTile || uniformSeqOffset + idx < seqLen) ? bytesPerInst : 0);
}
if constexpr (length % indicesPerIter != 0)
{
uint32_t const idx = indicesPerIter * (length / indicesPerIter) + indicesPerInst * tid;
if (idx < length)
{
ldgsts::copyAsync<bytesPerInst>(&dst[idx], &params.indices[srcOffset + idx],
(isFullTile || uniformSeqOffset + idx < seqLen) ? bytesPerInst : 0);
}
}
}
__device__ inline InputElem2 float2ToInputElem2(float2 src)
{
InputElem2 dst;
if constexpr (mha::is_same_v<InputElem2, half2>)
{
reinterpret_cast<half2&>(dst) = __float22half2_rn(src);
return dst;
}
else if constexpr (mha::is_same_v<InputElem2, nv_bfloat162>)
{
reinterpret_cast<nv_bfloat162&>(dst) = __float22bfloat162_rn(src);
return dst;
}
else if constexpr (mha::is_same_v<InputElem2, __nv_fp8x2_e4m3>)
{
reinterpret_cast<__nv_fp8x2_e4m3&>(dst) = __nv_fp8x2_e4m3{src};
return dst;
}
else
{
trap();
}
}
template <bool real>
using TokenOrNone = RealTypeOrNone<real, CtaBarrier::arrival_token>;
template <bool real>
__device__ inline TokenOrNone<real> arrive(CtaBarrier* pBarrier)
{
if constexpr (real)
{
return pBarrier->arrive();
}
else
{
assert(pBarrier == nullptr);
return None{};
}
}
template <bool real>
__device__ inline void wait(CtaBarrier* pBarrier, TokenOrNone<real>&& token)
{
if constexpr (real)
{
pBarrier->wait(mha::move(token));
}
else
{
assert(pBarrier == nullptr);
__syncwarp();
}
}
template <bool real>
__device__ inline bool test_wait(CtaBarrier* pBarrier, TokenOrNone<real>&& token)
{
if constexpr (real)
{
uint32_t complete;
asm volatile(
"{\n"
".reg .pred complete;\n"
"mbarrier.test_wait.acquire.cta.shared::cta.b64 complete, [%1], %2;\n"
"selp.b32 %0, 1, 0, complete;\n}\n"
: "=r"(complete)
: "l"(__cvta_generic_to_shared(pBarrier)), "l"(token));
return bool(complete);
}
else
{
return false;
}
}
template <bool real>
using ParityOrNone = RealTypeOrNone<real, bool>;
template <bool real>
__device__ inline void wait_parity(CtaBarrier* pBarrier, ParityOrNone<real> parity)
{
assert(real == (pBarrier != nullptr));
if constexpr (real)
{
pBarrier->wait_parity(parity);
}
else
{
__syncwarp();
}
}
template <bool real>
__device__ inline bool test_wait_parity(CtaBarrier* pBarrier, ParityOrNone<real> parity)
{
assert(real == (pBarrier != nullptr));
if constexpr (real)
{
#if USE_CUSTOM_BARRIER
return pBarrier->test_wait_parity(parity);
#else
return pBarrier->try_wait_parity_for(parity, cuda::std::chrono::nanoseconds(0));
#endif
}
else
{
return false;
}
}
template <bool real = true>
__device__ inline ParityOrNone<real>& flip(ParityOrNone<real>& flip)
{
if constexpr (real)
{
flip = !flip;
}
return flip;
}
template <bool real = true>
__device__ inline ParityOrNone<real> getAndFlip(ParityOrNone<real>& flag)
{
ParityOrNone<real> const ret = flag;
if constexpr (real)
{
flag = !flag;
}
return ret;
}