mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Tltin <TltinDeng01@gmail.com> Co-authored-by: zhaohb <zhaohbcloud@126.com> Co-authored-by: Bradley Heilbrun <brad@repl.it> Co-authored-by: nqbao11 <nqbao11.01@gmail.com> Co-authored-by: Nikhil Varghese <nikhil@bot-it.ai>
122 lines
4.7 KiB
C++
122 lines
4.7 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#pragma once
|
|
|
|
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
|
#include "tensorrt_llm/common/quantization.h"
|
|
#include "tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h"
|
|
#include "tensorrt_llm/kernels/gptKernels.h"
|
|
#include "tensorrt_llm/plugins/common/plugin.h"
|
|
#include <cassert>
|
|
#include <set>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace tensorrt_llm::plugins
|
|
{
|
|
|
|
class BertAttentionPlugin : public BasePlugin
|
|
{
|
|
public:
|
|
BertAttentionPlugin() = delete;
|
|
|
|
BertAttentionPlugin(int num_heads, int head_size, float q_scaling, bool qk_half_accum,
|
|
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, nvinfer1::DataType type,
|
|
bool do_relative_attention = false, int max_distance = 0, bool remove_padding = false);
|
|
|
|
BertAttentionPlugin(const void* data, size_t length);
|
|
|
|
~BertAttentionPlugin() override = default;
|
|
|
|
// IPluginV2DynamicExt Methods
|
|
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
|
|
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
|
|
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
|
|
bool supportsFormatCombination(
|
|
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override;
|
|
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
|
|
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override;
|
|
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
|
|
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override;
|
|
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
|
|
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
|
|
|
|
template <typename T>
|
|
int enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
|
|
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream);
|
|
|
|
// IPluginV2Ext Methods
|
|
nvinfer1::DataType getOutputDataType(
|
|
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override;
|
|
|
|
// IPluginV2 Methods
|
|
const char* getPluginType() const noexcept override;
|
|
const char* getPluginVersion() const noexcept override;
|
|
int getNbOutputs() const noexcept override;
|
|
int initialize() noexcept override;
|
|
void terminate() noexcept override;
|
|
size_t getSerializationSize() const noexcept override;
|
|
void serialize(void* buffer) const noexcept override;
|
|
void destroy() noexcept override;
|
|
|
|
private:
|
|
const std::string mLayerName;
|
|
|
|
int mNumHeads;
|
|
int mHeadSize;
|
|
float mQScaling;
|
|
nvinfer1::DataType mType;
|
|
bool mRelativeAttention = false;
|
|
int mMaxDistance = 0;
|
|
bool mRemovePadding = false;
|
|
|
|
// unfused mha
|
|
bool mQKHalfAccum = false;
|
|
|
|
// fmha runner (disable by default)
|
|
bool mEnableContextFMHA = false;
|
|
bool mFMHAForceFP32Acc = false;
|
|
bool mSM = tensorrt_llm::common::getSMVersion();
|
|
|
|
// The default copy constructor will leave them as nullptr. clone() shall initialize it.
|
|
UniqPtrWNullCopy<tensorrt_llm::kernels::FusedMHARunnerV2> mFMHARunner;
|
|
UniqPtrWNullCopy<tensorrt_llm::common::CublasMMWrapper> mCublasWrapper;
|
|
};
|
|
|
|
class BertAttentionPluginCreator : public BaseCreator
|
|
{
|
|
public:
|
|
BertAttentionPluginCreator();
|
|
|
|
const char* getPluginName() const noexcept override;
|
|
|
|
const char* getPluginVersion() const noexcept override;
|
|
|
|
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;
|
|
|
|
nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override;
|
|
|
|
nvinfer1::IPluginV2* deserializePlugin(
|
|
const char* name, const void* serialData, size_t serialLength) noexcept override;
|
|
|
|
private:
|
|
static nvinfer1::PluginFieldCollection mFC;
|
|
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
|
};
|
|
|
|
} // namespace tensorrt_llm::plugins
|