mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
1113 lines
55 KiB
C++
1113 lines
55 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
|
|
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include "gptAttentionCommon.h"
|
|
#include "tensorrt_llm/common/assert.h"
|
|
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h"
|
|
#include "tensorrt_llm/kernels/gptKernels.h"
|
|
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
|
|
#include "tensorrt_llm/plugins/common/checkMacrosPlugin.h"
|
|
#include "tensorrt_llm/runtime/iBuffer.h"
|
|
#include <NvInferRuntimePlugin.h>
|
|
#include <algorithm>
|
|
#include <cstdint>
|
|
#include <type_traits>
|
|
|
|
using namespace nvinfer1;
|
|
using namespace tensorrt_llm::kernels;
|
|
namespace tc = tensorrt_llm::common;
|
|
using tensorrt_llm::plugins::GPTAttentionPluginCreatorCommon;
|
|
using tensorrt_llm::plugins::GPTAttentionPluginCommon;
|
|
|
|
template <typename KVCacheBuffer>
|
|
struct KVCacheBufferDataType
|
|
{
|
|
};
|
|
|
|
template <>
|
|
struct KVCacheBufferDataType<KVLinearBuffer>
|
|
{
|
|
using Type = int8_t;
|
|
};
|
|
|
|
template <>
|
|
struct KVCacheBufferDataType<KVBlockArray>
|
|
{
|
|
using Type = int64_t;
|
|
};
|
|
|
|
template <typename T>
|
|
struct SATypeConverter
|
|
{
|
|
using Type = T;
|
|
};
|
|
|
|
template <>
|
|
struct SATypeConverter<half>
|
|
{
|
|
using Type = uint16_t;
|
|
};
|
|
|
|
template <typename T, typename KVCacheBuffer>
|
|
struct FusedQKVMaskedAttentionDispatchParams
|
|
{
|
|
const T* qkv_buf;
|
|
const T* qkv_bias;
|
|
const T* relative_attention_bias;
|
|
const int* cache_indir;
|
|
T* context_buf;
|
|
const bool* finished;
|
|
const int* sequence_lengths;
|
|
int max_batch_size;
|
|
int inference_batch_size;
|
|
int beam_width;
|
|
int head_num;
|
|
int kv_head_num;
|
|
int size_per_head;
|
|
int rotary_embedding_dim;
|
|
float rotary_embedding_base;
|
|
RotaryScalingType rotary_embedding_scale_type;
|
|
float rotary_embedding_scale;
|
|
int rotary_embedding_max_positions;
|
|
PositionEmbeddingType position_embedding_type;
|
|
int max_seq_len;
|
|
const int* input_lengths;
|
|
int step;
|
|
float q_scaling;
|
|
int relative_attention_bias_stride;
|
|
const T* linear_bias_slopes;
|
|
const int* ia3_tasks;
|
|
const T* ia3_key_weights;
|
|
const T* ia3_value_weights;
|
|
const float* qkv_scale_out;
|
|
const float* attention_out_scale;
|
|
tc::QuantMode quant_option;
|
|
bool multi_block_mode;
|
|
int max_seq_len_tile;
|
|
T* partial_out;
|
|
float* partial_sum;
|
|
float* partial_max;
|
|
int* block_counter;
|
|
const float* kv_scale_orig_quant;
|
|
const float* kv_scale_quant_orig;
|
|
tc::QuantMode kv_cache_quant_mode;
|
|
int multi_processor_count;
|
|
KVCacheBuffer kv_block_array;
|
|
bool cross_attention = false;
|
|
const int* memory_length_per_sample = nullptr;
|
|
int max_distance = 0;
|
|
};
|
|
|
|
template <typename T_MMHA, typename T, typename KVCacheBuffer, bool CROSS_ATTENTION>
|
|
void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS_ATTENTION>& params,
|
|
const FusedQKVMaskedAttentionDispatchParams<T, KVCacheBuffer>& input_params, cudaStream_t stream)
|
|
{
|
|
using DataType = typename SATypeConverter<T>::Type;
|
|
|
|
// Prepare the parameters.
|
|
memset(¶ms, 0, sizeof(params));
|
|
|
|
int hidden_units = input_params.head_num * input_params.size_per_head;
|
|
int hidden_units_kv = input_params.kv_head_num * input_params.size_per_head;
|
|
if (input_params.qkv_bias != nullptr)
|
|
{
|
|
params.q_bias = reinterpret_cast<const DataType*>(input_params.qkv_bias);
|
|
params.k_bias = reinterpret_cast<const DataType*>(input_params.qkv_bias) + hidden_units;
|
|
params.v_bias = reinterpret_cast<const DataType*>(input_params.qkv_bias) + hidden_units + hidden_units_kv;
|
|
}
|
|
else
|
|
{
|
|
params.q_bias = nullptr;
|
|
params.k_bias = nullptr;
|
|
params.v_bias = nullptr;
|
|
}
|
|
|
|
// Set the output buffer.
|
|
params.out = reinterpret_cast<DataType*>(input_params.context_buf);
|
|
|
|
// Set the input buffers.
|
|
params.q = reinterpret_cast<const DataType*>(input_params.qkv_buf);
|
|
params.k = reinterpret_cast<const DataType*>(input_params.qkv_buf) + hidden_units;
|
|
params.v = reinterpret_cast<const DataType*>(input_params.qkv_buf) + hidden_units + hidden_units_kv;
|
|
|
|
params.int8_kv_cache = input_params.kv_cache_quant_mode.hasInt8KvCache();
|
|
params.fp8_kv_cache = input_params.kv_cache_quant_mode.hasFp8KvCache();
|
|
if (input_params.kv_cache_quant_mode.hasKvCacheQuant())
|
|
{
|
|
params.kv_scale_orig_quant = input_params.kv_scale_orig_quant;
|
|
params.kv_scale_quant_orig = input_params.kv_scale_quant_orig;
|
|
}
|
|
|
|
params.stride = hidden_units + 2 * hidden_units_kv;
|
|
params.finished = const_cast<bool*>(input_params.finished);
|
|
|
|
params.cache_indir = input_params.cache_indir;
|
|
params.batch_size = input_params.inference_batch_size;
|
|
params.beam_width = input_params.beam_width;
|
|
params.memory_max_len = input_params.max_seq_len;
|
|
params.length_per_sample = input_params.sequence_lengths; // max_input_length + current output length
|
|
// timestep for shared memory size calculation and rotary embedding computation
|
|
params.timestep = input_params.step - 1;
|
|
params.num_heads = input_params.head_num;
|
|
params.num_kv_heads = input_params.kv_head_num;
|
|
params.hidden_size_per_head = input_params.size_per_head;
|
|
params.rotary_embedding_dim = input_params.rotary_embedding_dim;
|
|
params.rotary_embedding_base = input_params.rotary_embedding_base;
|
|
params.rotary_embedding_scale_type = input_params.rotary_embedding_scale_type;
|
|
params.rotary_embedding_scale = input_params.rotary_embedding_scale;
|
|
params.rotary_embedding_max_positions = input_params.rotary_embedding_max_positions;
|
|
params.position_embedding_type = input_params.position_embedding_type;
|
|
// Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
|
|
params.inv_sqrt_dh = 1.F / (sqrtf((float) params.hidden_size_per_head) * input_params.q_scaling);
|
|
|
|
params.relative_attention_bias = reinterpret_cast<const DataType*>(input_params.relative_attention_bias);
|
|
params.relative_attention_bias_stride = input_params.relative_attention_bias_stride;
|
|
params.max_distance = input_params.max_distance;
|
|
|
|
// The slope of linear position bias per head, e.g., ALiBi.
|
|
if (input_params.linear_bias_slopes != nullptr)
|
|
{
|
|
params.linear_bias_slopes = reinterpret_cast<const DataType*>(input_params.linear_bias_slopes);
|
|
}
|
|
params.input_lengths = input_params.input_lengths;
|
|
|
|
params.ia3_tasks = input_params.ia3_tasks;
|
|
params.ia3_key_weights = reinterpret_cast<const DataType*>(input_params.ia3_key_weights);
|
|
params.ia3_value_weights = reinterpret_cast<const DataType*>(input_params.ia3_value_weights);
|
|
|
|
if (input_params.quant_option.hasStaticActivationScaling())
|
|
{
|
|
params.qkv_scale_quant_orig = input_params.qkv_scale_out;
|
|
params.attention_out_scale_orig_quant = input_params.attention_out_scale;
|
|
}
|
|
|
|
params.multi_block_mode = input_params.multi_block_mode;
|
|
if (input_params.multi_block_mode)
|
|
{
|
|
params.max_seq_len_tile = input_params.max_seq_len_tile;
|
|
|
|
params.partial_out = reinterpret_cast<DataType*>(input_params.partial_out);
|
|
params.partial_sum = input_params.partial_sum;
|
|
params.partial_max = input_params.partial_max;
|
|
|
|
params.block_counter = input_params.block_counter;
|
|
}
|
|
|
|
params.multi_processor_count = input_params.multi_processor_count;
|
|
|
|
// cross attn
|
|
params.memory_length_per_sample = input_params.memory_length_per_sample;
|
|
sync_check_cuda_error();
|
|
|
|
masked_multihead_attention(params, input_params.kv_block_array, stream);
|
|
}
|
|
|
|
#define INSTANTIATE_MMHA_DISPATH(T_MMHA, T) \
|
|
template void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, false>&, \
|
|
const FusedQKVMaskedAttentionDispatchParams<T, KVLinearBuffer>&, cudaStream_t stream); \
|
|
template void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, true>&, \
|
|
const FusedQKVMaskedAttentionDispatchParams<T, KVLinearBuffer>&, cudaStream_t stream); \
|
|
template void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, false>&, \
|
|
const FusedQKVMaskedAttentionDispatchParams<T, KVBlockArray>&, cudaStream_t stream); \
|
|
template void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, true>&, \
|
|
const FusedQKVMaskedAttentionDispatchParams<T, KVBlockArray>&, cudaStream_t stream);
|
|
INSTANTIATE_MMHA_DISPATH(float, float)
|
|
INSTANTIATE_MMHA_DISPATH(uint16_t, half)
|
|
#ifdef ENABLE_BF16
|
|
INSTANTIATE_MMHA_DISPATH(__nv_bfloat16, __nv_bfloat16)
|
|
#endif
|
|
#undef INSTANTIATE_MMHA_DISPATH
|
|
|
|
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_heads, int head_size, int unidirectional,
|
|
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
|
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
|
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
|
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
|
|
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, int kv_cache_quant_mode,
|
|
bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type, bool paged_kv_cache,
|
|
int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length, bool qkv_bias_enabled,
|
|
bool cross_attention, int max_distance)
|
|
: mNumHeads(num_heads)
|
|
, mNumKVHeads(num_kv_heads)
|
|
, mHeadSize(head_size)
|
|
, mUnidirectional(unidirectional)
|
|
, mQScaling(q_scaling)
|
|
, mRotaryEmbeddingDim(rotary_embedding_dim)
|
|
, mRotaryEmbeddingBase(rotary_embedding_base)
|
|
, mRotaryEmbeddingScaleType(rotary_embedding_scale_type)
|
|
, mRotaryEmbeddingScale(rotary_embedding_scale)
|
|
, mRotaryEmbeddingMaxPositions(rotary_embedding_max_positions)
|
|
, mPositionEmbeddingType(position_embedding_type)
|
|
, mEnableContextFMHA(context_fmha_type != ContextFMHAType::DISABLED)
|
|
, mFMHAForceFP32Acc(context_fmha_type == ContextFMHAType::ENABLED_WITH_FP32_ACC || type == DataType::kBF16)
|
|
, mMaskType(mask_type)
|
|
, mType(type)
|
|
, mMultiBlockMode(multi_block_mode)
|
|
, mKVCacheQuantMode(kv_cache_quant_mode)
|
|
, mRemovePadding(remove_input_padding)
|
|
, mPagedKVCache(paged_kv_cache)
|
|
, mTokensPerBlock(tokens_per_block)
|
|
, mTpSize(tp_size)
|
|
, mTpRank(tp_rank)
|
|
, mMaxContextLength(max_context_length)
|
|
, mQKVBiasEnabled(qkv_bias_enabled)
|
|
, mCrossAttention(cross_attention)
|
|
, mMaxDistance(max_distance)
|
|
{
|
|
mEnableContextFMHA = mEnableContextFMHA && (mType == DataType::kHALF || mType == DataType::kBF16);
|
|
TLLM_CHECK(isRoPE() == (rotary_embedding_dim != 0));
|
|
TLLM_CHECK_WITH_INFO((tc::getSMVersion() >= 80) || (mType != DataType::kBF16),
|
|
"Unsupported data type, pre SM 80 GPUs do not support bfloat16");
|
|
}
|
|
|
|
const int GPTAttentionPluginCommon::getHeadSize(bool checkInit) const
|
|
{
|
|
if (checkInit)
|
|
{
|
|
TLLM_CHECK_WITH_INFO(mHeadSize > 0, "Trying to read mHeadSize before it's been initialized");
|
|
}
|
|
return mHeadSize;
|
|
}
|
|
|
|
// Parameterized constructor
|
|
GPTAttentionPluginCommon::GPTAttentionPluginCommon(const void* data, size_t length)
|
|
{
|
|
const char *d = reinterpret_cast<const char*>(data), *a = d;
|
|
unsigned int kvCacheQuantMode;
|
|
|
|
read(d, mNumHeads);
|
|
read(d, mNumKVHeads);
|
|
read(d, mHeadSize);
|
|
read(d, mUnidirectional);
|
|
read(d, mQScaling);
|
|
read(d, mPositionEmbeddingType);
|
|
read(d, mRotaryEmbeddingDim);
|
|
read(d, mRotaryEmbeddingBase);
|
|
read(d, mRotaryEmbeddingScaleType);
|
|
read(d, mRotaryEmbeddingScale);
|
|
read(d, mRotaryEmbeddingMaxPositions);
|
|
read(d, mTpSize);
|
|
read(d, mTpRank);
|
|
read(d, mEnableContextFMHA);
|
|
read(d, mFMHAForceFP32Acc);
|
|
read(d, mMultiBlockMode);
|
|
read(d, kvCacheQuantMode);
|
|
read(d, mRemovePadding);
|
|
read(d, mMaskType);
|
|
read(d, mPagedKVCache);
|
|
read(d, mTokensPerBlock);
|
|
read(d, mType);
|
|
read(d, mMaxContextLength);
|
|
read(d, mQKVBiasEnabled);
|
|
read(d, mCrossAttention);
|
|
read(d, mMaxDistance);
|
|
|
|
mKVCacheQuantMode = tc::QuantMode(kvCacheQuantMode);
|
|
|
|
TLLM_CHECK(d == a + length);
|
|
TLLM_CHECK_WITH_INFO((tc::getSMVersion() >= 80) || (mType != DataType::kBF16),
|
|
"Unsupported data type, pre SM 80 GPUs do not support bfloat16");
|
|
}
|
|
|
|
size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(
|
|
DataType type, int32_t nbReq, int32_t input_seq_length, int32_t cross_qkv_length) const noexcept
|
|
{
|
|
const int local_hidden_units_qo = mNumHeads * getHeadSize();
|
|
const int local_hidden_units_kv = mNumKVHeads * getHeadSize();
|
|
|
|
auto const size = tensorrt_llm::runtime::BufferDataType(type).getSize();
|
|
|
|
size_t context_workspace_size = 0;
|
|
|
|
const int batch_size = nbReq;
|
|
const size_t attention_mask_size = mEnableContextFMHA
|
|
? 0
|
|
: size * batch_size * input_seq_length * (isCrossAttention() ? cross_qkv_length : input_seq_length);
|
|
const size_t cu_seqlens_size = sizeof(int) * (batch_size + 1);
|
|
const size_t q_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_length * local_hidden_units_qo;
|
|
const size_t k_buf_2_size = mEnableContextFMHA
|
|
? 0
|
|
: size * batch_size * (isCrossAttention() ? cross_qkv_length : input_seq_length) * local_hidden_units_kv;
|
|
const size_t v_buf_2_size = mEnableContextFMHA
|
|
? 0
|
|
: size * batch_size * (isCrossAttention() ? cross_qkv_length : input_seq_length) * local_hidden_units_kv;
|
|
const size_t qk_buf_size = mEnableContextFMHA
|
|
? 0
|
|
: size * batch_size * mNumHeads * input_seq_length * (isCrossAttention() ? cross_qkv_length : input_seq_length);
|
|
const size_t qkv_buf_2_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_length * local_hidden_units_qo;
|
|
const size_t qk_buf_float_size = mEnableContextFMHA ? 0
|
|
: sizeof(float) * batch_size * mNumHeads * input_seq_length
|
|
* (isCrossAttention() ? cross_qkv_length : input_seq_length);
|
|
const size_t padding_offset_size = sizeof(int) * batch_size * input_seq_length;
|
|
|
|
const int NUM_BUFFERS = 10;
|
|
size_t workspaces[NUM_BUFFERS];
|
|
workspaces[0] = CUBLAS_WORKSPACE_SIZE;
|
|
workspaces[1] = attention_mask_size;
|
|
workspaces[2] = cu_seqlens_size;
|
|
workspaces[3] = q_buf_2_size;
|
|
workspaces[4] = k_buf_2_size;
|
|
workspaces[5] = v_buf_2_size;
|
|
workspaces[6] = qk_buf_size;
|
|
workspaces[7] = qkv_buf_2_size;
|
|
workspaces[8] = qk_buf_float_size;
|
|
workspaces[9] = padding_offset_size;
|
|
context_workspace_size = tensorrt_llm::plugins::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
|
|
return context_workspace_size;
|
|
}
|
|
|
|
size_t GPTAttentionPluginCommon::getWorkspaceSizeForGeneration(DataType type, int32_t total_num_seq) const noexcept
|
|
{
|
|
const int local_hidden_units_qo = mNumHeads * getHeadSize();
|
|
const int local_hidden_units_kv = mNumKVHeads * getHeadSize();
|
|
|
|
auto const size = tensorrt_llm::runtime::BufferDataType(type).getSize();
|
|
|
|
size_t context_workspace_size = 0;
|
|
size_t generation_workspace_size = 0;
|
|
|
|
const int batch_beam = total_num_seq;
|
|
int32_t const maxSeqLenTile = getMaxSeqLenTile(size);
|
|
|
|
const size_t partial_out_size = mMultiBlockMode ? size * batch_beam * mNumHeads * mHeadSize * maxSeqLenTile : 0;
|
|
const size_t partial_sum_size = mMultiBlockMode ? sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile : 0;
|
|
const size_t partial_max_size = mMultiBlockMode ? sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile : 0;
|
|
const size_t block_counter_size = mMultiBlockMode ? sizeof(int) * batch_beam * mNumHeads : 0;
|
|
|
|
const int NUM_BUFFERS = 4;
|
|
size_t workspaces[NUM_BUFFERS];
|
|
workspaces[0] = partial_out_size;
|
|
workspaces[1] = partial_sum_size;
|
|
workspaces[2] = partial_max_size;
|
|
workspaces[3] = block_counter_size;
|
|
generation_workspace_size = tensorrt_llm::plugins::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
|
|
return generation_workspace_size;
|
|
}
|
|
|
|
int GPTAttentionPluginCommon::getMaxSeqLenTile(int elemSize) const
|
|
{
|
|
if (mMultiBlockMode)
|
|
{
|
|
const int threads_per_value = pow2roundup(getHeadSize()) * elemSize / 16;
|
|
|
|
// max_seq_len_tile to make sure: seq_len_tile * threads_per_value <= threads_per_block (for
|
|
// multi_block_mode)
|
|
const int max_seq_len_tile
|
|
= 256 / threads_per_value; // for allocate partial output results memory. Regardless to THDS_PER_BLOCK
|
|
// (which may be smaller than 256 like being 128)
|
|
return max_seq_len_tile;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
template <typename T, typename KVCacheBuffer>
|
|
int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCacheBuffer>& params, cudaStream_t stream)
|
|
{
|
|
const int num_heads = mNumHeads;
|
|
const int num_kv_heads = mNumKVHeads;
|
|
const int head_size = getHeadSize();
|
|
const int local_hidden_units_qo = num_heads * head_size;
|
|
const int local_hidden_units_kv = num_kv_heads * head_size;
|
|
const PositionEmbeddingType position_embedding_type = mPositionEmbeddingType;
|
|
const float q_scaling = mQScaling;
|
|
const bool* finished = nullptr;
|
|
const bool has_ia3 = false;
|
|
|
|
KVCacheBuffer kv_cache_buffer;
|
|
const auto elem_size = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
|
|
if (mPagedKVCache)
|
|
{
|
|
using BufferDataType = typename KVCacheBufferDataType<KVCacheBuffer>::Type;
|
|
kv_cache_buffer = KVCacheBuffer(
|
|
params.batch_size, params.max_blocks_per_sequence, mTokensPerBlock, num_kv_heads * head_size * elem_size);
|
|
kv_cache_buffer.data = reinterpret_cast<BufferDataType*>(params.block_pointers);
|
|
}
|
|
else
|
|
{
|
|
using BufferDataType = typename KVCacheBufferDataType<KVCacheBuffer>::Type;
|
|
kv_cache_buffer = KVCacheBuffer(params.batch_size, 1,
|
|
isCrossAttention() ? params.cross_qkv_length : params.max_seq_length, num_kv_heads * head_size * elem_size);
|
|
kv_cache_buffer.data = reinterpret_cast<BufferDataType*>(params.key_value_cache);
|
|
}
|
|
|
|
const auto quant_option = tc::QuantMode::fromDescription();
|
|
const float* qkv_scale_out = nullptr;
|
|
const float* attention_out_scale = nullptr;
|
|
|
|
const int* ia3_tasks = nullptr;
|
|
const T* ia3_key_weights = nullptr;
|
|
const T* ia3_value_weights = nullptr;
|
|
|
|
const bool multi_block_mode = false;
|
|
const int max_seq_len_tile = 0;
|
|
T* partial_out = nullptr;
|
|
float* partial_sum = nullptr;
|
|
float* partial_max = nullptr;
|
|
int* block_counter = nullptr;
|
|
|
|
auto cublasHandle = mCublasWrapper->getCublasHandle();
|
|
TLLM_CUDA_CHECK(cublasSetStream(cublasHandle, stream));
|
|
mCublasWrapper->setStream(stream);
|
|
mCublasWrapper->setWorkspace(params.workspace);
|
|
if constexpr (std::is_same_v<T, half>)
|
|
{
|
|
mCublasWrapper->setFP16GemmConfig();
|
|
}
|
|
else if constexpr (std::is_same_v<T, float>)
|
|
{
|
|
mCublasWrapper->setFP32GemmConfig();
|
|
}
|
|
#ifdef ENABLE_BF16
|
|
else if constexpr (std::is_same_v<T, __nv_bfloat16>)
|
|
{
|
|
mCublasWrapper->setBF16GemmConfig();
|
|
}
|
|
#endif
|
|
|
|
const size_t attention_mask_size = mEnableContextFMHA ? 0
|
|
: sizeof(T) * params.batch_size * params.input_seq_length
|
|
* (isCrossAttention() ? params.cross_qkv_length : params.input_seq_length);
|
|
const size_t cu_seqlens_size = sizeof(int) * (params.batch_size + 1);
|
|
const size_t q_buf_2_size
|
|
= mEnableContextFMHA ? 0 : sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_qo;
|
|
const size_t k_buf_2_size = mEnableContextFMHA ? 0
|
|
: sizeof(T) * params.batch_size
|
|
* (isCrossAttention() ? params.cross_qkv_length : params.input_seq_length) * local_hidden_units_kv;
|
|
const size_t v_buf_2_size = mEnableContextFMHA ? 0
|
|
: sizeof(T) * params.batch_size
|
|
* (isCrossAttention() ? params.cross_qkv_length : params.input_seq_length) * local_hidden_units_kv;
|
|
const size_t qk_buf_size = mEnableContextFMHA ? 0
|
|
: sizeof(T) * params.batch_size * mNumHeads * params.input_seq_length
|
|
* (isCrossAttention() ? params.cross_qkv_length : params.input_seq_length);
|
|
const size_t qkv_buf_2_size
|
|
= mEnableContextFMHA ? 0 : sizeof(T) * params.batch_size * params.input_seq_length * local_hidden_units_qo;
|
|
const size_t qk_buf_float_size = mEnableContextFMHA ? 0
|
|
: sizeof(float) * params.batch_size * mNumHeads
|
|
* params.input_seq_length * (isCrossAttention() ? params.cross_qkv_length : params.input_seq_length);
|
|
const size_t padding_offset_size = sizeof(int) * params.batch_size * params.input_seq_length;
|
|
|
|
const bool is_qk_buf_float_ = true;
|
|
|
|
// Workspace pointer shift
|
|
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(params.workspace);
|
|
size_t offset = CUBLAS_WORKSPACE_SIZE;
|
|
|
|
T* attention_mask = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, attention_mask_size));
|
|
int* cu_seqlens = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, cu_seqlens_size));
|
|
T* q_buf_2_ = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, q_buf_2_size));
|
|
T* k_buf_2_ = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, k_buf_2_size));
|
|
T* v_buf_2_ = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, v_buf_2_size));
|
|
T* qk_buf_ = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, qk_buf_size));
|
|
T* qkv_buf_2_ = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, qkv_buf_2_size));
|
|
float* qk_buf_float_ = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, qk_buf_float_size));
|
|
int* padding_offset = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, padding_offset_size));
|
|
|
|
// build attention_mask, cu_seqlens, and padding_offset tensors
|
|
BuildDecoderInfoParams<T> decoder_params;
|
|
memset(&decoder_params, 0, sizeof(decoder_params));
|
|
decoder_params.seqOffsets = cu_seqlens;
|
|
decoder_params.paddingOffsets = padding_offset;
|
|
decoder_params.attentionMask = attention_mask;
|
|
decoder_params.seqLengths = params.context_lengths;
|
|
decoder_params.batchSize = params.batch_size;
|
|
decoder_params.maxSeqLength = params.input_seq_length;
|
|
decoder_params.numTokens = params.num_tokens;
|
|
decoder_params.attentionMaskType = mMaskType;
|
|
invokeBuildDecoderInfo(decoder_params, stream);
|
|
sync_check_cuda_error();
|
|
|
|
// In cross attention context phase, the attention mask should be a matrix of all ones.
|
|
// We reassign attention_mask to override what previous invokeBuildDecoderInfo() does
|
|
if (isCrossAttention())
|
|
{
|
|
std::vector<T> h_attention_mask(params.batch_size * params.cross_qkv_length * params.input_seq_length, 1.);
|
|
cudaMemcpyAsync(attention_mask, h_attention_mask.data(),
|
|
sizeof(T) * params.batch_size * params.cross_qkv_length * params.input_seq_length, cudaMemcpyHostToDevice,
|
|
stream);
|
|
}
|
|
|
|
// write KV to cache
|
|
const KvCacheDataType cache_type = mKVCacheQuantMode.hasInt8KvCache()
|
|
? KvCacheDataType::INT8
|
|
: (mKVCacheQuantMode.hasFp8KvCache() ? KvCacheDataType::FP8 : KvCacheDataType::BASE);
|
|
|
|
const cudaDataType_t gemm_data_type = tc::CudaDataType<T>::value;
|
|
const int attention_seq_len_1 = params.input_seq_length; // q length
|
|
const int attention_seq_len_2 = isCrossAttention() ? params.cross_qkv_length : params.input_seq_length; // kv length
|
|
|
|
// If the model has relative attentiona bias, q scaling should be applied in QK gemm stage and use 1 in
|
|
// softamax stage (because to get softmax[scale(Q*K) + rel pos bias] here, q_scaling can't be applied during
|
|
// softmax phase by qk_scale); otherwise, use 1 in gemm stage and apply scaling in softmax stage
|
|
const float qk_scale
|
|
= 1.0f / (sqrtf(getHeadSize() * 1.0f) * q_scaling); // q_scaling in denominator. by default q_scaling =1.0f
|
|
const float qk_scale_gemm = isRelativePosition() ? qk_scale : 1.0f;
|
|
const T qk_scale_softmax = static_cast<T>(isRelativePosition() ? 1.0f : qk_scale);
|
|
|
|
// in context phase, currently FMHA runner has two restrictions:
|
|
// 1. only apply to self attention. If want fused multi-head cross attention, FMHCA kernels and runner is needed
|
|
// 2. doesn't apply to MHA with relative attention bias, i.e. softmax(QK + bias) * V
|
|
if (mEnableContextFMHA && !isCrossAttention() && !isRelativePosition())
|
|
{
|
|
invokeApplyBiasRopeUpdateKVCache(const_cast<T*>(params.attention_input), kv_cache_buffer,
|
|
const_cast<T*>(params.qkv_bias), params.context_lengths, mRemovePadding ? padding_offset : nullptr,
|
|
params.batch_size, params.input_seq_length, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(),
|
|
mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType, mRotaryEmbeddingScale,
|
|
mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0, cache_type,
|
|
params.kv_scale_orig_quant, stream);
|
|
mFMHARunner->setup(params.batch_size, params.input_seq_length, params.num_tokens, isALiBi(), isAliBiWithScale(),
|
|
mTpSize, mTpRank);
|
|
mFMHARunner->run(const_cast<T*>(params.attention_input), cu_seqlens, params.context_buf, stream);
|
|
}
|
|
else
|
|
{
|
|
// FIXME: a temporary solution to make sure the padding part of key/value buffer is 0
|
|
// NOTE: pointer subtraction is used below since there could be some extra gap due to alignment.
|
|
// Otherwise, we could do cudaMemsetAsync(k_buf_2_, 0, k_buf_2_size + v_buf_2_size, stream);
|
|
cudaMemsetAsync(k_buf_2_, 0,
|
|
reinterpret_cast<int8_t*>(v_buf_2_) - reinterpret_cast<int8_t*>(k_buf_2_) + v_buf_2_size, stream);
|
|
|
|
if (!isCrossAttention())
|
|
{
|
|
// self attention, write to Q/K/V
|
|
invokeAddFusedQKVBiasTranspose(q_buf_2_, k_buf_2_, v_buf_2_, const_cast<T*>(params.attention_input),
|
|
const_cast<T*>(params.qkv_bias), params.context_lengths, mRemovePadding ? padding_offset : nullptr,
|
|
params.batch_size, params.input_seq_length, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(),
|
|
mEnableContextFMHA, mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType,
|
|
mRotaryEmbeddingScale, mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0,
|
|
stream);
|
|
}
|
|
else
|
|
{
|
|
// cross attention, write Q from self QKV, write KV from cross QKV
|
|
// kernel modifed accordingly to handle nullptr buffer
|
|
invokeAddFusedQKVBiasTranspose(q_buf_2_, (T*) nullptr, (T*) nullptr, const_cast<T*>(params.attention_input),
|
|
const_cast<T*>(params.qkv_bias), params.context_lengths, mRemovePadding ? padding_offset : nullptr,
|
|
params.batch_size, params.input_seq_length, params.num_tokens, mNumHeads, mNumKVHeads, getHeadSize(),
|
|
mEnableContextFMHA, mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType,
|
|
mRotaryEmbeddingScale, mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0,
|
|
stream);
|
|
invokeAddFusedQKVBiasTranspose((T*) nullptr, k_buf_2_, v_buf_2_, const_cast<T*>(params.cross_qkv),
|
|
const_cast<T*>(params.qkv_bias), params.encoder_input_lengths,
|
|
mRemovePadding ? padding_offset : nullptr, params.batch_size, params.cross_qkv_length,
|
|
params.num_encoder_tokens, mNumHeads, mNumKVHeads, getHeadSize(), mEnableContextFMHA,
|
|
mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType, mRotaryEmbeddingScale,
|
|
mRotaryEmbeddingMaxPositions, position_embedding_type, (float*) nullptr, 0, stream);
|
|
}
|
|
sync_check_cuda_error();
|
|
|
|
// write KV to cach
|
|
invokeTranspose4dBatchMajor(k_buf_2_, v_buf_2_, kv_cache_buffer, params.batch_size,
|
|
isCrossAttention() ? params.cross_qkv_length : params.input_seq_length,
|
|
isCrossAttention() ? params.cross_qkv_length : params.max_seq_length, getHeadSize(), mNumKVHeads,
|
|
cache_type, params.kv_scale_orig_quant,
|
|
isCrossAttention() ? params.encoder_input_lengths : params.context_lengths, stream);
|
|
sync_check_cuda_error();
|
|
|
|
const T* linear_bias_slopes = isALiBi() ? params.alibi_slopes : nullptr;
|
|
const T* relative_attention_bias = isRelativePosition() ? params.relative_attention_bias : nullptr;
|
|
const int relative_attention_bias_stride = isRelativePosition() ? params.relative_attention_bias_stride : 0;
|
|
const int max_distance = mMaxDistance;
|
|
cudaDataType_t gemm_out_data_type = is_qk_buf_float_ ? CUDA_R_32F : gemm_data_type;
|
|
void* gemm_out_buf_ = is_qk_buf_float_ ? static_cast<void*>(qk_buf_float_) : static_cast<void*>(qk_buf_);
|
|
if (mNumKVHeads == 1) // MQA
|
|
{
|
|
// Attn_weight[b, h*s_q, s_k] = Q[b, h*s_q, d] * K'[b, d, s_k]
|
|
// Attn_weight'[b, s_k, h*s_q] = K[b, s_k, d] * Q'[b, d, h*s_q]
|
|
mCublasWrapper->stridedBatchedGemm(CUBLAS_OP_T, CUBLAS_OP_N,
|
|
attention_seq_len_2, // n
|
|
attention_seq_len_1 * mNumHeads, // m
|
|
getHeadSize(), // k
|
|
qk_scale_gemm, k_buf_2_, gemm_data_type,
|
|
getHeadSize(), // k
|
|
attention_seq_len_2 * getHeadSize(), // n * k
|
|
q_buf_2_, gemm_data_type,
|
|
getHeadSize(), // k
|
|
attention_seq_len_1 * mNumHeads * getHeadSize(), // m * k
|
|
0.0f, gemm_out_buf_, gemm_out_data_type,
|
|
attention_seq_len_2, // n
|
|
attention_seq_len_1 * mNumHeads * attention_seq_len_2, // m * n
|
|
params.batch_size, // global batch size
|
|
CUDA_R_32F);
|
|
}
|
|
else if (mNumKVHeads == mNumHeads) // MHA
|
|
{
|
|
// Attn_weight[b*h, s_q, s_k] = Q[b*h, s_q, d] * K'[b*h, d, s_k]
|
|
// Attn_weight'[b*h, s_k, s_q] = K[b*h, s_k, d] * Q'[b*h, d, s_q]
|
|
mCublasWrapper->stridedBatchedGemm(CUBLAS_OP_T, CUBLAS_OP_N,
|
|
attention_seq_len_2, // n
|
|
attention_seq_len_1, // m
|
|
getHeadSize(), // k
|
|
qk_scale_gemm, k_buf_2_, gemm_data_type,
|
|
getHeadSize(), // k
|
|
attention_seq_len_2 * getHeadSize(), // n * k
|
|
q_buf_2_, gemm_data_type,
|
|
getHeadSize(), // k
|
|
attention_seq_len_1 * getHeadSize(), // m * k
|
|
0.0f, gemm_out_buf_, gemm_out_data_type,
|
|
attention_seq_len_2, // n
|
|
attention_seq_len_2 * attention_seq_len_1,
|
|
params.batch_size * mNumHeads, // global batch size
|
|
CUDA_R_32F);
|
|
}
|
|
else // GQA
|
|
{
|
|
// Some number of contiguous Q heads will share the same K/V head
|
|
// Since the KV stride is NOT fixed for all Q, we have 2 options:
|
|
// 1. Loop over stridedBatchedGemm for each KV head. (multiple API calls/cuda kernels)
|
|
// 2. Calculate the pointers and use batchedGemm() (extra device memory) ::TODO::
|
|
const int num_qheads_per_kv_head = mNumHeads / mNumKVHeads;
|
|
for (int ki = 0; ki < mNumKVHeads; ++ki)
|
|
{
|
|
T* qptr = q_buf_2_ + (ki * num_qheads_per_kv_head * attention_seq_len_1 * getHeadSize());
|
|
T* kptr = k_buf_2_ + (ki * attention_seq_len_2 * getHeadSize());
|
|
const int qk_offset = ki * attention_seq_len_1 * num_qheads_per_kv_head * attention_seq_len_2;
|
|
void* qkptr = is_qk_buf_float_ ? static_cast<void*>(qk_buf_float_ + qk_offset)
|
|
: static_cast<void*>(qk_buf_ + qk_offset);
|
|
mCublasWrapper->stridedBatchedGemm(CUBLAS_OP_T, CUBLAS_OP_N,
|
|
attention_seq_len_2, // n
|
|
attention_seq_len_1 * num_qheads_per_kv_head, // m
|
|
getHeadSize(), // k
|
|
qk_scale_gemm, kptr, gemm_data_type,
|
|
getHeadSize(), // k
|
|
mNumKVHeads * attention_seq_len_2 * getHeadSize(), // n * k
|
|
qptr, gemm_data_type,
|
|
getHeadSize(), // k
|
|
attention_seq_len_1 * mNumHeads * getHeadSize(), // m * k
|
|
0.0f, qkptr, gemm_out_data_type,
|
|
attention_seq_len_2, // n
|
|
attention_seq_len_1 * mNumHeads * attention_seq_len_2, // m * n
|
|
params.batch_size, // global batch size
|
|
CUDA_R_32F);
|
|
}
|
|
}
|
|
|
|
// add relative position bias
|
|
if (isRelativePosition())
|
|
{
|
|
// Add relative_attention_bias
|
|
// QK is (batch_size, local_head_num, q_length, k_length), relative_attention_bias is (1, local_head_num,
|
|
// max_output_len + 1, max_output_len + 1).
|
|
// broadcast along 1st dim. max_seq_len is already max_output_len + 1.
|
|
// In implicit mode, relative_attention_bias is relative_attention_table [num_heads, num_buckets], with
|
|
// necessary params (max_distance, num_buckets) passed at the end
|
|
invokeAddRelativeAttentionBiasUnaligned(qk_buf_float_, relative_attention_bias, params.batch_size,
|
|
mNumHeads, attention_seq_len_1, isCrossAttention() ? params.cross_qkv_length : params.max_seq_length,
|
|
stream, max_distance > 0, relative_attention_bias_stride, max_distance, true /* bidirectional */);
|
|
}
|
|
|
|
if (is_qk_buf_float_ == true)
|
|
{
|
|
MaskedSoftmaxParam<T, float> param;
|
|
param.attention_score = qk_buf_; // (batch_size, head_num, q_length, k_length)
|
|
param.qk = qk_buf_float_; // (batch_size, head_num, q_length, k_length)
|
|
param.attention_mask = attention_mask; // (batch_size, q_length, k_length)
|
|
param.batch_size = params.batch_size;
|
|
param.q_length = attention_seq_len_1;
|
|
param.k_length = attention_seq_len_2;
|
|
param.num_heads = mNumHeads;
|
|
param.qk_scale = qk_scale_softmax;
|
|
param.linear_bias_slopes = const_cast<T*>(linear_bias_slopes); // (head_num,), optional
|
|
invokeMaskedSoftmax(param, stream);
|
|
}
|
|
else
|
|
{
|
|
MaskedSoftmaxParam<T, T> param;
|
|
param.attention_score = qk_buf_; // (batch_size, head_num, q_length, k_length)
|
|
param.qk = qk_buf_; // (batch_size, head_num, q_length, k_length)
|
|
param.attention_mask = attention_mask; // (batch_size, q_length, k_length)
|
|
param.batch_size = params.batch_size;
|
|
param.q_length = attention_seq_len_1;
|
|
param.k_length = attention_seq_len_2;
|
|
param.num_heads = mNumHeads;
|
|
param.qk_scale = qk_scale_softmax;
|
|
param.linear_bias_slopes = const_cast<T*>(linear_bias_slopes); // (head_num,), optional
|
|
invokeMaskedSoftmax(param, stream);
|
|
}
|
|
|
|
if (mNumKVHeads == 1)
|
|
{
|
|
// Attn_weight[b, h*s_q, s_k]
|
|
// O[b, h*s_q, d] = Attn_weight[b, h*s_q, s_k] * V[b, s_k, d]
|
|
// O'[b, d, h*s_q] = V'[b, d, s_k] * Attn_weight'[b, s_k, h*s_q]
|
|
mCublasWrapper->stridedBatchedGemm(CUBLAS_OP_N, CUBLAS_OP_N,
|
|
getHeadSize(), // n
|
|
mNumHeads * attention_seq_len_1, // m
|
|
attention_seq_len_2, // k
|
|
v_buf_2_,
|
|
getHeadSize(), // n
|
|
getHeadSize() * attention_seq_len_2, // n * k
|
|
qk_buf_,
|
|
attention_seq_len_2, // k
|
|
attention_seq_len_2 * mNumHeads * attention_seq_len_1, // m * k
|
|
qkv_buf_2_,
|
|
getHeadSize(), // n
|
|
getHeadSize() * mNumHeads * attention_seq_len_1, // n * m
|
|
params.batch_size // global batch size
|
|
);
|
|
}
|
|
else if (mNumKVHeads == mNumHeads) // MHA
|
|
{
|
|
// O[b*h, s_q, d] = Attn_weight[b*h, s_q, s_k] * V[b*h, s_k, d]
|
|
// O'[b*h, d, s_q] = V'[b*h, d, s_k] * Attn_weight'[b*h, s_k, s_q]
|
|
mCublasWrapper->stridedBatchedGemm(CUBLAS_OP_N, CUBLAS_OP_N, getHeadSize(), attention_seq_len_1,
|
|
attention_seq_len_2, v_buf_2_, getHeadSize(), attention_seq_len_2 * getHeadSize(), qk_buf_,
|
|
attention_seq_len_2, attention_seq_len_1 * attention_seq_len_2, qkv_buf_2_, getHeadSize(),
|
|
attention_seq_len_1 * getHeadSize(), params.batch_size * mNumHeads);
|
|
}
|
|
else // GQA
|
|
{
|
|
// Attn_weight[b, h*s_q, s_k]
|
|
// O[b, h*s_q, d] = Attn_weight[b, h*s_q, s_k] * V[b, s_k, d]
|
|
// O'[b, d, h*s_q] = V'[b, d, s_k] * Attn_weight'[b, s_k, h*s_q]
|
|
const int num_qheads_per_kv_head = mNumHeads / mNumKVHeads;
|
|
for (int ki = 0; ki < mNumKVHeads; ++ki)
|
|
{
|
|
T* qkptr = qk_buf_ + (ki * num_qheads_per_kv_head * attention_seq_len_1 * attention_seq_len_2);
|
|
T* vptr = v_buf_2_ + (ki * attention_seq_len_2 * getHeadSize());
|
|
T* qkvptr = qkv_buf_2_ + (ki * attention_seq_len_1 * num_qheads_per_kv_head * getHeadSize());
|
|
mCublasWrapper->stridedBatchedGemm(CUBLAS_OP_N, CUBLAS_OP_N,
|
|
getHeadSize(), // n
|
|
num_qheads_per_kv_head * attention_seq_len_1, // m
|
|
attention_seq_len_2, // k
|
|
vptr,
|
|
getHeadSize(), // n
|
|
mNumKVHeads * getHeadSize() * attention_seq_len_2, // n * k
|
|
qkptr,
|
|
attention_seq_len_2, // k
|
|
attention_seq_len_2 * mNumHeads * attention_seq_len_1, // m * k
|
|
qkvptr,
|
|
getHeadSize(), // n
|
|
getHeadSize() * mNumHeads * attention_seq_len_1, // n * m
|
|
params.batch_size // global batch size
|
|
);
|
|
}
|
|
}
|
|
|
|
if (!mRemovePadding)
|
|
{
|
|
invokeTransposeQKV(params.context_buf, qkv_buf_2_, params.batch_size, attention_seq_len_1, mNumHeads,
|
|
getHeadSize(), (float*) nullptr, 0, stream);
|
|
}
|
|
else
|
|
{
|
|
invokeTransposeAttentionOutRemovePadding(qkv_buf_2_, params.context_buf, params.num_tokens,
|
|
params.batch_size, attention_seq_len_1, mNumHeads, getHeadSize(), padding_offset, (float*) nullptr, 0,
|
|
stream);
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
template int GPTAttentionPluginCommon::enqueueContext<half, KVLinearBuffer>(
|
|
const EnqueueContextParams<half, KVLinearBuffer>& params, cudaStream_t stream);
|
|
|
|
template int GPTAttentionPluginCommon::enqueueContext<float, KVLinearBuffer>(
|
|
const EnqueueContextParams<float, KVLinearBuffer>& params, cudaStream_t stream);
|
|
|
|
#ifdef ENABLE_BF16
|
|
template int GPTAttentionPluginCommon::enqueueContext<__nv_bfloat16, KVLinearBuffer>(
|
|
const EnqueueContextParams<__nv_bfloat16, KVLinearBuffer>& params, cudaStream_t stream);
|
|
#endif
|
|
|
|
template int GPTAttentionPluginCommon::enqueueContext<half, KVBlockArray>(
|
|
const EnqueueContextParams<half, KVBlockArray>& params, cudaStream_t stream);
|
|
|
|
template int GPTAttentionPluginCommon::enqueueContext<float, KVBlockArray>(
|
|
const EnqueueContextParams<float, KVBlockArray>& params, cudaStream_t stream);
|
|
|
|
#ifdef ENABLE_BF16
|
|
template int GPTAttentionPluginCommon::enqueueContext<__nv_bfloat16, KVBlockArray>(
|
|
const EnqueueContextParams<__nv_bfloat16, KVBlockArray>& params, cudaStream_t stream);
|
|
#endif
|
|
|
|
template <typename T, typename KVCacheBuffer>
|
|
int GPTAttentionPluginCommon::enqueueGeneration(
|
|
const EnqueueGenerationParams<T, KVCacheBuffer>& params, cudaStream_t stream)
|
|
{
|
|
const int step = params.past_kv_length + 1;
|
|
|
|
const int num_heads = mNumHeads;
|
|
const int num_kv_heads = mNumKVHeads;
|
|
const int head_size = getHeadSize();
|
|
const int local_hidden_units_qo = num_heads * head_size;
|
|
const int local_hidden_units_kv = num_kv_heads * head_size;
|
|
const PositionEmbeddingType position_embedding_type = mPositionEmbeddingType;
|
|
const float q_scaling = mQScaling;
|
|
const T* relative_attention_bias = isRelativePosition() ? params.relative_attention_bias : nullptr;
|
|
const int relative_attention_bias_stride = isRelativePosition() ? params.relative_attention_bias_stride : 0;
|
|
const int max_distance = mMaxDistance;
|
|
const bool* finished = nullptr;
|
|
const bool has_ia3 = false;
|
|
|
|
const auto quant_option = tc::QuantMode::fromDescription();
|
|
const float* qkv_scale_out = nullptr;
|
|
const float* attention_out_scale = nullptr;
|
|
|
|
const int* ia3_tasks = nullptr;
|
|
const T* ia3_key_weights = nullptr;
|
|
const T* ia3_value_weights = nullptr;
|
|
|
|
const bool multi_block_mode = false;
|
|
int32_t const batch_beam = params.beam_width * params.num_requests;
|
|
|
|
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(params.workspace);
|
|
size_t offset = 0;
|
|
int32_t const maxSeqLenTile = getMaxSeqLenTile(sizeof(T));
|
|
const size_t partial_out_size
|
|
= mMultiBlockMode ? sizeof(T) * batch_beam * mNumHeads * mHeadSize * maxSeqLenTile : 0;
|
|
const size_t partial_sum_size = mMultiBlockMode ? sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile : 0;
|
|
const size_t partial_max_size = mMultiBlockMode ? sizeof(float) * batch_beam * mNumHeads * maxSeqLenTile : 0;
|
|
const size_t block_counter_size = mMultiBlockMode ? sizeof(int) * batch_beam * mNumHeads : 0;
|
|
|
|
// Workspace pointer shift
|
|
T* partial_out = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, partial_out_size));
|
|
float* partial_sum = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, partial_sum_size));
|
|
float* partial_max = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, partial_max_size));
|
|
int* block_counter = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, block_counter_size));
|
|
if (mMultiBlockMode)
|
|
{
|
|
TLLM_CUDA_CHECK(cudaMemsetAsync(block_counter, 0, block_counter_size, stream));
|
|
}
|
|
|
|
KVCacheBuffer kv_cache_buffer;
|
|
const auto elem_size = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
|
|
if (mPagedKVCache)
|
|
{
|
|
using BufferDataType = typename KVCacheBufferDataType<KVCacheBuffer>::Type;
|
|
kv_cache_buffer = KVCacheBuffer(
|
|
batch_beam, params.max_blocks_per_sequence, mTokensPerBlock, num_kv_heads * head_size * elem_size);
|
|
kv_cache_buffer.data = reinterpret_cast<BufferDataType*>(params.block_pointers);
|
|
}
|
|
else
|
|
{
|
|
using BufferDataType = typename KVCacheBufferDataType<KVCacheBuffer>::Type;
|
|
kv_cache_buffer = KVCacheBuffer(batch_beam, 1, params.max_seq_length, num_kv_heads * head_size * elem_size);
|
|
kv_cache_buffer.data = reinterpret_cast<BufferDataType*>(params.key_value_cache);
|
|
}
|
|
sync_check_cuda_error();
|
|
|
|
FusedQKVMaskedAttentionDispatchParams<T, KVCacheBuffer> dispatch_params;
|
|
memset(&dispatch_params, 0, sizeof(dispatch_params));
|
|
dispatch_params.qkv_buf = params.attention_input;
|
|
dispatch_params.qkv_bias = params.qkv_bias;
|
|
dispatch_params.relative_attention_bias = relative_attention_bias;
|
|
dispatch_params.relative_attention_bias_stride = relative_attention_bias_stride;
|
|
dispatch_params.max_distance = max_distance;
|
|
dispatch_params.cache_indir = params.cache_indir;
|
|
dispatch_params.context_buf = params.context_buf;
|
|
dispatch_params.finished = finished;
|
|
dispatch_params.sequence_lengths
|
|
= params.sequence_lengths; // NOTE: current seq len including padding (fixed after meeting the finished id)
|
|
dispatch_params.max_batch_size = batch_beam;
|
|
dispatch_params.inference_batch_size = batch_beam;
|
|
dispatch_params.beam_width = params.beam_width;
|
|
dispatch_params.head_num = mNumHeads;
|
|
dispatch_params.kv_head_num = mNumKVHeads;
|
|
dispatch_params.size_per_head = getHeadSize();
|
|
dispatch_params.rotary_embedding_dim = mRotaryEmbeddingDim;
|
|
dispatch_params.position_embedding_type = mPositionEmbeddingType;
|
|
dispatch_params.max_seq_len = params.max_seq_length; // difference between max_seq_lengths and max_seq_length?
|
|
dispatch_params.input_lengths = params.context_lengths;
|
|
dispatch_params.step = step;
|
|
dispatch_params.q_scaling = q_scaling;
|
|
dispatch_params.relative_attention_bias_stride = relative_attention_bias_stride;
|
|
dispatch_params.linear_bias_slopes = isALiBi() ? params.alibi_slopes : nullptr;
|
|
dispatch_params.ia3_tasks = ia3_tasks;
|
|
dispatch_params.ia3_key_weights = ia3_key_weights;
|
|
dispatch_params.ia3_value_weights = ia3_value_weights;
|
|
dispatch_params.qkv_scale_out = qkv_scale_out;
|
|
dispatch_params.attention_out_scale = attention_out_scale;
|
|
dispatch_params.quant_option = quant_option;
|
|
dispatch_params.multi_block_mode = mMultiBlockMode;
|
|
dispatch_params.max_seq_len_tile = getMaxSeqLenTile(sizeof(T));
|
|
dispatch_params.partial_out = partial_out;
|
|
dispatch_params.partial_sum = partial_sum;
|
|
dispatch_params.partial_max = partial_max;
|
|
dispatch_params.block_counter = block_counter;
|
|
dispatch_params.kv_cache_quant_mode = mKVCacheQuantMode;
|
|
dispatch_params.kv_scale_orig_quant = params.kv_scale_orig_quant;
|
|
dispatch_params.kv_scale_quant_orig = params.kv_scale_quant_orig;
|
|
dispatch_params.kv_block_array = kv_cache_buffer;
|
|
dispatch_params.multi_processor_count = mMultiProcessorCount;
|
|
dispatch_params.rotary_embedding_base = mRotaryEmbeddingBase;
|
|
dispatch_params.rotary_embedding_scale_type = mRotaryEmbeddingScaleType;
|
|
dispatch_params.rotary_embedding_scale = mRotaryEmbeddingScale;
|
|
dispatch_params.rotary_embedding_max_positions = mRotaryEmbeddingMaxPositions;
|
|
dispatch_params.cross_attention = mCrossAttention;
|
|
dispatch_params.memory_length_per_sample = params.encoder_input_lengths;
|
|
|
|
using DataType = typename SATypeConverter<T>::Type;
|
|
if (!mCrossAttention)
|
|
{
|
|
// self attn
|
|
Masked_multihead_attention_params<DataType> mmha_params;
|
|
fusedQKV_masked_attention_dispatch(mmha_params, dispatch_params, stream);
|
|
}
|
|
else
|
|
{
|
|
// cross attn
|
|
Cross_multihead_attention_params<DataType> mmhca_params;
|
|
fusedQKV_masked_attention_dispatch(mmhca_params, dispatch_params, stream);
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
template int GPTAttentionPluginCommon::enqueueGeneration<half, KVLinearBuffer>(
|
|
const EnqueueGenerationParams<half, KVLinearBuffer>& params, cudaStream_t stream);
|
|
|
|
template int GPTAttentionPluginCommon::enqueueGeneration<float, KVLinearBuffer>(
|
|
const EnqueueGenerationParams<float, KVLinearBuffer>& params, cudaStream_t stream);
|
|
|
|
#ifdef ENABLE_BF16
|
|
template int GPTAttentionPluginCommon::enqueueGeneration<__nv_bfloat16, KVLinearBuffer>(
|
|
const EnqueueGenerationParams<__nv_bfloat16, KVLinearBuffer>& params, cudaStream_t stream);
|
|
#endif
|
|
|
|
template int GPTAttentionPluginCommon::enqueueGeneration<half, KVBlockArray>(
|
|
const EnqueueGenerationParams<half, KVBlockArray>& params, cudaStream_t stream);
|
|
|
|
template int GPTAttentionPluginCommon::enqueueGeneration<float, KVBlockArray>(
|
|
const EnqueueGenerationParams<float, KVBlockArray>& params, cudaStream_t stream);
|
|
|
|
#ifdef ENABLE_BF16
|
|
template int GPTAttentionPluginCommon::enqueueGeneration<__nv_bfloat16, KVBlockArray>(
|
|
const EnqueueGenerationParams<__nv_bfloat16, KVBlockArray>& params, cudaStream_t stream);
|
|
#endif
|
|
|
|
int GPTAttentionPluginCommon::initialize() noexcept
|
|
{
|
|
auto cublasHandle = getCublasHandle();
|
|
auto cublasLtHandle = getCublasLtHandle();
|
|
|
|
mCublasWrapper.reset(new tc::CublasMMWrapper(cublasHandle, cublasLtHandle, nullptr, nullptr));
|
|
if (mEnableContextFMHA)
|
|
{
|
|
// Pre-checked during constructing.
|
|
Data_type data_type;
|
|
if (mType == DataType::kHALF)
|
|
{
|
|
data_type = DATA_TYPE_FP16;
|
|
}
|
|
else if (mType == DataType::kBF16)
|
|
{
|
|
data_type = DATA_TYPE_BF16;
|
|
}
|
|
else
|
|
{
|
|
TLLM_CHECK_WITH_INFO(false, "GPTAttentionPlugin received wrong data type.");
|
|
}
|
|
|
|
mFMHARunner.reset(new FusedMHARunnerV2(data_type, mNumHeads, getHeadSize(false), mQScaling));
|
|
// set flags: force_fp32_acc, is_s_padded, causal_mask, num_kv_heads.
|
|
mFMHARunner->setup_flags(mFMHAForceFP32Acc, !mRemovePadding, true, mNumKVHeads);
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
void GPTAttentionPluginCommon::destroy() noexcept
|
|
{
|
|
delete this;
|
|
}
|
|
|
|
size_t GPTAttentionPluginCommon::getCommonSerializationSize() noexcept
|
|
{
|
|
return sizeof(mNumHeads) + sizeof(mNumKVHeads) + sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling)
|
|
+ sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim) + sizeof(mRotaryEmbeddingBase)
|
|
+ sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale) + sizeof(mRotaryEmbeddingMaxPositions)
|
|
+ sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA) + sizeof(mFMHAForceFP32Acc)
|
|
+ sizeof(mMultiBlockMode) + sizeof(unsigned int) // mKVCacheQuantMode
|
|
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mPagedKVCache) + sizeof(mTokensPerBlock) + sizeof(mType)
|
|
+ sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled) + sizeof(mCrossAttention) + sizeof(mMaxDistance);
|
|
}
|
|
|
|
void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
|
|
{
|
|
char *d = static_cast<char*>(buffer), *a = d;
|
|
write(d, mNumHeads);
|
|
write(d, mNumKVHeads);
|
|
write(d, mHeadSize);
|
|
write(d, mUnidirectional);
|
|
write(d, mQScaling);
|
|
write(d, mPositionEmbeddingType);
|
|
write(d, mRotaryEmbeddingDim);
|
|
write(d, mRotaryEmbeddingBase);
|
|
write(d, mRotaryEmbeddingScaleType);
|
|
write(d, mRotaryEmbeddingScale);
|
|
write(d, mRotaryEmbeddingMaxPositions);
|
|
write(d, mTpSize);
|
|
write(d, mTpRank);
|
|
write(d, mEnableContextFMHA);
|
|
write(d, mFMHAForceFP32Acc);
|
|
write(d, mMultiBlockMode);
|
|
write(d, mKVCacheQuantMode.value());
|
|
write(d, mRemovePadding);
|
|
write(d, mMaskType);
|
|
write(d, mPagedKVCache);
|
|
write(d, mTokensPerBlock);
|
|
write(d, mType);
|
|
write(d, mMaxContextLength);
|
|
write(d, mQKVBiasEnabled);
|
|
write(d, mCrossAttention);
|
|
write(d, mMaxDistance);
|
|
assert(d == a + getCommonSerializationSize());
|
|
}
|
|
|
|
void GPTAttentionPluginCommon::terminate() noexcept
|
|
{
|
|
// Do nothing, destroy will always be called, so release the resources there.
|
|
}
|
|
|
|
///////////////
|
|
|
|
GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
|
|
{
|
|
// Fill PluginFieldCollection with PluginField arguments metadata
|
|
mPluginAttributes.clear();
|
|
mPluginAttributes.emplace_back(PluginField("num_heads", nullptr, PluginFieldType::kINT32, -1));
|
|
mPluginAttributes.emplace_back(PluginField("num_kv_heads", nullptr, PluginFieldType::kINT32, 0));
|
|
mPluginAttributes.emplace_back(PluginField("head_size", nullptr, PluginFieldType::kINT32, 0));
|
|
mPluginAttributes.emplace_back(PluginField("unidirectional", nullptr, PluginFieldType::kINT32, 1));
|
|
mPluginAttributes.emplace_back(PluginField("q_scaling", nullptr, PluginFieldType::kFLOAT32, 1.0));
|
|
mPluginAttributes.emplace_back(PluginField("position_embedding_type", nullptr, PluginFieldType::kINT8, 0));
|
|
mPluginAttributes.emplace_back(PluginField("rotary_embedding_dim", nullptr, PluginFieldType::kINT32, 0));
|
|
mPluginAttributes.emplace_back(PluginField("rotary_embedding_base", nullptr, PluginFieldType::kFLOAT32, 0));
|
|
mPluginAttributes.emplace_back(PluginField("rotary_embedding_scale_type", nullptr, PluginFieldType::kINT8, 0));
|
|
mPluginAttributes.emplace_back(PluginField("rotary_embedding_scale", nullptr, PluginFieldType::kFLOAT32, 0));
|
|
mPluginAttributes.emplace_back(PluginField("rotary_embedding_max_positions", nullptr, PluginFieldType::kINT32, 0));
|
|
mPluginAttributes.emplace_back(PluginField("tp_size", nullptr, PluginFieldType::kINT32, 0));
|
|
mPluginAttributes.emplace_back(PluginField("tp_rank", nullptr, PluginFieldType::kINT32, 0));
|
|
mPluginAttributes.emplace_back(PluginField("context_fmha_type", nullptr, PluginFieldType::kINT8, 0));
|
|
mPluginAttributes.emplace_back(PluginField("multi_block_mode", nullptr, PluginFieldType::kINT8, 0));
|
|
mPluginAttributes.emplace_back(PluginField("kv_cache_quant_mode", nullptr, PluginFieldType::kINT32, 0));
|
|
mPluginAttributes.emplace_back(PluginField("remove_input_padding", nullptr, PluginFieldType::kINT8, 0));
|
|
mPluginAttributes.emplace_back(PluginField("mask_type", nullptr, PluginFieldType::kINT32, 0));
|
|
mPluginAttributes.emplace_back(PluginField("paged_kv_cache", nullptr, PluginFieldType::kINT32, 0));
|
|
mPluginAttributes.emplace_back(PluginField("tokens_per_block", nullptr, PluginFieldType::kINT32, 0));
|
|
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
|
|
mPluginAttributes.emplace_back(PluginField("max_context_length", nullptr, PluginFieldType::kINT32, 1));
|
|
mPluginAttributes.emplace_back(PluginField("qkv_bias_enabled", nullptr, PluginFieldType::kINT8, 0));
|
|
mPluginAttributes.emplace_back(PluginField("do_cross_attention", nullptr, PluginFieldType::kINT8, 0));
|
|
mPluginAttributes.emplace_back(PluginField("max_distance", nullptr, PluginFieldType::kINT32, 0));
|
|
mFC.nbFields = mPluginAttributes.size();
|
|
mFC.fields = mPluginAttributes.data();
|
|
}
|
|
|
|
const PluginFieldCollection* GPTAttentionPluginCreatorCommon::getFieldNames() noexcept
|
|
{
|
|
return &mFC;
|
|
}
|