mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* feat: TRT-LLM Gen FP8 MoE Llama4 Signed-off-by: Nikita Korobov <nkorobov@nvidia.com> * feat: TRT-LLM Gen llama4 MoE Top1 routing Signed-off-by: Jiqun Tu <jtu@nvidia.com> * feat: add per tensor FP8 TRT-LLM Gen GEMMs Signed-off-by: Nikita Korobov <nkorobov@nvidia.com> * Update Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com> * Update Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com> * Add license for cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/gemmCubins Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com> * Add guard for routingIndicesClusterKernel Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com> * Guard sm90+ for routingkernels Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com> * Guard sm90+ for routingkernels Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com> --------- Signed-off-by: Nikita Korobov <nkorobov@nvidia.com> Signed-off-by: Jiqun Tu <jtu@nvidia.com> Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com> Co-authored-by: Nikita Korobov <nkorobov@nvidia.com> Co-authored-by: Jiqun Tu <jtu@nvidia.com>
415 lines
19 KiB
Plaintext
415 lines
19 KiB
Plaintext
/*
|
|
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "gemmCommon.h"
|
|
#include "gemmList.h"
|
|
#include "runner.h"
|
|
#include "trtllmGenSrc/DevKernel.h"
|
|
#include "trtllmGenSrc/RoutingKernel.h"
|
|
#include <iostream>
|
|
|
|
namespace tensorrt_llm
|
|
{
|
|
namespace kernels
|
|
{
|
|
namespace trtllmGenFp8BlockScaleMoe
|
|
{
|
|
|
|
namespace Routing
|
|
{
|
|
namespace
|
|
{
|
|
inline int32_t computeLog2(int32_t val, std::string const& name = "")
|
|
{
|
|
int32_t n = val;
|
|
int32_t out = 0;
|
|
while (n >>= 1)
|
|
{
|
|
++out;
|
|
}
|
|
TLLM_CHECK_ERROR((1 << out) == val, "Expected ", name, " to be a power of 2, got ", val);
|
|
return out;
|
|
}
|
|
} // namespace
|
|
|
|
Runner::Runner() {}
|
|
|
|
void Runner::run(void* routingLogits, void* routingBias, int32_t num_tokens, int32_t num_experts, int32_t top_k,
|
|
int32_t n_group, int32_t topk_group, int32_t local_expert_offset, int32_t local_num_experts,
|
|
float routed_scaling_factor, int32_t* routingExpertIndexes, int32_t* expertCountHistogram,
|
|
int32_t* permuted_idx_size, int32_t* expanded_idx_to_permuted_idx, int32_t* permuted_idx_to_expanded_idx,
|
|
int32_t* permuted_idx_to_token_idx, void* expert_weights, int32_t* num_tokens_per_expert,
|
|
int32_t* cta_idx_xy_to_batch_idx, int32_t* cta_idx_xy_to_mn_limit, int32_t* num_non_exiting_ctas,
|
|
tg::Dtype dtypeElt, bool use_routing_scales_on_input, bool use_deep_seek_fp8, cudaStream_t stream)
|
|
{
|
|
if (top_k == 8)
|
|
{
|
|
std::vector<int32_t> selectedIndex;
|
|
for (size_t ii = 0; ii < PermuteGemm1::gemmList.size(); ii++)
|
|
{
|
|
auto gemmInfo = PermuteGemm1::gemmList[ii];
|
|
if (gemmInfo.dtypeElt == dtypeElt && gemmInfo.usePerTokenSfB == use_routing_scales_on_input
|
|
&& gemmInfo.useDeepSeekFp8 == use_deep_seek_fp8)
|
|
{
|
|
selectedIndex.push_back(ii);
|
|
}
|
|
}
|
|
TLLM_CHECK_WITH_INFO(selectedIndex.size() != 0, "No kernel found for the given element type");
|
|
TLLM_CHECK_WITH_INFO(selectedIndex.size() == 1, "Multiple kernels found for the given element type");
|
|
auto const& kernelInfo = PermuteGemm1::gemmList[*selectedIndex.begin()];
|
|
int32_t tileN = kernelInfo.tileN;
|
|
|
|
moe::dev::routing::Data routingData;
|
|
routingData.mDtypeElt = dtypeElt; // no-op for now as hidden_state is not input
|
|
routingData.mDtypeExpW = tg::Dtype::Bfloat16;
|
|
routingData.mUsePdl = true;
|
|
|
|
// output:
|
|
routingData.mPtrExpertIdx = routingExpertIndexes;
|
|
routingData.mPtrExpertCounts = expertCountHistogram;
|
|
routingData.mPtrPermutedIdxSize = permuted_idx_size;
|
|
routingData.mPtrExpandedIdxToPermutedIdx = expanded_idx_to_permuted_idx;
|
|
routingData.mPtrPermutedIdxToExpandedIdx = permuted_idx_to_expanded_idx;
|
|
routingData.mPtrPermutedIdxToTokenIdx = permuted_idx_to_token_idx;
|
|
routingData.mPtrNumTokensPerExpert = num_tokens_per_expert;
|
|
routingData.mPtrExpertWeights = expert_weights;
|
|
|
|
routingData.mPtrCtaIdxXyToBatchIdx = cta_idx_xy_to_batch_idx;
|
|
routingData.mPtrCtaIdxXyToMnLimit = cta_idx_xy_to_mn_limit;
|
|
routingData.mPtrNumNonExitingCtas = num_non_exiting_ctas;
|
|
routingData.mAllToAllRouteAct = false;
|
|
|
|
// input:
|
|
// routingData.mPtrRoutingWeights = args.mRoutingWeights; // routing weights (don't need if not using gemm)
|
|
routingData.mPtrRoutingBias = routingBias;
|
|
routingData.mPtrScores = reinterpret_cast<float*>(routingLogits);
|
|
// routingData.mPtrIn = args.mInputActs;
|
|
routingData.mNumTokens = num_tokens;
|
|
// routingData.mHiddenDim = args.mHiddenDim;
|
|
routingData.mNumExperts = num_experts;
|
|
routingData.mNumExpertGroups = n_group;
|
|
routingData.mNumLimitedGroups = topk_group;
|
|
routingData.mTopK = top_k;
|
|
routingData.mPaddingLog2 = computeLog2(tileN);
|
|
routingData.mLocalExpertsStartIdx = local_expert_offset;
|
|
routingData.mLocalExpertsStrideLog2 = 0;
|
|
routingData.mNumLocalExperts = local_num_experts;
|
|
routingData.mRouteScale = routed_scaling_factor;
|
|
routingData.mUseRoutingSoftmax = false;
|
|
moe::dev::routing::run(routingData, stream);
|
|
}
|
|
else if (top_k == 1)
|
|
{
|
|
std::vector<int32_t> selectedIndex;
|
|
for (size_t ii = 0; ii < PermuteGemm1::gemmList.size(); ii++)
|
|
{
|
|
auto gemmInfo = PermuteGemm1::gemmList[ii];
|
|
if (gemmInfo.dtypeElt == dtypeElt && gemmInfo.usePerTokenSfB == use_routing_scales_on_input
|
|
&& gemmInfo.useDeepSeekFp8 == use_deep_seek_fp8)
|
|
{
|
|
selectedIndex.push_back(ii);
|
|
}
|
|
}
|
|
TLLM_CHECK_WITH_INFO(selectedIndex.size() != 0, "No kernel found for the given element type");
|
|
TLLM_CHECK_WITH_INFO(selectedIndex.size() == 1, "Multiple kernels found for the given element type");
|
|
auto const& kernelInfo = PermuteGemm1::gemmList[*selectedIndex.begin()];
|
|
int32_t tileN = kernelInfo.tileN;
|
|
|
|
moe::dev::routingLlama4::Data routingData;
|
|
// routingData.mDtypeElt = dtypeElt; // no-op for now as hidden_state is not input
|
|
routingData.mDtypeExpW = tg::Dtype::Bfloat16;
|
|
routingData.mUsePdl = true;
|
|
|
|
// output:
|
|
routingData.mPtrExpertIdx = routingExpertIndexes;
|
|
routingData.mPtrExpertCounts = expertCountHistogram;
|
|
routingData.mPtrPermutedIdxSize = permuted_idx_size;
|
|
routingData.mPtrExpandedIdxToPermutedIdx = expanded_idx_to_permuted_idx;
|
|
// routingData.mPtrPermutedIdxToExpandedIdx = permuted_idx_to_expanded_idx;
|
|
routingData.mPtrPermutedIdxToTokenIdx = permuted_idx_to_token_idx;
|
|
// routingData.mPtrNumTokensPerExpert = num_tokens_per_expert;
|
|
routingData.mPtrExpertWeights = expert_weights;
|
|
|
|
routingData.mPtrCtaIdxXyToBatchIdx = cta_idx_xy_to_batch_idx;
|
|
routingData.mPtrCtaIdxXyToMnLimit = cta_idx_xy_to_mn_limit;
|
|
routingData.mPtrNumNonExitingCtas = num_non_exiting_ctas;
|
|
// routingData.mAllToAllRouteAct = false;
|
|
|
|
// input:
|
|
// routingData.mPtrRoutingWeights = args.mRoutingWeights; // routing weights (don't need if not using gemm)
|
|
// routingData.mPtrRoutingBias = routingBias;
|
|
routingData.mPtrScores = routingLogits;
|
|
// routingData.mPtrIn = args.mInputActs;
|
|
routingData.mNumTokens = num_tokens;
|
|
// routingData.mHiddenDim = args.mHiddenDim;
|
|
routingData.mNumExperts = num_experts;
|
|
// routingData.mNumExpertGroups = n_group;
|
|
// routingData.mNumLimitedGroups = topk_group;
|
|
routingData.mTopK = top_k;
|
|
routingData.mPaddingLog2 = computeLog2(tileN);
|
|
routingData.mLocalExpertsStartIdx = local_expert_offset;
|
|
routingData.mLocalExpertsStrideLog2 = 0;
|
|
routingData.mNumLocalExperts = local_num_experts;
|
|
// routingData.mRouteScale = routed_scaling_factor;
|
|
// routingData.mUseRoutingSoftmax = false;
|
|
moe::dev::routingLlama4::run(routingData, stream);
|
|
}
|
|
else
|
|
{
|
|
TLLM_CHECK_ERROR(false, "top_k can only be 1 or 8.");
|
|
}
|
|
}
|
|
} // namespace Routing
|
|
|
|
namespace PermuteGemm1
|
|
{
|
|
Runner::Runner(trtllm::gen::Dtype dtypeElt)
|
|
: mDtypeElt(dtypeElt)
|
|
{
|
|
}
|
|
|
|
void Runner::run(void* hidden_state, void* hidden_state_scale, void* weight, void* weight_scale, void* expert_weights,
|
|
float* output_scales_scalar, float* output_scales_gate_scalar, void* output, void* output_scale, int32_t top_k,
|
|
int32_t hidden_size, int32_t intermediate_size, int32_t num_experts, int32_t num_tokens,
|
|
int32_t* permuted_idx_to_token_idx, int32_t* ptr_num_non_exiting_ctas, int32_t* ptr_total_num_padded_tokens,
|
|
int32_t* ptr_cta_idx_xy_to_batch_idx, int32_t* ptr_cta_idx_xy_to_mn_limit, bool use_routing_scales_on_input,
|
|
bool use_deep_seek_fp8, cudaStream_t stream)
|
|
{
|
|
std::vector<int32_t> selectedIndex;
|
|
for (size_t ii = 0; ii < gemmList.size(); ii++)
|
|
{
|
|
auto gemmInfo = gemmList[ii];
|
|
if (gemmInfo.dtypeElt == mDtypeElt && gemmInfo.usePerTokenSfB == use_routing_scales_on_input
|
|
&& gemmInfo.useDeepSeekFp8 == use_deep_seek_fp8)
|
|
{
|
|
selectedIndex.push_back(ii);
|
|
}
|
|
}
|
|
TLLM_CHECK_WITH_INFO(selectedIndex.size() != 0, "No kernel found for the given element type");
|
|
TLLM_CHECK_WITH_INFO(selectedIndex.size() == 1, "Multiple kernels found for the given element type");
|
|
auto const& kernelInfo = gemmList[*selectedIndex.begin()];
|
|
|
|
gemmCommon::MyOptions options;
|
|
options.mTopK = top_k;
|
|
options.mBatchM = false;
|
|
options.mTransposeMmaOutput = true;
|
|
options.mNumTokens = num_tokens;
|
|
options.mNumExperts = num_experts;
|
|
options.mM = 2 * intermediate_size;
|
|
options.mN = 256; // A default value in GemmOptions.h that is not supposed to be used. Same as trtllm-gen behavior.
|
|
options.mK = hidden_size;
|
|
options.mClusterDimX = 1;
|
|
options.mClusterDimY = 1;
|
|
options.mClusterDimZ = 1;
|
|
options.mAllReduceAlgo = gemmCommon::gemm::AllReduceAlgo::None;
|
|
options.mSplitK = gemmCommon::gemm::SplitK::None;
|
|
options.mPtrNumNonExitingCtas = ptr_num_non_exiting_ctas;
|
|
options.mPtrTotalNumPaddedTokens = ptr_total_num_padded_tokens;
|
|
options.mPtrCtaIdxXyToBatchIdx = ptr_cta_idx_xy_to_batch_idx;
|
|
options.mPtrCtaIdxXyToMnLimit = ptr_cta_idx_xy_to_mn_limit;
|
|
options.mSfLayoutB = tg::SfLayout::Linear;
|
|
options.mSfLayoutC = tg::SfLayout::Linear;
|
|
options.mUseCustomLowLatencyImpl = false;
|
|
options.mAllToAllRouteAct = false;
|
|
options.mIsStaticBatch = false;
|
|
options.mBatchedN = std::vector(num_experts, -1);
|
|
gemmCommon::copyKernelInfoToOptions(kernelInfo, options);
|
|
gemmCommon::batchedGemm::checkAndUpdateGemmOptions(options, true, false, false);
|
|
|
|
gemmCommon::BatchedGemmData batchedGemmData;
|
|
auto max_num_padded_tokens = Routing::getMaxPermutedPaddedCount(num_tokens, top_k, num_experts, kernelInfo.tileN);
|
|
gemmCommon::setSingleBatchedGemmData(weight, hidden_state, output, output_scales_scalar, output_scales_gate_scalar,
|
|
reinterpret_cast<float*>(weight_scale), reinterpret_cast<float*>(hidden_state_scale),
|
|
reinterpret_cast<float*>(output_scale),
|
|
// FIXME: we pass the same scaling factors in one case for dsfp8 and in the other case for fp4
|
|
// We should pass them once only and decide on the case inside of the setSingleBatchedGemmData
|
|
weight_scale, hidden_state_scale, output_scale, permuted_idx_to_token_idx, nullptr, nullptr, expert_weights,
|
|
kernelInfo.numSlicesForSplitK, options, batchedGemmData, max_num_padded_tokens);
|
|
|
|
gemmCommon::launchGemmFromData(kernelInfo, options, batchedGemmData, stream, /*usePDL*/ false);
|
|
}
|
|
} // namespace PermuteGemm1
|
|
|
|
namespace Gemm2
|
|
{
|
|
Runner::Runner(tg::Dtype dtypeElt, tg::Dtype outputDtype)
|
|
: mDtypeElt(dtypeElt)
|
|
, mOutputDtype(outputDtype)
|
|
{
|
|
}
|
|
|
|
void Runner::run(void* permuted_hidden_state, void* permuted_hidden_state_scale, void* weight, void* weight_scale,
|
|
float* output_scales_scalar, void* output, void* output_scale, int32_t top_k, int32_t hidden_size,
|
|
int32_t intermediate_size, int32_t num_experts, int32_t num_tokens, int32_t* ptr_num_non_exiting_ctas,
|
|
int32_t* ptr_total_num_padded_tokens, int32_t* ptr_cta_idx_xy_to_batch_idx, int32_t* ptr_cta_idx_xy_to_mn_limit,
|
|
bool use_deep_seek_fp8, cudaStream_t stream)
|
|
{
|
|
std::vector<int32_t> selectedIndex;
|
|
for (size_t ii = 0; ii < gemmList.size(); ii++)
|
|
{
|
|
auto gemmInfo = gemmList[ii];
|
|
if (gemmInfo.dtypeElt == mDtypeElt && gemmInfo.dtypeC == mOutputDtype
|
|
&& gemmInfo.useDeepSeekFp8 == use_deep_seek_fp8)
|
|
{
|
|
selectedIndex.push_back(ii);
|
|
}
|
|
}
|
|
TLLM_CHECK_WITH_INFO(selectedIndex.size() != 0, "No kernel found for the given element and output types");
|
|
TLLM_CHECK_WITH_INFO(selectedIndex.size() == 1, "Multiple kernels found for the given element and output types");
|
|
auto const& kernelInfo = gemmList[*selectedIndex.begin()];
|
|
|
|
gemmCommon::MyOptions options;
|
|
options.mTopK = top_k;
|
|
options.mBatchM = false;
|
|
options.mTransposeMmaOutput = true;
|
|
options.mNumExperts = num_experts;
|
|
options.mNumTokens = num_tokens;
|
|
options.mM = hidden_size;
|
|
options.mN = 256; // A default value in GemmOptions.h that is not supposed to be used. Same as trtllm-gen behavior.
|
|
options.mK = intermediate_size;
|
|
options.mClusterDimX = 1;
|
|
options.mClusterDimY = 1;
|
|
options.mClusterDimZ = 1;
|
|
options.mAllReduceAlgo = gemmCommon::gemm::AllReduceAlgo::None;
|
|
options.mSplitK = gemmCommon::gemm::SplitK::None;
|
|
options.mPtrNumNonExitingCtas = ptr_num_non_exiting_ctas;
|
|
options.mPtrTotalNumPaddedTokens = ptr_total_num_padded_tokens;
|
|
options.mPtrCtaIdxXyToBatchIdx = ptr_cta_idx_xy_to_batch_idx;
|
|
options.mPtrCtaIdxXyToMnLimit = ptr_cta_idx_xy_to_mn_limit;
|
|
options.mSfLayoutB = tg::SfLayout::Linear;
|
|
options.mSfLayoutC = tg::SfLayout::Linear;
|
|
options.mUseCustomLowLatencyImpl = false;
|
|
options.mAllToAllRouteAct = false;
|
|
options.mIsStaticBatch = false;
|
|
options.mBatchedN = std::vector(num_experts, -1);
|
|
gemmCommon::copyKernelInfoToOptions(kernelInfo, options);
|
|
gemmCommon::batchedGemm::checkAndUpdateGemmOptions(options, true, false, false);
|
|
|
|
gemmCommon::BatchedGemmData batchedGemmData;
|
|
auto max_num_padded_tokens = Routing::getMaxPermutedPaddedCount(num_tokens, top_k, num_experts, kernelInfo.tileN);
|
|
gemmCommon::setSingleBatchedGemmData(weight, permuted_hidden_state, output, output_scales_scalar, nullptr,
|
|
reinterpret_cast<float*>(weight_scale), reinterpret_cast<float*>(permuted_hidden_state_scale),
|
|
reinterpret_cast<float*>(output_scale), weight_scale, permuted_hidden_state_scale, output_scale, nullptr,
|
|
nullptr, nullptr, nullptr, kernelInfo.numSlicesForSplitK, options, batchedGemmData, max_num_padded_tokens);
|
|
|
|
gemmCommon::launchGemmFromData(kernelInfo, options, batchedGemmData, stream);
|
|
}
|
|
} // namespace Gemm2
|
|
|
|
namespace MoE
|
|
{
|
|
Runner::Runner() {}
|
|
|
|
void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace,
|
|
moe::dev::convertsf::Data& convertSfData, moe::dev::activation::Data& activationData,
|
|
moe::dev::finalize::Data& finalizeData)
|
|
{
|
|
// Setup sf conversion data if needed
|
|
convertSfData.inSfPtr = args.hidden_states_scale;
|
|
convertSfData.outSfPtr = workspace.hidden_states_scale_linear;
|
|
convertSfData.hiddenDimSf = args.hidden_size / 16;
|
|
convertSfData.numTokens = args.num_tokens;
|
|
convertSfData.sfLayoutSrc = tg::SfLayout::R128c4;
|
|
convertSfData.sfLayoutDst = tg::SfLayout::Linear;
|
|
convertSfData.mUsePdl = true;
|
|
|
|
// Setup activation data
|
|
activationData.mDtypeElt = args.mDtypeElt;
|
|
activationData.mUsePdl = true;
|
|
activationData.mUseDeepSeekFp8 = true;
|
|
activationData.inPtr = workspace.gemm1_output;
|
|
activationData.outPtr = workspace.activation_output;
|
|
activationData.inDqSfsPtr = workspace.gemm1_output_scale;
|
|
activationData.outDqSfsPtr = workspace.activation_output_scale;
|
|
activationData.innerDim = args.intermediate_size * 2;
|
|
activationData.topK = args.top_k;
|
|
activationData.numTokens = args.num_tokens;
|
|
activationData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx;
|
|
|
|
activationData.totalNumPaddedTokens = workspace.total_num_padded_tokens;
|
|
|
|
// Setup finalize data
|
|
finalizeData.mDtypeElt = args.mDtypeOut;
|
|
finalizeData.mDtypeExpW = args.mDtypeExpW;
|
|
finalizeData.mUsePdl = true;
|
|
finalizeData.mUseDeepSeekFp8 = false;
|
|
finalizeData.inPtr = workspace.gemm2_output;
|
|
finalizeData.outPtr = args.output;
|
|
finalizeData.inDqSfsPtr = workspace.gemm2_output_scale;
|
|
finalizeData.outDqSfsPtr = args.output_scale;
|
|
if (args.mUseRoutingScalesOnInput)
|
|
{
|
|
finalizeData.expertWeightsPtr = nullptr;
|
|
}
|
|
else
|
|
{
|
|
finalizeData.expertWeightsPtr = workspace.expert_weights;
|
|
}
|
|
finalizeData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx;
|
|
finalizeData.numTokens = args.num_tokens;
|
|
finalizeData.numExperts = args.num_experts;
|
|
finalizeData.topK = args.top_k;
|
|
finalizeData.hiddenDim = args.hidden_size;
|
|
finalizeData.totalNumPaddedTokens = workspace.total_num_padded_tokens;
|
|
}
|
|
|
|
void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, cudaStream_t stream)
|
|
{
|
|
// Setup all operation data
|
|
moe::dev::activation::Data activationData;
|
|
moe::dev::finalize::Data finalizeData;
|
|
moe::dev::convertsf::Data convertSfData;
|
|
|
|
setOpsData(args, workspace, convertSfData, activationData, finalizeData);
|
|
|
|
void* hidden_states_scale_linear{args.hidden_states_scale};
|
|
|
|
PermuteGemm1::Runner permuteGemm1(args.mDtypeElt);
|
|
permuteGemm1.run(args.hidden_states, hidden_states_scale_linear, args.gemm1_weights, args.gemm1_weights_scale,
|
|
workspace.expert_weights, args.output1_scales_scalar, args.output1_scales_gate_scalar, workspace.gemm1_output,
|
|
workspace.gemm1_output_scale, args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts,
|
|
args.num_tokens, workspace.permuted_idx_to_token_idx, workspace.num_non_exiting_ctas,
|
|
workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit,
|
|
args.mUseRoutingScalesOnInput, args.mUseDeepSeekFp8, stream);
|
|
|
|
// We do not fuse activation with FC1 for DeepSeek FP8 due to the weights shuffling constraint.
|
|
void* gemm2_input = workspace.gemm1_output;
|
|
void* gemm2_input_scale = workspace.gemm1_output_scale;
|
|
// We do activation only for DeepSeek FP8, as cubins do not have fused activation.
|
|
if (args.mDtypeElt == tg::Dtype::E4m3 && args.mUseDeepSeekFp8)
|
|
{
|
|
// Run activation
|
|
moe::dev::activation::run(activationData, stream);
|
|
gemm2_input = workspace.activation_output;
|
|
gemm2_input_scale = workspace.activation_output_scale;
|
|
}
|
|
|
|
// Run gemm2
|
|
Gemm2::Runner gemm2(args.mDtypeElt, tg::Dtype::Bfloat16);
|
|
gemm2.run(gemm2_input, gemm2_input_scale, args.gemm2_weights, args.gemm2_weights_scale, args.output2_scales_scalar,
|
|
workspace.gemm2_output, workspace.gemm2_output_scale, args.top_k, args.hidden_size, args.intermediate_size,
|
|
args.local_num_experts, args.num_tokens, workspace.num_non_exiting_ctas, workspace.total_num_padded_tokens,
|
|
workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit, args.mUseDeepSeekFp8, stream);
|
|
|
|
// Run finalize
|
|
moe::dev::finalize::run(finalizeData, stream);
|
|
}
|
|
} // namespace MoE
|
|
|
|
} // namespace trtllmGenFp8BlockScaleMoe
|
|
} // namespace kernels
|
|
} // namespace tensorrt_llm
|