TensorRT-LLMs/cpp/kernels/xqa/mha.h
Ming Wei ed887940d4
infra: open source XQA kernels (#3762)
Replace libtensorrt_llm_nvrtc_wrapper.so with its source code, which
consists of two parts:

1. NVRTC glue code
2. XQA kernel code

During TensorRT-LLM build, XQA kernel code is embedded as C++ arries via
gen_cpp_header.py and passed to NVRTC for JIT compilation.

Signed-off-by: Ming Wei <2345434+ming-wei@users.noreply.github.com>
2025-04-30 18:05:15 +08:00

176 lines
5.7 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#ifndef __CUDACC__
#include <cuda_runtime_api.h>
#endif
#include "defines.h"
#include "utils.h"
#if SPEC_DEC
#include "specDec.h"
#endif
using CacheElem = ElemType<CACHE_ELEM_ENUM>;
constexpr uint32_t validElemsPerHead = HEAD_ELEMS;
static_assert(validElemsPerHead <= 256 && (sizeof(CacheElem) * validElemsPerHead) % 16 == 0);
constexpr uint32_t headElems = validElemsPerHead <= 64 ? 64 : (validElemsPerHead <= 128 ? 128 : 256);
static_assert(headElems == 64 || headElems == 128 || headElems == 256, "not implemented");
constexpr uint32_t beamWidth = BEAM_WIDTH;
constexpr uint32_t headGrpSize = HEAD_GRP_SIZE;
#if SPEC_DEC
__device__ constexpr uint32_t rowsPerBlock = M_TILESIZE;
#endif
inline constexpr bool useSpecDec = SPEC_DEC;
using InputElem = INPUT_ELEM;
using InputElem2 = INPUT_ELEM2;
constexpr uint32_t inputSeqLen = 1; // speculative decoding if > 1
constexpr bool useKVCache = USE_KV_CACHE;
using SeqLenDataType = uint32_t;
constexpr bool usePagedKVCache = USE_PAGED_KV_CACHE;
constexpr uint32_t tokensPerPage = TOKENS_PER_PAGE;
using IOHead = Vec<InputElem, validElemsPerHead>;
using InputHead = IOHead;
using GMemCacheHead = Vec<CacheElem, validElemsPerHead>;
constexpr bool lowPrecOutput = LOW_PREC_OUTPUT;
using OutputHead = mha::conditional_t<lowPrecOutput, GMemCacheHead, InputHead>;
using OutputElem = OutputHead::Elem;
using PaddedInputHead = Vec<InputElem, headElems>;
using PaddedCacheHead = Vec<CacheElem, headElems>;
// impl detail, may be moved to mha.cu/mha_sm90.cu
constexpr bool isHeadPadded = (validElemsPerHead != headElems);
constexpr bool useInputKV = USE_INPUT_KV;
using GMemKVCacheHead = mha::conditional_t<useInputKV, GMemCacheHead, GMemCacheHead const>;
using KVCachePageIndex = int32_t; // shape: KVCacheHead[nbKHeads][tokensPerPage]. Page index in the global pool of pages
constexpr bool allowSlidingWindow = SLIDING_WINDOW;
struct BeamSearchParams
{
uint32_t const* __restrict__ indices; // shape: [batchSize][beamWidth][capacity]
uint32_t capacity;
uint32_t const* __restrict__ ctxLenList; // shape: [batchSize][beamWidth]. Should be [batchSize] but we have to
// match trt-llm API.
};
void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
#endif
float qScale, OutputHead* output,
#if LOW_PREC_OUTPUT
float const* rcpOutScale,
#endif
#if USE_INPUT_KV
InputHead const* qkv,
#if ROPE_STYLE != 0
Vec<float, validElemsPerHead> const* ropeCosSin,
#endif
#else
InputHead const* q,
#endif
#if USE_PAGED_KV_CACHE
GMemCacheHead* pool, // global pool of pages
KVCachePageIndex const*
kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq]
#else
GMemKVCacheHead* kvCacheData,
#endif
uint32_t maxSeqLen, uint32_t const* seqLen,
#if BEAM_WIDTH > 1
BeamSearchParams const& beamSearchParams,
#endif
uint32_t batchSize,
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream);
void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
#endif
float qScale, OutputHead* output,
#if LOW_PREC_OUTPUT
float const* rcpOutScale,
#endif
#if USE_INPUT_KV
InputHead const* qkv,
#if ROPE_STYLE != 0
Vec<float, validElemsPerHead> const* ropeCosSin,
#endif
#else
InputHead const* q,
#endif
#if USE_PAGED_KV_CACHE
GMemCacheHead* pool, // global pool of pages
KVCachePageIndex const*
kvCachePageList, // device pointer. shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq].
#else
GMemKVCacheHead* kvCacheData,
#endif
uint32_t maxSeqLen, uint32_t const* seqLen,
#if BEAM_WIDTH > 1
BeamSearchParams const& beamSearchParams,
#endif
uint32_t batchSize,
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream);
#if STATIC_NB_K_HEADS
constexpr uint32_t nbKHeads = NB_K_HEADS;
constexpr uint32_t nbVHeads = nbKHeads;
constexpr uint32_t nbQHeads = nbKHeads * headGrpSize;
constexpr uint32_t nbQKVHeads = nbQHeads + nbKHeads + nbVHeads;
#endif
constexpr uint32_t cacheElemSize = sizeof(CacheElem);
constexpr uint32_t inputElemSize = sizeof(InputElem);
constexpr uint32_t outputElemSize = sizeof(OutputElem);
constexpr uint32_t ioHeadBytes = sizeof(IOHead);
constexpr uint32_t gmemCacheHeadBytes = sizeof(GMemCacheHead);
constexpr uint32_t paddedInputHeadBytes = sizeof(PaddedInputHead);
constexpr uint32_t paddedCacheHeadBytes = sizeof(PaddedCacheHead);
constexpr bool allowMultiBlockMode = ALLOW_MULTI_BLOCK_MODE;
enum class XQAKernelType : int32_t
{
kAMPERE_WARP_SPECIALIZED = 0,
kHOPPER_WARP_SPECIALIZED = 1
};
#ifdef GENERATE_CUBIN
#define CUBIN_EXPORT extern "C"
#else
#define CUBIN_EXPORT static
#endif