TensorRT-LLMs/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
Bo Li 515dd0d78f
feat: Add support for FP8 MLA on Hopper and Blackwell. (#3190)
* fp8 kv + bf16 ctx MLA + fp8 gen MLA

Use BF16 for context MLA.
mFP8GenerationMLA and mFP8ContextFMHA shouldn't be enabled together.

Allow mSM==90 for mFP8GenerationMLA==true.
For FMHA, dataTypeKv should be FP8.

For FP8 MLA generation, the output is still in BF16.

Refine debug info for FMHA kernel metadata.

Use inputType, outputType, SM together to hash kernel list.

Add FP8 MLA generation FMHA kernel.

Special WAR of NUM_COMPUTE_GROUPS for MLA generation kernel.

Separate the implementation of fused_multihead_attention_v2.h to CPP and print some debug info if checkIfKernelExist fails.

Refine debug info in fused_multihead_attention_v2.cpp

Correct FP8 MLA metadata.

New kernel provided by Yuxin, which outputs BF16.

smem size is not set correctly, which will lead to illegal mem access.

Yuxin fixed the error in FMHA MLA kernel: previously the BF16 isn't correctly written: some parts are repeatedly written, while some others are untouched.

There are two bmm1 scales that should be set correctly.

New kernel generated by Yuxin.

Modificatiosn to common/attentionOp for FP8 MLA on Hopper using FMHA.

Not necessary. If mFP8GenerationMLA, is_fp8_out is false, so mFP8ContextFMHA is false.

Skip a check in fmhaDispatcher.

Modifications in fmhaRunner:
- Debug dump.
- if (!isFP8GenerationMLA) skips a lot of flag setting.
- TMA descriptor modification for qo (by Yuxin).

Cleanup debug output.

Clean up o tma descriptor modifications.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Resolve conflicts.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Apply the patch of FP8 FlashMLA and resolve conflicts.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Fix compilation error.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Fix compile error.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* pick blackwell support

Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>

* Add copyright notice to fused_multihead_attention_v2.cpp.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Add license.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Add missing license.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Exclude building flashMLA kernels under sm90.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Revert "Exclude building flashMLA kernels under sm90."

    This reverts commit f0c859d459.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

* Use macro to skip compiling FlashMLA for non sm90 targets.

Signed-off-by: Bo Li <bobboli0202@gmail.com>

---------

Signed-off-by: Bo Li <bobboli0202@gmail.com>
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: Dylan Chen <ziqingc@nvidia.com>
Co-authored-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
2025-04-07 15:14:13 +08:00

212 lines
9.1 KiB
C++

/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fmhaDispatcher.h"
#include "tensorrt_llm/common/cudaUtils.h"
namespace tensorrt_llm::kernels
{
////////////////////////////////////////////////////////////////////////////////////////////////////
QkvLayout AttentionInputLayoutToQkvLayout(AttentionInputLayout layout)
{
if (layout == AttentionInputLayout::PACKED_QKV)
{
return QkvLayout::PackedQkv;
}
else if (layout == AttentionInputLayout::Q_CONTIGUOUS_KV)
{
return QkvLayout::ContiguousKv;
}
else if (layout == AttentionInputLayout::Q_PAGED_KV)
{
return QkvLayout::PagedKv;
}
TLLM_CHECK_WITH_INFO(false, "Unexpected AttentionInputLayout");
return QkvLayout::SeparateQkv;
}
FmhaDispatcher::FmhaDispatcher(MHARunnerFixedParams fixedParams)
: mFixedParams(fixedParams)
, mUseTllmGen(tensorrt_llm::common::getSMVersion() == 100)
{
if (mUseTllmGen)
{
mTllmGenFMHARunner.reset(
new TllmGenFmhaRunner(mFixedParams.dataType, mFixedParams.dataTypeKv, mFixedParams.dataTypeOut));
if (!isSupported())
{
TLLM_LOG_WARNING("TRTLLM-GEN does not support the requested kernels.");
}
}
else
{
TLLM_CHECK_WITH_INFO(mFixedParams.dataType == mFixedParams.dataTypeKv,
"KV cache data type should be the same as input data type.");
// For FP8 MLA generation, the output type is BF16, which could be different from the input type.
// So we shouldn't do this check anymore.
// TLLM_CHECK_WITH_INFO(mFixedParams.dataType == mFixedParams.dataTypeOut,
// "Output data type should be the same as input data type.");
mFMHARunner.reset(new FusedMHARunnerV2(fixedParams));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
bool FmhaDispatcher::isSupported()
{
bool foundKernels = false;
if (mUseTllmGen)
{
if (mFixedParams.attentionMaskType == ContextAttentionMaskType::CUSTOM_MASK)
{
TLLM_LOG_WARNING("TRTLLM-GEN does not support custom mask.");
return false;
}
if (mFixedParams.hasAlibi)
{
TLLM_LOG_WARNING("TRTLLM-GEN does not support ALiBi.");
return false;
}
if (mFixedParams.isSPadded)
{
TLLM_LOG_WARNING("TRTLLM-GEN does not support padded inputs.");
return false;
}
auto qkvLayout = AttentionInputLayoutToQkvLayout(mFixedParams.attentionInputLayout);
// Create TllmGenFmhaRunnerParams based on MHARunnerFixedParams. Only fill necessary
// attributes for kernel selection.
TllmGenFmhaRunnerParams tllmRunnerParams;
memset(&tllmRunnerParams, 0, sizeof(tllmRunnerParams));
tllmRunnerParams.mQkvLayout = qkvLayout;
tllmRunnerParams.setAttentionMaskType(static_cast<std::int8_t>(mFixedParams.attentionMaskType));
tllmRunnerParams.mKernelType = FmhaKernelType::Context;
tllmRunnerParams.mTileScheduler = TileScheduler::Persistent;
tllmRunnerParams.mMultiCtasKvMode = false;
// Assume same headDim for Qk and V here.
tllmRunnerParams.mHeadDimQk = mFixedParams.headSize;
tllmRunnerParams.mHeadDimV = mFixedParams.headSizeV;
tllmRunnerParams.mNumTokensPerPage = mFixedParams.numTokensPerBlock;
tllmRunnerParams.mNumHeadsQPerKv = mFixedParams.numQHeads / mFixedParams.numKvHeads;
foundKernels = mTllmGenFMHARunner->isSupported(tllmRunnerParams);
}
else
{
foundKernels = mFMHARunner->isFmhaSupported();
}
if (!foundKernels)
{
TLLM_LOG_WARNING("Fall back to unfused MHA for %s in sm_%d.", mFixedParams.convertToStrOutput().c_str(),
tensorrt_llm::common::getSMVersion());
}
return foundKernels;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void FmhaDispatcher::run(MHARunnerParams runnerParams)
{
if (mUseTllmGen)
{
TLLM_LOG_DEBUG("Running TRTLLM-GEN context FMHA kernel.");
// Convert from MHAFixedParams + MHARunnerParams to TllmGenFmhaRunnerParams
void const* kvPoolPtr = nullptr;
void const* kvPageIdxPtr = nullptr;
auto qkvLayout = kernels::QkvLayout::PackedQkv;
int32_t maxBlocksPerSeq = 0;
int32_t numTokensPerBlock = 0;
if (mFixedParams.attentionInputLayout == AttentionInputLayout::Q_PAGED_KV)
{
qkvLayout = kernels::QkvLayout::PagedKv;
auto pagedKvCache = runnerParams.pagedKvCache.copyKVBlockArrayForContextFMHA();
kvPoolPtr = pagedKvCache.mPrimaryPoolPtr;
kvPageIdxPtr = reinterpret_cast<int const*>(pagedKvCache.data);
maxBlocksPerSeq = pagedKvCache.mMaxBlocksPerSeq;
numTokensPerBlock = pagedKvCache.mTokensPerBlock;
}
TllmGenFmhaRunnerParams tllmRunnerParams;
memset(&tllmRunnerParams, 0, sizeof(tllmRunnerParams));
// Parameters to select kernels.
tllmRunnerParams.mQkvLayout = qkvLayout;
tllmRunnerParams.setAttentionMaskType(static_cast<std::int8_t>(mFixedParams.attentionMaskType));
tllmRunnerParams.mKernelType = FmhaKernelType::Context;
// Always use persistent scheduler for better performance.
tllmRunnerParams.mTileScheduler = TileScheduler::Persistent;
tllmRunnerParams.mMultiCtasKvMode = false;
tllmRunnerParams.qPtr = runnerParams.qPtr;
tllmRunnerParams.kPtr = nullptr;
tllmRunnerParams.vPtr = nullptr;
tllmRunnerParams.kvPtr = kvPoolPtr;
tllmRunnerParams.qkvPtr = runnerParams.qkvPtr;
tllmRunnerParams.cumSeqLensQPtr = reinterpret_cast<int const*>(runnerParams.cuQSeqLenPtr);
tllmRunnerParams.cumSeqLensKvPtr = reinterpret_cast<int const*>(runnerParams.cuKvSeqLenPtr);
tllmRunnerParams.outputScalePtr = reinterpret_cast<float const*>(runnerParams.scaleBmm2Ptr);
// TRTLLM-GEN kernels always use the Log2 scale
tllmRunnerParams.scaleSoftmaxLog2Ptr
= reinterpret_cast<float const*>(runnerParams.scaleBmm1Ptr + kIdxScaleSoftmaxLog2Ptr);
tllmRunnerParams.kvPageIdxPtr = reinterpret_cast<int const*>(kvPageIdxPtr);
tllmRunnerParams.oSfScalePtr = runnerParams.oSfScalePtr;
tllmRunnerParams.oPtr = runnerParams.outputPtr;
tllmRunnerParams.oSfPtr = runnerParams.outputSfPtr;
// The sequence lengths for K/V.
tllmRunnerParams.seqLensKvPtr = reinterpret_cast<int const*>(runnerParams.kvSeqLenPtr);
// Assume same headDim for Qk and V here.
tllmRunnerParams.mHeadDimQk = mFixedParams.headSize;
tllmRunnerParams.mHeadDimV = mFixedParams.headSizeV;
tllmRunnerParams.mNumHeadsQ = mFixedParams.numQHeads;
tllmRunnerParams.mNumHeadsKv = mFixedParams.numKvHeads;
tllmRunnerParams.mNumHeadsQPerKv = tllmRunnerParams.mNumHeadsQ / tllmRunnerParams.mNumHeadsKv;
tllmRunnerParams.mBatchSize = runnerParams.b;
// It is used to construct contiguous kv cache TMA descriptors.
tllmRunnerParams.mMaxSeqLenCacheKv = runnerParams.slidingWindowSize;
tllmRunnerParams.mMaxSeqLenQ = runnerParams.qSeqLen;
tllmRunnerParams.mMaxSeqLenKv = runnerParams.kvSeqLen;
tllmRunnerParams.mAttentionWindowSize = runnerParams.slidingWindowSize;
tllmRunnerParams.mSumOfSeqLensQ = runnerParams.totalQSeqLen;
tllmRunnerParams.mSumOfSeqLensKv = runnerParams.totalKvSeqLen;
tllmRunnerParams.mMaxNumPagesPerSeqKv = maxBlocksPerSeq;
tllmRunnerParams.mNumTokensPerPage = numTokensPerBlock;
tllmRunnerParams.mScaleQ = mFixedParams.qScaling;
if (mFixedParams.attentionInputLayout == AttentionInputLayout::Q_PAGED_KV)
{
auto const [freeMemory, totalMemory] = tensorrt_llm::common::getDeviceMemoryInfo(false);
// The kv cache should be based on the maximum headDim of K and V due to paddings.
int maxHeadDimKv = std::max(tllmRunnerParams.mHeadDimQk, tllmRunnerParams.mHeadDimV);
tllmRunnerParams.mNumPagesInMemPool = totalMemory
/ (tllmRunnerParams.mNumHeadsKv * tllmRunnerParams.mNumTokensPerPage * maxHeadDimKv
* get_size_in_bytes(mFixedParams.dataType));
}
tllmRunnerParams.mSfStartTokenIdx = 0;
tllmRunnerParams.stream = runnerParams.stream;
mTllmGenFMHARunner->run(tllmRunnerParams);
}
else
{
mFMHARunner->run(runnerParams);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace tensorrt_llm::kernels