TensorRT-LLMs/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp
2024-05-07 23:34:28 +08:00

238 lines
9.3 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/runtime/tllmLogger.h"
#include "tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.h"
#include "tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h"
#include "tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h"
#include "tensorrt_llm/plugins/identityPlugin/identityPlugin.h"
#include "tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.h"
#include "tensorrt_llm/plugins/lookupPlugin/lookupPlugin.h"
#include "tensorrt_llm/plugins/loraPlugin/loraPlugin.h"
#include "tensorrt_llm/plugins/lruPlugin/lruPlugin.h"
#include "tensorrt_llm/plugins/mambaConv1dPlugin/mambaConv1dPlugin.h"
#include "tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h"
#if ENABLE_MULTI_DEVICE
#include "tensorrt_llm/plugins/ncclPlugin/allgatherPlugin.h"
#include "tensorrt_llm/plugins/ncclPlugin/allreducePlugin.h"
#include "tensorrt_llm/plugins/ncclPlugin/recvPlugin.h"
#include "tensorrt_llm/plugins/ncclPlugin/reduceScatterPlugin.h"
#include "tensorrt_llm/plugins/ncclPlugin/sendPlugin.h"
#endif // ENABLE_MULTI_DEVICE
#include "tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.h"
#include "tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.h"
#include "tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.h"
#include "tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.h"
#include "tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.h"
#include "tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.h"
#include "tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.h"
#include "tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h"
#include <array>
#include <cstdlib>
#include <NvInferRuntime.h>
namespace tc = tensorrt_llm::common;
namespace
{
nvinfer1::IPluginCreator* creatorPtr(nvinfer1::IPluginCreator& creator)
{
return &creator;
}
auto tllmLogger = tensorrt_llm::runtime::TllmLogger();
nvinfer1::ILogger* gLogger{&tllmLogger};
class GlobalLoggerFinder : public nvinfer1::ILoggerFinder
{
public:
nvinfer1::ILogger* findLogger() override
{
return gLogger;
}
};
GlobalLoggerFinder gGlobalLoggerFinder{};
#if !defined(_MSC_VER)
[[maybe_unused]] __attribute__((constructor))
#endif
void initOnLoad()
{
auto constexpr kLoadPlugins = "TRT_LLM_LOAD_PLUGINS";
auto const loadPlugins = std::getenv(kLoadPlugins);
if (loadPlugins && loadPlugins[0] == '1')
{
initTrtLlmPlugins(gLogger);
}
}
bool pluginsInitialized = false;
} // namespace
namespace tensorrt_llm::plugins::api
{
LoggerManager& tensorrt_llm::plugins::api::LoggerManager::getInstance() noexcept
{
static LoggerManager instance;
return instance;
}
void LoggerManager::setLoggerFinder(nvinfer1::ILoggerFinder* finder)
{
std::lock_guard<std::mutex> lk(mMutex);
if (mLoggerFinder == nullptr && finder != nullptr)
{
mLoggerFinder = finder;
}
}
[[maybe_unused]] nvinfer1::ILogger* LoggerManager::logger()
{
std::lock_guard<std::mutex> lk(mMutex);
if (mLoggerFinder != nullptr)
{
return mLoggerFinder->findLogger();
}
return nullptr;
}
nvinfer1::ILogger* LoggerManager::defaultLogger() noexcept
{
return gLogger;
}
} // namespace tensorrt_llm::plugins::api
// New Plugin APIs
extern "C"
{
bool initTrtLlmPlugins(void* logger, char const* libNamespace)
{
if (pluginsInitialized)
return true;
if (logger)
{
gLogger = static_cast<nvinfer1::ILogger*>(logger);
}
setLoggerFinder(&gGlobalLoggerFinder);
auto registry = getPluginRegistry();
std::int32_t nbCreators;
auto creators = getPluginCreators(nbCreators);
for (std::int32_t i = 0; i < nbCreators; ++i)
{
auto const creator = creators[i];
creator->setPluginNamespace(libNamespace);
registry->registerCreator(*creator, libNamespace);
if (gLogger)
{
auto const msg = tc::fmtstr("Registered plugin creator %s version %s in namespace %s",
creator->getPluginName(), creator->getPluginVersion(), libNamespace);
gLogger->log(nvinfer1::ILogger::Severity::kVERBOSE, msg.c_str());
}
}
pluginsInitialized = true;
return true;
}
[[maybe_unused]] void setLoggerFinder([[maybe_unused]] nvinfer1::ILoggerFinder* finder)
{
tensorrt_llm::plugins::api::LoggerManager::getInstance().setLoggerFinder(finder);
}
[[maybe_unused]] nvinfer1::IPluginCreator* const* getPluginCreators(std::int32_t& nbCreators)
{
static tensorrt_llm::plugins::IdentityPluginCreator identityPluginCreator;
static tensorrt_llm::plugins::BertAttentionPluginCreator bertAttentionPluginCreator;
static tensorrt_llm::plugins::GPTAttentionPluginCreator gptAttentionPluginCreator;
static tensorrt_llm::plugins::GemmPluginCreator gemmPluginCreator;
static tensorrt_llm::plugins::MixtureOfExpertsPluginCreator moePluginCreator;
#if ENABLE_MULTI_DEVICE
static tensorrt_llm::plugins::SendPluginCreator sendPluginCreator;
static tensorrt_llm::plugins::RecvPluginCreator recvPluginCreator;
static tensorrt_llm::plugins::AllreducePluginCreator allreducePluginCreator;
static tensorrt_llm::plugins::AllgatherPluginCreator allgatherPluginCreator;
static tensorrt_llm::plugins::ReduceScatterPluginCreator reduceScatterPluginCreator;
#endif // ENABLE_MULTI_DEVICE
static tensorrt_llm::plugins::SmoothQuantGemmPluginCreator smoothQuantGemmPluginCreator;
static tensorrt_llm::plugins::LayernormQuantizationPluginCreator layernormQuantizationPluginCreator;
static tensorrt_llm::plugins::QuantizePerTokenPluginCreator quantizePerTokenPluginCreator;
static tensorrt_llm::plugins::QuantizeTensorPluginCreator quantizeTensorPluginCreator;
static tensorrt_llm::plugins::RmsnormQuantizationPluginCreator rmsnormQuantizationPluginCreator;
static tensorrt_llm::plugins::WeightOnlyGroupwiseQuantMatmulPluginCreator
weightOnlyGroupwiseQuantMatmulPluginCreator;
static tensorrt_llm::plugins::WeightOnlyQuantMatmulPluginCreator weightOnlyQuantMatmulPluginCreator;
static tensorrt_llm::plugins::LookupPluginCreator lookupPluginCreator;
static tensorrt_llm::plugins::LoraPluginCreator loraPluginCreator;
static tensorrt_llm::plugins::SelectiveScanPluginCreator selectiveScanPluginCreator;
static tensorrt_llm::plugins::MambaConv1dPluginCreator mambaConv1DPluginCreator;
static tensorrt_llm::plugins::lruPluginCreator lruPluginCreator;
static tensorrt_llm::plugins::CumsumLastDimPluginCreator cumsumLastDimPluginCreator;
static std::array pluginCreators
= { creatorPtr(identityPluginCreator),
creatorPtr(bertAttentionPluginCreator),
creatorPtr(gptAttentionPluginCreator),
creatorPtr(gemmPluginCreator),
creatorPtr(moePluginCreator),
#if ENABLE_MULTI_DEVICE
creatorPtr(sendPluginCreator),
creatorPtr(recvPluginCreator),
creatorPtr(allreducePluginCreator),
creatorPtr(allgatherPluginCreator),
creatorPtr(reduceScatterPluginCreator),
#endif // ENABLE_MULTI_DEVICE
creatorPtr(smoothQuantGemmPluginCreator),
creatorPtr(layernormQuantizationPluginCreator),
creatorPtr(quantizePerTokenPluginCreator),
creatorPtr(quantizeTensorPluginCreator),
creatorPtr(rmsnormQuantizationPluginCreator),
creatorPtr(weightOnlyGroupwiseQuantMatmulPluginCreator),
creatorPtr(weightOnlyQuantMatmulPluginCreator),
creatorPtr(lookupPluginCreator),
creatorPtr(loraPluginCreator),
creatorPtr(selectiveScanPluginCreator),
creatorPtr(mambaConv1DPluginCreator),
creatorPtr(lruPluginCreator),
creatorPtr(cumsumLastDimPluginCreator),
};
nbCreators = pluginCreators.size();
return pluginCreators.data();
}
#if NV_TENSORRT_MAJOR >= 10
[[maybe_unused]] tensorrt_llm::plugins::api::IPluginCreatorInterface* const* getCreators(std::int32_t& nbCreators)
{
return reinterpret_cast<tensorrt_llm::plugins::api::IPluginCreatorInterface* const*>(
getPluginCreators(nbCreators));
}
#endif // NV_TENSORRT_MAJOR >= 10
} // extern "C"