TensorRT-LLMs/cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
qixiang-99 0d4d50a745
feat: no-cache attention in PyTorch workflow (#3085)
* init trtllm attn no cache

Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>

* fix: fix the seq_len issue and attn metadata prepare for qwen reward model test

fix: fix minor bugs after rebase
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>

* refactor: remove unnecessary debug logs and clean up commented code

refactor: update max_seq_len documentation and remove max_seq_len for decoder model contructor in PyTorchModelEngine
Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>

* refactor: update calculate_ref_result function to accept tensor inputs and mask type, enhance test_attention_no_cache to support FULL and CAUSAL masks

Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>

* refactor: remove unused BERT attention metadata conversion method and add type assertion for no cache attention in PyTorchModelEngine

Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>

* refactor: remove use_kv_cache parameter from attention function and related classes, update documentation for KV cache handling

Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>

* refactor: implement setAttentionMaskType method for better mask type handling and remove unused conversion function

Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>

* refactor: streamline KV cache handling by replacing direct member access with useKVCache method and simplify token per block assignment

remove Debug code.

Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>

* refactor: Resolve comments for Python code

Simplify no cache attention metadata preparation and streamline related attributes in TrtllmAttentionMetadata

Removed the private method for converting to no cache attention metadata and integrated its logic into the prepare method. Updated the test for BERT sequence classification to reflect these changes and ensure proper handling of attention metadata.

Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>

* docs: Add is_dummy_attention field to attention metadata for simulation operations

Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>

* refactor: add KVCacheParams to attention backend interface and import relevant metadata classes

Updated the attention backend interface to include KVCacheParams and imported TrtllmAttentionMetadata and VanillaAttentionMetadata in model_engine.py for enhanced functionality.

Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>

* fix: fix rebase format issue

Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>

* fix: extend attention mask type handling in MHARunnerFixedParams

Added support for additional attention mask types (BIDIRECTIONAL, BIDIRECTIONALGLM, BLOCKSPARSE) in the MHARunnerFixedParams structure to fix the mapping issue between ContextAttentionMaskType and AttentionMaskType

Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>

* fix: enhance attention mask type handling in TllmGenFmhaRunnerParams

Updated the setAttentionMaskType method to include a switch-case structure for better handling of attention mask types, ensuring proper mapping and error handling for invalid types.

Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>

---------

Signed-off-by: Qixiang Lin <qixiangl@nvidia.com>
2025-04-05 01:54:32 +08:00

208 lines
9.0 KiB
C++

/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fmhaDispatcher.h"
#include "tensorrt_llm/common/cudaUtils.h"
namespace tensorrt_llm::kernels
{
////////////////////////////////////////////////////////////////////////////////////////////////////
QkvLayout AttentionInputLayoutToQkvLayout(AttentionInputLayout layout)
{
if (layout == AttentionInputLayout::PACKED_QKV)
{
return QkvLayout::PackedQkv;
}
else if (layout == AttentionInputLayout::Q_CONTIGUOUS_KV)
{
return QkvLayout::ContiguousKv;
}
else if (layout == AttentionInputLayout::Q_PAGED_KV)
{
return QkvLayout::PagedKv;
}
TLLM_CHECK_WITH_INFO(false, "Unexpected AttentionInputLayout");
return QkvLayout::SeparateQkv;
}
FmhaDispatcher::FmhaDispatcher(MHARunnerFixedParams fixedParams)
: mFixedParams(fixedParams)
, mUseTllmGen(tensorrt_llm::common::getSMVersion() == 100)
{
if (mUseTllmGen)
{
mTllmGenFMHARunner.reset(
new TllmGenFmhaRunner(mFixedParams.dataType, mFixedParams.dataTypeKv, mFixedParams.dataTypeOut));
if (!isSupported())
{
TLLM_LOG_WARNING("TRTLLM-GEN does not support the requested kernels.");
}
}
else
{
TLLM_CHECK_WITH_INFO(mFixedParams.dataType == mFixedParams.dataTypeKv,
"KV cache data type should be the same as input data type.");
TLLM_CHECK_WITH_INFO(mFixedParams.dataType == mFixedParams.dataTypeOut,
"Output data type should be the same as input data type.");
mFMHARunner.reset(new FusedMHARunnerV2(fixedParams));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
bool FmhaDispatcher::isSupported()
{
bool foundKernels = false;
if (mUseTllmGen)
{
if (mFixedParams.attentionMaskType == ContextAttentionMaskType::CUSTOM_MASK)
{
TLLM_LOG_WARNING("TRTLLM-GEN does not support custom mask.");
return false;
}
if (mFixedParams.hasAlibi)
{
TLLM_LOG_WARNING("TRTLLM-GEN does not support ALiBi.");
return false;
}
if (mFixedParams.isSPadded)
{
TLLM_LOG_WARNING("TRTLLM-GEN does not support padded inputs.");
return false;
}
auto qkvLayout = AttentionInputLayoutToQkvLayout(mFixedParams.attentionInputLayout);
// Create TllmGenFmhaRunnerParams based on MHARunnerFixedParams. Only fill necessary
// attributes for kernel selection.
TllmGenFmhaRunnerParams tllmRunnerParams;
memset(&tllmRunnerParams, 0, sizeof(tllmRunnerParams));
tllmRunnerParams.mQkvLayout = qkvLayout;
tllmRunnerParams.setAttentionMaskType(static_cast<std::int8_t>(mFixedParams.attentionMaskType));
tllmRunnerParams.mKernelType = FmhaKernelType::Context;
tllmRunnerParams.mTileScheduler = TileScheduler::Persistent;
tllmRunnerParams.mMultiCtasKvMode = false;
// Assume same headDim for Qk and V here.
tllmRunnerParams.mHeadDimQk = mFixedParams.headSize;
tllmRunnerParams.mHeadDimV = mFixedParams.headSizeV;
tllmRunnerParams.mNumTokensPerPage = mFixedParams.numTokensPerBlock;
tllmRunnerParams.mNumHeadsQPerKv = mFixedParams.numQHeads / mFixedParams.numKvHeads;
foundKernels = mTllmGenFMHARunner->isSupported(tllmRunnerParams);
}
else
{
foundKernels = mFMHARunner->isFmhaSupported();
}
if (!foundKernels)
{
TLLM_LOG_WARNING("Fall back to unfused MHA for %s in sm_%d.", mFixedParams.convertToStrOutput().c_str(),
tensorrt_llm::common::getSMVersion());
}
return foundKernels;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void FmhaDispatcher::run(MHARunnerParams runnerParams)
{
if (mUseTllmGen)
{
TLLM_LOG_DEBUG("Running TRTLLM-GEN context FMHA kernel.");
// Convert from MHAFixedParams + MHARunnerParams to TllmGenFmhaRunnerParams
void const* kvPoolPtr = nullptr;
void const* kvPageIdxPtr = nullptr;
auto qkvLayout = kernels::QkvLayout::PackedQkv;
int32_t maxBlocksPerSeq = 0;
int32_t numTokensPerBlock = 0;
if (mFixedParams.attentionInputLayout == AttentionInputLayout::Q_PAGED_KV)
{
qkvLayout = kernels::QkvLayout::PagedKv;
auto pagedKvCache = runnerParams.pagedKvCache.copyKVBlockArrayForContextFMHA();
kvPoolPtr = pagedKvCache.mPrimaryPoolPtr;
kvPageIdxPtr = reinterpret_cast<int const*>(pagedKvCache.data);
maxBlocksPerSeq = pagedKvCache.mMaxBlocksPerSeq;
numTokensPerBlock = pagedKvCache.mTokensPerBlock;
}
TllmGenFmhaRunnerParams tllmRunnerParams;
memset(&tllmRunnerParams, 0, sizeof(tllmRunnerParams));
// Parameters to select kernels.
tllmRunnerParams.mQkvLayout = qkvLayout;
tllmRunnerParams.setAttentionMaskType(static_cast<std::int8_t>(mFixedParams.attentionMaskType));
tllmRunnerParams.mKernelType = FmhaKernelType::Context;
// Always use persistent scheduler for better performance.
tllmRunnerParams.mTileScheduler = TileScheduler::Persistent;
tllmRunnerParams.mMultiCtasKvMode = false;
tllmRunnerParams.qPtr = runnerParams.qPtr;
tllmRunnerParams.kPtr = nullptr;
tllmRunnerParams.vPtr = nullptr;
tllmRunnerParams.kvPtr = kvPoolPtr;
tllmRunnerParams.qkvPtr = runnerParams.qkvPtr;
tllmRunnerParams.cumSeqLensQPtr = reinterpret_cast<int const*>(runnerParams.cuQSeqLenPtr);
tllmRunnerParams.cumSeqLensKvPtr = reinterpret_cast<int const*>(runnerParams.cuKvSeqLenPtr);
tllmRunnerParams.outputScalePtr = reinterpret_cast<float const*>(runnerParams.scaleBmm2Ptr);
// TRTLLM-GEN kernels always use the Log2 scale
tllmRunnerParams.scaleSoftmaxLog2Ptr
= reinterpret_cast<float const*>(runnerParams.scaleBmm1Ptr + kIdxScaleSoftmaxLog2Ptr);
tllmRunnerParams.kvPageIdxPtr = reinterpret_cast<int const*>(kvPageIdxPtr);
tllmRunnerParams.oSfScalePtr = runnerParams.oSfScalePtr;
tllmRunnerParams.oPtr = runnerParams.outputPtr;
tllmRunnerParams.oSfPtr = runnerParams.outputSfPtr;
// The sequence lengths for K/V.
tllmRunnerParams.seqLensKvPtr = reinterpret_cast<int const*>(runnerParams.kvSeqLenPtr);
// Assume same headDim for Qk and V here.
tllmRunnerParams.mHeadDimQk = mFixedParams.headSize;
tllmRunnerParams.mHeadDimV = mFixedParams.headSizeV;
tllmRunnerParams.mNumHeadsQ = mFixedParams.numQHeads;
tllmRunnerParams.mNumHeadsKv = mFixedParams.numKvHeads;
tllmRunnerParams.mNumHeadsQPerKv = tllmRunnerParams.mNumHeadsQ / tllmRunnerParams.mNumHeadsKv;
tllmRunnerParams.mBatchSize = runnerParams.b;
// It is used to construct contiguous kv cache TMA descriptors.
tllmRunnerParams.mMaxSeqLenCacheKv = runnerParams.slidingWindowSize;
tllmRunnerParams.mMaxSeqLenQ = runnerParams.qSeqLen;
tllmRunnerParams.mMaxSeqLenKv = runnerParams.kvSeqLen;
tllmRunnerParams.mAttentionWindowSize = runnerParams.slidingWindowSize;
tllmRunnerParams.mSumOfSeqLensQ = runnerParams.totalQSeqLen;
tllmRunnerParams.mSumOfSeqLensKv = runnerParams.totalKvSeqLen;
tllmRunnerParams.mMaxNumPagesPerSeqKv = maxBlocksPerSeq;
tllmRunnerParams.mNumTokensPerPage = numTokensPerBlock;
tllmRunnerParams.mScaleQ = mFixedParams.qScaling;
if (mFixedParams.attentionInputLayout == AttentionInputLayout::Q_PAGED_KV)
{
auto const [freeMemory, totalMemory] = tensorrt_llm::common::getDeviceMemoryInfo(false);
// The kv cache should be based on the maximum headDim of K and V due to paddings.
int maxHeadDimKv = std::max(tllmRunnerParams.mHeadDimQk, tllmRunnerParams.mHeadDimV);
tllmRunnerParams.mNumPagesInMemPool = totalMemory
/ (tllmRunnerParams.mNumHeadsKv * tllmRunnerParams.mNumTokensPerPage * maxHeadDimKv
* get_size_in_bytes(mFixedParams.dataType));
}
tllmRunnerParams.mSfStartTokenIdx = 0;
tllmRunnerParams.stream = runnerParams.stream;
mTllmGenFMHARunner->run(tllmRunnerParams);
}
else
{
mFMHARunner->run(runnerParams);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace tensorrt_llm::kernels