/* * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "mha_stdheaders.cuh" #define STATIC_NB_K_HEADS 0 #if STATIC_NB_K_HEADS #define NB_K_HEADS 2 #endif // allowed values are multiples of 16 in range [16, 256] #ifndef HEAD_ELEMS #define HEAD_ELEMS 128 #endif // nbQHeads / nbKHeads for MQA/GQA #ifndef HEAD_GRP_SIZE #define HEAD_GRP_SIZE 8 #endif #define IS_MLA (HEAD_GRP_SIZE == 128 && HEAD_ELEMS == 576) #if IS_MLA #define INPUT_ELEM __nv_fp8_e4m3 #define INPUT_ELEM2 __nv_fp8x2_e4m3 #define HEAD_ELEMS_V 512 #else // 1 means fp16 and 0 means bf16 input/output #ifndef INPUT_FP16 #define INPUT_FP16 1 #endif // Don't modify #if INPUT_FP16 #define INPUT_ELEM half #define INPUT_ELEM2 half2 #else #define INPUT_ELEM __nv_bfloat16 #define INPUT_ELEM2 __nv_bfloat162 #endif #endif // For beam search. Allowed values: 1, 4 #ifndef BEAM_WIDTH #define BEAM_WIDTH 1 #endif #ifndef SPEC_DEC #define SPEC_DEC 0 #endif #if SPEC_DEC using MaskType = uint32_t; #ifndef M_TILESIZE #define M_TILESIZE 32 #endif #endif // Enables SWAP AB optimization for speculative decoding when using a small, fixed Q_SEQ_LEN. // NOTE: Requires a uniform input sequence length for the entire batch. #ifdef SPEC_Q_SEQ_LEN static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is enabled."); #endif // 0: half/bf16 based on INPUT_FP16; 1: int8_t; 2: __nv_fp8_e4m3 #ifndef CACHE_ELEM_ENUM #define CACHE_ELEM_ENUM 2 #endif // don't modify #define USE_KV_CACHE true // don't modify #ifndef ALLOW_MULTI_BLOCK_MODE #define ALLOW_MULTI_BLOCK_MODE true #endif // For paged KV cache. Allowed values: 0, 16, 32, 64, 128 // 0 means contiguous KV cache (non-paged). #ifndef TOKENS_PER_PAGE #define TOKENS_PER_PAGE 32 #endif // don't modify #ifndef USE_PAGED_KV_CACHE #define USE_PAGED_KV_CACHE (TOKENS_PER_PAGE > 0) #endif // Paged KV Cache Format // 0 - XQA Original // 1 - separate K and V cache pools, each with layout (batch, seq_len, head, head_elem) for VLLM/SGLang #ifdef USE_PAGED_KV_CACHE #ifndef PAGED_KV_CACHE_LAYOUT #define PAGED_KV_CACHE_LAYOUT 0 #endif #endif // don't modify #define USE_BEAM_SEARCH (BEAM_WIDTH > 1) #if CACHE_ELEM_ENUM == 0 #define PRAGMA_UNROLL_FP16_ONLY _Pragma("unroll") #else #define PRAGMA_UNROLL_FP16_ONLY _Pragma("unroll(1)") #endif // good for short sequence length but bad for long sequence length. Only for mha.cu. #ifndef SHORT_SEQ_OPT #define SHORT_SEQ_OPT 1 #endif #ifndef SLIDING_WINDOW #define SLIDING_WINDOW 0 #endif // 0 - no PDL // 1 - naive PDL // 2 - aggressive PDL (implemented only in mha_sm90.cu for now) #ifndef ENABLE_PDL #define ENABLE_PDL 2 #endif #ifndef USE_INPUT_KV #define USE_INPUT_KV 0 #endif #if USE_INPUT_KV // 0 - no RoPE // 1 - NEOX style // 2 - GPTJ style #ifndef ROPE_STYLE #define ROPE_STYLE 0 #endif #if SPEC_DEC #error "SPEC_DEC is not supported for USE_INPUT_KV" #endif #endif // Output element type: // 0 - input element type // 1 - KV cache element type #ifndef LOW_PREC_OUTPUT #define LOW_PREC_OUTPUT 0 #endif #if LOW_PREC_OUTPUT static_assert(CACHE_ELEM_ENUM != 0); #endif // true should be better if warpTile.x * cacheElemSize < 128. otherwise use false. #define GRP_LOAD_V (CACHE_ELEM_ENUM != 0) || (HEAD_ELEMS == 256 && USE_PAGED_KV_CACHE && BEAM_WIDTH > 1) // use custom barrier for NVRTC to avoid pulling in many headers #ifndef USE_CUSTOM_BARRIER #define USE_CUSTOM_BARRIER 1 #endif #ifndef OPTIMIZE_FOR_LATENCY #define OPTIMIZE_FOR_LATENCY 1 #endif #ifndef IS_SPEC_DEC_TREE #define IS_SPEC_DEC_TREE 1 // by default SPEC_DEC expect tree-based draft token structure #endif #define DBG_BATCH_SIZE 2 #define DBG_SEQ_LEN 256 * 4 + 3 #define DBG_NB_CTAS_PER_SEQ 8 #include #include template using ElemType = mha::conditional_t>>;