/* * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #ifndef __CUDACC__ #include #endif #include "defines.h" #include "utils.h" #if SPEC_DEC #include "specDec.h" #endif using CacheElem = ElemType; constexpr uint32_t validElemsPerHead = HEAD_ELEMS; constexpr bool isMLA = IS_MLA; static_assert((isMLA || validElemsPerHead <= 256) && (sizeof(CacheElem) * validElemsPerHead) % 16 == 0); constexpr uint32_t headElems = validElemsPerHead <= 64 ? 64 : (validElemsPerHead <= 128 ? 128 : (isMLA ? 576 : 256)); static_assert(headElems == 64 || headElems == 128 || headElems == 256 || headElems == 576, "not implemented"); constexpr uint32_t beamWidth = BEAM_WIDTH; constexpr uint32_t headGrpSize = HEAD_GRP_SIZE; #if SPEC_DEC __device__ constexpr uint32_t rowsPerBlock = M_TILESIZE; #endif inline constexpr bool useSpecDec = SPEC_DEC; using InputElem = INPUT_ELEM; using InputElem2 = INPUT_ELEM2; #if !(SPEC_DEC) constexpr uint32_t inputSeqLen = 1; // speculative decoding if > 1 #endif constexpr bool useKVCache = USE_KV_CACHE; using SeqLenDataType = uint32_t; constexpr bool usePagedKVCache = USE_PAGED_KV_CACHE; constexpr uint32_t tokensPerPage = TOKENS_PER_PAGE; using IOHead = Vec; using InputHead = IOHead; using GMemCacheHead = Vec; constexpr uint32_t validElemsPerKHead = validElemsPerHead; constexpr bool lowPrecOutput = LOW_PREC_OUTPUT; #if IS_MLA constexpr uint32_t validElemsPerVHead = 512; static_assert(lowPrecOutput == false); using OutputHead = Vec<__nv_bfloat16, validElemsPerVHead>; #else constexpr uint32_t validElemsPerVHead = validElemsPerHead; using OutputHead = mha::conditional_t; #endif using OutputElem = OutputHead::Elem; using PaddedInputHead = Vec; using PaddedCacheHead = Vec; // impl detail, may be moved to mha.cu/mha_sm90.cu constexpr bool isHeadPadded = (validElemsPerHead != headElems); constexpr bool useInputKV = USE_INPUT_KV; using GMemKVCacheHead = mha::conditional_t; using KVCachePageIndex = int32_t; // shape: KVCacheHead[nbKHeads][tokensPerPage]. Page index in the global pool of pages constexpr bool allowSlidingWindow = SLIDING_WINDOW; struct BeamSearchParams { uint32_t const* __restrict__ indices; // shape: [batchSize][beamWidth][capacity] uint32_t capacity; uint32_t const* __restrict__ ctxLenList; // shape: [batchSize][beamWidth]. Should be [batchSize] but we have to // match trt-llm API. }; uint32_t computeNbSubSeqPerSeqMHA( cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen); void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads, #if SLIDING_WINDOW uint32_t slidingWinSize, #endif float qScale, OutputHead* output, #if LOW_PREC_OUTPUT float const* rcpOutScale, #endif #if USE_INPUT_KV InputHead const* qkv, #if ROPE_STYLE != 0 Vec const* ropeCosSin, #endif #else InputHead const* q, #endif float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, #else GMemCacheHead* pool, // global pool of pages #endif KVCachePageIndex const* kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] #else GMemKVCacheHead* kvCacheData, #endif uint32_t maxSeqLen, uint32_t const* seqLen, #if BEAM_WIDTH > 1 BeamSearchParams const& beamSearchParams, #endif uint32_t batchSize, float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for // int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, #endif #if SKIP_SOFTMAX_ATTN float const skipSoftmaxThresholdScaleFactor, #if SKIP_SOFTMAX_ATTN_BLOCK_STATS uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount, #endif #endif uint32_t* semaphores, void* scratch, cudaStream_t stream); uint32_t computeNbSubSeqPerSeqHopperF8MHA( cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen); void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #if SLIDING_WINDOW uint32_t slidingWinSize, #endif float qScale, OutputHead* output, #if LOW_PREC_OUTPUT float const* rcpOutScale, #endif #if USE_INPUT_KV InputHead const* qkv, #if ROPE_STYLE != 0 Vec const* ropeCosSin, #endif #else InputHead const* q, #endif float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, #else GMemCacheHead* pool, // global pool of pages #endif KVCachePageIndex const* kvCachePageList, // device pointer. shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq]. #else GMemKVCacheHead* kvCacheData, #endif uint32_t maxSeqLen, uint32_t const* seqLen, #if BEAM_WIDTH > 1 BeamSearchParams const& beamSearchParams, #endif uint32_t batchSize, float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for // int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, #endif #if SKIP_SOFTMAX_ATTN float const skipSoftmaxThresholdScaleFactor, #if SKIP_SOFTMAX_ATTN_BLOCK_STATS uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount, #endif #endif uint32_t* semaphores, void* scratch, cudaStream_t stream); void launchMLA(cudaDeviceProp const& prop, uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed float qScale, OutputHead* output, InputHead const* q, #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, #else GMemCacheHead* pool, // global pool of pages #endif KVCachePageIndex const* kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] (Layout 0) or // [batchSize][maxNbPagesPerSeq] (Layout 1) #else GMemKVCacheHead* kvCacheData, #endif uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for // int8/fp8 KV cache. uint32_t* semaphores, void* scratch, cudaStream_t stream); #if STATIC_NB_K_HEADS constexpr uint32_t nbKHeads = NB_K_HEADS; constexpr uint32_t nbVHeads = nbKHeads; constexpr uint32_t nbQHeads = nbKHeads * headGrpSize; constexpr uint32_t nbQKVHeads = nbQHeads + nbKHeads + nbVHeads; #endif constexpr uint32_t cacheElemSize = sizeof(CacheElem); constexpr uint32_t inputElemSize = sizeof(InputElem); constexpr uint32_t outputElemSize = sizeof(OutputElem); constexpr uint32_t ioHeadBytes = sizeof(IOHead); constexpr uint32_t gmemCacheHeadBytes = sizeof(GMemCacheHead); constexpr uint32_t paddedInputHeadBytes = sizeof(PaddedInputHead); constexpr uint32_t paddedCacheHeadBytes = sizeof(PaddedCacheHead); constexpr bool allowMultiBlockMode = ALLOW_MULTI_BLOCK_MODE; enum class XQAKernelType : int32_t { kAMPERE_WARP_SPECIALIZED = 0, kHOPPER_WARP_SPECIALIZED = 1, kSM120_MLA = 2 }; #ifdef GENERATE_CUBIN #define CUBIN_EXPORT extern "C" #else #define CUBIN_EXPORT static #endif