mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
212 lines
9.1 KiB
C++
212 lines
9.1 KiB
C++
/*
|
|
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "fmhaDispatcher.h"
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
|
|
namespace tensorrt_llm::kernels
|
|
{
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
QkvLayout AttentionInputLayoutToQkvLayout(AttentionInputLayout layout)
|
|
{
|
|
if (layout == AttentionInputLayout::PACKED_QKV)
|
|
{
|
|
return QkvLayout::PackedQkv;
|
|
}
|
|
else if (layout == AttentionInputLayout::Q_CONTIGUOUS_KV)
|
|
{
|
|
return QkvLayout::ContiguousKv;
|
|
}
|
|
else if (layout == AttentionInputLayout::Q_PAGED_KV)
|
|
{
|
|
return QkvLayout::PagedKv;
|
|
}
|
|
TLLM_CHECK_WITH_INFO(false, "Unexpected AttentionInputLayout");
|
|
return QkvLayout::SeparateQkv;
|
|
}
|
|
|
|
FmhaDispatcher::FmhaDispatcher(MHARunnerFixedParams fixedParams)
|
|
: mFixedParams(fixedParams)
|
|
, mUseTllmGen(tensorrt_llm::common::getSMVersion() == 100)
|
|
{
|
|
if (mUseTllmGen)
|
|
{
|
|
mTllmGenFMHARunner.reset(
|
|
new TllmGenFmhaRunner(mFixedParams.dataType, mFixedParams.dataTypeKv, mFixedParams.dataTypeOut));
|
|
if (!isSupported())
|
|
{
|
|
TLLM_LOG_WARNING("TRTLLM-GEN does not support the requested kernels.");
|
|
}
|
|
}
|
|
else
|
|
{
|
|
TLLM_CHECK_WITH_INFO(mFixedParams.dataType == mFixedParams.dataTypeKv,
|
|
"KV cache data type should be the same as input data type.");
|
|
|
|
// For FP8 MLA generation, the output type is BF16, which could be different from the input type.
|
|
// So we shouldn't do this check anymore.
|
|
// TLLM_CHECK_WITH_INFO(mFixedParams.dataType == mFixedParams.dataTypeOut,
|
|
// "Output data type should be the same as input data type.");
|
|
|
|
mFMHARunner.reset(new FusedMHARunnerV2(fixedParams));
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
bool FmhaDispatcher::isSupported()
|
|
{
|
|
bool foundKernels = false;
|
|
if (mUseTllmGen)
|
|
{
|
|
if (mFixedParams.attentionMaskType == ContextAttentionMaskType::CUSTOM_MASK)
|
|
{
|
|
TLLM_LOG_WARNING("TRTLLM-GEN does not support custom mask.");
|
|
return false;
|
|
}
|
|
if (mFixedParams.hasAlibi)
|
|
{
|
|
TLLM_LOG_WARNING("TRTLLM-GEN does not support ALiBi.");
|
|
return false;
|
|
}
|
|
if (mFixedParams.isSPadded)
|
|
{
|
|
TLLM_LOG_WARNING("TRTLLM-GEN does not support padded inputs.");
|
|
return false;
|
|
}
|
|
|
|
auto qkvLayout = AttentionInputLayoutToQkvLayout(mFixedParams.attentionInputLayout);
|
|
// Create TllmGenFmhaRunnerParams based on MHARunnerFixedParams. Only fill necessary
|
|
// attributes for kernel selection.
|
|
TllmGenFmhaRunnerParams tllmRunnerParams;
|
|
memset(&tllmRunnerParams, 0, sizeof(tllmRunnerParams));
|
|
tllmRunnerParams.mQkvLayout = qkvLayout;
|
|
tllmRunnerParams.setAttentionMaskType(static_cast<std::int8_t>(mFixedParams.attentionMaskType));
|
|
tllmRunnerParams.mKernelType = FmhaKernelType::Context;
|
|
tllmRunnerParams.mTileScheduler = TileScheduler::Persistent;
|
|
tllmRunnerParams.mMultiCtasKvMode = false;
|
|
// Assume same headDim for Qk and V here.
|
|
tllmRunnerParams.mHeadDimQk = mFixedParams.headSize;
|
|
tllmRunnerParams.mHeadDimV = mFixedParams.headSizeV;
|
|
tllmRunnerParams.mNumTokensPerPage = mFixedParams.numTokensPerBlock;
|
|
tllmRunnerParams.mNumHeadsQPerKv = mFixedParams.numQHeads / mFixedParams.numKvHeads;
|
|
// Set the chunked attention size and sliding window size to INT_MAX to disable them when checking if
|
|
// the kernel is supported.
|
|
tllmRunnerParams.mChunkedAttentionSize = INT_MAX;
|
|
tllmRunnerParams.mAttentionWindowSize = INT_MAX;
|
|
|
|
foundKernels = mTllmGenFMHARunner->isSupported(tllmRunnerParams);
|
|
}
|
|
else
|
|
{
|
|
foundKernels = mFMHARunner->isFmhaSupported();
|
|
}
|
|
if (!foundKernels)
|
|
{
|
|
TLLM_LOG_WARNING("Fall back to unfused MHA for %s in sm_%d.", mFixedParams.convertToStrOutput().c_str(),
|
|
tensorrt_llm::common::getSMVersion());
|
|
}
|
|
return foundKernels;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
void FmhaDispatcher::run(MHARunnerParams runnerParams)
|
|
{
|
|
if (mUseTllmGen)
|
|
{
|
|
TLLM_LOG_DEBUG("Running TRTLLM-GEN context FMHA kernel.");
|
|
TLLM_CHECK_WITH_INFO(mTllmGenFMHARunner.get(), "mTllmGenFMHARunner not initialized.");
|
|
// Convert from MHAFixedParams + MHARunnerParams to TllmGenFmhaRunnerParams
|
|
void const* kvPoolPtr = nullptr;
|
|
void const* kvPageIdxPtr = nullptr;
|
|
auto qkvLayout = kernels::QkvLayout::PackedQkv;
|
|
int32_t maxBlocksPerSeq = 0;
|
|
int32_t numTokensPerBlock = 0;
|
|
if (mFixedParams.attentionInputLayout == AttentionInputLayout::Q_PAGED_KV)
|
|
{
|
|
qkvLayout = kernels::QkvLayout::PagedKv;
|
|
auto pagedKvCache = runnerParams.pagedKvCache.copyKVBlockArrayForContextFMHA();
|
|
kvPoolPtr = pagedKvCache.mPrimaryPoolPtr;
|
|
kvPageIdxPtr = reinterpret_cast<int const*>(pagedKvCache.data);
|
|
maxBlocksPerSeq = pagedKvCache.mMaxBlocksPerSeq;
|
|
numTokensPerBlock = pagedKvCache.mTokensPerBlock;
|
|
}
|
|
|
|
TllmGenFmhaRunnerParams tllmRunnerParams;
|
|
memset(&tllmRunnerParams, 0, sizeof(tllmRunnerParams));
|
|
|
|
// Parameters to select kernels.
|
|
tllmRunnerParams.mQkvLayout = qkvLayout;
|
|
tllmRunnerParams.setAttentionMaskType(static_cast<std::int8_t>(mFixedParams.attentionMaskType));
|
|
tllmRunnerParams.mKernelType = FmhaKernelType::Context;
|
|
// Always use persistent scheduler for better performance.
|
|
tllmRunnerParams.mTileScheduler = TileScheduler::Persistent;
|
|
tllmRunnerParams.mMultiCtasKvMode = false;
|
|
|
|
tllmRunnerParams.qPtr = runnerParams.qPtr;
|
|
tllmRunnerParams.kPtr = nullptr;
|
|
tllmRunnerParams.vPtr = nullptr;
|
|
tllmRunnerParams.kvPtr = kvPoolPtr;
|
|
tllmRunnerParams.qkvPtr = runnerParams.qkvPtr;
|
|
tllmRunnerParams.cumSeqLensQPtr = reinterpret_cast<int const*>(runnerParams.cuQSeqLenPtr);
|
|
tllmRunnerParams.cumSeqLensKvPtr = reinterpret_cast<int const*>(runnerParams.cuKvSeqLenPtr);
|
|
tllmRunnerParams.outputScalePtr = reinterpret_cast<float const*>(runnerParams.scaleBmm2Ptr);
|
|
// TRTLLM-GEN kernels always use the Log2 scale
|
|
tllmRunnerParams.scaleSoftmaxLog2Ptr
|
|
= reinterpret_cast<float const*>(runnerParams.scaleBmm1Ptr + kIdxScaleSoftmaxLog2Ptr);
|
|
tllmRunnerParams.kvPageIdxPtr = reinterpret_cast<int const*>(kvPageIdxPtr);
|
|
tllmRunnerParams.oSfScalePtr = runnerParams.oSfScalePtr;
|
|
tllmRunnerParams.oPtr = runnerParams.outputPtr;
|
|
tllmRunnerParams.oSfPtr = runnerParams.outputSfPtr;
|
|
// The sequence lengths for K/V.
|
|
tllmRunnerParams.seqLensKvPtr = reinterpret_cast<int const*>(runnerParams.kvSeqLenPtr);
|
|
// Assume same headDim for Qk and V here.
|
|
tllmRunnerParams.mHeadDimQk = mFixedParams.headSize;
|
|
tllmRunnerParams.mHeadDimV = mFixedParams.headSizeV;
|
|
tllmRunnerParams.mNumHeadsQ = mFixedParams.numQHeads;
|
|
tllmRunnerParams.mNumHeadsKv = mFixedParams.numKvHeads;
|
|
tllmRunnerParams.mNumHeadsQPerKv = tllmRunnerParams.mNumHeadsQ / tllmRunnerParams.mNumHeadsKv;
|
|
tllmRunnerParams.mBatchSize = runnerParams.b;
|
|
// It is used to construct contiguous kv cache TMA descriptors.
|
|
tllmRunnerParams.mMaxSeqLenCacheKv = runnerParams.slidingWindowSize;
|
|
tllmRunnerParams.mMaxSeqLenQ = runnerParams.qSeqLen;
|
|
tllmRunnerParams.mMaxSeqLenKv = runnerParams.kvSeqLen;
|
|
tllmRunnerParams.mAttentionWindowSize = runnerParams.slidingWindowSize;
|
|
tllmRunnerParams.mChunkedAttentionSize = runnerParams.chunkedAttentionSize;
|
|
tllmRunnerParams.mSumOfSeqLensQ = runnerParams.totalQSeqLen;
|
|
tllmRunnerParams.mSumOfSeqLensKv = runnerParams.totalKvSeqLen;
|
|
tllmRunnerParams.mMaxNumPagesPerSeqKv = maxBlocksPerSeq;
|
|
tllmRunnerParams.mNumTokensPerPage = numTokensPerBlock;
|
|
tllmRunnerParams.mScaleQ = mFixedParams.qScaling;
|
|
// Set it to INT_MAX as the kv cache pageOffsets will ensure that there is no out-of-bounds access.
|
|
tllmRunnerParams.mNumPagesInMemPool = INT_MAX;
|
|
tllmRunnerParams.mSfStartTokenIdx = 0;
|
|
tllmRunnerParams.stream = runnerParams.stream;
|
|
mTllmGenFMHARunner->run(tllmRunnerParams);
|
|
}
|
|
else
|
|
{
|
|
mFMHARunner->run(runnerParams);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace tensorrt_llm::kernels
|