TensorRT-LLMs/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp
Guoming Zhang 57079cecb3
[None][chroe] Rename TensorRT-LLM to TensorRT LLM for source code. (#7851)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
2025-09-22 10:05:47 -07:00

1309 lines
55 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h"
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include "tensorrt_llm/common/dataType.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/utils/debugUtils.h"
#include <numeric>
using namespace nvinfer1;
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::plugins;
using tensorrt_llm::common::QuantMode;
using tensorrt_llm::common::nextWorkspacePtr;
using tensorrt_llm::common::calculateTotalWorkspaceSize;
using tensorrt_llm::plugins::MixtureOfExpertsPluginCreator;
using tensorrt_llm::plugins::MixtureOfExpertsPlugin;
using tensorrt_llm::plugins::read;
using tensorrt_llm::plugins::write;
using LoraImpl = tensorrt_llm::kernels::LoraImpl;
using LoraParams = tensorrt_llm::kernels::LoraParams;
static char const* MIXTURE_OF_EXPERTS_PLUGIN_VERSION{"1"};
static char const* MIXTURE_OF_EXPERTS_PLUGIN_NAME{"MixtureOfExperts"};
nvinfer1::PluginFieldCollection MixtureOfExpertsPluginCreator::mFC{};
std::vector<nvinfer1::PluginField> MixtureOfExpertsPluginCreator::mPluginAttributes;
MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(bool remove_input_padding, int number_of_experts, int experts_per_token,
int expert_hidden_size, int expert_inter_size, int groupwise_quant_algo, int group_size,
ActivationType activation_type, nvinfer1::DataType type, nvinfer1::DataType weight_type,
nvinfer1::DataType output_type, QuantMode quant_mode, bool use_final_scales, bool use_bias, int tp_size,
int tp_rank, int ep_size, int ep_rank, bool force_determinism, int side_stream_id,
MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr, bool use_lora, nvinfer1::DataType lora_type,
LoraPluginProfilerPtr lora_profiler, int max_low_rank)
: mNumExperts(number_of_experts)
, mExpertsPerToken(experts_per_token)
, mExpertHiddenSize(expert_hidden_size)
, mExpertInterSize(expert_inter_size)
, mGroupwiseQuantAlgo(groupwise_quant_algo)
, mGroupSize(group_size)
, mActivationType(activation_type)
, mType(type)
, mWeightType(weight_type)
, mOutputType(output_type)
, mQuantMode(quant_mode)
, mUseFinalScales(use_final_scales)
, mUseBias(use_bias)
, mParallelismConfig(MOEParallelismConfig{tp_size, tp_rank, ep_size, ep_rank})
, mUseDeterministicKernels(force_determinism)
, mSideStreamId(side_stream_id)
, mGemmProfiler(std::move(gemm_profiler_ptr))
, mUseLora(use_lora)
, mLoraType(lora_type)
, mMaxLowRank(max_low_rank)
, mRemoveInputPadding(remove_input_padding)
, mLoraProfiler(std::move(lora_profiler))
{
init();
}
tensorrt_llm::plugins::MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(MixtureOfExpertsPlugin const& other)
: mMOERunner()
, mNumExperts(other.mNumExperts)
, mExpertsPerToken(other.mExpertsPerToken)
, mExpertHiddenSize(other.mExpertHiddenSize)
, mExpertInterSize(other.mExpertInterSize)
, mGroupwiseQuantAlgo(other.mGroupwiseQuantAlgo)
, mGroupSize(other.mGroupSize)
, mActivationType(other.mActivationType)
, mType(other.mType)
, mWeightType(other.mWeightType)
, mOutputType(other.mOutputType)
, mQuantMode(other.mQuantMode)
, mUseFinalScales(other.mUseFinalScales)
, mUseBias(other.mUseBias)
, mParallelismConfig(other.mParallelismConfig)
, mDims(other.mDims)
, mUseDeterministicKernels(other.mUseDeterministicKernels)
, mSideStreamId(other.mSideStreamId)
, mGemmId1(other.mGemmId1)
, mGemmId2(other.mGemmId2)
, mGemmProfiler(other.mGemmProfiler)
, mUseLora(other.mUseLora)
, mLoraType(other.mLoraType)
, mMaxLowRank(other.mMaxLowRank)
, mRemoveInputPadding(other.mRemoveInputPadding)
, mLoraImpl1(other.mLoraImpl1)
, mLoraImpl2(other.mLoraImpl2)
, mLoraGemmId1(other.mLoraGemmId1)
, mLoraGemmId2(other.mLoraGemmId2)
, mLoraProfiler(other.mLoraProfiler)
, mLayerName(other.mLayerName)
, mNamespace(other.mNamespace)
{
init();
}
size_t MixtureOfExpertsPlugin::getSerializationSize() const noexcept
{
size_t size = sizeof(mRemoveInputPadding) + sizeof(mNumExperts) + sizeof(mExpertsPerToken)
+ sizeof(mExpertHiddenSize) + sizeof(mExpertInterSize) + sizeof(mGroupwiseQuantAlgo) + sizeof(mGroupSize)
+ sizeof(mActivationType) + sizeof(mType) + sizeof(mWeightType) + sizeof(mOutputType)
+ sizeof(QuantMode::BaseType) + sizeof(mUseFinalScales) + sizeof(mUseBias) + sizeof(mParallelismConfig)
+ sizeof(mDims) + sizeof(mUseDeterministicKernels) + sizeof(mSideStreamId)
+ mGemmProfiler->getSerializationSize(mGemmId1) + mGemmProfiler->getSerializationSize(mGemmId2)
+ sizeof(mUseLora) + sizeof(mLoraType) + sizeof(mMaxLowRank);
if (hasLora())
{
size += mLoraProfiler->getSerializationSize(mLoraGemmId1);
size += mLoraProfiler->getSerializationSize(mLoraGemmId2);
}
return size;
}
MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(void const* data, size_t length,
MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr, LoraPluginProfilerPtr lora_profiler)
: mGemmProfiler(gemm_profiler_ptr)
, mLoraProfiler(lora_profiler)
{
char const* d = reinterpret_cast<char const*>(data);
char const* a = d;
read(d, mRemoveInputPadding);
read(d, mNumExperts);
read(d, mExpertsPerToken);
read(d, mExpertHiddenSize);
read(d, mExpertInterSize);
read(d, mGroupwiseQuantAlgo);
read(d, mGroupSize);
read(d, mActivationType);
read(d, mType);
read(d, mWeightType);
read(d, mOutputType);
QuantMode::BaseType quant_mode;
read(d, quant_mode);
mQuantMode = QuantMode{quant_mode};
read(d, mUseFinalScales);
read(d, mUseBias);
read(d, mParallelismConfig);
read(d, mDims);
read(d, mUseDeterministicKernels);
read(d, mSideStreamId);
read(d, mUseLora);
read(d, mLoraType);
read(d, mMaxLowRank);
// Call init before deserialising the profiler to initialize mGemmId
init();
mGemmProfiler->deserialize(d, mDims, mGemmId1);
mGemmProfiler->deserialize(d, mDims, mGemmId2);
if (hasLora())
{
mLoraProfiler->deserialize(d, mDims, mLoraGemmId1);
mLoraProfiler->deserialize(d, mDims, mLoraGemmId2);
}
TLLM_CHECK_WITH_INFO(d == a + length,
"Expected length (%d) != real length (%d). This is often "
"caused by using different TensorRT LLM version to build "
"engine and run engine.",
(int) length, (int) (d - a));
}
void MixtureOfExpertsPlugin::serialize(void* buffer) const noexcept
{
char* d = static_cast<char*>(buffer);
char* a = d;
write(d, mRemoveInputPadding);
write(d, mNumExperts);
write(d, mExpertsPerToken);
write(d, mExpertHiddenSize);
write(d, mExpertInterSize);
write(d, mGroupwiseQuantAlgo);
write(d, mGroupSize);
write(d, mActivationType);
write(d, mType);
write(d, mWeightType);
write(d, mOutputType);
write(d, mQuantMode.value());
write(d, mUseFinalScales);
write(d, mUseBias);
write(d, mParallelismConfig);
write(d, mDims);
write(d, mUseDeterministicKernels);
write(d, mSideStreamId);
write(d, mUseLora);
write(d, mLoraType);
write(d, mMaxLowRank);
mGemmProfiler->serialize(d, mGemmId1);
mGemmProfiler->serialize(d, mGemmId2);
if (hasLora())
{
mLoraProfiler->serialize(d, mLoraGemmId1);
mLoraProfiler->serialize(d, mLoraGemmId2);
}
TLLM_CHECK(d == a + getSerializationSize());
}
template <typename Type, bool NeedQuant = false>
std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> switch_output_type(nvinfer1::DataType output_type)
{
switch (output_type)
{
case nvinfer1::DataType::kFP4:
case nvinfer1::DataType::kFP8:
// TODO We need an atomic FP8 reduction for the finalize fusions
TLLM_THROW("Outputting %d directly is not currently supported", static_cast<int>(output_type));
// return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type>>();
case nvinfer1::DataType::kHALF:
if constexpr (NeedQuant)
{
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, half, half>>();
}
else
{
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, half, Type>>();
}
#ifdef ENABLE_BF16
case nvinfer1::DataType::kBF16:
if constexpr (NeedQuant)
{
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, __nv_bfloat16, __nv_bfloat16>>();
}
else
{
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, __nv_bfloat16, Type>>();
}
#endif
default: TLLM_THROW("Invalid output type %d", static_cast<int>(output_type));
}
};
void MixtureOfExpertsPlugin::init()
{
TLLM_CHECK_WITH_INFO(mType == DataType::kFP8 || mType == DataType::kFP4 || mOutputType == mType,
"MOE plugin only supports a different output type for FP4/FP8");
TLLM_CHECK_WITH_INFO(mType != DataType::kFP8 || tensorrt_llm::common::getSMVersion() >= 89,
"MoE FP8 is not supported for architectures less than SM89");
TLLM_CHECK_WITH_INFO(mType != DataType::kFP4 || (tensorrt_llm::common::getSMVersion() >= 100),
"MoE FP4 is only supported on architecture SM100 or later");
TLLM_CHECK_WITH_INFO(!hasLora() || mLoraType == mOutputType, "The LoraType need to keep same with moe OutputType.");
if (mWeightType == nvinfer1::DataType::kINT8 && mQuantMode.hasInt4Weights())
{
mWeightType = DataType::kINT4;
}
if (mType == DataType::kHALF && mWeightType == DataType::kHALF)
{
mMOERunner = std::make_unique<kernels::CutlassMoeFCRunner<half, half>>();
}
else if (mType == DataType::kFLOAT && mWeightType == DataType::kFLOAT)
{
mMOERunner = std::make_unique<kernels::CutlassMoeFCRunner<float, float>>();
}
else if (mType == DataType::kHALF && mWeightType == DataType::kINT8)
{
mMOERunner = std::make_unique<kernels::CutlassMoeFCRunner<half, uint8_t>>();
}
else if (mType == DataType::kHALF && mWeightType == DataType::kINT4)
{
mMOERunner = std::make_unique<kernels::CutlassMoeFCRunner<half, cutlass::uint4b_t>>();
}
#ifdef ENABLE_FP8
else if (mType == DataType::kFP8 && mWeightType == DataType::kINT4 && mOutputType == DataType::kHALF)
{
mMOERunner = std::make_unique<kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>>();
}
#endif
#ifdef ENABLE_BF16
else if (mType == DataType::kBF16 && mWeightType == DataType::kBF16)
{
mMOERunner = std::make_unique<kernels::CutlassMoeFCRunner<__nv_bfloat16, __nv_bfloat16>>();
}
else if (mType == DataType::kBF16 && mWeightType == DataType::kINT8)
{
mMOERunner = std::make_unique<kernels::CutlassMoeFCRunner<__nv_bfloat16, uint8_t>>();
}
else if (mType == DataType::kBF16 && mWeightType == DataType::kINT4)
{
mMOERunner = std::make_unique<kernels::CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>>();
}
#ifdef ENABLE_FP8
else if (mType == DataType::kFP8 && mWeightType == DataType::kINT4 && mOutputType == DataType::kBF16)
{
mMOERunner = std::make_unique<
kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>>();
}
#endif
#endif
#ifdef ENABLE_FP8
if (mType == DataType::kFP8 && mWeightType == DataType::kFP8)
{
mMOERunner = switch_output_type<__nv_fp8_e4m3>(mOutputType);
}
#endif
#ifdef ENABLE_FP4
if (mType == DataType::kFP4 && mWeightType == DataType::kFP4)
{
mMOERunner = switch_output_type<__nv_fp4_e2m1, true>(mOutputType);
}
#endif
if (!mMOERunner)
{
TLLM_THROW(
"Could not construct the mixture of experts plugin with the requested input combination Activation: %d "
"Weight: %d Output: %d",
static_cast<int>(mType), static_cast<int>(mWeightType), static_cast<int>(mOutputType));
}
mMOERunner->use_deterministic_hopper_reduce_ = mExpertsPerToken > 2 && mUseDeterministicKernels;
mGemmId1 = GemmIDMoe{1, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize,
mGroupSize, mActivationType, mType, mWeightType, mQuantMode, mMOERunner->use_deterministic_hopper_reduce_};
mGemmId2 = GemmIDMoe{2, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize,
mGroupSize, mActivationType, mType, mWeightType, mQuantMode, mMOERunner->use_deterministic_hopper_reduce_};
mGemmProfiler->setMaxProfileM(16384 * mNumExperts / mExpertsPerToken);
if (hasLora())
{
auto cublasHandle = getCublasHandle();
auto cublasLtHandle = getCublasLtHandle();
auto cublasWrapper = std::make_shared<CublasMMWrapper>(cublasHandle, cublasLtHandle, nullptr, nullptr);
mLoraGemmId1 = GemmIdCublas(mExpertInterSize, mExpertHiddenSize, mLoraType, false, true, mLoraType);
mLoraGemmId2 = GemmIdCublas(mExpertHiddenSize, mExpertInterSize, mLoraType, false, true, mLoraType);
std::vector<int> loraOutSizes1 = {static_cast<int>(mExpertInterSize)};
mLoraImpl1 = std::make_shared<LoraImpl>(
mExpertHiddenSize, loraOutSizes1, false, true, 1, mLoraType, mMaxLowRank, cublasWrapper);
std::vector<int> loraOutSizes2 = {static_cast<int>(mExpertHiddenSize)};
mLoraImpl2 = std::make_shared<LoraImpl>(
mExpertInterSize, loraOutSizes2, false, true, 1, mLoraType, mMaxLowRank, cublasWrapper);
TLLM_CUDA_CHECK(cudaEventCreate(&mMemcpyEvent));
}
mSideStreamPtr = nullptr;
mDebugStallMain = tensorrt_llm::runtime::utils::stallStream("TLLM_DEBUG_MOE_STALL_MAIN");
mDebugStallSide = tensorrt_llm::runtime::utils::stallStream("TLLM_DEBUG_MOE_STALL_SIDE");
}
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* MixtureOfExpertsPlugin::clone() const noexcept
{
auto* plugin = new MixtureOfExpertsPlugin(*this);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
nvinfer1::DimsExprs MixtureOfExpertsPlugin::getOutputDimensions(
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
assert(outputIndex == getOutputTensorIndex() || outputIndex == getOutputDummyTensorIndex());
return inputs[getInputTensorIndex()];
}
bool MixtureOfExpertsPlugin::supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
{
TLLM_CHECK(0 <= pos && pos < getNbInputs() + getNbOutputs());
TLLM_CHECK_WITH_INFO(
nbInputs == getNbInputs(), "Required input to plugin is missing. Expected %d Got %d", getNbInputs(), nbInputs);
TLLM_CHECK_WITH_INFO(nbOutputs == getNbOutputs(), "Required output to plugin is missing. Expected %d Got %d",
getNbOutputs(), nbOutputs);
if (inOut[pos].format != TensorFormat::kLINEAR)
{
return false;
}
if (pos == getExpertWeights1Index() || pos == getExpertWeights2Index())
{
if (mGroupwiseQuantAlgo == 0)
{
auto normalized_weight_type
= mWeightType == nvinfer1::DataType::kINT4 ? nvinfer1::DataType::kINT8 : mWeightType;
return inOut[pos].type == normalized_weight_type;
}
else
{
return inOut[pos].type == mOutputType;
}
}
else if (pos == getTokenSelectedExpertsIndex())
{
return inOut[pos].type == DataType::kINT32;
}
else if (pos == getTokenFinalScalesIndex())
{
return inOut[pos].type == DataType::kFLOAT;
}
else if (pos == getExpertBias1Index() || pos == getExpertBias2Index())
{
return inOut[pos].type == mOutputType;
}
else if (pos == nbInputs + getOutputTensorIndex())
{
return inOut[pos].type == mOutputType;
}
else if (useSideStream() && pos == nbInputs + getOutputDummyTensorIndex())
{
return inOut[pos].type == inOut[getInputDummyTensorIndex()].type;
}
else if (useSideStream() && pos == getInputDummyTensorIndex())
{
return true;
}
else if (hasExpertFp8QuantScales() && getExpertFP8Dequant1Index() <= pos && pos <= getExpertFP8QuantFinalIndex())
{
return inOut[pos].type == DataType::kFLOAT;
}
else if (hasExpertIntQuantScales() && getExpertIntQuantScale1Index() <= pos
&& pos <= getExpertIntQuantScale2Index())
{
return inOut[pos].type == mOutputType;
}
else if (hasFP4QuantScales() && getFP4GlobalActSF1Index() <= pos && pos <= getFP4GlobalSF2Index())
{
if (pos == getFP4WeightSF1Index() || pos == getFP4WeightSF2Index())
return inOut[pos].type == nvinfer1::DataType::kFP8;
else
return inOut[pos].type == nvinfer1::DataType::kFLOAT;
}
else if (hasLora() && hasExpertFp8QuantScales() && pos == getInputFP8DequantIndex())
{
return inOut[pos].type == nvinfer1::DataType::kFLOAT;
}
else if (hasExpertWeightQuantZeros() && getExpertIntQuantZeros1Index() <= pos
&& pos <= getExpertIntQuantZeros2Index())
{
return inOut[pos].type == mOutputType;
}
else if (hasExpertPrequantScales() && getExpertPrequantScales1Index() <= pos
&& pos <= getExpertPrequantScales2Index())
{
return inOut[pos].type == mOutputType;
}
else if (hasGroupwiseFp8Alpha() && getExpertFp8Alpha1Index() <= pos && pos <= getExpertFp8Alpha2Index())
{
return inOut[pos].type == DataType::kFLOAT;
}
else if (hasLora() && pos == getHostRequestTypeIndex())
{
return inOut[pos].type == nvinfer1::DataType::kINT32;
}
else if (hasLora() && (pos == getLoraFC1RanksIndex() || pos == getLoraFC2RanksIndex()))
{
return inOut[pos].type == nvinfer1::DataType::kINT32;
}
else if (hasGatedLoraWeightsAndRanks() && pos == getLoraGatedRanksIndex())
{
return inOut[pos].type == nvinfer1::DataType::kINT32;
}
else if (hasLora() && (pos == getLoraFC1WeightPtrsIndex() || pos == getLoraFC2WeightPtrsIndex()))
{
return inOut[pos].type == nvinfer1::DataType::kINT64;
}
else if (hasGatedLoraWeightsAndRanks() && pos == getLoraGatedWeightPtrsIndex())
{
return inOut[pos].type == nvinfer1::DataType::kINT64;
}
else if (hasLora() && mRemoveInputPadding && pos == getHostContextLengthIndex())
{
return inOut[pos].type == nvinfer1::DataType::kINT32;
}
else if ((hasFP4QuantScales() || hasGroupwiseFp8Alpha()) && pos == getInputTensorIndex())
{
return inOut[pos].type == mOutputType;
}
else
{
return inOut[pos].type == mType;
}
return false;
}
void MixtureOfExpertsPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
{
TLLM_CHECK_WITH_INFO(
nbInputs == getNbInputs(), "Required input to plugin is missing. Expected %d Got %d", getNbInputs(), nbInputs);
TLLM_CHECK_WITH_INFO(nbOutputs == getNbOutputs(), "Required output to plugin is missing. Expected %d Got %d",
getNbOutputs(), nbOutputs);
auto in_tensor = in[getInputTensorIndex()];
auto const minM
= std::accumulate(in_tensor.min.d, in_tensor.min.d + in_tensor.min.nbDims - 1, 1, std::multiplies<int>());
auto const maxM
= std::accumulate(in_tensor.max.d, in_tensor.max.d + in_tensor.max.nbDims - 1, 1, std::multiplies<int>());
auto weights_1 = in[getExpertWeights1Index()];
auto weights_2 = in[getExpertWeights2Index()];
int inner_dim_idx = getGemmShapeInnerDimIndex();
int const maxK = weights_1.max.d[inner_dim_idx];
int const maxN = weights_2.max.d[inner_dim_idx];
int const minK = weights_1.min.d[inner_dim_idx];
int const minN = weights_2.min.d[inner_dim_idx];
TLLM_CHECK_WITH_INFO(minN == maxN, "Variable out channels is not allowed");
TLLM_CHECK_WITH_INFO(minK == maxK, "Variable in channels is not allowed");
TLLM_CHECK_WITH_INFO(maxK == mExpertHiddenSize && maxN == mExpertInterSize,
"Configured tensor sizes %dx%d does not match constructor param size %ldx%ld", maxK, maxN, mExpertHiddenSize,
mExpertInterSize);
if (!mDims.isInitialized())
{
mDims = {minM, maxM, maxN, maxK};
}
mGemmId1 = GemmIDMoe{1, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize,
mGroupSize, mActivationType, mType, mWeightType, mQuantMode};
mGemmId2 = GemmIDMoe{2, mNumExperts, mExpertsPerToken, mParallelismConfig, mExpertHiddenSize, mExpertInterSize,
mGroupSize, mActivationType, mType, mWeightType, mQuantMode};
if (hasLora())
{
auto const N = utils::computeNDimension(true, in[getHostRequestTypeIndex()].max);
mLoraGemmId1 = GemmIdCublas(N, mExpertHiddenSize, mLoraType, false, true, mLoraType);
mLoraGemmId2 = GemmIdCublas(N, mExpertInterSize, mLoraType, false, true, mLoraType);
}
}
auto MixtureOfExpertsPlugin::setupWorkspace(void* base_ptr, int64_t num_tokens, int num_reqs) const -> WorkspaceInfo
{
size_t moe_workspace_size
= mMOERunner->getWorkspaceSize(num_tokens, mExpertHiddenSize, mExpertInterSize, mNumExperts, mExpertsPerToken,
mActivationType, mParallelismConfig, hasLora(), /*use_deepseek_fp8_block_scale=*/false,
/*min_latency_mode=*/false, hasExpertPrequantScales());
// Permutation map
size_t src_to_dest_map_size = mExpertsPerToken * num_tokens * sizeof(int);
size_t lora_workspace_size = 0;
if (hasLora())
{
int64_t num_reqs_lora = std::min(num_tokens * mExpertsPerToken, static_cast<int64_t>(num_reqs * mNumExperts));
lora_workspace_size
= std::max(mLoraImpl1->getWorkspaceSize(num_tokens * mExpertsPerToken, num_reqs_lora, mLoraType),
mLoraImpl2->getWorkspaceSize(num_tokens * mExpertsPerToken, num_reqs_lora, mLoraType));
}
std::vector<size_t> workspaces{
moe_workspace_size,
src_to_dest_map_size,
lora_workspace_size,
};
WorkspaceInfo info{};
info.size = calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
if (base_ptr)
{
info.workspace = base_ptr;
info.src_to_dest_map = nextWorkspacePtr((int8_t*) info.workspace, moe_workspace_size);
info.lora_workspace = nextWorkspacePtr((int8_t*) info.src_to_dest_map, src_to_dest_map_size);
}
return info;
}
int64_t MixtureOfExpertsPlugin::getNumTokens(nvinfer1::PluginTensorDesc const* input_tensors) const
{
int ndim = input_tensors[getInputTensorIndex()].dims.nbDims;
TLLM_CHECK_WITH_INFO(
3 == ndim || 2 == ndim, "hidden_state dimension should be either 2 [b*s, hidden], or 3 [b, s, hidden]");
int64_t num_tokens = input_tensors[getInputTensorIndex()].dims.d[0];
if (ndim == 3)
{
num_tokens *= input_tensors[getInputTensorIndex()].dims.d[1];
}
return num_tokens;
}
size_t MixtureOfExpertsPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
{
TLLM_CHECK_WITH_INFO(
nbInputs == getNbInputs(), "Required input to plugin is missing. Expected %d Got %d", getNbInputs(), nbInputs);
TLLM_CHECK_WITH_INFO(nbOutputs == getNbOutputs(), "Required output to plugin is missing. Expected %d Got %d",
getNbOutputs(), nbOutputs);
if (useSideStream())
{
return 0;
}
int const num_tokens = getNumTokens(inputs);
int const num_lora_reqs = getNumLoraRequests(inputs);
return setupWorkspace(nullptr, num_tokens, num_lora_reqs).size;
}
MOEParallelismConfig MixtureOfExpertsPlugin::getParallelismConfig() const
{
return mParallelismConfig;
}
QuantParams tensorrt_llm::plugins::MixtureOfExpertsPlugin::getQuantParams(nvinfer1::PluginTensorDesc const* inputDesc,
void const* const* inputs, int scale_1_idx, int scale_2_idx, int scale_3_idx, int scale_4_idx, int scale_5_idx,
int scale_6_idx, int scale_7_idx, int scale_8_idx) const
{
void const* scale_1 = scale_1_idx >= 0 ? inputs[scale_1_idx] : nullptr;
void const* scale_2 = scale_2_idx >= 0 ? inputs[scale_2_idx] : nullptr;
void const* scale_3 = scale_3_idx >= 0 ? inputs[scale_3_idx] : nullptr;
void const* scale_4 = scale_4_idx >= 0 ? inputs[scale_4_idx] : nullptr;
void const* scale_5 = scale_5_idx >= 0 ? inputs[scale_5_idx] : nullptr;
void const* scale_6 = scale_6_idx >= 0 ? inputs[scale_6_idx] : nullptr;
void const* scale_7 = scale_7_idx >= 0 ? inputs[scale_7_idx] : nullptr;
void const* scale_8 = scale_8_idx >= 0 ? inputs[scale_8_idx] : nullptr;
nvinfer1::PluginTensorDesc const* desc_1 = scale_1_idx >= 0 ? &inputDesc[scale_1_idx] : nullptr;
nvinfer1::PluginTensorDesc const* desc_2 = scale_2_idx >= 0 ? &inputDesc[scale_2_idx] : nullptr;
nvinfer1::PluginTensorDesc const* desc_3 = scale_3_idx >= 0 ? &inputDesc[scale_3_idx] : nullptr;
nvinfer1::PluginTensorDesc const* desc_4 = scale_4_idx >= 0 ? &inputDesc[scale_4_idx] : nullptr;
nvinfer1::PluginTensorDesc const* desc_5 = scale_5_idx >= 0 ? &inputDesc[scale_5_idx] : nullptr;
nvinfer1::PluginTensorDesc const* desc_6 = scale_6_idx >= 0 ? &inputDesc[scale_6_idx] : nullptr;
auto const gated_inter_size = isGatedActivation(mActivationType) ? mExpertInterSize * 2 : mExpertInterSize;
auto const experts_per_node = mNumExperts / mParallelismConfig.ep_size;
if (hasExpertIntQuantScales())
{
TLLM_CHECK(scale_1 && scale_2);
if (!hasGroupwiseIntQuantScales())
{
TLLM_CHECK(!scale_3 && !scale_4 && !scale_5 && !scale_6);
TLLM_CHECK(desc_1->dims.nbDims == 2);
TLLM_CHECK(desc_2->dims.nbDims == 2);
TLLM_CHECK_WITH_INFO(
desc_1->dims.d[0] == experts_per_node, "Incorrect number of experts in int quant scale");
TLLM_CHECK(desc_1->dims.d[1] == gated_inter_size);
TLLM_CHECK_WITH_INFO(
desc_2->dims.d[0] == experts_per_node, "Incorrect number of experts in int quant scale");
TLLM_CHECK(desc_2->dims.d[1] == mExpertHiddenSize);
return QuantParams::Int(scale_1, scale_2);
}
else
{
TLLM_CHECK(desc_1->dims.nbDims == 3);
TLLM_CHECK(desc_2->dims.nbDims == 3);
TLLM_CHECK((scale_3 && scale_4) || !hasExpertPrequantScales());
TLLM_CHECK((scale_5 && scale_6) || !hasExpertWeightQuantZeros());
TLLM_CHECK((scale_7 && scale_8) || !hasGroupwiseFp8Alpha());
return QuantParams::GroupWise(mGroupSize, scale_1, scale_2, scale_3, scale_4, scale_5, scale_6,
static_cast<float const*>(scale_7), static_cast<float const*>(scale_8));
}
}
else if (hasExpertFp8QuantScales())
{
TLLM_CHECK(scale_1 && scale_2 && scale_3);
TLLM_CHECK(scale_4 || !hasExpertFp8FinalQuantScales());
TLLM_CHECK((scale_5 != nullptr) == hasLora());
TLLM_CHECK(!scale_6);
TLLM_CHECK(desc_1->dims.nbDims == 2);
TLLM_CHECK(desc_2->dims.nbDims == 1);
TLLM_CHECK(desc_3->dims.nbDims == 2);
TLLM_CHECK_WITH_INFO(
desc_1->dims.d[0] == experts_per_node && desc_1->dims.d[1] == 1, "Incorrect shape for weight FP8 scale");
TLLM_CHECK(desc_2->dims.d[0] == 1);
TLLM_CHECK_WITH_INFO(
desc_3->dims.d[0] == experts_per_node && desc_3->dims.d[1] == 1, "Incorrect shape for weight FP8 scale");
return QuantParams::FP8(static_cast<float const*>(scale_1), static_cast<float const*>(scale_2),
static_cast<float const*>(scale_3), static_cast<float const*>(scale_4), static_cast<float const*>(scale_5));
}
else if (hasFP4QuantScales())
{
TLLM_CHECK(scale_1 && scale_2 && scale_3 && scale_4 && scale_5 && scale_6);
TLLM_CHECK(desc_1->dims.nbDims == 1);
TLLM_CHECK(desc_2->dims.nbDims == 3);
TLLM_CHECK(desc_3->dims.nbDims == 1);
TLLM_CHECK(desc_4->dims.nbDims == 1);
TLLM_CHECK(desc_5->dims.nbDims == 3);
TLLM_CHECK(desc_6->dims.nbDims == 1);
TLLM_CHECK(desc_1->dims.d[0] == 1);
TLLM_CHECK_WITH_INFO(desc_2->dims.d[0] == experts_per_node && desc_2->dims.d[1] == gated_inter_size
&& desc_2->dims.d[2]
== mExpertHiddenSize / TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize,
"Incorrect shape for FP4 scale");
TLLM_CHECK_WITH_INFO(desc_3->dims.d[0] == experts_per_node, "Incorrect shape for FP4 scale");
TLLM_CHECK(desc_4->dims.d[0] == 1);
TLLM_CHECK_WITH_INFO(desc_5->dims.d[0] == experts_per_node && desc_5->dims.d[1] == mExpertHiddenSize
&& desc_5->dims.d[2]
== mExpertInterSize / TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize,
"Incorrect shape for FP4 scale");
TLLM_CHECK_WITH_INFO(desc_6->dims.d[0] == experts_per_node, "Incorrect shape for FP4 scale");
return QuantParams::FP4(static_cast<float const*>(scale_1),
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF const*>(scale_2),
static_cast<float const*>(scale_3), static_cast<float const*>(scale_4),
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF const*>(scale_5),
static_cast<float const*>(scale_6));
}
return {};
}
int MixtureOfExpertsPlugin::getNumLoraRequests(nvinfer1::PluginTensorDesc const* input_tensors) const
{
if (!hasLora())
return 0;
int num_reqs = input_tensors[getLoraFC1RanksIndex()].dims.d[0];
return num_reqs;
}
LoraParams MixtureOfExpertsPlugin::getLoraParams(
nvinfer1::PluginTensorDesc const* inputDesc, void const* const* inputs, void* workspace)
{
TLLM_CHECK(hasLora());
int const num_reqs = getNumLoraRequests(inputDesc);
int64_t const num_tokens = getNumTokens(inputDesc);
bool is_gated_actiation = isGatedActivation(mActivationType);
mLoraExpandFC1WeightPtrs.clear();
mLoraExpandFC2WeightPtrs.clear();
mLoraExpandFC1Ranks.clear();
mLoraExpandFC2Ranks.clear();
mLoraExpandFC1WeightPtrs.reserve(num_tokens * 2);
mLoraExpandFC2WeightPtrs.reserve(num_tokens * 2);
mLoraExpandFC1Ranks.reserve(num_tokens);
mLoraExpandFC2Ranks.reserve(num_tokens);
if (is_gated_actiation)
{
mLoraExpandGatedWeightPtrs.clear();
mLoraExpandGatedRanks.clear();
mLoraExpandGatedWeightPtrs.reserve(num_tokens * 2);
mLoraExpandGatedRanks.reserve(num_tokens);
}
int const seq_len = mRemoveInputPadding ? 0 : inputDesc[getInputTensorIndex()].dims.d[1];
int32_t const* req_types = static_cast<int32_t const*>(inputs[getHostRequestTypeIndex()]);
int32_t const* host_context_lens
= mRemoveInputPadding ? static_cast<int32_t const*>(inputs[getHostContextLengthIndex()]) : nullptr;
auto const fc1_lora_weight_ptrs = static_cast<void const* const*>(inputs[getLoraFC1WeightPtrsIndex()]);
auto const fc1_lora_ranks = static_cast<int32_t const*>(inputs[getLoraFC1RanksIndex()]);
auto const fc2_lora_weight_ptrs = static_cast<void const* const*>(inputs[getLoraFC2WeightPtrsIndex()]);
auto const fc2_lora_ranks = static_cast<int32_t const*>(inputs[getLoraFC2RanksIndex()]);
auto const gated_lora_weight_ptrs
= is_gated_actiation ? static_cast<void const* const*>(inputs[getLoraGatedWeightPtrsIndex()]) : nullptr;
auto const gated_lora_ranks
= is_gated_actiation ? static_cast<int32_t const*>(inputs[getLoraGatedRanksIndex()]) : nullptr;
int idx = 0;
for (int req_id = 0; req_id < num_reqs; req_id++)
{
RequestType const reqType = static_cast<RequestType const>(req_types[req_id]);
if (reqType == RequestType::kGENERATION)
{
// lora_weight_ptrs has 3 pointers for each module: A,B, and an optional DoRA magnitude
// the current LoRA implementation does not apply DoRA scaling, so the magnitude is ignored
mLoraExpandFC1WeightPtrs.push_back(fc1_lora_weight_ptrs[req_id * 3]);
mLoraExpandFC1WeightPtrs.push_back(fc1_lora_weight_ptrs[req_id * 3 + 1]);
mLoraExpandFC1Ranks.push_back(fc1_lora_ranks[req_id]);
mLoraExpandFC2WeightPtrs.push_back(fc2_lora_weight_ptrs[req_id * 3]);
mLoraExpandFC2WeightPtrs.push_back(fc2_lora_weight_ptrs[req_id * 3 + 1]);
mLoraExpandFC2Ranks.push_back(fc2_lora_ranks[req_id]);
if (is_gated_actiation)
{
mLoraExpandGatedWeightPtrs.push_back(gated_lora_weight_ptrs[req_id * 3]);
mLoraExpandGatedWeightPtrs.push_back(gated_lora_weight_ptrs[req_id * 3 + 1]);
mLoraExpandGatedRanks.push_back(gated_lora_ranks[req_id]);
}
idx += 1;
}
else
{
int context_len = (mRemoveInputPadding ? host_context_lens[req_id] : seq_len);
for (int context_id = 0; context_id < context_len; context_id++)
{
mLoraExpandFC1WeightPtrs.push_back(fc1_lora_weight_ptrs[req_id * 3]);
mLoraExpandFC1WeightPtrs.push_back(fc1_lora_weight_ptrs[req_id * 3 + 1]);
mLoraExpandFC1Ranks.push_back(fc1_lora_ranks[req_id]);
mLoraExpandFC2WeightPtrs.push_back(fc2_lora_weight_ptrs[req_id * 3]);
mLoraExpandFC2WeightPtrs.push_back(fc2_lora_weight_ptrs[req_id * 3 + 1]);
mLoraExpandFC2Ranks.push_back(fc2_lora_ranks[req_id]);
if (is_gated_actiation)
{
mLoraExpandGatedWeightPtrs.push_back(gated_lora_weight_ptrs[req_id * 3]);
mLoraExpandGatedWeightPtrs.push_back(gated_lora_weight_ptrs[req_id * 3 + 1]);
mLoraExpandGatedRanks.push_back(gated_lora_ranks[req_id]);
}
}
idx += context_len;
}
}
TLLM_CHECK_WITH_INFO(idx == num_tokens, fmtstr("idx %d num_tokens %ld", idx, num_tokens));
return LoraParams(num_reqs, mLoraExpandFC1Ranks.data(), mLoraExpandFC1WeightPtrs.data(), mLoraExpandFC2Ranks.data(),
mLoraExpandFC2WeightPtrs.data(), mLoraImpl1, mLoraImpl2, workspace, &mMemcpyEvent, mLoraExpandGatedRanks.data(),
mLoraExpandGatedWeightPtrs.data());
}
int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace_ptr,
cudaStream_t stream) noexcept
{
if (isBuilding())
{
return 0;
}
int64_t const num_tokens = getNumTokens(inputDesc);
int64_t const num_reqs = getNumLoraRequests(inputDesc);
if (useSideStream())
{
// Prepare the side stream
if (!mSideStreamPtr)
{
auto const resource_name = nvinfer1::pluginInternal::SideStream::getResourceKey(mSideStreamId);
nvinfer1::pluginInternal::SideStream side_stream{};
mSideStreamPtr = reinterpret_cast<nvinfer1::pluginInternal::SideStream*>(
getPluginRegistry()->acquirePluginResource(resource_name.c_str(), &side_stream));
}
// Debug the code with the main stream stalled (only executed when the environment variable
// TLLM_DEBUG_MOE_STALL_MAIN is set and has a positive value)
mSideStreamPtr->stallMainStream("TLLM_DEBUG_MOE_STALL_MAIN", stream, mDebugStallMain);
// The side stream waits for the inputs managed by the main stream to be ready
mSideStreamPtr->waitMainStreamOnSideStream(stream);
// Provide data dependency for the shared experts running after this plugin by copying inputs on the main stream
size_t count = 1;
for (int i = 0; i < inputDesc[getInputDummyTensorIndex()].dims.nbDims; ++i)
{
count *= inputDesc[getInputDummyTensorIndex()].dims.d[i];
}
count *= tensorrt_llm::runtime::BufferDataType(inputDesc[getInputDummyTensorIndex()].type).getSize();
TLLM_CUDA_CHECK(cudaMemcpyAsync(outputs[getOutputDummyTensorIndex()], inputs[getInputDummyTensorIndex()], count,
cudaMemcpyDeviceToDevice, stream));
// Switch from the main stream to the side stream
stream = mSideStreamPtr->getStream();
// The workspace is managed by the side stream (otherwise, the lifetime of workspace may be incorrect)
auto const workspace_size = setupWorkspace(nullptr, num_tokens, num_reqs).size;
workspace_ptr = mSideStreamPtr->getWorkspacePtr(workspace_size);
}
auto workspace = setupWorkspace(workspace_ptr, num_tokens, num_reqs);
auto w1_desc = inputDesc[getExpertWeights1Index()];
auto w2_desc = inputDesc[getExpertWeights2Index()];
TLLM_CHECK(w1_desc.dims.nbDims == 3);
auto const experts_per_node = mNumExperts / mParallelismConfig.ep_size;
TLLM_CHECK(w1_desc.dims.d[0] == experts_per_node);
TLLM_CHECK(w2_desc.dims.nbDims == 3);
TLLM_CHECK(w2_desc.dims.d[0] == experts_per_node);
auto [inner_packed_elements, outer_packed_elements] = getWeightPackedElements();
int inner_dim_idx = getGemmShapeInnerDimIndex();
int outer_dim_idx = getGemmShapeOuterDimIndex();
TLLM_CHECK(w1_desc.dims.d[inner_dim_idx] * inner_packed_elements == mExpertHiddenSize);
if (isGatedActivation(mActivationType))
{
TLLM_CHECK(w1_desc.dims.d[outer_dim_idx] * outer_packed_elements == mExpertInterSize * 2);
}
else
{
TLLM_CHECK(w1_desc.dims.d[outer_dim_idx] * outer_packed_elements == mExpertInterSize);
}
TLLM_CHECK(w2_desc.dims.d[inner_dim_idx] * inner_packed_elements == mExpertInterSize);
TLLM_CHECK(w2_desc.dims.d[outer_dim_idx] * outer_packed_elements == mExpertHiddenSize);
QuantParams quant_params{};
if (hasExpertIntQuantScales())
{
if (mGroupSize > 0)
{
quant_params = getQuantParams(inputDesc, inputs, getExpertIntQuantScale1Index(),
getExpertIntQuantScale2Index(), hasExpertPrequantScales() ? getExpertPrequantScales1Index() : -1,
hasExpertPrequantScales() ? getExpertPrequantScales2Index() : -1,
hasExpertWeightQuantZeros() ? getExpertIntQuantZeros1Index() : -1,
hasExpertWeightQuantZeros() ? getExpertIntQuantZeros2Index() : -1,
hasGroupwiseFp8Alpha() ? getExpertFp8Alpha1Index() : -1,
hasGroupwiseFp8Alpha() ? getExpertFp8Alpha2Index() : -1);
}
else
{
quant_params
= getQuantParams(inputDesc, inputs, getExpertIntQuantScale1Index(), getExpertIntQuantScale2Index());
}
}
else if (hasExpertFp8QuantScales())
{
quant_params = getQuantParams(inputDesc, inputs, //
getExpertFP8Dequant1Index(), //
getExpertFP8Quant2Index(), //
getExpertFP8Dequant2Index(), //
hasExpertFp8FinalQuantScales() ? getExpertFP8QuantFinalIndex() : -1,
hasLora() ? getInputFP8DequantIndex() : -1);
}
else if (hasFP4QuantScales())
{
quant_params = getQuantParams(inputDesc, inputs, //
getFP4GlobalActSF1Index(), //
getFP4WeightSF1Index(), //
getFP4GlobalSF1Index(), //
getFP4GlobalActSF2Index(), //
getFP4WeightSF2Index(), //
getFP4GlobalSF2Index() //
);
}
LoraParams lora_params{};
if (hasLora())
{
lora_params = getLoraParams(inputDesc, inputs, workspace.lora_workspace);
auto lora_gemm1 = mLoraProfiler->getBestConfig(num_tokens, mLoraGemmId1);
auto lora_gemm2 = mLoraProfiler->getBestConfig(num_tokens, mLoraGemmId2);
mLoraImpl1->setBestTactic(lora_gemm1);
mLoraImpl2->setBestTactic(lora_gemm2);
}
std::optional<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> gemm1;
std::optional<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> gemm2;
if (common::getEnvForceDeterministicMOE())
{
gemm1 = mMOERunner->getTactics()[0];
gemm2 = mMOERunner->getTactics()[0];
}
else
{
gemm1 = mGemmProfiler->getBestConfig(num_tokens, mGemmId1);
gemm2 = mGemmProfiler->getBestConfig(num_tokens, mGemmId2);
}
MoeMinLatencyParams min_latency_params{};
mMOERunner->setTactic(gemm1, gemm2);
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr,
static_cast<int const*>(inputs[getTokenSelectedExpertsIndex()]),
hasFinalScales() ? static_cast<float const*>(inputs[getTokenFinalScalesIndex()]) : nullptr,
inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr, mActivationType,
inputs[getExpertWeights2Index()], hasBias() ? inputs[getExpertBias2Index()] : nullptr, quant_params, num_tokens,
mExpertHiddenSize, mExpertInterSize, mNumExperts, mExpertsPerToken, static_cast<char*>(workspace.workspace),
// Outputs
outputs[getOutputTensorIndex()], static_cast<int*>(workspace.src_to_dest_map), mParallelismConfig,
/*enable_alltoall=*/false, hasLora(), lora_params, /*use_deepseek_fp8_block_scale=*/false,
/*min_latency_mode=*/false, min_latency_params, stream);
#else
mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr,
static_cast<int const*>(inputs[getTokenSelectedExpertsIndex()]),
hasFinalScales() ? static_cast<float const*>(inputs[getTokenFinalScalesIndex()]) : nullptr,
inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr, mActivationType,
inputs[getExpertWeights2Index()], hasBias() ? inputs[getExpertBias2Index()] : nullptr, quant_params, num_tokens,
mExpertHiddenSize, mExpertInterSize, mNumExperts, mExpertsPerToken, static_cast<char*>(workspace.workspace),
// Outputs
outputs[getOutputTensorIndex()], static_cast<int*>(workspace.src_to_dest_map), mParallelismConfig, hasLora(),
lora_params, /*use_deepseek_fp8_block_scale=*/false,
/*min_latency_mode=*/false, min_latency_params, stream);
#endif
if (useSideStream())
{
// Debug the code with the side stream stalled (only executed when the environment variable
// TLLM_DEBUG_MOE_STALL_SIDE is set and has a positive value)
mSideStreamPtr->stallSideStream("TLLM_DEBUG_MOE_STALL_SIDE", mDebugStallSide);
}
return 0;
}
// IPluginV2Ext Methods
nvinfer1::DataType MixtureOfExpertsPlugin::getOutputDataType(
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
{
TLLM_CHECK(index == getOutputTensorIndex() || index == getOutputDummyTensorIndex());
if (useSideStream() && index == getOutputDummyTensorIndex())
{
return inputTypes[getInputDummyTensorIndex()];
}
return mOutputType;
}
// IPluginV2 Methods
char const* MixtureOfExpertsPlugin::getPluginType() const noexcept
{
return MIXTURE_OF_EXPERTS_PLUGIN_NAME;
}
char const* MixtureOfExpertsPlugin::getPluginVersion() const noexcept
{
return MIXTURE_OF_EXPERTS_PLUGIN_VERSION;
}
int MixtureOfExpertsPlugin::initialize() noexcept
{
mGemmProfiler->setGemmToProfile(kernels::GemmProfilerBackend::GemmToProfile::GEMM_1);
mGemmProfiler->profileTactics(this, mType, mDims, mGemmId1);
mGemmProfiler->setGemmToProfile(kernels::GemmProfilerBackend::GemmToProfile::GEMM_2);
mGemmProfiler->profileTactics(this, mType, mDims, mGemmId2);
if (hasLora())
{
mLoraImpl1->setGemmConfig();
mLoraImpl2->setGemmConfig();
mLoraProfiler->profileTactics(mLoraImpl1->getCublasWrapper(), mType, mDims, mLoraGemmId1);
mLoraProfiler->profileTactics(mLoraImpl2->getCublasWrapper(), mType, mDims, mLoraGemmId2);
}
return 0;
}
void MixtureOfExpertsPlugin::terminate() noexcept
{
if (mSideStreamPtr)
{
auto const resource_name = nvinfer1::pluginInternal::SideStream::getResourceKey(mSideStreamId);
getPluginRegistry()->releasePluginResource(resource_name.c_str());
mSideStreamPtr = nullptr;
}
}
void MixtureOfExpertsPlugin::destroy() noexcept
{
if (hasLora())
{
TLLM_CUDA_CHECK(cudaEventDestroy(mMemcpyEvent));
}
// This gets called when the network containing plugin is destroyed
delete this;
}
void MixtureOfExpertsPlugin::setPluginNamespace(char const* libNamespace) noexcept
{
mNamespace = libNamespace;
}
char const* MixtureOfExpertsPlugin::getPluginNamespace() const noexcept
{
return mNamespace.c_str();
}
///////////////
char const* MixtureOfExpertsPluginCreator::getPluginName() const noexcept
{
return MIXTURE_OF_EXPERTS_PLUGIN_NAME;
}
char const* MixtureOfExpertsPluginCreator::getPluginVersion() const noexcept
{
return MIXTURE_OF_EXPERTS_PLUGIN_VERSION;
}
nvinfer1::PluginFieldCollection const* MixtureOfExpertsPluginCreator::getFieldNames() noexcept
{
return &mFC;
}
MixtureOfExpertsPluginCreator::MixtureOfExpertsPluginCreator()
{
// Fill PluginFieldCollection with PluginField arguments metadata
mPluginAttributes.clear();
mPluginAttributes.emplace_back(nvinfer1::PluginField("remove_input_padding", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("number_of_experts", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("experts_per_token", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("expert_hidden_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("expert_inter_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("groupwise_quant_algo", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("group_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("activation_type", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("type_id", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("weight_type_id", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("quant_mode", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("use_final_scales", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("use_bias", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("tp_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("tp_rank", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("ep_size", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("ep_rank", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("side_stream_id", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("use_lora", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("lora_type_id", nullptr, PluginFieldType::kINT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("max_low_rank", nullptr, PluginFieldType::kINT32));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
IPluginV2* MixtureOfExpertsPluginCreator::createPlugin(
char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept
{
nvinfer1::PluginField const* fields = fc->fields;
int mRemoveInputPadding{};
int mNumExperts{};
int mExpertsPerToken{};
int mExpertHiddenSize{};
int mExpertInterSize{};
int mGroupwiseQuantAlgo{};
int mGroupSize{};
int mActivationType{};
int mType{};
int mWeightType{};
int mOutputType{INT_MAX};
int mQuantMode{};
int mUseFinalScales{1}; // Default to true
int mUseBias{0};
int mTPSize{};
int mTPRank{};
int mEPSize{};
int mEPRank{};
int mRequiresDeterminism{0};
int mSideStreamId{0};
int mUseLora{};
int mLoraType{INT_MAX};
int mMaxLowRank{0};
// Read configurations from each fields
struct MapPair
{
char const* key;
int& field;
bool optional = false;
bool set = false;
};
std::array input_map{
MapPair{"remove_input_padding", std::ref(mRemoveInputPadding)},
MapPair{"number_of_experts", std::ref(mNumExperts)},
MapPair{"experts_per_token", std::ref(mExpertsPerToken)},
MapPair{"expert_hidden_size", std::ref(mExpertHiddenSize)},
MapPair{"expert_inter_size", std::ref(mExpertInterSize)},
MapPair{"groupwise_quant_algo", std::ref(mGroupwiseQuantAlgo)},
MapPair{"group_size", std::ref(mGroupSize)},
MapPair{"activation_type", std::ref(mActivationType)},
MapPair{"type_id", std::ref(mType)},
MapPair{"weight_type_id", std::ref(mWeightType)},
MapPair{"quant_mode", std::ref(mQuantMode)},
MapPair{"tp_size", std::ref(mTPSize)},
MapPair{"tp_rank", std::ref(mTPRank)},
MapPair{"ep_size", std::ref(mEPSize)},
MapPair{"ep_rank", std::ref(mEPRank)},
MapPair{"use_lora", std::ref(mUseLora)},
MapPair{"use_final_scales", std::ref(mUseFinalScales)},
// Optional
MapPair{"use_bias", std::ref(mUseBias), true},
MapPair{"output_type_id", std::ref(mOutputType), true},
MapPair{"force_determinism", std::ref(mRequiresDeterminism), true},
MapPair{"side_stream_id", std::ref(mSideStreamId), true},
MapPair{"lora_type_id", std::ref(mLoraType), true},
MapPair{"max_low_rank", std::ref(mMaxLowRank), true},
};
for (int i = 0; i < fc->nbFields; ++i)
{
char const* attrName = fields[i].name;
for (auto& item : input_map)
{
if (!strcmp(item.key, attrName))
{
TLLM_CHECK(fields[i].type == nvinfer1::PluginFieldType::kINT32);
TLLM_CHECK_WITH_INFO(!item.set, "Parameter %s was set twice", item.key);
item.field = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
item.set = true;
}
}
}
for (auto& item : input_map)
{
TLLM_CHECK_WITH_INFO(item.set || item.optional, "Parameter %s is required but not set", item.key);
}
// Output type is optional, if not set it to the same as mType
if (mOutputType == INT_MAX)
{
mOutputType = mType;
}
if (mUseLora)
{
TLLM_CHECK_WITH_INFO(mLoraType != INT_MAX && mMaxLowRank != 0,
"MoE fuse lora, lora_type_id and max_low_rank are required but not set");
}
try
{
auto gemmProfiler = moePluginProfiler.createGemmPluginProfiler(/* inference */ false);
auto loraProfiler = loraPluginProfileManager.createGemmPluginProfiler(/* inference */ false, /* skip */ true);
auto* obj = new MixtureOfExpertsPlugin(
// Constructor parameters
mRemoveInputPadding, mNumExperts, mExpertsPerToken, mExpertHiddenSize, mExpertInterSize,
mGroupwiseQuantAlgo, mGroupSize, static_cast<ActivationType>(mActivationType),
static_cast<nvinfer1::DataType>(mType), static_cast<nvinfer1::DataType>(mWeightType),
static_cast<nvinfer1::DataType>(mOutputType), QuantMode(mQuantMode), mUseFinalScales != 0, mUseBias != 0,
mTPSize, mTPRank, mEPSize, mEPRank, mRequiresDeterminism != 0, mSideStreamId, gemmProfiler, mUseLora != 0,
static_cast<nvinfer1::DataType>(mLoraType), loraProfiler, mMaxLowRank);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
IPluginV2* MixtureOfExpertsPluginCreator::deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept
{
// This object will be deleted when the network is destroyed, which will
// call MixtureOfExpertsPlugin::destroy()
try
{
auto gemmProfiler = moePluginProfiler.createGemmPluginProfiler(/* inference */ true);
auto loraProfiler = loraPluginProfileManager.createGemmPluginProfiler(/* inference */ false, /* skip */ true);
auto* obj = new MixtureOfExpertsPlugin(
// Constructor parameters
serialData, serialLength, gemmProfiler, loraProfiler);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
void MixtureOfExpertsPluginCreator::setPluginNamespace(char const* libNamespace) noexcept
{
mNamespace = libNamespace;
}
char const* MixtureOfExpertsPluginCreator::getPluginNamespace() const noexcept
{
return mNamespace.c_str();
}
void MixtureOfExpertsGemmProfiler::computeTmpSize(size_t maxM, size_t n, size_t k)
{
checkInit();
size_t bytes = backend.getWorkspaceSize(maxM);
this->setTmpWorkspaceSizeInBytes(bytes);
}
void MixtureOfExpertsGemmProfiler::runTactic(int m, int n, int k, MixtureOfExpertsGemmProfiler::Config const& tactic,
char* workspace_ptr_char, cudaStream_t const& stream)
{
checkInit();
backend.runProfiler(m, tactic, workspace_ptr_char, /*expert_weights*/ nullptr, stream);
}
auto MixtureOfExpertsGemmProfiler::getTactics(int m, int n, int k) const -> std::vector<Config>
{
assert(mRunner);
return mRunner->mMOERunner->getTactics();
}
void MixtureOfExpertsGemmProfiler::initTmpData(
int m, int n, int k, char* workspace, size_t ws_size, cudaStream_t stream)
{
checkInit();
backend.prepare(m, workspace, /*expert_weights*/ nullptr, stream);
}
void MixtureOfExpertsGemmProfiler::checkInit()
{
assert(mRunner);
if (init_backend)
{
return;
}
init_backend = true;
auto& plugin = *mRunner;
#ifdef USING_OSS_CUTLASS_MOE_GEMM
backend.init(*plugin.mMOERunner, backend.mGemmToProfile, plugin.mType, plugin.mWeightType, plugin.mOutputType,
plugin.mNumExperts, plugin.mExpertsPerToken, plugin.mExpertHiddenSize, plugin.mExpertInterSize,
plugin.mGroupSize, plugin.mActivationType, plugin.hasBias(), plugin.hasLora(), /*min_latency_mode=*/false,
/*need_weights=*/true, plugin.getParallelismConfig(), /*enable_alltoall=*/false);
#else
backend.init(*plugin.mMOERunner, backend.mGemmToProfile, plugin.mType, plugin.mWeightType, plugin.mOutputType,
plugin.mNumExperts, plugin.mExpertsPerToken, plugin.mExpertHiddenSize, plugin.mExpertInterSize,
plugin.mGroupSize, plugin.mActivationType, plugin.hasBias(), plugin.hasLora(), /*min_latency_mode=*/false,
/*need_weights=*/true, plugin.getParallelismConfig());
#endif
}