TensorRT-LLMs/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp
forrestl 9477661f4c
Support RingAttention in the BertAttention plugin and the DiT model (#3661)
support ring attn for bert_attention plugin and dit model

Signed-off-by: ChunhuanLin <lch_xdu@163.com>
2025-05-09 08:06:54 +08:00

1192 lines
56 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bertAttentionPlugin.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/recoverFromRingAtten.h"
#include "tensorrt_llm/kernels/sageAttentionKernels.h"
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
#include "tensorrt_llm/runtime/iBuffer.h"
using namespace nvinfer1;
using namespace tensorrt_llm::kernels;
namespace tc = tensorrt_llm::common;
using tensorrt_llm::plugins::BertAttentionPluginCreator;
using tensorrt_llm::plugins::BertAttentionPlugin;
static char const* BERT_ATTENTION_PLUGIN_VERSION{"1"};
static char const* BERT_ATTENTION_PLUGIN_NAME{"BertAttention"};
PluginFieldCollection BertAttentionPluginCreator::mFC{};
std::vector<nvinfer1::PluginField> BertAttentionPluginCreator::mPluginAttributes;
BertAttentionPlugin::BertAttentionPlugin(int num_heads, int head_size, float q_scaling,
ContextFMHAType context_fmha_type, nvinfer1::DataType type, bool do_relative_attention, int max_distance,
bool remove_padding, bool sage_attn, int sage_attn_q_block_size, int sage_attn_k_block_size,
int sage_attn_v_block_size, int cp_size, int cp_rank, std::set<int> cp_group)
: mNumHeads(num_heads)
, mHeadSize(head_size)
, mQScaling(q_scaling)
, mType(type)
, mRelativeAttention(do_relative_attention)
, mMaxDistance(max_distance)
, mRemovePadding(remove_padding)
, mEnableContextFMHA(context_fmha_type != ContextFMHAType::DISABLED)
, mFMHAForceFP32Acc(context_fmha_type == ContextFMHAType::ENABLED_WITH_FP32_ACC)
, mSageAttn(sage_attn)
, mCpSize(cp_size)
, mCpRank(cp_rank)
, mCpGroup(std::move(cp_group))
{
// pre-check whether FMHA is supported in order to save memory allocation
if (mEnableContextFMHA)
{
mEnableContextFMHA = false;
if (!(mType == DataType::kHALF || mType == DataType::kBF16))
{
TLLM_LOG_WARNING("Fall back to unfused MHA because of unsupported data type.");
}
else if (mRelativeAttention)
{
TLLM_LOG_WARNING("Fall back to unfused MHA because of relative position embedding.");
}
else
{
mEnableContextFMHA = true;
}
}
if (mSageAttn)
{
mSageAttnQBlockSize = sage_attn_q_block_size;
mSageAttnKBlockSize = sage_attn_k_block_size;
mSageAttnVBlockSize = sage_attn_v_block_size;
std::vector<int> blockSizeCombination
= {sage_attn_q_block_size, sage_attn_k_block_size, sage_attn_v_block_size};
if (mSageAttnSupportedBlockSizes.find(blockSizeCombination) == mSageAttnSupportedBlockSizes.end()
|| (head_size != 128 && head_size != 72 && head_size != 80))
{
TLLM_LOG_WARNING(" Q, k ,v quant block size not support. disable sage attention");
mSageAttn = false;
}
else
{
TLLM_LOG_INFO("SageAttnQBlockSize: %d, SageAttnKBlockSize: %d, SageAttnVBlockSize: %d", mSageAttnQBlockSize,
mSageAttnKBlockSize, mSageAttnVBlockSize);
}
}
if (cp_group.size() > 1 && !mEnableContextFMHA)
{
TLLM_LOG_ERROR("Unfused MHA do not support context parallel now.");
}
}
// Parameterized constructor
BertAttentionPlugin::BertAttentionPlugin(void const* data, size_t length)
{
char const *d = reinterpret_cast<char const*>(data), *a = d;
read(d, mNumHeads);
read(d, mHeadSize);
read(d, mQScaling);
read(d, mQKHalfAccum);
read(d, mEnableContextFMHA);
read(d, mFMHAForceFP32Acc);
read(d, mType);
read(d, mRelativeAttention);
read(d, mMaxDistance);
read(d, mRemovePadding);
read(d, mSageAttn);
read(d, mSageAttnQBlockSize);
read(d, mSageAttnKBlockSize);
read(d, mSageAttnVBlockSize);
read(d, mCpSize);
read(d, mCpRank);
mCpGroup.clear();
int groupItem = 0;
while (d != a + length)
{
read(d, groupItem);
mCpGroup.insert(groupItem);
}
TLLM_CHECK_WITH_INFO(d == a + length,
"Expected length (%d) != real length (%d). This is often "
"caused by using different TensorRT-LLM version to build "
"engine and run engine.",
(int) length, (int) (d - a));
}
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* BertAttentionPlugin::clone() const noexcept
{
auto* plugin = new BertAttentionPlugin(*this);
plugin->setPluginNamespace(mNamespace.c_str());
plugin->initialize();
return plugin;
}
nvinfer1::DimsExprs BertAttentionPlugin::getOutputDimensions(
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
TLLM_CHECK(outputIndex == 0);
auto ret = inputs[0];
ret.d[mRemovePadding ? 1 : 2] = exprBuilder.constant(ret.d[mRemovePadding ? 1 : 2]->getConstantValue() / 3);
return ret;
}
bool BertAttentionPlugin::supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
{
// inputs: [0] qkv, [1] input_lengths, [2] max_input_length (optional), [3] relative_attention_bias (optional)
// outputs: [X] hidden_states
if (nbInputs == 2)
{ // BERT
if (pos == 1)
{
return inOut[pos].type == nvinfer1::DataType::kINT32;
}
return (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR);
}
if (nbInputs > 2)
{ // Encoder in encoder-decoder
if (pos == 1 || pos == 2)
{
return inOut[pos].type == nvinfer1::DataType::kINT32;
}
return (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR);
}
return false;
}
void BertAttentionPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
{
}
size_t BertAttentionPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
{
// if remove padding, inputs[0] "qkv_hidden_states" dim is [num_tokens, 3*hidden_dim] which doesn't have shape
// info should get max_batch_size and max_input_length from inputs[1] "input_lengths" and input[2]
// "max_input_length"
int const batch_size = mRemovePadding ? inputs[1].dims.d[0] : inputs[0].dims.d[0];
int const input_seq_len = mRemovePadding ? inputs[2].dims.d[0] : inputs[0].dims.d[1];
int const local_hidden_units_ = inputs[0].dims.d[mRemovePadding ? 1 : 2] / 3;
auto const size = tensorrt_llm::runtime::BufferDataType(inputs[0].type).getSize();
size_t const attention_mask_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_len * input_seq_len;
size_t const cu_seqlens_size = sizeof(int) * (batch_size + 1);
size_t const q_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_len * local_hidden_units_;
size_t const k_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_len * local_hidden_units_;
size_t const v_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_len * local_hidden_units_;
size_t const qk_buf_size = mEnableContextFMHA ? 0 : size * batch_size * mNumHeads * input_seq_len * input_seq_len;
size_t const qkv_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_len * local_hidden_units_;
size_t const qk_buf_float_size
= mEnableContextFMHA ? 0 : sizeof(float) * batch_size * mNumHeads * input_seq_len * input_seq_len;
size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * batch_size * input_seq_len;
size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0;
int const paddedHeadSize = mSageAttn ? ((mHeadSize + 15) / 16) * 16 : mHeadSize;
const size_t quanted_qkv_size
= mSageAttn ? sizeof(__nv_fp8_e4m3) * batch_size * input_seq_len * mNumHeads * paddedHeadSize * 3 : 0;
const size_t q_scale_size = mSageAttn
? sizeof(float) * batch_size * ((input_seq_len + mSageAttnQBlockSize - 1) / mSageAttnQBlockSize) * mNumHeads
: 0;
const size_t k_scale_size = mSageAttn
? sizeof(float) * batch_size * ((input_seq_len + mSageAttnKBlockSize - 1) / mSageAttnKBlockSize) * mNumHeads
: 0;
const size_t v_scale_size = mSageAttn
? sizeof(float) * batch_size * ((input_seq_len + mSageAttnVBlockSize - 1) / mSageAttnVBlockSize) * mNumHeads
: 0;
const size_t scale_bmm1_device_size = mSageAttn ? sizeof(float) * 2 : 0;
const size_t scale_bmm2_device_size = mSageAttn ? sizeof(float) : 0;
size_t sage_quant_space_size = mSageAttn ? sizeof(float) * batch_size * mNumHeads * mHeadSize : 0;
if (paddedHeadSize != mHeadSize)
sage_quant_space_size
= sage_quant_space_size < (batch_size * input_seq_len * mNumHeads * paddedHeadSize * sizeof(__nv_bfloat16))
? (batch_size * input_seq_len * mNumHeads * paddedHeadSize * sizeof(__nv_bfloat16))
: sage_quant_space_size;
// workspace for RingAttention ping-pong buffer
bool const enableRingAttn = (mCpGroup.size() > 1);
const size_t ring_q_buf_size = enableRingAttn ? size * batch_size * input_seq_len * local_hidden_units_ : 0;
const size_t ring_kv_buf_size = enableRingAttn
? 2 * size * batch_size * input_seq_len * local_hidden_units_ + sizeof(int) * (batch_size + 1)
: 0;
const size_t ring_softmax_stats_buf_size
= enableRingAttn ? 2 * sizeof(float) * batch_size * input_seq_len * mNumHeads : 0;
const size_t ring_softmax_stats_accu_buf_size
= enableRingAttn ? 2 * sizeof(float) * batch_size * input_seq_len * mNumHeads : 0;
const size_t ring_block_output_size = enableRingAttn ? size * batch_size * input_seq_len * local_hidden_units_ : 0;
int const NUM_BUFFERS = 24;
size_t workspaces[NUM_BUFFERS];
workspaces[0] = CUBLAS_WORKSPACE_SIZE;
workspaces[1] = attention_mask_size;
workspaces[2] = cu_seqlens_size;
workspaces[3] = q_buf_2_size;
workspaces[4] = k_buf_2_size;
workspaces[5] = v_buf_2_size;
workspaces[6] = qk_buf_size;
workspaces[7] = qkv_buf_2_size;
workspaces[8] = qk_buf_float_size;
workspaces[9] = padding_offset_size;
workspaces[10] = fmha_scheduler_counter;
workspaces[11] = quanted_qkv_size;
workspaces[12] = q_scale_size;
workspaces[13] = v_scale_size;
workspaces[14] = k_scale_size;
workspaces[15] = scale_bmm1_device_size;
workspaces[16] = scale_bmm2_device_size;
workspaces[17] = sage_quant_space_size;
workspaces[18] = ring_q_buf_size;
workspaces[19] = ring_kv_buf_size; // kv1
workspaces[20] = ring_kv_buf_size; // kv2
workspaces[21] = ring_softmax_stats_buf_size;
workspaces[22] = ring_softmax_stats_accu_buf_size;
workspaces[23] = ring_block_output_size;
return tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
}
template <typename T>
int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream)
{
// inputs
// input_tensor [batch_size, seq_len, local_hidden_size*3] or [num_tokens, local_hidden_size*3]
// input_lengths [batch_size]
// max_input_length [max_input_length] -- use shape dim to represent max value. If remove padding, this records
// the max input length among sequences; otherwise same as input_tensor's padded dim[1] relative_attention_bias
// [num_heads, num_buckets] (optional)
// outputs
// output_tensor [batch_size, seq_len, local_hidden_size] or [num_tokens, local_hidden_size]
// if remove padding, inputs[0] dim is [num_tokens] which doesn't have workspace info
// should get max_batch_size from inputs[1] and max_input_length from plugin attribute
int const batch_size = mRemovePadding ? inputDesc[1].dims.d[0] : inputDesc[0].dims.d[0];
int const input_seq_len = mRemovePadding ? inputDesc[2].dims.d[0] : inputDesc[0].dims.d[1];
int const num_tokens = mRemovePadding ? inputDesc[0].dims.d[0] : batch_size * input_seq_len;
int const request_batch_size = batch_size;
int const request_seq_len = input_seq_len;
int const local_hidden_units_ = inputDesc[0].dims.d[mRemovePadding ? 1 : 2] / 3;
float const q_scaling = mQScaling;
T const* attention_input = reinterpret_cast<T const*>(inputs[0]);
int const* input_lengths = reinterpret_cast<int const*>(inputs[1]);
T const* relative_attn_table = mRelativeAttention ? reinterpret_cast<T const*>(inputs[3]) : nullptr;
T* context_buf_ = (T*) (outputs[0]);
auto cublasHandle = mCublasWrapper->getCublasHandle();
TLLM_CUDA_CHECK(cublasSetStream(cublasHandle, stream));
mCublasWrapper->setStream(stream);
mCublasWrapper->setWorkspace(workspace);
if (inputDesc[0].type == DataType::kHALF)
{
mCublasWrapper->setFP16GemmConfig();
}
else if (inputDesc[0].type == DataType::kFLOAT)
{
mCublasWrapper->setFP32GemmConfig();
}
#ifdef ENABLE_BF16
else if constexpr (std::is_same_v<T, __nv_bfloat16>)
{
mCublasWrapper->setBF16GemmConfig();
}
#endif
size_t const attention_mask_size = mEnableContextFMHA ? 0 : sizeof(T) * batch_size * input_seq_len * input_seq_len;
size_t const cu_seqlens_size = sizeof(int) * (batch_size + 1);
size_t const q_buf_2_size = mEnableContextFMHA ? 0 : sizeof(T) * batch_size * input_seq_len * local_hidden_units_;
size_t const k_buf_2_size = mEnableContextFMHA ? 0 : sizeof(T) * batch_size * input_seq_len * local_hidden_units_;
size_t const v_buf_2_size = mEnableContextFMHA ? 0 : sizeof(T) * batch_size * input_seq_len * local_hidden_units_;
size_t const qk_buf_size
= mEnableContextFMHA ? 0 : sizeof(T) * batch_size * mNumHeads * input_seq_len * input_seq_len;
size_t const qkv_buf_2_size = mEnableContextFMHA ? 0 : sizeof(T) * batch_size * input_seq_len * local_hidden_units_;
size_t const qk_buf_float_size
= mEnableContextFMHA ? 0 : sizeof(float) * batch_size * mNumHeads * input_seq_len * input_seq_len;
size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * batch_size * input_seq_len;
size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0;
int const paddedHeadSize = mSageAttn ? ((mHeadSize + 15) / 16) * 16 : mHeadSize;
const size_t quanted_qkv_size
= mSageAttn ? sizeof(__nv_fp8_e4m3) * batch_size * input_seq_len * mNumHeads * paddedHeadSize * 3 : 0;
const size_t q_scale_size = mSageAttn
? sizeof(float) * batch_size * ((input_seq_len + mSageAttnQBlockSize - 1) / mSageAttnQBlockSize) * mNumHeads
: 0;
const size_t k_scale_size = mSageAttn
? sizeof(float) * batch_size * ((input_seq_len + mSageAttnKBlockSize - 1) / mSageAttnKBlockSize) * mNumHeads
: 0;
const size_t v_scale_size = mSageAttn
? sizeof(float) * batch_size * ((input_seq_len + mSageAttnVBlockSize - 1) / mSageAttnVBlockSize) * mNumHeads
: 0;
const size_t scale_bmm1_device_size = mSageAttn ? sizeof(float) * 2 : 0;
const size_t scale_bmm2_device_size = mSageAttn ? sizeof(float) : 0;
size_t sage_quant_space_size = mSageAttn ? sizeof(float) * batch_size * mNumHeads * mHeadSize : 0;
if (paddedHeadSize != mHeadSize)
sage_quant_space_size
= sage_quant_space_size < (batch_size * input_seq_len * mNumHeads * paddedHeadSize * sizeof(__nv_bfloat16))
? (batch_size * input_seq_len * mNumHeads * paddedHeadSize * sizeof(__nv_bfloat16))
: sage_quant_space_size;
bool const enableRingAttn = (mCpGroup.size() > 1);
const size_t ring_q_buf_size = enableRingAttn ? sizeof(T) * batch_size * input_seq_len * local_hidden_units_ : 0;
const size_t ring_kv_buf_size
= enableRingAttn ? 2 * sizeof(T) * batch_size * input_seq_len * local_hidden_units_ : 0;
const size_t ring_softmax_stats_buf_size
= enableRingAttn ? 2 * sizeof(float) * batch_size * input_seq_len * mNumHeads : 0;
const size_t ring_block_output_size
= enableRingAttn ? sizeof(T) * batch_size * input_seq_len * local_hidden_units_ : 0;
// Workspace pointer shift
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(workspace);
size_t offset = CUBLAS_WORKSPACE_SIZE;
T* attention_mask = reinterpret_cast<T*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, attention_mask_size));
int* cu_seqlens = reinterpret_cast<int*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, cu_seqlens_size));
T* q_buf_2_ = reinterpret_cast<T*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, q_buf_2_size));
T* k_buf_2_ = reinterpret_cast<T*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, k_buf_2_size));
T* v_buf_2_ = reinterpret_cast<T*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, v_buf_2_size));
T* qk_buf_ = reinterpret_cast<T*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, qk_buf_size));
T* qkv_buf_2_ = reinterpret_cast<T*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, qkv_buf_2_size));
float* qk_buf_float_
= reinterpret_cast<float*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, qk_buf_float_size));
int* padding_offset = reinterpret_cast<int*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, padding_offset_size));
uint32_t* fmha_tile_counter_ptr
= reinterpret_cast<uint32_t*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, fmha_scheduler_counter));
__nv_fp8_e4m3* quanted_qkv_ptr
= reinterpret_cast<__nv_fp8_e4m3*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, quanted_qkv_size));
float* q_scale_ptr = reinterpret_cast<float*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, q_scale_size));
float* k_scale_ptr = reinterpret_cast<float*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, k_scale_size));
float* v_scale_ptr = reinterpret_cast<float*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, v_scale_size));
float* scale_bmm1_ptr
= reinterpret_cast<float*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, scale_bmm1_device_size));
float* scale_bmm2_ptr
= reinterpret_cast<float*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, scale_bmm2_device_size));
void* sage_quant_space_ptr
= reinterpret_cast<void*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, sage_quant_space_size));
T* ring_q_buf_ = reinterpret_cast<T*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_q_buf_size));
T* ring_kv_buf_1_ = reinterpret_cast<T*>(
tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_kv_buf_size + sizeof(int) * (batch_size + 1)));
T* ring_kv_buf_2_ = reinterpret_cast<T*>(
tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_kv_buf_size + sizeof(int) * (batch_size + 1)));
float* ring_softmax_stats_buf_
= reinterpret_cast<float*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_softmax_stats_buf_size));
float* ring_softmax_accu_stats_buf_
= reinterpret_cast<float*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_softmax_stats_buf_size));
T* ring_block_output_
= reinterpret_cast<T*>(tc::nextWorkspacePtr(workspace_byte_ptr, offset, ring_block_output_size));
// build attention_mask, cu_seqlens, and padding_offset tensors
BuildDecoderInfoParams<T> params{};
params.seqQOffsets = cu_seqlens;
params.paddingOffsets = padding_offset;
params.attentionMask = attention_mask;
params.seqQLengths = input_lengths;
params.batchSize = batch_size;
params.maxQSeqLength = input_seq_len;
params.numTokens = num_tokens;
params.attentionMaskType = AttentionMaskType::PADDING;
params.fmhaTileCounter = fmha_tile_counter_ptr;
if (mSageAttn)
{
params.fmhaHostBmm1Scale = 1.0f / (sqrtf(mHeadSize * 1.0f) * q_scaling);
params.fmhaBmm1Scale = scale_bmm1_ptr;
params.fmhaBmm2Scale = scale_bmm2_ptr;
}
invokeBuildDecoderInfo(params, stream);
sync_check_cuda_error(stream);
auto const gemm_data_type = tc::CudaDataType<T>::value;
int const attention_seq_len_1 = request_seq_len; // q length
int const attention_seq_len_2 = request_seq_len; // kv length
// If the model has relative attentiona bias, q scaling should be applied in QK gemm stage and use 1 in
// softamax stage (because to get softmax[scale(Q*K) + rel pos bias] here, q_scaling can't be applied during
// softmax phase by qk_scale); otherwise, use 1 in gemm stage and apply scaling in softmax stage
float const qk_scale
= 1.0f / (sqrtf(mHeadSize * 1.0f) * q_scaling); // q_scaling in denominator. by default q_scaling =1.0f
float const qk_scale_gemm = mRelativeAttention ? qk_scale : 1.0f;
T const qk_scale_softmax = static_cast<T>(mRelativeAttention ? 1.0f : qk_scale);
T* linear_bias_slopes = nullptr;
// FMHA doesn't apply to MHA with relative attention bias, i.e. softmax(QK + bias) * V
// We update mEnableContextFMHA in constructor to check this condition
if (mEnableContextFMHA)
{
if (enableRingAttn)
{
// make sure the padding part of key/value buffer is 0
cudaMemsetAsync(ring_kv_buf_1_, 0,
reinterpret_cast<int8_t*>(ring_kv_buf_2_) - reinterpret_cast<int8_t*>(ring_kv_buf_1_), stream);
cudaMemcpyAsync(ring_q_buf_, attention_input, ring_q_buf_size, cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(ring_kv_buf_1_,
const_cast<char*>(reinterpret_cast<char const*>(attention_input)) + ring_q_buf_size, ring_kv_buf_size,
cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(reinterpret_cast<char*>(ring_kv_buf_1_) + ring_kv_buf_size, cu_seqlens,
sizeof(int) * (batch_size + 1), cudaMemcpyDeviceToDevice, stream);
// init softmax_stats
cudaMemsetAsync(ring_softmax_accu_stats_buf_, 0, ring_softmax_stats_buf_size, stream);
#if ENABLE_MULTI_DEVICE
// relative position of prev/next rank in cp group
int prev_rank = mCpRank > 0 ? mCpRank - 1 : mCpGroup.size() - 1;
int next_rank = (mCpRank == static_cast<int>(mCpGroup.size() - 1)) ? 0 : mCpRank + 1;
#endif // ENABLE_MULTI_DEVICE
common::check_cuda_error(cudaStreamCreate(&mNcclStream));
common::check_cuda_error(cudaStreamSynchronize(stream));
uint32_t* fmha_scheduler_counter_h = (uint32_t*) malloc(sizeof(uint32_t));
cudaMemcpyAsync(
fmha_scheduler_counter_h, fmha_tile_counter_ptr, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
for (size_t iter = 0; iter < mCpGroup.size(); ++iter)
{
// KV buffer used by fmha
T* ring_fmha_kv_buf_ = (iter % 2 == 0) ? ring_kv_buf_1_ : ring_kv_buf_2_;
#if ENABLE_MULTI_DEVICE
T* ring_send_kv_buf_ = (iter % 2 == 0) ? ring_kv_buf_1_ : ring_kv_buf_2_;
T* ring_recv_kv_buf_ = (iter % 2 == 0) ? ring_kv_buf_2_ : ring_kv_buf_1_;
if (iter < mCpGroup.size() - 1)
{
NCCLCHECK(ncclGroupStart());
TLLM_CHECK_WITH_INFO(mNcclComm.get() != nullptr, "mNcclComm should be initialized before used");
NCCLCHECK(ncclSend(ring_send_kv_buf_,
ring_kv_buf_size / sizeof(T) + sizeof(int) / sizeof(T) * (batch_size + 1),
(*getDtypeMap())[inputDesc[0].type], next_rank, *mNcclComm, mNcclStream));
NCCLCHECK(ncclRecv(ring_recv_kv_buf_,
ring_kv_buf_size / sizeof(T) + sizeof(int) / sizeof(T) * (batch_size + 1),
(*getDtypeMap())[inputDesc[0].type], prev_rank, *mNcclComm, mNcclStream));
NCCLCHECK(ncclGroupEnd());
}
#else
TLLM_LOG_ERROR("Please set ENABLE_MULTI_DEVICE to enable RingAttention");
return 1;
#endif // ENABLE_MULTI_DEVICE
// Construct the fmha params for running kernels.
MHARunnerParams fmhaParams{};
fmhaParams.b = request_batch_size;
fmhaParams.qSeqLen = request_seq_len;
fmhaParams.kvSeqLen = request_seq_len;
fmhaParams.totalQSeqLen = request_batch_size * request_seq_len;
// Device buffer pointers.
fmhaParams.qPtr = ring_q_buf_;
fmhaParams.kvPtr = ring_fmha_kv_buf_;
if (iter == 0)
{
fmhaParams.outputPtr = context_buf_;
fmhaParams.softmaxStatsPtr = ring_softmax_accu_stats_buf_;
}
else
{
cudaMemsetAsync(ring_softmax_stats_buf_, 0, ring_softmax_stats_buf_size, stream);
fmhaParams.outputPtr = ring_block_output_;
fmhaParams.softmaxStatsPtr = ring_softmax_stats_buf_;
}
fmhaParams.cuQSeqLenPtr = cu_seqlens;
fmhaParams.cuKvSeqLenPtr
= reinterpret_cast<int*>(reinterpret_cast<char*>(ring_fmha_kv_buf_) + ring_kv_buf_size);
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
fmhaParams.stream = stream;
// Run the fmha kernel.
cudaMemsetAsync(fmhaParams.outputPtr, 0, ring_block_output_size, stream);
cudaMemcpyAsync(fmhaParams.tileCounterPtr, fmha_scheduler_counter_h, sizeof(uint32_t),
cudaMemcpyHostToDevice, stream);
mFMHARunner->run(fmhaParams);
if (iter != 0)
{
invokeRecoverFromRA<T>((T*) context_buf_, (float*) ring_softmax_accu_stats_buf_,
(T*) ring_block_output_, (float*) ring_softmax_stats_buf_, fmhaParams.b, fmhaParams.qSeqLen,
mNumHeads, mHeadSize, cu_seqlens, stream);
}
cudaStreamSynchronize(stream);
cudaStreamSynchronize(mNcclStream);
}
common::check_cuda_error(cudaStreamDestroy(mNcclStream));
free(fmha_scheduler_counter_h);
}
else
{
if (mSageAttn && mHeadSize == 72 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
&& mSageAttnVBlockSize == 256)
{
sage_quant<72, 80, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize,
attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
cu_seqlens, sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 80 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
&& mSageAttnVBlockSize == 256)
{
sage_quant<80, 80, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize,
attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
cu_seqlens, sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 128 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 64
&& mSageAttnVBlockSize == 256)
{
sage_quant<128, 128, 64, 64, 256, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize,
attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
cu_seqlens, sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 128 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
&& mSageAttnVBlockSize == 32)
{
sage_quant<128, 128, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize,
attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
cu_seqlens, sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 80 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
&& mSageAttnVBlockSize == 32)
{
sage_quant<80, 80, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize,
attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
cu_seqlens, sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
sync_check_cuda_error(stream);
}
if (mSageAttn && mHeadSize == 72 && mSageAttnQBlockSize == 64 && mSageAttnKBlockSize == 32
&& mSageAttnVBlockSize == 32)
{
sage_quant<72, 80, 64, 32, 32, __nv_bfloat16, __nv_fp8_e4m3, float>(
// host var
batch_size, mNumHeads, input_seq_len, true, true,
// device var
// q k v
attention_input, attention_input + mNumHeads * mHeadSize,
attention_input + 2 * mNumHeads * mHeadSize,
// stride
3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, 3 * mNumHeads * mHeadSize, cu_seqlens,
cu_seqlens, sage_quant_space_ptr,
// quant q k v
quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * paddedHeadSize,
quanted_qkv_ptr + 2 * mNumHeads * paddedHeadSize,
// quanted_qkv_ptr, quanted_qkv_ptr + mNumHeads * mHeadSize, context,
3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize, 3 * mNumHeads * paddedHeadSize,
// scales
q_scale_ptr, k_scale_ptr, v_scale_ptr, stream);
sync_check_cuda_error(stream);
}
// Construct the fmha params for running kernels.
MHARunnerParams fmhaParams{};
fmhaParams.b = request_batch_size;
fmhaParams.qSeqLen = request_seq_len;
fmhaParams.kvSeqLen = request_seq_len;
fmhaParams.totalQSeqLen = request_batch_size * request_seq_len;
// Device buffer pointers.
fmhaParams.qkvPtr = attention_input;
fmhaParams.outputPtr = context_buf_;
fmhaParams.cuQSeqLenPtr = cu_seqlens;
fmhaParams.cuKvSeqLenPtr = cu_seqlens;
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
fmhaParams.stream = stream;
if (mSageAttn)
{
if (paddedHeadSize != mHeadSize)
fmhaParams.outputPtr = sage_quant_space_ptr;
fmhaParams.qkvPtr = quanted_qkv_ptr;
fmhaParams.scaleBmm1Ptr = scale_bmm1_ptr;
fmhaParams.scaleBmm2Ptr = scale_bmm2_ptr;
fmhaParams.qScalePtr = q_scale_ptr;
fmhaParams.kScalePtr = k_scale_ptr;
fmhaParams.vScalePtr = v_scale_ptr;
fmhaParams.qMaxNBlock = (input_seq_len + mSageAttnQBlockSize - 1) / mSageAttnQBlockSize;
fmhaParams.kMaxNBlock = (input_seq_len + mSageAttnKBlockSize - 1) / mSageAttnKBlockSize;
fmhaParams.vMaxNBlock = (input_seq_len + mSageAttnVBlockSize - 1) / mSageAttnVBlockSize;
}
// Run the fmha kernel.
mFMHARunner->run(fmhaParams);
sync_check_cuda_error(stream);
if (mSageAttn)
{
if (paddedHeadSize != mHeadSize && mHeadSize == 72)
{
unpadding<80, 72, __nv_bfloat16>(batch_size, mNumHeads, input_seq_len, sage_quant_space_ptr,
mNumHeads * 72, mNumHeads * 80, cu_seqlens, context_buf_, stream);
}
}
}
}
else
{
// FIXME: a temporary solution to make sure the padding part of key/value buffer is 0
// NOTE: pointer subtraction is used below since there could be some extra gap due to alignment.
// Otherwise, we could do cudaMemsetAsync(k_buf_2_, 0, k_buf_2_size + v_buf_2_size, stream);
// cudaMemsetAsync(k_buf_2_, 0, reinterpret_cast<int8_t*>(qk_buf_) - reinterpret_cast<int8_t*>(k_buf_2_),
// stream);
// FIXME: the final solution is to change the add_fusedQKV_bias_transpose_kernel to map CTAs corresponding to
// the output shape, and set the padding part to 0. Without zero-initialize guarantee, these workspace buffers
// may contain random NaN values when IFB workload is high.
cudaMemsetAsync(k_buf_2_, 0,
reinterpret_cast<int8_t*>(v_buf_2_) - reinterpret_cast<int8_t*>(k_buf_2_) + v_buf_2_size, stream);
// only non-FMHA path needs to split Q,K,V from QKV
invokeAddFusedQKVBiasTranspose(q_buf_2_, k_buf_2_, v_buf_2_, const_cast<T*>(attention_input), input_lengths,
mRemovePadding ? padding_offset : nullptr, batch_size, input_seq_len, num_tokens, mNumHeads, mNumHeads,
mHeadSize, 0, 0.0f, RotaryScalingType::kNONE, 0.0f, 0, PositionEmbeddingType::kLEARNED_ABSOLUTE,
(float*) nullptr, 0, stream);
if (!mQKHalfAccum && gemm_data_type != CUDA_R_32F)
{
mCublasWrapper->stridedBatchedGemm(CUBLAS_OP_T, CUBLAS_OP_N,
attention_seq_len_2, // n
attention_seq_len_1, // m
mHeadSize, // k
qk_scale_gemm, k_buf_2_, gemm_data_type,
mHeadSize, // k
attention_seq_len_2 * mHeadSize, // n * k
q_buf_2_, gemm_data_type,
mHeadSize, // k
attention_seq_len_1 * mHeadSize, // m * k
0.0f, qk_buf_float_, CUDA_R_32F,
attention_seq_len_2, // n
attention_seq_len_2 * attention_seq_len_1,
request_batch_size * mNumHeads, // global batch size
CUDA_R_32F);
// add relative position bias
if (mRelativeAttention)
{
// add rel pos bias
// QK is (batch_size, local_head_num, q_length, k_length), rel pos bias is (1, local_head_num,
// max_output_len + 1, max_output_len + 1). broadcast along 1st dim. max_seq_len is already
// max_output_len + 1. In implicit mode, relative_attention_bias is rel attn table
// [num_heads, num_buckets], with necessary params (max_distance, num_buckets) passed at the end
invokeAddRelativeAttentionBiasUnaligned(qk_buf_float_, relative_attn_table, request_batch_size,
mNumHeads, attention_seq_len_1, attention_seq_len_2, stream, mMaxDistance > 0,
inputDesc[3].dims.d[1], mMaxDistance, true /* bidirectional */);
}
MaskedSoftmaxParam<T, float> param;
param.attention_score = qk_buf_; // (batch_size, head_num, q_length, k_length)
param.qk = qk_buf_float_; // (batch_size, head_num, q_length, k_length)
param.attention_mask = attention_mask; // (batch_size, q_length, k_length)
param.batch_size = request_batch_size;
param.q_length = attention_seq_len_1;
param.k_length = attention_seq_len_2;
param.num_heads = mNumHeads;
param.qk_scale = qk_scale_softmax;
param.linear_bias_slopes = const_cast<T*>(linear_bias_slopes); // (head_num,), optional
invokeMaskedSoftmax(param, stream);
}
else
{
mCublasWrapper->stridedBatchedGemm(CUBLAS_OP_T, CUBLAS_OP_N, attention_seq_len_2, attention_seq_len_1,
mHeadSize, k_buf_2_, mHeadSize, attention_seq_len_2 * mHeadSize, q_buf_2_, mHeadSize,
attention_seq_len_1 * mHeadSize, qk_buf_, attention_seq_len_2,
attention_seq_len_2 * attention_seq_len_1, request_batch_size * mNumHeads, qk_scale_gemm,
0.0f); // alpha, beta
// add relative position bias
if (mRelativeAttention)
{
// add rel pos bias
// QK is (batch_size, local_head_num, q_length, k_length), rel pos bias is (1, local_head_num,
// max_output_len + 1, max_output_len + 1). broadcast along 1st dim. max_seq_len is already
// max_output_len + 1. In implicit mode, relative_attention_bias is rel attn table
// [num_heads, num_buckets], with necessary params (max_distance, num_buckets) passed at the end
invokeAddRelativeAttentionBiasUnaligned(qk_buf_, relative_attn_table, request_batch_size, mNumHeads,
attention_seq_len_1, attention_seq_len_2, stream, mMaxDistance > 0, inputDesc[3].dims.d[1],
mMaxDistance, true /* bidirectional */);
}
MaskedSoftmaxParam<T, T> param;
param.attention_score = qk_buf_; // (batch_size, head_num, q_length, k_length)
param.qk = qk_buf_; // (batch_size, head_num, q_length, k_length)
param.attention_mask = attention_mask; // (batch_size, q_length, k_length)
param.batch_size = request_batch_size;
param.q_length = attention_seq_len_1;
param.k_length = attention_seq_len_2;
param.num_heads = mNumHeads;
param.qk_scale = qk_scale_softmax;
param.linear_bias_slopes = const_cast<T*>(linear_bias_slopes); // (head_num,), optional
invokeMaskedSoftmax(param, stream);
}
mCublasWrapper->stridedBatchedGemm(CUBLAS_OP_N, CUBLAS_OP_N, mHeadSize, attention_seq_len_1,
attention_seq_len_2, v_buf_2_, mHeadSize, attention_seq_len_2 * mHeadSize, qk_buf_, attention_seq_len_2,
attention_seq_len_1 * attention_seq_len_2, qkv_buf_2_, mHeadSize, attention_seq_len_1 * mHeadSize,
request_batch_size * mNumHeads);
if (!mRemovePadding)
{
invokeTransposeQKV(context_buf_, qkv_buf_2_, request_batch_size, attention_seq_len_1, mNumHeads, mHeadSize,
(float*) nullptr, 0, stream);
}
else
{
invokeTransposeAttentionOutRemovePadding(qkv_buf_2_, context_buf_, num_tokens, request_batch_size,
request_seq_len, mNumHeads, mHeadSize, padding_offset, (float*) nullptr, 0, stream);
}
}
sync_check_cuda_error(stream);
return 0;
}
template int BertAttentionPlugin::enqueueImpl<half>(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream);
template int BertAttentionPlugin::enqueueImpl<float>(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream);
#ifdef ENABLE_BF16
template int BertAttentionPlugin::enqueueImpl<__nv_bfloat16>(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream);
#endif
int BertAttentionPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
{
if (mType == DataType::kHALF)
{
return enqueueImpl<half>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
}
else if (mType == DataType::kFLOAT)
{
return enqueueImpl<float>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
}
#ifdef ENABLE_BF16
else if (mType == DataType::kBF16)
{
return enqueueImpl<__nv_bfloat16>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
}
#endif
return 0;
}
// IPluginV2Ext Methods
nvinfer1::DataType BertAttentionPlugin::getOutputDataType(
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
{
TLLM_CHECK(index == 0);
return inputTypes[0];
}
// IPluginV2 Methods
char const* BertAttentionPlugin::getPluginType() const noexcept
{
return BERT_ATTENTION_PLUGIN_NAME;
}
char const* BertAttentionPlugin::getPluginVersion() const noexcept
{
return BERT_ATTENTION_PLUGIN_VERSION;
}
int BertAttentionPlugin::getNbOutputs() const noexcept
{
return 1;
}
int BertAttentionPlugin::initialize() noexcept
{
auto cublasHandle = getCublasHandle();
auto cublasLtHandle = getCublasLtHandle();
mCublasWrapper.reset(new tc::CublasMMWrapper(cublasHandle, cublasLtHandle, nullptr, nullptr));
if (mEnableContextFMHA)
{
// Pre-checked during constructing.
Data_type data_type;
if (mType == DataType::kHALF)
{
data_type = DATA_TYPE_FP16;
}
else if (mType == DataType::kBF16)
{
data_type = DATA_TYPE_BF16;
}
else
{
TLLM_CHECK_WITH_INFO(false, "GPTAttentionPlugin received wrong data type.");
}
// Construct the fmha runner.
MHARunnerFixedParams fmhaParams{};
if (mSageAttn)
{
fmhaParams.dataType = DATA_TYPE_E4M3;
}
else
{
fmhaParams.dataType = data_type;
}
fmhaParams.dataTypeOut = data_type;
fmhaParams.forceFp32Acc = mFMHAForceFP32Acc;
fmhaParams.attentionMaskType = ContextAttentionMaskType::PADDING;
fmhaParams.isSPadded = !mRemovePadding;
fmhaParams.numQHeads = mNumHeads;
fmhaParams.numKvHeads = mNumHeads;
fmhaParams.headSize = mHeadSize;
fmhaParams.qScaling = mQScaling;
fmhaParams.sageBlockSizeQ = mSageAttnQBlockSize;
fmhaParams.sageBlockSizeK = mSageAttnKBlockSize;
fmhaParams.sageBlockSizeV = mSageAttnVBlockSize;
if (mSageAttn)
{
int const paddedHeadSize = ((mHeadSize + 15) / 16) * 16;
fmhaParams.headSize = paddedHeadSize;
}
if (mCpGroup.size() > 1)
{
fmhaParams.attentionInputLayout = AttentionInputLayout::Q_CONTIGUOUS_KV;
fmhaParams.saveSoftmax = true;
}
// Load kernels from the pre-compiled cubins.
mFMHARunner.reset(new FusedMHARunnerV2(fmhaParams));
// Fall back to unfused MHA kernels if not supported.
mEnableContextFMHA = mFMHARunner->isFmhaSupported();
}
#if ENABLE_MULTI_DEVICE
if (mCpGroup.size() > 1 && COMM_SESSION.getSize() > 1)
{
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
mNcclComm = getComm(mCpGroup);
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, COMM_SESSION.getRank());
}
#endif // ENABLE_MULTI_DEVICE
return 0;
}
void BertAttentionPlugin::destroy() noexcept
{
delete this;
}
size_t BertAttentionPlugin::getSerializationSize() const noexcept
{
return sizeof(mNumHeads) + sizeof(mHeadSize) + sizeof(mQScaling) + sizeof(mQKHalfAccum) + sizeof(mEnableContextFMHA)
+ sizeof(mFMHAForceFP32Acc) + sizeof(mType) + sizeof(mRelativeAttention) + sizeof(mMaxDistance)
+ sizeof(mRemovePadding) + sizeof(mSageAttn) + sizeof(mSageAttnQBlockSize) + sizeof(mSageAttnKBlockSize)
+ sizeof(mSageAttnVBlockSize) + sizeof(mCpSize) + sizeof(mCpRank) + sizeof(int32_t) * mCpGroup.size();
}
void BertAttentionPlugin::serialize(void* buffer) const noexcept
{
char *d = static_cast<char*>(buffer), *a = d;
write(d, mNumHeads);
write(d, mHeadSize);
write(d, mQScaling);
write(d, mQKHalfAccum);
write(d, mEnableContextFMHA);
write(d, mFMHAForceFP32Acc);
write(d, mType);
write(d, mRelativeAttention);
write(d, mMaxDistance);
write(d, mRemovePadding);
write(d, mSageAttn);
write(d, mSageAttnQBlockSize);
write(d, mSageAttnKBlockSize);
write(d, mSageAttnVBlockSize);
write(d, mCpSize);
write(d, mCpRank);
for (auto it = mCpGroup.begin(); it != mCpGroup.end(); ++it)
{
write(d, *it);
}
TLLM_CHECK(d == a + getSerializationSize());
}
void BertAttentionPlugin::terminate() noexcept {}
///////////////
BertAttentionPluginCreator::BertAttentionPluginCreator()
{
// Fill PluginFieldCollection with PluginField arguments metadata
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("num_heads", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("head_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("q_scaling", nullptr, PluginFieldType::kFLOAT32));
mPluginAttributes.emplace_back(PluginField("context_fmha_type", nullptr, PluginFieldType::kINT8));
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("do_relative_attention", nullptr, PluginFieldType::kINT8));
mPluginAttributes.emplace_back(PluginField("max_distance", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("remove_padding", nullptr, PluginFieldType::kINT8));
mPluginAttributes.emplace_back(PluginField("sage_attn", nullptr, PluginFieldType::kINT8));
mPluginAttributes.emplace_back(PluginField("sage_attn_q_block_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("sage_attn_k_block_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("sage_attn_v_block_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("cp_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("cp_rank", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(PluginField("cp_group", nullptr, PluginFieldType::kINT32));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
char const* BertAttentionPluginCreator::getPluginName() const noexcept
{
return BERT_ATTENTION_PLUGIN_NAME;
}
char const* BertAttentionPluginCreator::getPluginVersion() const noexcept
{
return BERT_ATTENTION_PLUGIN_VERSION;
}
PluginFieldCollection const* BertAttentionPluginCreator::getFieldNames() noexcept
{
return &mFC;
}
IPluginV2* BertAttentionPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
{
PluginField const* fields = fc->fields;
int num_heads{};
int head_size{};
ContextFMHAType context_fmha_type{};
float q_scaling{};
nvinfer1::DataType type{};
bool do_relative_attention{};
int max_distance{};
bool remove_padding{};
bool sage_attn{};
int sage_attn_q_block_size{};
int sage_attn_k_block_size{};
int sage_attn_v_block_size{};
int cp_size{};
int cp_rank{};
std::set<int> cp_group{};
// Read configurations from each fields
for (int i = 0; i < fc->nbFields; ++i)
{
char const* attrName = fields[i].name;
if (!strcmp(attrName, "num_heads"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
num_heads = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "head_size"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
head_size = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "q_scaling"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kFLOAT32);
q_scaling = static_cast<float>(*(static_cast<float const*>(fields[i].data)));
}
else if (!strcmp(attrName, "context_fmha_type"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
context_fmha_type = static_cast<ContextFMHAType>(*(static_cast<int8_t const*>(fields[i].data)));
}
else if (!strcmp(attrName, "type_id"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
type = static_cast<nvinfer1::DataType>(*(static_cast<nvinfer1::DataType const*>(fields[i].data)));
}
else if (!strcmp(attrName, "do_relative_attention"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
do_relative_attention = static_cast<bool>(*(static_cast<int8_t const*>(fields[i].data)));
}
else if (!strcmp(attrName, "max_distance"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
max_distance = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "remove_padding"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
remove_padding = static_cast<bool>(*(static_cast<int8_t const*>(fields[i].data)));
}
else if (!strcmp(attrName, "sage_attn"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT8);
sage_attn = static_cast<bool>(*(static_cast<int8_t const*>(fields[i].data)));
if (sage_attn)
{
std::cout << "sage attn true!" << std::endl;
}
}
else if (!strcmp(attrName, "sage_attn_q_block_size"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
sage_attn_q_block_size = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "sage_attn_k_block_size"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
sage_attn_k_block_size = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "sage_attn_v_block_size"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
sage_attn_v_block_size = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "cp_size"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
cp_size = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "cp_rank"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
cp_rank = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "cp_group"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
auto const* r = static_cast<int const*>(fields[i].data);
for (int j = 0; j < fields[i].length; ++j)
{
cp_group.insert(*r);
++r;
}
}
}
try
{
auto* obj = new BertAttentionPlugin(num_heads, head_size, q_scaling, context_fmha_type, type,
do_relative_attention, max_distance, remove_padding, sage_attn, sage_attn_q_block_size,
sage_attn_k_block_size, sage_attn_v_block_size, cp_size, cp_rank, cp_group);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
IPluginV2* BertAttentionPluginCreator::deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept
{
// This object will be deleted when the network is destroyed, which will
// call BertAttentionPlugin::destroy()
try
{
auto* obj = new BertAttentionPlugin(serialData, serialLength);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}