TensorRT-LLMs/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu
chenfeiz0326 7f5716ef83
Cherry-pick trtllm-gen from feat/llama4 to main (#4086)
* feat: TRT-LLM Gen FP8 MoE Llama4

Signed-off-by: Nikita Korobov <nkorobov@nvidia.com>

* feat: TRT-LLM Gen llama4 MoE Top1 routing

Signed-off-by: Jiqun Tu <jtu@nvidia.com>

* feat: add per tensor FP8 TRT-LLM Gen GEMMs

Signed-off-by: Nikita Korobov <nkorobov@nvidia.com>

* Update

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>

* Update

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>

* Add license for cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/gemmCubins

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>

* Add guard for routingIndicesClusterKernel

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>

* Guard sm90+ for routingkernels

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>

* Guard sm90+ for routingkernels

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>

---------

Signed-off-by: Nikita Korobov <nkorobov@nvidia.com>
Signed-off-by: Jiqun Tu <jtu@nvidia.com>
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
Co-authored-by: Nikita Korobov <nkorobov@nvidia.com>
Co-authored-by: Jiqun Tu <jtu@nvidia.com>
2025-05-08 14:13:01 -07:00

415 lines
19 KiB
Plaintext

/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "gemmCommon.h"
#include "gemmList.h"
#include "runner.h"
#include "trtllmGenSrc/DevKernel.h"
#include "trtllmGenSrc/RoutingKernel.h"
#include <iostream>
namespace tensorrt_llm
{
namespace kernels
{
namespace trtllmGenFp8BlockScaleMoe
{
namespace Routing
{
namespace
{
inline int32_t computeLog2(int32_t val, std::string const& name = "")
{
int32_t n = val;
int32_t out = 0;
while (n >>= 1)
{
++out;
}
TLLM_CHECK_ERROR((1 << out) == val, "Expected ", name, " to be a power of 2, got ", val);
return out;
}
} // namespace
Runner::Runner() {}
void Runner::run(void* routingLogits, void* routingBias, int32_t num_tokens, int32_t num_experts, int32_t top_k,
int32_t n_group, int32_t topk_group, int32_t local_expert_offset, int32_t local_num_experts,
float routed_scaling_factor, int32_t* routingExpertIndexes, int32_t* expertCountHistogram,
int32_t* permuted_idx_size, int32_t* expanded_idx_to_permuted_idx, int32_t* permuted_idx_to_expanded_idx,
int32_t* permuted_idx_to_token_idx, void* expert_weights, int32_t* num_tokens_per_expert,
int32_t* cta_idx_xy_to_batch_idx, int32_t* cta_idx_xy_to_mn_limit, int32_t* num_non_exiting_ctas,
tg::Dtype dtypeElt, bool use_routing_scales_on_input, bool use_deep_seek_fp8, cudaStream_t stream)
{
if (top_k == 8)
{
std::vector<int32_t> selectedIndex;
for (size_t ii = 0; ii < PermuteGemm1::gemmList.size(); ii++)
{
auto gemmInfo = PermuteGemm1::gemmList[ii];
if (gemmInfo.dtypeElt == dtypeElt && gemmInfo.usePerTokenSfB == use_routing_scales_on_input
&& gemmInfo.useDeepSeekFp8 == use_deep_seek_fp8)
{
selectedIndex.push_back(ii);
}
}
TLLM_CHECK_WITH_INFO(selectedIndex.size() != 0, "No kernel found for the given element type");
TLLM_CHECK_WITH_INFO(selectedIndex.size() == 1, "Multiple kernels found for the given element type");
auto const& kernelInfo = PermuteGemm1::gemmList[*selectedIndex.begin()];
int32_t tileN = kernelInfo.tileN;
moe::dev::routing::Data routingData;
routingData.mDtypeElt = dtypeElt; // no-op for now as hidden_state is not input
routingData.mDtypeExpW = tg::Dtype::Bfloat16;
routingData.mUsePdl = true;
// output:
routingData.mPtrExpertIdx = routingExpertIndexes;
routingData.mPtrExpertCounts = expertCountHistogram;
routingData.mPtrPermutedIdxSize = permuted_idx_size;
routingData.mPtrExpandedIdxToPermutedIdx = expanded_idx_to_permuted_idx;
routingData.mPtrPermutedIdxToExpandedIdx = permuted_idx_to_expanded_idx;
routingData.mPtrPermutedIdxToTokenIdx = permuted_idx_to_token_idx;
routingData.mPtrNumTokensPerExpert = num_tokens_per_expert;
routingData.mPtrExpertWeights = expert_weights;
routingData.mPtrCtaIdxXyToBatchIdx = cta_idx_xy_to_batch_idx;
routingData.mPtrCtaIdxXyToMnLimit = cta_idx_xy_to_mn_limit;
routingData.mPtrNumNonExitingCtas = num_non_exiting_ctas;
routingData.mAllToAllRouteAct = false;
// input:
// routingData.mPtrRoutingWeights = args.mRoutingWeights; // routing weights (don't need if not using gemm)
routingData.mPtrRoutingBias = routingBias;
routingData.mPtrScores = reinterpret_cast<float*>(routingLogits);
// routingData.mPtrIn = args.mInputActs;
routingData.mNumTokens = num_tokens;
// routingData.mHiddenDim = args.mHiddenDim;
routingData.mNumExperts = num_experts;
routingData.mNumExpertGroups = n_group;
routingData.mNumLimitedGroups = topk_group;
routingData.mTopK = top_k;
routingData.mPaddingLog2 = computeLog2(tileN);
routingData.mLocalExpertsStartIdx = local_expert_offset;
routingData.mLocalExpertsStrideLog2 = 0;
routingData.mNumLocalExperts = local_num_experts;
routingData.mRouteScale = routed_scaling_factor;
routingData.mUseRoutingSoftmax = false;
moe::dev::routing::run(routingData, stream);
}
else if (top_k == 1)
{
std::vector<int32_t> selectedIndex;
for (size_t ii = 0; ii < PermuteGemm1::gemmList.size(); ii++)
{
auto gemmInfo = PermuteGemm1::gemmList[ii];
if (gemmInfo.dtypeElt == dtypeElt && gemmInfo.usePerTokenSfB == use_routing_scales_on_input
&& gemmInfo.useDeepSeekFp8 == use_deep_seek_fp8)
{
selectedIndex.push_back(ii);
}
}
TLLM_CHECK_WITH_INFO(selectedIndex.size() != 0, "No kernel found for the given element type");
TLLM_CHECK_WITH_INFO(selectedIndex.size() == 1, "Multiple kernels found for the given element type");
auto const& kernelInfo = PermuteGemm1::gemmList[*selectedIndex.begin()];
int32_t tileN = kernelInfo.tileN;
moe::dev::routingLlama4::Data routingData;
// routingData.mDtypeElt = dtypeElt; // no-op for now as hidden_state is not input
routingData.mDtypeExpW = tg::Dtype::Bfloat16;
routingData.mUsePdl = true;
// output:
routingData.mPtrExpertIdx = routingExpertIndexes;
routingData.mPtrExpertCounts = expertCountHistogram;
routingData.mPtrPermutedIdxSize = permuted_idx_size;
routingData.mPtrExpandedIdxToPermutedIdx = expanded_idx_to_permuted_idx;
// routingData.mPtrPermutedIdxToExpandedIdx = permuted_idx_to_expanded_idx;
routingData.mPtrPermutedIdxToTokenIdx = permuted_idx_to_token_idx;
// routingData.mPtrNumTokensPerExpert = num_tokens_per_expert;
routingData.mPtrExpertWeights = expert_weights;
routingData.mPtrCtaIdxXyToBatchIdx = cta_idx_xy_to_batch_idx;
routingData.mPtrCtaIdxXyToMnLimit = cta_idx_xy_to_mn_limit;
routingData.mPtrNumNonExitingCtas = num_non_exiting_ctas;
// routingData.mAllToAllRouteAct = false;
// input:
// routingData.mPtrRoutingWeights = args.mRoutingWeights; // routing weights (don't need if not using gemm)
// routingData.mPtrRoutingBias = routingBias;
routingData.mPtrScores = routingLogits;
// routingData.mPtrIn = args.mInputActs;
routingData.mNumTokens = num_tokens;
// routingData.mHiddenDim = args.mHiddenDim;
routingData.mNumExperts = num_experts;
// routingData.mNumExpertGroups = n_group;
// routingData.mNumLimitedGroups = topk_group;
routingData.mTopK = top_k;
routingData.mPaddingLog2 = computeLog2(tileN);
routingData.mLocalExpertsStartIdx = local_expert_offset;
routingData.mLocalExpertsStrideLog2 = 0;
routingData.mNumLocalExperts = local_num_experts;
// routingData.mRouteScale = routed_scaling_factor;
// routingData.mUseRoutingSoftmax = false;
moe::dev::routingLlama4::run(routingData, stream);
}
else
{
TLLM_CHECK_ERROR(false, "top_k can only be 1 or 8.");
}
}
} // namespace Routing
namespace PermuteGemm1
{
Runner::Runner(trtllm::gen::Dtype dtypeElt)
: mDtypeElt(dtypeElt)
{
}
void Runner::run(void* hidden_state, void* hidden_state_scale, void* weight, void* weight_scale, void* expert_weights,
float* output_scales_scalar, float* output_scales_gate_scalar, void* output, void* output_scale, int32_t top_k,
int32_t hidden_size, int32_t intermediate_size, int32_t num_experts, int32_t num_tokens,
int32_t* permuted_idx_to_token_idx, int32_t* ptr_num_non_exiting_ctas, int32_t* ptr_total_num_padded_tokens,
int32_t* ptr_cta_idx_xy_to_batch_idx, int32_t* ptr_cta_idx_xy_to_mn_limit, bool use_routing_scales_on_input,
bool use_deep_seek_fp8, cudaStream_t stream)
{
std::vector<int32_t> selectedIndex;
for (size_t ii = 0; ii < gemmList.size(); ii++)
{
auto gemmInfo = gemmList[ii];
if (gemmInfo.dtypeElt == mDtypeElt && gemmInfo.usePerTokenSfB == use_routing_scales_on_input
&& gemmInfo.useDeepSeekFp8 == use_deep_seek_fp8)
{
selectedIndex.push_back(ii);
}
}
TLLM_CHECK_WITH_INFO(selectedIndex.size() != 0, "No kernel found for the given element type");
TLLM_CHECK_WITH_INFO(selectedIndex.size() == 1, "Multiple kernels found for the given element type");
auto const& kernelInfo = gemmList[*selectedIndex.begin()];
gemmCommon::MyOptions options;
options.mTopK = top_k;
options.mBatchM = false;
options.mTransposeMmaOutput = true;
options.mNumTokens = num_tokens;
options.mNumExperts = num_experts;
options.mM = 2 * intermediate_size;
options.mN = 256; // A default value in GemmOptions.h that is not supposed to be used. Same as trtllm-gen behavior.
options.mK = hidden_size;
options.mClusterDimX = 1;
options.mClusterDimY = 1;
options.mClusterDimZ = 1;
options.mAllReduceAlgo = gemmCommon::gemm::AllReduceAlgo::None;
options.mSplitK = gemmCommon::gemm::SplitK::None;
options.mPtrNumNonExitingCtas = ptr_num_non_exiting_ctas;
options.mPtrTotalNumPaddedTokens = ptr_total_num_padded_tokens;
options.mPtrCtaIdxXyToBatchIdx = ptr_cta_idx_xy_to_batch_idx;
options.mPtrCtaIdxXyToMnLimit = ptr_cta_idx_xy_to_mn_limit;
options.mSfLayoutB = tg::SfLayout::Linear;
options.mSfLayoutC = tg::SfLayout::Linear;
options.mUseCustomLowLatencyImpl = false;
options.mAllToAllRouteAct = false;
options.mIsStaticBatch = false;
options.mBatchedN = std::vector(num_experts, -1);
gemmCommon::copyKernelInfoToOptions(kernelInfo, options);
gemmCommon::batchedGemm::checkAndUpdateGemmOptions(options, true, false, false);
gemmCommon::BatchedGemmData batchedGemmData;
auto max_num_padded_tokens = Routing::getMaxPermutedPaddedCount(num_tokens, top_k, num_experts, kernelInfo.tileN);
gemmCommon::setSingleBatchedGemmData(weight, hidden_state, output, output_scales_scalar, output_scales_gate_scalar,
reinterpret_cast<float*>(weight_scale), reinterpret_cast<float*>(hidden_state_scale),
reinterpret_cast<float*>(output_scale),
// FIXME: we pass the same scaling factors in one case for dsfp8 and in the other case for fp4
// We should pass them once only and decide on the case inside of the setSingleBatchedGemmData
weight_scale, hidden_state_scale, output_scale, permuted_idx_to_token_idx, nullptr, nullptr, expert_weights,
kernelInfo.numSlicesForSplitK, options, batchedGemmData, max_num_padded_tokens);
gemmCommon::launchGemmFromData(kernelInfo, options, batchedGemmData, stream, /*usePDL*/ false);
}
} // namespace PermuteGemm1
namespace Gemm2
{
Runner::Runner(tg::Dtype dtypeElt, tg::Dtype outputDtype)
: mDtypeElt(dtypeElt)
, mOutputDtype(outputDtype)
{
}
void Runner::run(void* permuted_hidden_state, void* permuted_hidden_state_scale, void* weight, void* weight_scale,
float* output_scales_scalar, void* output, void* output_scale, int32_t top_k, int32_t hidden_size,
int32_t intermediate_size, int32_t num_experts, int32_t num_tokens, int32_t* ptr_num_non_exiting_ctas,
int32_t* ptr_total_num_padded_tokens, int32_t* ptr_cta_idx_xy_to_batch_idx, int32_t* ptr_cta_idx_xy_to_mn_limit,
bool use_deep_seek_fp8, cudaStream_t stream)
{
std::vector<int32_t> selectedIndex;
for (size_t ii = 0; ii < gemmList.size(); ii++)
{
auto gemmInfo = gemmList[ii];
if (gemmInfo.dtypeElt == mDtypeElt && gemmInfo.dtypeC == mOutputDtype
&& gemmInfo.useDeepSeekFp8 == use_deep_seek_fp8)
{
selectedIndex.push_back(ii);
}
}
TLLM_CHECK_WITH_INFO(selectedIndex.size() != 0, "No kernel found for the given element and output types");
TLLM_CHECK_WITH_INFO(selectedIndex.size() == 1, "Multiple kernels found for the given element and output types");
auto const& kernelInfo = gemmList[*selectedIndex.begin()];
gemmCommon::MyOptions options;
options.mTopK = top_k;
options.mBatchM = false;
options.mTransposeMmaOutput = true;
options.mNumExperts = num_experts;
options.mNumTokens = num_tokens;
options.mM = hidden_size;
options.mN = 256; // A default value in GemmOptions.h that is not supposed to be used. Same as trtllm-gen behavior.
options.mK = intermediate_size;
options.mClusterDimX = 1;
options.mClusterDimY = 1;
options.mClusterDimZ = 1;
options.mAllReduceAlgo = gemmCommon::gemm::AllReduceAlgo::None;
options.mSplitK = gemmCommon::gemm::SplitK::None;
options.mPtrNumNonExitingCtas = ptr_num_non_exiting_ctas;
options.mPtrTotalNumPaddedTokens = ptr_total_num_padded_tokens;
options.mPtrCtaIdxXyToBatchIdx = ptr_cta_idx_xy_to_batch_idx;
options.mPtrCtaIdxXyToMnLimit = ptr_cta_idx_xy_to_mn_limit;
options.mSfLayoutB = tg::SfLayout::Linear;
options.mSfLayoutC = tg::SfLayout::Linear;
options.mUseCustomLowLatencyImpl = false;
options.mAllToAllRouteAct = false;
options.mIsStaticBatch = false;
options.mBatchedN = std::vector(num_experts, -1);
gemmCommon::copyKernelInfoToOptions(kernelInfo, options);
gemmCommon::batchedGemm::checkAndUpdateGemmOptions(options, true, false, false);
gemmCommon::BatchedGemmData batchedGemmData;
auto max_num_padded_tokens = Routing::getMaxPermutedPaddedCount(num_tokens, top_k, num_experts, kernelInfo.tileN);
gemmCommon::setSingleBatchedGemmData(weight, permuted_hidden_state, output, output_scales_scalar, nullptr,
reinterpret_cast<float*>(weight_scale), reinterpret_cast<float*>(permuted_hidden_state_scale),
reinterpret_cast<float*>(output_scale), weight_scale, permuted_hidden_state_scale, output_scale, nullptr,
nullptr, nullptr, nullptr, kernelInfo.numSlicesForSplitK, options, batchedGemmData, max_num_padded_tokens);
gemmCommon::launchGemmFromData(kernelInfo, options, batchedGemmData, stream);
}
} // namespace Gemm2
namespace MoE
{
Runner::Runner() {}
void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace,
moe::dev::convertsf::Data& convertSfData, moe::dev::activation::Data& activationData,
moe::dev::finalize::Data& finalizeData)
{
// Setup sf conversion data if needed
convertSfData.inSfPtr = args.hidden_states_scale;
convertSfData.outSfPtr = workspace.hidden_states_scale_linear;
convertSfData.hiddenDimSf = args.hidden_size / 16;
convertSfData.numTokens = args.num_tokens;
convertSfData.sfLayoutSrc = tg::SfLayout::R128c4;
convertSfData.sfLayoutDst = tg::SfLayout::Linear;
convertSfData.mUsePdl = true;
// Setup activation data
activationData.mDtypeElt = args.mDtypeElt;
activationData.mUsePdl = true;
activationData.mUseDeepSeekFp8 = true;
activationData.inPtr = workspace.gemm1_output;
activationData.outPtr = workspace.activation_output;
activationData.inDqSfsPtr = workspace.gemm1_output_scale;
activationData.outDqSfsPtr = workspace.activation_output_scale;
activationData.innerDim = args.intermediate_size * 2;
activationData.topK = args.top_k;
activationData.numTokens = args.num_tokens;
activationData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx;
activationData.totalNumPaddedTokens = workspace.total_num_padded_tokens;
// Setup finalize data
finalizeData.mDtypeElt = args.mDtypeOut;
finalizeData.mDtypeExpW = args.mDtypeExpW;
finalizeData.mUsePdl = true;
finalizeData.mUseDeepSeekFp8 = false;
finalizeData.inPtr = workspace.gemm2_output;
finalizeData.outPtr = args.output;
finalizeData.inDqSfsPtr = workspace.gemm2_output_scale;
finalizeData.outDqSfsPtr = args.output_scale;
if (args.mUseRoutingScalesOnInput)
{
finalizeData.expertWeightsPtr = nullptr;
}
else
{
finalizeData.expertWeightsPtr = workspace.expert_weights;
}
finalizeData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx;
finalizeData.numTokens = args.num_tokens;
finalizeData.numExperts = args.num_experts;
finalizeData.topK = args.top_k;
finalizeData.hiddenDim = args.hidden_size;
finalizeData.totalNumPaddedTokens = workspace.total_num_padded_tokens;
}
void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, cudaStream_t stream)
{
// Setup all operation data
moe::dev::activation::Data activationData;
moe::dev::finalize::Data finalizeData;
moe::dev::convertsf::Data convertSfData;
setOpsData(args, workspace, convertSfData, activationData, finalizeData);
void* hidden_states_scale_linear{args.hidden_states_scale};
PermuteGemm1::Runner permuteGemm1(args.mDtypeElt);
permuteGemm1.run(args.hidden_states, hidden_states_scale_linear, args.gemm1_weights, args.gemm1_weights_scale,
workspace.expert_weights, args.output1_scales_scalar, args.output1_scales_gate_scalar, workspace.gemm1_output,
workspace.gemm1_output_scale, args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts,
args.num_tokens, workspace.permuted_idx_to_token_idx, workspace.num_non_exiting_ctas,
workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit,
args.mUseRoutingScalesOnInput, args.mUseDeepSeekFp8, stream);
// We do not fuse activation with FC1 for DeepSeek FP8 due to the weights shuffling constraint.
void* gemm2_input = workspace.gemm1_output;
void* gemm2_input_scale = workspace.gemm1_output_scale;
// We do activation only for DeepSeek FP8, as cubins do not have fused activation.
if (args.mDtypeElt == tg::Dtype::E4m3 && args.mUseDeepSeekFp8)
{
// Run activation
moe::dev::activation::run(activationData, stream);
gemm2_input = workspace.activation_output;
gemm2_input_scale = workspace.activation_output_scale;
}
// Run gemm2
Gemm2::Runner gemm2(args.mDtypeElt, tg::Dtype::Bfloat16);
gemm2.run(gemm2_input, gemm2_input_scale, args.gemm2_weights, args.gemm2_weights_scale, args.output2_scales_scalar,
workspace.gemm2_output, workspace.gemm2_output_scale, args.top_k, args.hidden_size, args.intermediate_size,
args.local_num_experts, args.num_tokens, workspace.num_non_exiting_ctas, workspace.total_num_padded_tokens,
workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit, args.mUseDeepSeekFp8, stream);
// Run finalize
moe::dev::finalize::run(finalizeData, stream);
}
} // namespace MoE
} // namespace trtllmGenFp8BlockScaleMoe
} // namespace kernels
} // namespace tensorrt_llm