/* * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "gemmAllReducePlugin.h" #include "tensorrt_llm/kernels/cutlass_kernels/allreduce_gemm/allreduce_gemm_runner.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" #include "tensorrt_llm/plugins/common/pluginUtils.h" #include static char const* GEMM_ALLREDUCE_PLUGIN_VERSION = "1"; static char const* GEMM_ALLREDUCE_PLUGIN_NAME = "GemmAllReduce"; namespace tensorrt_llm::plugins { template static std::pair makeEntry() { return {std::make_tuple(ElementA, ElementB, ElementD), [&]() { using GemmTraits = tensorrt_llm::kernels::cutlass_kernels::GemmTypes::type, typename CutlassType::type, typename CutlassType::type, // C, unused typename CutlassType::type, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor, // C, unused cutlass::layout::RowMajor>; return new GemmAllReduceImplRunner(); }}; } template static std::map getTypedInstantiators() { return std::map({makeEntry(), makeEntry(), makeEntry(), makeEntry()}); } //////////////////////////////////////////////////////////// // GemmAllReducePlugin Methods //////////////////////////////////////////////////////////// GemmAllReducePlugin::GemmAllReducePlugin(GemmAllReducePluginOptions const& options) : mOptions(options) , mGemmId(GemmIdCore(options.maxProblemShape.n, options.maxProblemShape.k, options.typeD)) , mProfiler(mGemmPluginProfileManager.createGemmPluginProfiler(/*inference=*/options.deserialize)) { // Use map instead of huge switch case mTypedInstantiators = getTypedInstantiators(); auto key = std::make_tuple(mOptions.typeA, mOptions.typeB, mOptions.typeD); TLLM_CHECK_WITH_INFO(mTypedInstantiators.count(key) > 0, "No cutlass gemm for impl."); mGemm = std::shared_ptr(mTypedInstantiators[key]()); } void GemmAllReducePlugin::allocatePersistentWorkspace() { TLLM_CHECK(mOptions.maxProblemShape.isInitialized()); GemmAllReduceImplInterface::LaunchConfig smallest_tile_config = mGemm->getSupportedLaunchConfigs()[0]; GemmAllReduceImplInterface::ProblemArgs args; args.argProblemShape(mOptions.maxProblemShape.maxM, mOptions.maxProblemShape.n, mOptions.maxProblemShape.k, 1) .argRanks(mRank, mOptions.group) .argLaunchConfig(smallest_tile_config); TLLM_CHECK(mWorkspace == nullptr); // Wrap persistent workspace in IPluginResource type // so that clone() can be called to allocate memory GemmAllReducePersistentWorkspace unallocated_resource(mGemm->getPersistentWorkspace(args)); // Register and allocate workspace mWorkspace = static_cast( getPluginRegistry()->acquirePluginResource(mWorkspaceKey, &unallocated_resource)); TLLM_CHECK(mWorkspace != nullptr); } static GemmAllReducePluginOptions deserializeOptions(void const*& data, size_t length) { char const* begin = reinterpret_cast(data); char const*& end = reinterpret_cast(data); GemmAllReducePluginOptions options; options.deserialize = true; read(end, options.typeA); read(end, options.typeB); read(end, options.typeD); read(end, options.transA); read(end, options.transB); read(end, options.alpha); read(end, options.maxProblemShape); read(end, options.groupSize); for (int i = 0; i < options.groupSize; ++i) { int rank = -1; read(end, rank); options.group.insert(rank); } return options; } GemmAllReducePlugin::GemmAllReducePlugin(void const* data, size_t length) : GemmAllReducePlugin(deserializeOptions(std::ref(data), length)) { // char const* end = reinterpret_cast(data); mProfiler->deserializeFromOwnFile(mGemmId, mOptions.maxProblemShape); } ////////////////////////////////// // IPluginV2DynamicExt Methods ////////////////////////////////// IPluginV2DynamicExt* GemmAllReducePlugin::clone() const noexcept { return new GemmAllReducePlugin(*this); } DimsExprs GemmAllReducePlugin::getOutputDimensions( int outputIndex, DimsExprs const* inputs, int nbInputs, IExprBuilder& exprBuilder) noexcept { try { TLLM_CHECK(nbInputs == 2); // number of input tensors TLLM_CHECK(inputs[0].nbDims == inputs[1].nbDims); TLLM_CHECK(outputIndex < getNbOutputs()); TLLM_CHECK(mOptions.transA == false); TLLM_CHECK(mOptions.transB == true); int const nbDimsA = inputs[0].nbDims; // number of dims int const nbDimsB = inputs[1].nbDims; DimsExprs out_dims; // subtract 2 -> K from each input out_dims.nbDims = nbDimsA + nbDimsB - 2; if (mOptions.transA) { for (int i = 1; i < nbDimsA; ++i) { out_dims.d[i - 1] = inputs[0].d[i]; } } else { for (int i = 0; i < nbDimsA - 1; ++i) { out_dims.d[i] = inputs[0].d[i]; } } if (mOptions.transB) { for (int i = 0; i < nbDimsB - 1; ++i) { out_dims.d[nbDimsA - 1 + i] = inputs[1].d[i]; } } else { for (int i = 1; i < nbDimsB; ++i) { out_dims.d[nbDimsA - 2 + i] = inputs[1].d[i]; } } return out_dims; } catch (std::exception const& e) { caughtError(e); } return DimsExprs{}; } bool GemmAllReducePlugin::supportsFormatCombination( int32_t pos, PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept { // inOut[0] -> activation // inOut[1] -> weight // inOut[2] -> output[0] // inOut[3] -> output[1] TLLM_CHECK_WITH_INFO(pos < 2 + getNbOutputs(), "Unexpected pos: %d", pos); auto const& desc = inOut[pos]; auto typeExists = [&](DataType dtype) -> bool { for (const auto& [key, value] : mTypedInstantiators) { // key format: if (std::get(key) == dtype) { return true; } } return false; }; switch (pos) { case 0: // activation return typeExists.template operator()<0>(desc.type); case 1: // weight return typeExists.template operator()<1>(desc.type); case 2: // output[0] case 3: // output[1] return typeExists.template operator()<2>(desc.type); default: return false; } } void GemmAllReducePlugin::configurePlugin( DynamicPluginTensorDesc const* in, int32_t nbInputs, DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept { // Get problem shape int const nbDimsA = in[0].max.nbDims; int const minM = utils::computeMDimension(mOptions.transA, in[0].min); int const maxM = utils::computeMDimension(mOptions.transA, in[0].max); int const N = utils::computeNDimension(mOptions.transB, in[1].max); int const K = mOptions.transA ? in[0].max.d[0] : in[0].max.d[nbDimsA - 1]; TLLM_CHECK_WITH_INFO(out[0].desc.type == mOptions.typeD, "Output type mismatch."); // Ensure call from execution phase does // not override call from build phase if (!mOptions.maxProblemShape.isInitialized()) { mOptions.maxProblemShape = {minM, maxM, N, K}; mGemmId = {N, K, mOptions.typeD}; } // Build phase doesn't have COMM_SESSION (i.e built on single rank) // so do not allocate persistent workspace if (!isBuilding()) { auto getTPRank = [&]() { int rank = COMM_SESSION.getRank(); auto it = std::find(mOptions.group.begin(), mOptions.group.end(), rank); TLLM_CHECK_WITH_INFO(it != mOptions.group.end(), "Incorrect group specified - rank " + std::to_string(rank) + " not found in group"); return std::distance(mOptions.group.begin(), it); }; mRank = getTPRank(); if (mWorkspace == nullptr) { allocatePersistentWorkspace(); } } } size_t GemmAllReducePlugin::getWorkspaceSize( PluginTensorDesc const* inputs, int32_t nbInputs, PluginTensorDesc const* outputs, int32_t nbOutputs) const noexcept { return 0; } int GemmAllReducePlugin::enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { int const rank = COMM_SESSION.getRank(); // inputs[0] -> [M(*), K] // inputs[1] -> [K, N] // outputs[0] -> [M(*), N] unicast ptr // outputs[1] -> [M(*), N] multicast ptr auto const nbDimsA = inputDesc[0].dims.nbDims; auto const M = utils::computeMDimension(mOptions.transA, inputDesc[0].dims); auto const N = utils::computeNDimension(mOptions.transB, inputDesc[1].dims); auto const K = mOptions.transA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; TLLM_CHECK_WITH_INFO(M > 0, "GemmAllReducePlugin M is 0."); TLLM_CHECK_WITH_INFO(N > 0, "GemmAllReducePlugin N is 0."); TLLM_CHECK_WITH_INFO(K > 0, "GemmAllReducePlugin K is 0."); TLLM_CHECK_WITH_INFO(inputs[0] != nullptr, "GemmAllReducePlugin inputs[0] is null."); TLLM_CHECK_WITH_INFO(inputs[1] != nullptr, "GemmAllReducePlugin inputs[1] is null."); TLLM_CHECK_WITH_INFO(outputs[0] != nullptr, "GemmAllReducePlugin outputs[0] is null."); TLLM_CHECK_WITH_INFO(outputs[1] != nullptr, "GemmAllReducePlugin outputs[1] is null."); TLLM_CHECK_WITH_INFO(outputs[1] != nullptr, "GemmAllReducePlugin outputs[1] is null."); TLLM_CHECK_WITH_INFO(mWorkspace != nullptr, "GemmAllReducePlugin workspace is null."); auto bestLaunchConfig = mProfiler->getBestConfig(M, mGemmId); GemmAllReduceImplInterface::ProblemArgs args; args.argProblemShape(M, N, K, 1) .argA(inputs[0]) .argB(inputs[1]) .argC(nullptr) .argD(outputs[0], outputs[1]) .argRanks(mRank, mOptions.group) .argAlpha(mOptions.alpha) .argBeta(0.f) // no bias .argLaunchConfig(bestLaunchConfig.value()) .argWorkspace(mWorkspace->mWorkspace.get()); mGemm->run(args, stream); return 0; } ////////////////////////////////// // IPluginV2Ext Methods ////////////////////////////////// DataType GemmAllReducePlugin::getOutputDataType(int index, DataType const* inputTypes, int nbInputs) const noexcept { TLLM_CHECK_WITH_INFO(index < getNbOutputs(), "Output index out of bounds: %d", index); return mOptions.typeD; } ////////////////////////////////// // IPluginV2 Methods ////////////////////////////////// char const* GemmAllReducePlugin::getPluginType() const noexcept { return GEMM_ALLREDUCE_PLUGIN_NAME; } char const* GemmAllReducePlugin::getPluginVersion() const noexcept { return GEMM_ALLREDUCE_PLUGIN_VERSION; } int GemmAllReducePlugin::getNbOutputs() const noexcept { // outputs[0] -> unicast address // outputs[1] -> multicast address return 2; } int GemmAllReducePlugin::initialize() noexcept { if (isBuilding()) { // TODO (xsimmons): interfaces between GemmPluginProfiler and Plugin // needs to be relooked at - current interface implicitly assigns runner to profiler // object in profileTactics() assert(mOptions.maxProblemShape.isInitialized()); mProfiler->profileTactics(mGemm, mOptions.typeD, mOptions.maxProblemShape, mGemmId); } return 0; } void GemmAllReducePlugin::terminate() noexcept { if (isBuilding()) // need this otherwise getComm will crash during build phase { return; } // free mWorkspace if (mWorkspace) { getPluginRegistry()->releasePluginResource(mWorkspaceKey); mWorkspace = nullptr; } } size_t GemmAllReducePlugin::getSerializationSize() const noexcept { // cannot use sizeof(GemmAllReducePluginOptions) // becaused need packed attribute which doesn't work on enum // without making the enum also packed size_t size = 0; size += sizeof(mOptions.typeA); size += sizeof(mOptions.typeB); size += sizeof(mOptions.typeD); size += sizeof(mOptions.transA); size += sizeof(mOptions.transB); size += sizeof(mOptions.alpha); size += sizeof(mOptions.maxProblemShape); size += sizeof(mOptions.groupSize); size += mOptions.group.size() * sizeof(int); return size; } void GemmAllReducePlugin::serialize(void* buffer) const noexcept { char* begin = reinterpret_cast(buffer); char* end = reinterpret_cast(buffer); write(end, mOptions.typeA); write(end, mOptions.typeB); write(end, mOptions.typeD); write(end, mOptions.transA); write(end, mOptions.transB); write(end, mOptions.alpha); write(end, mOptions.maxProblemShape); write(end, mOptions.groupSize); for (auto const& rank : mOptions.group) { write(end, rank); } assert(end == begin + getSerializationSize()); // Profiler MNK->kernel mappings need to be deterministic and consistent across ranks // to ensure correct functionality (unlike standalone GEMMs). // Since by default each rank will generate and serialize its own profiler mapping // this can lead to different mappings between ranks which will result in fatal // error. Therefore only generate and use profiler mapping for single rank. if (COMM_SESSION.getRank() == 0) { mProfiler->serializeToOwnFile(mGemmId); } } void GemmAllReducePlugin::destroy() noexcept { delete this; } //////////////////////////////////////////////////////////// // GemmAllReducePluginCreator Methods //////////////////////////////////////////////////////////// PluginFieldCollection GemmAllReducePluginCreator::mFC; std::vector GemmAllReducePluginCreator::mPluginAttributes; GemmAllReducePluginCreator::GemmAllReducePluginCreator() { // Fill PluginFieldCollection with PluginField arguments metadata mPluginAttributes.clear(); mPluginAttributes.emplace_back("type_a", nullptr, PluginFieldType::kINT32, 1); mPluginAttributes.emplace_back("type_b", nullptr, PluginFieldType::kINT32, 1); mPluginAttributes.emplace_back("type_d", nullptr, PluginFieldType::kINT32, 1); mPluginAttributes.emplace_back("transa", nullptr, PluginFieldType::kINT32, 1); mPluginAttributes.emplace_back("transb", nullptr, PluginFieldType::kINT32, 1); mPluginAttributes.emplace_back("alpha", nullptr, PluginFieldType::kFLOAT32, 1); mPluginAttributes.emplace_back("group", nullptr, PluginFieldType::kINT32, 1); mFC.nbFields = mPluginAttributes.size(); mFC.fields = mPluginAttributes.data(); } char const* GemmAllReducePluginCreator::getPluginName() const noexcept { return GEMM_ALLREDUCE_PLUGIN_NAME; } char const* GemmAllReducePluginCreator::getPluginVersion() const noexcept { return GEMM_ALLREDUCE_PLUGIN_VERSION; } PluginFieldCollection const* GemmAllReducePluginCreator::getFieldNames() noexcept { return &mFC; } IPluginV2* GemmAllReducePluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { PluginField const* fields = fc->fields; GemmAllReducePluginOptions options; options.deserialize = false; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { char const* attrName = fields[i].name; if (!strcmp(attrName, "type_a")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); options.typeA = *static_cast(fields[i].data); } else if (!strcmp(attrName, "type_b")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); options.typeB = *static_cast(fields[i].data); } else if (!strcmp(attrName, "type_d")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); options.typeD = *static_cast(fields[i].data); } else if (!strcmp(attrName, "transa")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); options.transA = *static_cast(fields[i].data); } else if (!strcmp(attrName, "transb")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); options.transB = *static_cast(fields[i].data); } else if (!strcmp(attrName, "alpha")) { TLLM_CHECK(fields[i].type == PluginFieldType::kFLOAT32); options.alpha = *static_cast(fields[i].data); } else if (!strcmp(attrName, "group")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); auto const* ranks = static_cast(fields[i].data); for (int j = 0; j < fields[i].length; ++j) { options.group.insert(ranks[j]); } options.groupSize = options.group.size(); } } try { // GemmAllReducePluginCreator is unique and shared for an engine generation auto* obj = new GemmAllReducePlugin(options); obj->setPluginNamespace(mNamespace.c_str()); return obj; } catch (std::exception const& e) { caughtError(e); return nullptr; } } IPluginV2* GemmAllReducePluginCreator::deserializePlugin( char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call GemmAllReducePlugin::destroy() try { auto* obj = new GemmAllReducePlugin(serialData, serialLength); obj->setPluginNamespace(mNamespace.c_str()); return obj; } catch (std::exception const& e) { caughtError(e); } return nullptr; } } // namespace tensorrt_llm::plugins