/* * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "gemmAllReducePlugin.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" #include "tensorrt_llm/plugins/common/pluginUtils.h" #include static char const* GEMM_ALLREDUCE_PLUGIN_VERSION = "1"; static char const* GEMM_ALLREDUCE_PLUGIN_NAME = "GemmAllReduce"; template using CutlassType = ::tensorrt_llm::kernels::cutlass_kernels::CutlassType; namespace tensorrt_llm::plugins { template static std::pair makeEntry() { return {std::make_tuple(ElementA, ElementB, ElementD), [&]() { using GemmTraits = cutlass_kernels::GemmTypes::type, typename CutlassType::type, typename CutlassType::type, // C, unused typename CutlassType::type, std::conditional_t, // SFA std::conditional_t, // SFB cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, // C, unused cutlass::layout::RowMajor>; return new cutlass_kernels::GemmAllReduceImplRunner(); }}; } template static std::map getTypedInstantiators() { return std::map({makeEntry(), makeEntry(), makeEntry(), makeEntry(), makeEntry(), makeEntry()}); } //////////////////////////////////////////////////////////// // GemmAllReducePlugin Methods //////////////////////////////////////////////////////////// GemmAllReducePlugin::GemmAllReducePlugin(GemmAllReducePluginOptions const& options) : mOptions(options) , mGemmId(GemmIdCore(options.maxProblemShape.n, options.maxProblemShape.k, options.typeD)) , mProfiler(mGemmPluginProfileManager.createGemmPluginProfiler(/*inference=*/options.deserialize)) { // construct mapping of input/output pos to argument int argIdx = 0; // inputs mArgMap[argIdx++] = TensorArg::IN_ACTIVATION; mArgMap[argIdx++] = TensorArg::IN_WEIGHT; if (mOptions.hasSFA) { mArgMap[argIdx++] = TensorArg::IN_ACTIVATION_SF; } if (mOptions.hasSFB) { mArgMap[argIdx++] = TensorArg::IN_WEIGHT_SF; } if (mOptions.alphaIsPtr) { mArgMap[argIdx++] = TensorArg::IN_ALPHA; } mNbInputs = argIdx; // outputs mArgMap[argIdx++] = TensorArg::OUT_D_UC; mArgMap[argIdx++] = TensorArg::OUT_D_MC; mArgMap[argIdx++] = TensorArg::OUT_D_IPC; mNbOutputs = argIdx - mNbInputs; // Create mapping of argument to tensor pos for (auto const& pair : mArgMap) { mArgInvMap[pair.second] = pair.first; } // Use map instead of huge switch case mTypedInstantiators = getTypedInstantiators(); auto key = std::make_tuple(mOptions.typeA, mOptions.typeB, mOptions.typeD); TLLM_CHECK_WITH_INFO(mTypedInstantiators.count(key) > 0, "No cutlass gemm for impl."); mGemm = std::shared_ptr(mTypedInstantiators[key]()); } void GemmAllReducePlugin::allocatePersistentWorkspace() { TLLM_CHECK(mOptions.maxProblemShape.isInitialized()); mWorkspaceKey = "gemm_allreduce_workspace_m" + std::to_string(mOptions.maxProblemShape.maxM); cutlass_kernels::GemmAllReduceImplInterface::LaunchConfig smallest_tile_config = mGemm->getSupportedLaunchConfigs()[0]; cutlass_kernels::GemmAllReduceImplInterface::ProblemArgs args; args.argProblemShape(mOptions.maxProblemShape.maxM, mOptions.maxProblemShape.n, mOptions.maxProblemShape.k, 1) .argRanks(mRank, mOptions.group) .argLaunchConfig(smallest_tile_config); TLLM_CHECK(mWorkspace == nullptr); // Wrap persistent workspace in IPluginResource type // so that clone() can be called to allocate memory GemmAllReducePersistentWorkspace unallocated_resource(mGemm->getPersistentWorkspace(args)); // Register and allocate workspace mWorkspace = static_cast( getPluginRegistry()->acquirePluginResource(mWorkspaceKey.c_str(), &unallocated_resource)); TLLM_CHECK(mWorkspace != nullptr); } LaunchConfig GemmAllReducePlugin::getStaticHeuristicLaunchConfig(int M) const { using namespace tensorrt_llm::cutlass_extensions; // This is only applicable when we swap and transpose A & B. // When M is small we want to select tile that best fits it to maximize MMA efficiency. auto filterByM = [&](std::vector candidateConfigs) { std::vector result; if (M <= 16) { std::copy_if(candidateConfigs.begin(), candidateConfigs.end(), std::back_inserter(result), [](const LaunchConfig& config) { return config.tile_shape == TileShape::TileShape_128x16x128 and config.transposed; }); } else if (M <= 32) { std::copy_if(candidateConfigs.begin(), candidateConfigs.end(), std::back_inserter(result), [](const LaunchConfig& config) { return config.tile_shape == TileShape::TileShape_128x32x128 and config.transposed; }); } else if (M <= 64) { std::copy_if(candidateConfigs.begin(), candidateConfigs.end(), std::back_inserter(result), [](const LaunchConfig& config) { return config.tile_shape == TileShape::TileShape_128x64x128 and config.transposed; }); } else { std::copy_if(candidateConfigs.begin(), candidateConfigs.end(), std::back_inserter(result), [](const LaunchConfig& config) { return config.tile_shape == TileShape::TileShape_128x128x128 and config.transposed; }); } // If result empty then use any. if (result.empty()) { result = candidateConfigs; } return result; }; auto bestLaunchConfigs = mGemm->getSupportedLaunchConfigs(); bestLaunchConfigs = filterByM(bestLaunchConfigs); TLLM_CHECK(!bestLaunchConfigs.empty()); // Return first one, because who knows which is best. return bestLaunchConfigs.front(); } static GemmAllReducePluginOptions deserializeOptions(void const*& data, size_t length) { char const* begin = reinterpret_cast(data); char const*& end = reinterpret_cast(data); GemmAllReducePluginOptions options; options.deserialize = true; read(end, options.typeA); read(end, options.typeB); read(end, options.typeD); read(end, options.transA); read(end, options.transB); read(end, options.alpha); read(end, options.maxProblemShape); read(end, options.groupSize); for (int i = 0; i < options.groupSize; ++i) { int rank = -1; read(end, rank); options.group.insert(rank); } read(end, options.hasSFA); read(end, options.hasSFB); read(end, options.alphaIsPtr); TLLM_CHECK_WITH_INFO(end == begin + length, "Expected length (%d) != real length (%d). This is often " "caused by using different TensorRT LLM version to build " "engine and run engine.", (int) length, (int) (end - begin)); return options; } GemmAllReducePlugin::GemmAllReducePlugin(void const* data, size_t length) : GemmAllReducePlugin(deserializeOptions(std::ref(data), length)) { if (mProfiler->useProfiler()) { mProfiler->deserializeFromOwnFile(mGemmId, mOptions.maxProblemShape); } } ////////////////////////////////// // IPluginV2DynamicExt Methods ////////////////////////////////// IPluginV2DynamicExt* GemmAllReducePlugin::clone() const noexcept { return new GemmAllReducePlugin(*this); } DimsExprs GemmAllReducePlugin::getOutputDimensions( int outputIndex, DimsExprs const* inputs, int nbInputs, IExprBuilder& exprBuilder) noexcept { try { TLLM_CHECK(nbInputs == mNbInputs); // number of input tensors TLLM_CHECK(inputs[0].nbDims == inputs[1].nbDims); TLLM_CHECK(outputIndex < getNbOutputs()); // List of pointers to D on each rank if ((nbInputs + outputIndex) == TensorArg::OUT_D_IPC) { DimsExprs out_dims; out_dims.nbDims = 1; out_dims.d[0] = exprBuilder.constant(mOptions.groupSize); return out_dims; } TLLM_CHECK(mOptions.transA == false); TLLM_CHECK(mOptions.transB == true); int const nbDimsA = inputs[0].nbDims; // number of dims int const nbDimsB = inputs[1].nbDims; DimsExprs out_dims; // subtract 2 -> K from each input out_dims.nbDims = nbDimsA + nbDimsB - 2; if (mOptions.transA) { for (int i = 1; i < nbDimsA; ++i) { out_dims.d[i - 1] = inputs[0].d[i]; } } else { for (int i = 0; i < nbDimsA - 1; ++i) { out_dims.d[i] = inputs[0].d[i]; } } if (mOptions.transB) { for (int i = 0; i < nbDimsB - 1; ++i) { out_dims.d[nbDimsA - 1 + i] = inputs[1].d[i]; } } else { for (int i = 1; i < nbDimsB; ++i) { out_dims.d[nbDimsA - 2 + i] = inputs[1].d[i]; } } return out_dims; } catch (std::exception const& e) { caughtError(e); } return DimsExprs{}; } bool GemmAllReducePlugin::supportsFormatCombination( int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept { // inOut[0] -> activation // inOut[1] -> weight // inOut[1+hasInputSF] -> activation_sf // inOut[1+hasInputSF*2] -> weight_sf // inOut[2+hasInputSF*2] -> output[0] = D_uc // inOut[3+hasInputSF*2] -> output[1] = D_mc TLLM_CHECK_WITH_INFO(pos < mNbInputs + mNbOutputs, "Unexpected pos: %d", pos); auto const& desc = inOut[pos]; TLLM_CHECK_WITH_INFO(mArgMap.count(pos) > 0, "pos %d not found in mArgMap.", pos); TensorArg arg = mArgMap[pos]; auto typeExists = [&](DataType dtype, auto idx) -> bool { for (const auto& [key, value] : mTypedInstantiators) { // key format: if (std::get(key) == dtype) { return true; } } return false; }; switch (arg) { case TensorArg::IN_ACTIVATION: return typeExists(desc.type, std::integral_constant{}); case TensorArg::IN_WEIGHT: return typeExists(desc.type, std::integral_constant{}); case TensorArg::IN_ACTIVATION_SF: case TensorArg::IN_WEIGHT_SF: // Assumed SF for only FP4 at the moment return desc.type == DataType::kFP8; case TensorArg::IN_ALPHA: return desc.type == DataType::kFLOAT; case TensorArg::OUT_D_UC: case TensorArg::OUT_D_MC: case TensorArg::OUT_D_IPC: return typeExists(desc.type, std::integral_constant{}); default: return false; } } void GemmAllReducePlugin::configurePlugin( DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept { // Get problem shape int const nbDimsA = in[0].max.nbDims; int const minM = utils::computeMDimension(mOptions.transA, in[0].min); int const maxM = utils::computeMDimension(mOptions.transA, in[0].max); int const N = utils::computeNDimension(mOptions.transB, in[1].max); int const K = mOptions.transA ? in[0].max.d[0] : in[0].max.d[nbDimsA - 1]; TLLM_CHECK_WITH_INFO(out[0].desc.type == mOptions.typeD, "Output type mismatch."); // Ensure call from execution phase does // not override call from build phase if (!mOptions.maxProblemShape.isInitialized()) { mOptions.maxProblemShape = {minM, maxM, N, K}; mGemmId = {N, K, mOptions.typeD}; } // Build phase doesn't have COMM_SESSION (i.e built on single rank) // so do not allocate persistent workspace if (!isBuilding()) { auto getTPRank = [&]() { int rank = COMM_SESSION.getRank(); auto it = std::find(mOptions.group.begin(), mOptions.group.end(), rank); TLLM_CHECK_WITH_INFO(it != mOptions.group.end(), "Incorrect group specified - rank " + std::to_string(rank) + " not found in group"); return std::distance(mOptions.group.begin(), it); }; mRank = getTPRank(); if (mWorkspace == nullptr) { allocatePersistentWorkspace(); } } } size_t GemmAllReducePlugin::getWorkspaceSize( PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept { return 0; } int GemmAllReducePlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { // inputs[0] -> [M(*), K] // inputs[1] -> [K, N] // outputs[0] -> [M(*), N] unicast ptr // outputs[1] -> [M(*), N] multicast ptr auto const nbDimsA = inputDesc[0].dims.nbDims; auto const M = utils::computeMDimension(mOptions.transA, inputDesc[0].dims); auto const N = utils::computeNDimension(mOptions.transB, inputDesc[1].dims); auto const K = mOptions.transA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; TLLM_CHECK_WITH_INFO(M <= mOptions.maxProblemShape.maxM, "GemmAllReducePlugin M > maxM."); TLLM_CHECK_WITH_INFO(M > 0, "GemmAllReducePlugin M is 0."); TLLM_CHECK_WITH_INFO(N > 0, "GemmAllReducePlugin N is 0."); TLLM_CHECK_WITH_INFO(K > 0, "GemmAllReducePlugin K is 0."); TLLM_CHECK_WITH_INFO(mWorkspace != nullptr, "GemmAllReducePlugin workspace is null."); LaunchConfig bestLaunchConfig; if (mProfiler->useProfiler()) { bestLaunchConfig = mProfiler->getBestConfig(M, mGemmId).value(); } else { bestLaunchConfig = getStaticHeuristicLaunchConfig(M); } void const* activation = inputs[mArgInvMap[TensorArg::IN_ACTIVATION]]; void const* weight = inputs[mArgInvMap[TensorArg::IN_WEIGHT]]; void* D_out_uc = outputs[mArgInvMap[TensorArg::OUT_D_UC] - mNbInputs]; void* D_out_mc = outputs[mArgInvMap[TensorArg::OUT_D_MC] - mNbInputs]; void* D_out_ipc = outputs[mArgInvMap[TensorArg::OUT_D_IPC] - mNbInputs]; TLLM_CHECK_WITH_INFO(activation != nullptr, "GemmAllReducePlugin activation is NULL"); TLLM_CHECK_WITH_INFO(weight != nullptr, "GemmAllReducePlugin weight is NULL"); TLLM_CHECK_WITH_INFO(D_out_uc != nullptr, "GemmAllReducePlugin out_uc is NULL"); TLLM_CHECK_WITH_INFO(D_out_mc != nullptr, "GemmAllReducePlugin out_mc is NULL"); TLLM_CHECK_WITH_INFO(D_out_ipc != nullptr, "GemmAllReducePlugin out_ipc is NULL"); cutlass_kernels::GemmAllReduceImplInterface::ProblemArgs args; args.argProblemShape(M, N, K, 1) .argA(activation) .argB(weight) .argC(nullptr) .argD(D_out_uc, D_out_mc, (void**) D_out_ipc) .argRanks(mRank, mOptions.group) .argBeta(0.f) // no bias .argLaunchConfig(bestLaunchConfig) .argWorkspace(mWorkspace->mWorkspace.get()); // tensor for scaling input A if (mOptions.hasSFA) { void const* activation_sf = inputs[mArgInvMap[TensorArg::IN_ACTIVATION_SF]]; TLLM_CHECK_WITH_INFO(activation_sf != nullptr, "GemmAllReducePlugin activation_sf is NULL"); args.argAScale(activation_sf); } // tensor for scaling input B if (mOptions.hasSFB) { void const* weight_sf = inputs[mArgInvMap[TensorArg::IN_WEIGHT_SF]]; TLLM_CHECK_WITH_INFO(weight_sf != nullptr, "GemmAllReducePlugin weight_sf is NULL"); args.argBScale(weight_sf); } // tensor for scaling output D if (mOptions.alphaIsPtr) { void const* alpha_vec = inputs[mArgInvMap[TensorArg::IN_ALPHA]]; TLLM_CHECK_WITH_INFO(alpha_vec != nullptr, "GemmAllReducePlugin alpha_vec is NULL"); args.argAlphaPtr(reinterpret_cast(alpha_vec)); } else { args.argAlpha(mOptions.alpha); } mGemm->run(args, stream); return 0; } ////////////////////////////////// // IPluginV2Ext Methods ////////////////////////////////// DataType GemmAllReducePlugin::getOutputDataType(int index, DataType const* inputTypes, int nbInputs) const noexcept { TLLM_CHECK_WITH_INFO(index < getNbOutputs(), "Output index out of bounds: %d", index); return mOptions.typeD; } ////////////////////////////////// // IPluginV2 Methods ////////////////////////////////// char const* GemmAllReducePlugin::getPluginType() const noexcept { return GEMM_ALLREDUCE_PLUGIN_NAME; } char const* GemmAllReducePlugin::getPluginVersion() const noexcept { return GEMM_ALLREDUCE_PLUGIN_VERSION; } int GemmAllReducePlugin::getNbOutputs() const noexcept { return mNbOutputs; } int GemmAllReducePlugin::initialize() noexcept { if (isBuilding() && mProfiler->useProfiler()) { // TODO (xsimmons): interfaces between GemmPluginProfiler and Plugin // needs to be relooked at - current interface implicitly assigns runner to profiler // object in profileTactics() assert(mOptions.maxProblemShape.isInitialized()); mProfiler->profileTactics(mGemm, mOptions.typeD, mOptions.maxProblemShape, mGemmId); } return 0; } void GemmAllReducePlugin::terminate() noexcept { if (isBuilding()) // need this otherwise getComm will crash during build phase { return; } // free mWorkspace if (mWorkspace) { getPluginRegistry()->releasePluginResource(mWorkspaceKey.c_str()); mWorkspace = nullptr; } } size_t GemmAllReducePlugin::getSerializationSize() const noexcept { // cannot use sizeof(GemmAllReducePluginOptions) // becaused need packed attribute which doesn't work on enum // without making the enum also packed size_t size = 0; size += sizeof(mOptions.typeA); size += sizeof(mOptions.typeB); size += sizeof(mOptions.typeD); size += sizeof(mOptions.transA); size += sizeof(mOptions.transB); size += sizeof(mOptions.alpha); size += sizeof(mOptions.maxProblemShape); size += sizeof(mOptions.groupSize); size += mOptions.group.size() * sizeof(int); size += sizeof(mOptions.hasSFA); size += sizeof(mOptions.hasSFB); size += sizeof(mOptions.alphaIsPtr); return size; } void GemmAllReducePlugin::serialize(void* buffer) const noexcept { char* begin = reinterpret_cast(buffer); char* end = reinterpret_cast(buffer); write(end, mOptions.typeA); write(end, mOptions.typeB); write(end, mOptions.typeD); write(end, mOptions.transA); write(end, mOptions.transB); write(end, mOptions.alpha); write(end, mOptions.maxProblemShape); write(end, mOptions.groupSize); for (auto const& rank : mOptions.group) { write(end, rank); } write(end, mOptions.hasSFA); write(end, mOptions.hasSFB); write(end, mOptions.alphaIsPtr); TLLM_CHECK(end == begin + getSerializationSize()); // Profiler MNK->kernel mappings need to be deterministic and consistent across ranks // to ensure correct functionality (unlike standalone GEMMs). // Since by default each rank will generate and serialize its own profiler mapping // this can lead to different mappings between ranks which will result in fatal // error. Therefore only generate and use profiler mapping for single rank. if (mProfiler->useProfiler() && COMM_SESSION.getRank() == 0) { mProfiler->serializeToOwnFile(mGemmId); } } void GemmAllReducePlugin::destroy() noexcept { delete this; } //////////////////////////////////////////////////////////// // GemmAllReducePluginCreator Methods //////////////////////////////////////////////////////////// PluginFieldCollection GemmAllReducePluginCreator::mFC; std::vector GemmAllReducePluginCreator::mPluginAttributes; GemmAllReducePluginCreator::GemmAllReducePluginCreator() { // Fill PluginFieldCollection with PluginField arguments metadata mPluginAttributes.clear(); mPluginAttributes.emplace_back("type_a", nullptr, PluginFieldType::kINT32, 1); mPluginAttributes.emplace_back("type_b", nullptr, PluginFieldType::kINT32, 1); mPluginAttributes.emplace_back("type_d", nullptr, PluginFieldType::kINT32, 1); mPluginAttributes.emplace_back("transa", nullptr, PluginFieldType::kINT32, 1); mPluginAttributes.emplace_back("transb", nullptr, PluginFieldType::kINT32, 1); mPluginAttributes.emplace_back("alpha", nullptr, PluginFieldType::kFLOAT32, 1); mPluginAttributes.emplace_back("group", nullptr, PluginFieldType::kINT32, 1); mPluginAttributes.emplace_back("has_sfa", nullptr, PluginFieldType::kINT8, 1); mPluginAttributes.emplace_back("has_sfb", nullptr, PluginFieldType::kINT8, 1); mPluginAttributes.emplace_back("alpha_is_ptr", nullptr, PluginFieldType::kINT8, 1); mFC.nbFields = mPluginAttributes.size(); mFC.fields = mPluginAttributes.data(); } char const* GemmAllReducePluginCreator::getPluginName() const noexcept { return GEMM_ALLREDUCE_PLUGIN_NAME; } char const* GemmAllReducePluginCreator::getPluginVersion() const noexcept { return GEMM_ALLREDUCE_PLUGIN_VERSION; } PluginFieldCollection const* GemmAllReducePluginCreator::getFieldNames() noexcept { return &mFC; } IPluginV2* GemmAllReducePluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { PluginField const* fields = fc->fields; GemmAllReducePluginOptions options; options.deserialize = false; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { char const* attrName = fields[i].name; if (!strcmp(attrName, "type_a")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); options.typeA = *static_cast(fields[i].data); } else if (!strcmp(attrName, "type_b")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); options.typeB = *static_cast(fields[i].data); } else if (!strcmp(attrName, "type_d")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); options.typeD = *static_cast(fields[i].data); } else if (!strcmp(attrName, "transa")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); options.transA = *static_cast(fields[i].data); } else if (!strcmp(attrName, "transb")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); options.transB = *static_cast(fields[i].data); } else if (!strcmp(attrName, "alpha")) { TLLM_CHECK(fields[i].type == PluginFieldType::kFLOAT32); options.alpha = *static_cast(fields[i].data); } else if (!strcmp(attrName, "group")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); auto const* ranks = static_cast(fields[i].data); for (int j = 0; j < fields[i].length; ++j) { options.group.insert(ranks[j]); } options.groupSize = options.group.size(); } else if (!strcmp(attrName, "has_sfa")) // passed in as input tensor { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); options.hasSFA = *static_cast(fields[i].data); } else if (!strcmp(attrName, "has_sfb")) // passed in as input tensor { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); options.hasSFB = *static_cast(fields[i].data); } else if (!strcmp(attrName, "alpha_is_ptr")) // passed in as input tensor { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); options.alphaIsPtr = *static_cast(fields[i].data); } } try { // GemmAllReducePluginCreator is unique and shared for an engine generation auto* obj = new GemmAllReducePlugin(options); obj->setPluginNamespace(mNamespace.c_str()); return obj; } catch (std::exception const& e) { caughtError(e); return nullptr; } } IPluginV2* GemmAllReducePluginCreator::deserializePlugin( char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call GemmAllReducePlugin::destroy() try { auto* obj = new GemmAllReducePlugin(serialData, serialLength); obj->setPluginNamespace(mNamespace.c_str()); return obj; } catch (std::exception const& e) { caughtError(e); } return nullptr; } } // namespace tensorrt_llm::plugins