[None][feat] support JIT mha.cu for SPEC_DEC in runtime (#6078)

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
This commit is contained in:
Jhao-Ting Chen 2025-09-23 14:56:17 -07:00 committed by GitHub
parent e3c1a9409f
commit 220dc01372
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 107 additions and 30 deletions

View File

@ -505,7 +505,7 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
bool const maskFlag = col + actualQSeqLen < nbValidCols
? true
: packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : -INFINITY;
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
}
}
}

View File

@ -20,13 +20,14 @@
namespace tensorrt_llm::kernels
{
uint32_t getKernelMTileSize(uint32_t headGrpSize, bool isSpecDec, uint32_t qSeqLen, bool isXqaJit)
uint32_t getKernelMTileSize(
uint32_t headGrpSize, bool isSpecDec, uint32_t qSeqLen, bool isXqaJit, bool supportQGMMA, bool supportMLA)
{
if (!isSpecDec)
{
return headGrpSize;
}
if (isXqaJit)
if (isXqaJit && (supportQGMMA || supportMLA)) // HMMA (mha.cu) goes to the heuristic below
{
return 64;
}
@ -34,7 +35,7 @@ uint32_t getKernelMTileSize(uint32_t headGrpSize, bool isSpecDec, uint32_t qSeqL
return gemmM < 16 ? 16 : 32;
}
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams, bool isXqaJit)
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams, bool isXqaJit, int SM)
{
unsigned int head_size = xqaParams.head_size;
unsigned int num_q_heads = xqaParams.num_q_heads;
@ -43,14 +44,13 @@ XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParam
unsigned int num_q_heads_over_kv = num_q_heads / num_kv_heads;
unsigned int beam_width = xqaParams.beam_width;
// Use mTileSize = 16 kernels when qSeqLen <= 16.
unsigned int qSeqLen = static_cast<unsigned int>(xqaParams.generation_input_length);
unsigned int mTileSize = qSeqLen <= 16 ? 16 : 32;
// MultiQueryToken kernels can support any num_q_heads_over_kv that is power of 2.
unsigned int kernel_num_q_heads_over_kv = xqaParams.multi_query_tokens ? 0 : num_q_heads_over_kv;
// MultiQueryToken kernels can handle either 16/32 for M direction per CTA.
unsigned int kernel_m_tilesize
= getKernelMTileSize(num_q_heads_over_kv, xqaParams.multi_query_tokens, qSeqLen, isXqaJit);
bool supportQGMMA = jit::supportConfigQGMMA(xqaParams, SM, true);
bool supportMLA = jit::supportConfigMLA(xqaParams, SM, true);
unsigned int kernel_m_tilesize = getKernelMTileSize(
num_q_heads_over_kv, xqaParams.multi_query_tokens, qSeqLen, isXqaJit, supportQGMMA, supportMLA);
// precompiled XQA does not use is_fp8_output as hashing key
return {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, kernel_m_tilesize,

View File

@ -22,6 +22,7 @@
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/common/workspace.h"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.h"
#include "tensorrt_llm/kernels/kvCacheUtils.h"
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
#include "xqaParams.h"
@ -79,9 +80,10 @@ struct XQAKernelRuntimeHashKey
}
};
uint32_t getKernelMTileSize(uint32_t headGrpSize, bool isSpecDec, uint32_t qSeqLen, bool isXqaJit);
uint32_t getKernelMTileSize(
uint32_t headGrpSize, bool isSpecDec, uint32_t qSeqLen, bool isXqaJit, bool supportQGMMA, bool supportMLA);
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams, bool isXqaJit);
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams, bool isXqaJit, int SM);
struct XQAKernelRuntimeHasher
{

View File

@ -55,6 +55,14 @@ DecoderXQAImplJIT::DecoderXQAImplJIT(DecoderXQARunner* runner)
{
}
bool DecoderXQAImplJIT::needHMMASpecDec(XQAParams const& xqaParams, bool forConfigurePlugin) const
{
return xqaParams.multi_query_tokens && !jit::supportConfigQGMMA(xqaParams, mSM, forConfigurePlugin)
&& jit::supportConfigHMMA(xqaParams, mSM, forConfigurePlugin)
&& !jit::supportConfigMLA(xqaParams, mSM, forConfigurePlugin);
}
bool DecoderXQAImplJIT::supportConfig(XQAParams const& xqaParams, bool forConfigurePlugin) const
{
@ -129,7 +137,7 @@ jit::CubinObjKey DecoderXQAImplJIT::getCubinObjKeyFromXQAParams(XQAParams const&
loadKey.data_type = xqaParams.data_type;
loadKey.sm = mSM;
XQAKernelRuntimeHashKey runtimeKey = getRuntimeHashKeyFromXQAParams(xqaParams, true);
XQAKernelRuntimeHashKey runtimeKey = getRuntimeHashKeyFromXQAParams(xqaParams, true, mSM);
return {loadKey, runtimeKey};
}
@ -150,12 +158,16 @@ void DecoderXQAImplJIT::prepareForActualXQAParams(XQAParams const& xqaParams)
void DecoderXQAImplJIT::prepare(XQAParams const& umbrellaXQAParams)
{
for (int beam_width = 1; beam_width <= umbrellaXQAParams.beam_width; ++beam_width)
{
XQAParams actualXQAParams = umbrellaXQAParams;
actualXQAParams.beam_width = beam_width;
prepareForActualXQAParams(actualXQAParams);
if (needHMMASpecDec(umbrellaXQAParams, true))
{
actualXQAParams.generation_input_length = 16; // a WAR to generate tileSize=32 JIT cubin
prepareForActualXQAParams(actualXQAParams);
}
}
}
@ -209,10 +221,11 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
jit::CubinObj const* const cubinObj = mResource->getCubinObjRegistry()->getCubin(key);
TLLM_CHECK(cubinObj != nullptr && cubinObj->isInitialized());
bool const isSpecDec = xqaParams.multi_query_tokens;
bool const isHMMAKernel = (cubinObj->getKernelType() == XQAKernelType::kAMPERE_WARP_SPECIALIZED);
bool const isGMMAKernel = (cubinObj->getKernelType() == XQAKernelType::kHOPPER_WARP_SPECIALIZED);
bool const isMLAKernel = (cubinObj->getKernelType() == XQAKernelType::kSM120_MLA);
TLLM_CHECK_WITH_INFO(
!isSpecDec || isGMMAKernel || (isMLAKernel && !xqaParams.spec_decoding_is_generation_length_variable),
TLLM_CHECK_WITH_INFO(!isSpecDec || isGMMAKernel || isHMMAKernel
|| (isMLAKernel && !xqaParams.spec_decoding_is_generation_length_variable),
"speculative decoding is available for GMMA/MLA kernel only in JIT path for now. For MLA, the input sequence "
"length must be uniform and draft tokens must be linear.");
TLLM_CHECK_DEBUG(isGMMAKernel == jit::supportConfigQGMMA(xqaParams, mSM, false));
@ -390,6 +403,59 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
dim3 const blockDim(128 * 3, 1, 1);
cubinObj->launch(dimGrid, blockDim, stream, kernelParams);
}
else if (isSpecDec && isHMMAKernel)
{
// MultiQueryTokens (generation_input_length > 1) need extra parameters (like qSeqLen, headGrpSize, and
// mask). Input parameters for MultiQueryTokens kernels.
unsigned int headGrpSize = num_q_heads_over_kv;
// Use mTileSize = 16 kernels when qSeqLen <= 16.
unsigned int qSeqLen = static_cast<unsigned int>(xqaParams.generation_input_length);
unsigned int mTileSize = qSeqLen <= 16 ? 16 : 32;
unsigned int nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, mTileSize);
unsigned int maxQSeqLen = xqaParams.spec_decoding_is_generation_length_variable ? // true for ReDrafter
xqaParams.spec_decoding_max_generation_length
: qSeqLen;
appendParam(&maxQSeqLen);
appendParam(&launchParams.num_k_heads);
appendParam(&headGrpSize);
appendParam(&launchParams.cu_seq_lens);
bool const allowSlidingWindow
= !(isSpecDec && xqaParams.is_spec_dec_tree); // sliding windows does not support spec dec with tree-based
// token, only chained tokens
if (allowSlidingWindow)
{
appendParam(&launchParams.slidingWindowSize);
}
appendParam(&launchParams.qScale);
appendParam(&launchParams.output);
if (isFp8Out && !needOutputCvt)
{
appendParam(&launchParams.rcpOutScale);
}
appendParam(&kernel_input_tokens);
appendParam(&xqaParams.spec_decoding_packed_mask);
appendParam(&xqaParams.attention_sinks);
appendParam(&launchParams.kvCacheParams);
if (xqaParams.beam_width > 1)
{
appendParam(&launchParams.beamSearchParams.value());
}
appendParam(&launchParams.batch_size);
appendParam(&launchParams.kv_scale_quant_orig);
appendParam(&launchParams.semaphores);
appendParam(&launchParams.scratch);
uint32_t multi_block = 1;
// if (xqaParams.multi_block_mode)
// {
// multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
// }
auto const gridDim = (dim3{multi_block, xqaParams.num_kv_heads * nbTokenBlocksPerGrp, xqaParams.batch_size});
dim3 const blockDim(128, 1, 2);
cubinObj->launch(gridDim, blockDim, stream, kernelParams);
}
else
{
appendParam(&launchParams.num_k_heads);

View File

@ -50,6 +50,9 @@ private:
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> mDriver;
std::shared_ptr<DecoderXQARunnerResource> mResource;
//! Whether DecoderXQAImplJIT needs to compile 2 sets (tilesize = 16, 32) kernels for spec-dec
bool needHMMASpecDec(XQAParams const& xqaParams, bool forConfigurePlugin) const;
//! Whether DecoderXQAImplJIT supports xqaParams.
bool supportConfig(XQAParams const& xqaParams, bool forConfigurePlugin) const;
//! Whether DecoderXQAImplJIT has perf gain over the default (non-XQA-optimized) implementation.

View File

@ -258,7 +258,7 @@ public:
invokeQKVPreprocessing<T, KVCacheBuffer>(preprocessingParams, stream);
sync_check_cuda_error(stream);
XQAKernelRuntimeHashKey hash_key = getRuntimeHashKeyFromXQAParams(xqaParams, false);
XQAKernelRuntimeHashKey hash_key = getRuntimeHashKeyFromXQAParams(xqaParams, false, mSM);
auto const findIter = mFunctions.find(hash_key);
TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(), "XQAKernelFunc not found.");

View File

@ -75,28 +75,34 @@ constexpr inline T roundUp(T a, T b)
DecoderXQAImpl* DecoderXQARunner::getImplFromXQAParams(XQAParams const& xqaParams, bool for_configure_plugin)
{
int const smVersion = tensorrt_llm::common::getSMVersion();
if (xqaParams.multi_query_tokens)
{
auto const grpSize = xqaParams.num_q_heads / xqaParams.num_kv_heads;
// Ampere XQA supports spec dec with pre-compiled cubins (may also work with JIT but not implemented yet)
// Hopper XQA supports spec dec with JIT, but only for E4M3 kv cache data type. Only allow 64%grpSize==0 for
// now.
bool const supportedByHopperXqa
= (smVersion == 90 && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3 && grpSize <= 64);
bool const supportedBySm120Mla
= (smVersion == 120 && xqaParams.isMLA() && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3);
return (supportedByHopperXqa || supportedBySm120Mla) ? mJITImpl.get() : mPrecompiledImpl.get();
}
std::optional<bool> envEnableXQAJIT = tensorrt_llm::common::getEnvEnableXQAJIT();
if (envEnableXQAJIT.has_value())
{
return envEnableXQAJIT.value() ? mJITImpl.get() : mPrecompiledImpl.get();
}
else
{
return mJITImpl.get();
if (xqaParams.multi_query_tokens)
{
// Some multi_query kernels are not ported to JIT yet.
auto const grpSize = xqaParams.num_q_heads / xqaParams.num_kv_heads;
// Hopper XQA supports spec dec with JIT, but only for E4M3 kv cache data type. Only allow 64%grpSize==0 for
// now.
bool const supportedByHopperXqa
= (smVersion == 90 && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3 && grpSize <= 64);
bool const supportedBySm120Mla = (smVersion == 120 && xqaParams.isMLA()
&& xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3);
bool const supportedByAmpereXqa = (!xqaParams.isMLA() && (64 % grpSize == 0));
return (supportedByHopperXqa || supportedBySm120Mla || supportedByAmpereXqa) ? mJITImpl.get()
: mPrecompiledImpl.get();
}
else
{
// regular decoding kernels uses JIT by default
return mJITImpl.get();
}
}
}