mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] support JIT mha.cu for SPEC_DEC in runtime (#6078)
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
This commit is contained in:
parent
e3c1a9409f
commit
220dc01372
@ -505,7 +505,7 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy
|
||||
bool const maskFlag = col + actualQSeqLen < nbValidCols
|
||||
? true
|
||||
: packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart));
|
||||
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : -INFINITY;
|
||||
acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,13 +20,14 @@
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
uint32_t getKernelMTileSize(uint32_t headGrpSize, bool isSpecDec, uint32_t qSeqLen, bool isXqaJit)
|
||||
uint32_t getKernelMTileSize(
|
||||
uint32_t headGrpSize, bool isSpecDec, uint32_t qSeqLen, bool isXqaJit, bool supportQGMMA, bool supportMLA)
|
||||
{
|
||||
if (!isSpecDec)
|
||||
{
|
||||
return headGrpSize;
|
||||
}
|
||||
if (isXqaJit)
|
||||
if (isXqaJit && (supportQGMMA || supportMLA)) // HMMA (mha.cu) goes to the heuristic below
|
||||
{
|
||||
return 64;
|
||||
}
|
||||
@ -34,7 +35,7 @@ uint32_t getKernelMTileSize(uint32_t headGrpSize, bool isSpecDec, uint32_t qSeqL
|
||||
return gemmM < 16 ? 16 : 32;
|
||||
}
|
||||
|
||||
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams, bool isXqaJit)
|
||||
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams, bool isXqaJit, int SM)
|
||||
{
|
||||
unsigned int head_size = xqaParams.head_size;
|
||||
unsigned int num_q_heads = xqaParams.num_q_heads;
|
||||
@ -43,14 +44,13 @@ XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParam
|
||||
unsigned int num_q_heads_over_kv = num_q_heads / num_kv_heads;
|
||||
unsigned int beam_width = xqaParams.beam_width;
|
||||
|
||||
// Use mTileSize = 16 kernels when qSeqLen <= 16.
|
||||
unsigned int qSeqLen = static_cast<unsigned int>(xqaParams.generation_input_length);
|
||||
unsigned int mTileSize = qSeqLen <= 16 ? 16 : 32;
|
||||
// MultiQueryToken kernels can support any num_q_heads_over_kv that is power of 2.
|
||||
unsigned int kernel_num_q_heads_over_kv = xqaParams.multi_query_tokens ? 0 : num_q_heads_over_kv;
|
||||
// MultiQueryToken kernels can handle either 16/32 for M direction per CTA.
|
||||
unsigned int kernel_m_tilesize
|
||||
= getKernelMTileSize(num_q_heads_over_kv, xqaParams.multi_query_tokens, qSeqLen, isXqaJit);
|
||||
bool supportQGMMA = jit::supportConfigQGMMA(xqaParams, SM, true);
|
||||
bool supportMLA = jit::supportConfigMLA(xqaParams, SM, true);
|
||||
unsigned int kernel_m_tilesize = getKernelMTileSize(
|
||||
num_q_heads_over_kv, xqaParams.multi_query_tokens, qSeqLen, isXqaJit, supportQGMMA, supportMLA);
|
||||
|
||||
// precompiled XQA does not use is_fp8_output as hashing key
|
||||
return {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, kernel_m_tilesize,
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/common/workspace.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.h"
|
||||
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
||||
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
|
||||
#include "xqaParams.h"
|
||||
@ -79,9 +80,10 @@ struct XQAKernelRuntimeHashKey
|
||||
}
|
||||
};
|
||||
|
||||
uint32_t getKernelMTileSize(uint32_t headGrpSize, bool isSpecDec, uint32_t qSeqLen, bool isXqaJit);
|
||||
uint32_t getKernelMTileSize(
|
||||
uint32_t headGrpSize, bool isSpecDec, uint32_t qSeqLen, bool isXqaJit, bool supportQGMMA, bool supportMLA);
|
||||
|
||||
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams, bool isXqaJit);
|
||||
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams, bool isXqaJit, int SM);
|
||||
|
||||
struct XQAKernelRuntimeHasher
|
||||
{
|
||||
|
||||
@ -55,6 +55,14 @@ DecoderXQAImplJIT::DecoderXQAImplJIT(DecoderXQARunner* runner)
|
||||
{
|
||||
}
|
||||
|
||||
bool DecoderXQAImplJIT::needHMMASpecDec(XQAParams const& xqaParams, bool forConfigurePlugin) const
|
||||
{
|
||||
|
||||
return xqaParams.multi_query_tokens && !jit::supportConfigQGMMA(xqaParams, mSM, forConfigurePlugin)
|
||||
&& jit::supportConfigHMMA(xqaParams, mSM, forConfigurePlugin)
|
||||
&& !jit::supportConfigMLA(xqaParams, mSM, forConfigurePlugin);
|
||||
}
|
||||
|
||||
bool DecoderXQAImplJIT::supportConfig(XQAParams const& xqaParams, bool forConfigurePlugin) const
|
||||
{
|
||||
|
||||
@ -129,7 +137,7 @@ jit::CubinObjKey DecoderXQAImplJIT::getCubinObjKeyFromXQAParams(XQAParams const&
|
||||
loadKey.data_type = xqaParams.data_type;
|
||||
loadKey.sm = mSM;
|
||||
|
||||
XQAKernelRuntimeHashKey runtimeKey = getRuntimeHashKeyFromXQAParams(xqaParams, true);
|
||||
XQAKernelRuntimeHashKey runtimeKey = getRuntimeHashKeyFromXQAParams(xqaParams, true, mSM);
|
||||
return {loadKey, runtimeKey};
|
||||
}
|
||||
|
||||
@ -150,12 +158,16 @@ void DecoderXQAImplJIT::prepareForActualXQAParams(XQAParams const& xqaParams)
|
||||
|
||||
void DecoderXQAImplJIT::prepare(XQAParams const& umbrellaXQAParams)
|
||||
{
|
||||
|
||||
for (int beam_width = 1; beam_width <= umbrellaXQAParams.beam_width; ++beam_width)
|
||||
{
|
||||
XQAParams actualXQAParams = umbrellaXQAParams;
|
||||
actualXQAParams.beam_width = beam_width;
|
||||
prepareForActualXQAParams(actualXQAParams);
|
||||
if (needHMMASpecDec(umbrellaXQAParams, true))
|
||||
{
|
||||
actualXQAParams.generation_input_length = 16; // a WAR to generate tileSize=32 JIT cubin
|
||||
prepareForActualXQAParams(actualXQAParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -209,10 +221,11 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
|
||||
jit::CubinObj const* const cubinObj = mResource->getCubinObjRegistry()->getCubin(key);
|
||||
TLLM_CHECK(cubinObj != nullptr && cubinObj->isInitialized());
|
||||
bool const isSpecDec = xqaParams.multi_query_tokens;
|
||||
bool const isHMMAKernel = (cubinObj->getKernelType() == XQAKernelType::kAMPERE_WARP_SPECIALIZED);
|
||||
bool const isGMMAKernel = (cubinObj->getKernelType() == XQAKernelType::kHOPPER_WARP_SPECIALIZED);
|
||||
bool const isMLAKernel = (cubinObj->getKernelType() == XQAKernelType::kSM120_MLA);
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
!isSpecDec || isGMMAKernel || (isMLAKernel && !xqaParams.spec_decoding_is_generation_length_variable),
|
||||
TLLM_CHECK_WITH_INFO(!isSpecDec || isGMMAKernel || isHMMAKernel
|
||||
|| (isMLAKernel && !xqaParams.spec_decoding_is_generation_length_variable),
|
||||
"speculative decoding is available for GMMA/MLA kernel only in JIT path for now. For MLA, the input sequence "
|
||||
"length must be uniform and draft tokens must be linear.");
|
||||
TLLM_CHECK_DEBUG(isGMMAKernel == jit::supportConfigQGMMA(xqaParams, mSM, false));
|
||||
@ -390,6 +403,59 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
|
||||
dim3 const blockDim(128 * 3, 1, 1);
|
||||
cubinObj->launch(dimGrid, blockDim, stream, kernelParams);
|
||||
}
|
||||
else if (isSpecDec && isHMMAKernel)
|
||||
{
|
||||
// MultiQueryTokens (generation_input_length > 1) need extra parameters (like qSeqLen, headGrpSize, and
|
||||
// mask). Input parameters for MultiQueryTokens kernels.
|
||||
unsigned int headGrpSize = num_q_heads_over_kv;
|
||||
// Use mTileSize = 16 kernels when qSeqLen <= 16.
|
||||
unsigned int qSeqLen = static_cast<unsigned int>(xqaParams.generation_input_length);
|
||||
unsigned int mTileSize = qSeqLen <= 16 ? 16 : 32;
|
||||
unsigned int nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, mTileSize);
|
||||
unsigned int maxQSeqLen = xqaParams.spec_decoding_is_generation_length_variable ? // true for ReDrafter
|
||||
xqaParams.spec_decoding_max_generation_length
|
||||
: qSeqLen;
|
||||
|
||||
appendParam(&maxQSeqLen);
|
||||
appendParam(&launchParams.num_k_heads);
|
||||
appendParam(&headGrpSize);
|
||||
appendParam(&launchParams.cu_seq_lens);
|
||||
bool const allowSlidingWindow
|
||||
= !(isSpecDec && xqaParams.is_spec_dec_tree); // sliding windows does not support spec dec with tree-based
|
||||
// token, only chained tokens
|
||||
if (allowSlidingWindow)
|
||||
{
|
||||
appendParam(&launchParams.slidingWindowSize);
|
||||
}
|
||||
appendParam(&launchParams.qScale);
|
||||
appendParam(&launchParams.output);
|
||||
if (isFp8Out && !needOutputCvt)
|
||||
{
|
||||
appendParam(&launchParams.rcpOutScale);
|
||||
}
|
||||
appendParam(&kernel_input_tokens);
|
||||
appendParam(&xqaParams.spec_decoding_packed_mask);
|
||||
appendParam(&xqaParams.attention_sinks);
|
||||
appendParam(&launchParams.kvCacheParams);
|
||||
if (xqaParams.beam_width > 1)
|
||||
{
|
||||
appendParam(&launchParams.beamSearchParams.value());
|
||||
}
|
||||
appendParam(&launchParams.batch_size);
|
||||
appendParam(&launchParams.kv_scale_quant_orig);
|
||||
appendParam(&launchParams.semaphores);
|
||||
appendParam(&launchParams.scratch);
|
||||
|
||||
uint32_t multi_block = 1;
|
||||
// if (xqaParams.multi_block_mode)
|
||||
// {
|
||||
// multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
|
||||
// }
|
||||
auto const gridDim = (dim3{multi_block, xqaParams.num_kv_heads * nbTokenBlocksPerGrp, xqaParams.batch_size});
|
||||
dim3 const blockDim(128, 1, 2);
|
||||
|
||||
cubinObj->launch(gridDim, blockDim, stream, kernelParams);
|
||||
}
|
||||
else
|
||||
{
|
||||
appendParam(&launchParams.num_k_heads);
|
||||
|
||||
@ -50,6 +50,9 @@ private:
|
||||
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> mDriver;
|
||||
std::shared_ptr<DecoderXQARunnerResource> mResource;
|
||||
|
||||
//! Whether DecoderXQAImplJIT needs to compile 2 sets (tilesize = 16, 32) kernels for spec-dec
|
||||
bool needHMMASpecDec(XQAParams const& xqaParams, bool forConfigurePlugin) const;
|
||||
|
||||
//! Whether DecoderXQAImplJIT supports xqaParams.
|
||||
bool supportConfig(XQAParams const& xqaParams, bool forConfigurePlugin) const;
|
||||
//! Whether DecoderXQAImplJIT has perf gain over the default (non-XQA-optimized) implementation.
|
||||
|
||||
@ -258,7 +258,7 @@ public:
|
||||
invokeQKVPreprocessing<T, KVCacheBuffer>(preprocessingParams, stream);
|
||||
sync_check_cuda_error(stream);
|
||||
|
||||
XQAKernelRuntimeHashKey hash_key = getRuntimeHashKeyFromXQAParams(xqaParams, false);
|
||||
XQAKernelRuntimeHashKey hash_key = getRuntimeHashKeyFromXQAParams(xqaParams, false, mSM);
|
||||
auto const findIter = mFunctions.find(hash_key);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(), "XQAKernelFunc not found.");
|
||||
|
||||
@ -75,28 +75,34 @@ constexpr inline T roundUp(T a, T b)
|
||||
DecoderXQAImpl* DecoderXQARunner::getImplFromXQAParams(XQAParams const& xqaParams, bool for_configure_plugin)
|
||||
{
|
||||
int const smVersion = tensorrt_llm::common::getSMVersion();
|
||||
if (xqaParams.multi_query_tokens)
|
||||
{
|
||||
auto const grpSize = xqaParams.num_q_heads / xqaParams.num_kv_heads;
|
||||
// Ampere XQA supports spec dec with pre-compiled cubins (may also work with JIT but not implemented yet)
|
||||
// Hopper XQA supports spec dec with JIT, but only for E4M3 kv cache data type. Only allow 64%grpSize==0 for
|
||||
// now.
|
||||
bool const supportedByHopperXqa
|
||||
= (smVersion == 90 && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3 && grpSize <= 64);
|
||||
bool const supportedBySm120Mla
|
||||
= (smVersion == 120 && xqaParams.isMLA() && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3);
|
||||
return (supportedByHopperXqa || supportedBySm120Mla) ? mJITImpl.get() : mPrecompiledImpl.get();
|
||||
}
|
||||
|
||||
std::optional<bool> envEnableXQAJIT = tensorrt_llm::common::getEnvEnableXQAJIT();
|
||||
|
||||
if (envEnableXQAJIT.has_value())
|
||||
{
|
||||
return envEnableXQAJIT.value() ? mJITImpl.get() : mPrecompiledImpl.get();
|
||||
}
|
||||
else
|
||||
{
|
||||
return mJITImpl.get();
|
||||
if (xqaParams.multi_query_tokens)
|
||||
{
|
||||
// Some multi_query kernels are not ported to JIT yet.
|
||||
auto const grpSize = xqaParams.num_q_heads / xqaParams.num_kv_heads;
|
||||
// Hopper XQA supports spec dec with JIT, but only for E4M3 kv cache data type. Only allow 64%grpSize==0 for
|
||||
// now.
|
||||
bool const supportedByHopperXqa
|
||||
= (smVersion == 90 && xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3 && grpSize <= 64);
|
||||
bool const supportedBySm120Mla = (smVersion == 120 && xqaParams.isMLA()
|
||||
&& xqaParams.kv_cache_data_type == XQADataType::DATA_TYPE_E4M3);
|
||||
bool const supportedByAmpereXqa = (!xqaParams.isMLA() && (64 % grpSize == 0));
|
||||
|
||||
return (supportedByHopperXqa || supportedBySm120Mla || supportedByAmpereXqa) ? mJITImpl.get()
|
||||
: mPrecompiledImpl.get();
|
||||
}
|
||||
else
|
||||
{
|
||||
// regular decoding kernels uses JIT by default
|
||||
return mJITImpl.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user