[https://nvbugs/5726962][feat] Apply fusion for W4AFP8_AWQ MoE (#9838)

Signed-off-by: Min Yu <171526537+yumin066@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Co-authored-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
This commit is contained in:
Min Yu 2026-01-06 10:16:41 +08:00 committed by GitHub
parent 6b8ae6fa81
commit 9cae7277ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 791 additions and 335 deletions

View File

@ -315,7 +315,8 @@ struct QuantParams
{
struct GroupwiseGemmInputs
{
void const* act_scales = nullptr;
bool use_per_expert_act_scale = false;
void const* act_scales = nullptr; // (1 or num_experts_per_node, hidden_size or intermediate_size)
void const* weight_scales = nullptr;
void const* weight_zeros = nullptr;
float const* alpha = nullptr;
@ -401,12 +402,15 @@ struct QuantParams
static QuantParams GroupWise(int group_size, void const* fc1_weight_scales, void const* fc2_weight_scales,
void const* fc1_activation_scales = nullptr, void const* fc2_activation_scales = nullptr,
void const* fc1_weight_zeros = nullptr, void const* fc2_weight_zeros = nullptr,
float const* fc1_alpha = nullptr, float const* fc2_alpha = nullptr)
float const* fc1_alpha = nullptr, float const* fc2_alpha = nullptr, bool fc1_use_per_expert_act_scale = false,
bool fc2_use_per_expert_act_scale = false)
{
QuantParams qp;
qp.groupwise.group_size = group_size;
qp.groupwise.fc1 = {fc1_activation_scales, fc1_weight_scales, fc1_weight_zeros, fc1_alpha};
qp.groupwise.fc2 = {fc2_activation_scales, fc2_weight_scales, fc2_weight_zeros, fc2_alpha};
qp.groupwise.fc1
= {fc1_use_per_expert_act_scale, fc1_activation_scales, fc1_weight_scales, fc1_weight_zeros, fc1_alpha};
qp.groupwise.fc2
= {fc2_use_per_expert_act_scale, fc2_activation_scales, fc2_weight_scales, fc2_weight_zeros, fc2_alpha};
return qp;
}
@ -646,7 +650,7 @@ public:
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
int* num_active_experts_per, int* active_expert_global_ids);
int* num_active_experts_per, int* active_expert_global_ids, void const* fc2_prequant_scale = nullptr);
static void gemm2(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output,
@ -803,6 +807,16 @@ private:
bool min_latency_mode, bool use_awq);
private:
static bool useAwq(cutlass_kernels::QuantParams const& quant_params)
{
return quant_params.groupwise.fc1.act_scales && quant_params.groupwise.fc2.act_scales && !use_wfp4a16;
}
static bool usePrequantScaleKernel(cutlass_kernels::QuantParams const& quant_params)
{
return useAwq(quant_params) && !std::is_same_v<T, WeightType>;
}
bool mayHaveDifferentGEMMOutputType() const
{
// We just check if its supported because we need to know when calculating workspace size
@ -813,13 +827,13 @@ private:
bool mayHaveFinalizeFused() const
{
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && use_fused_finalize_
&& !use_w4_groupwise;
&& !use_wfp4a16;
}
static bool mayHaveFinalizeFused(int sm)
{
using RunnerType = decltype(moe_gemm_runner_);
return RunnerType::supportsTmaWarpSpecialized(sm) && sm >= 90 && !use_w4_groupwise;
return RunnerType::supportsTmaWarpSpecialized(sm) && sm >= 90 && !use_wfp4a16;
}
// TODO: This should eventually take the quant params to give more flexibility
@ -866,7 +880,8 @@ private:
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
cudaStream_t stream, int64_t* expert_first_token_offset = nullptr, int const num_experts_per_node = 0);
cudaStream_t stream, QuantParams const& quant_params, int64_t* expert_first_token_offset = nullptr,
int const num_experts_per_node = 0);
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;

View File

@ -28,8 +28,9 @@ namespace cutlass_kernels_oss
{
using tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput;
using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput;
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, typename CTAShape,
typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion;
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, EpilogueFusion FUSION,
typename CTAShape, typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
cutlass::WeightOnlyQuantOp QuantOp>
void sm90_generic_mixed_moe_gemm_kernelLauncher(
tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,

View File

@ -45,6 +45,7 @@
#include "cutlass/util/tensor_view_io.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp"
#include "cutlass_extensions/gemm_configs.h"
@ -71,11 +72,12 @@ namespace cutlass_kernels_oss
using namespace tensorrt_llm::kernels::cutlass_kernels;
namespace tk = tensorrt_llm::common;
namespace tkc = tensorrt_llm::cutlass_extensions;
using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion;
using namespace cute;
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, typename CTAShape,
typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, EpilogueFusion FUSION,
typename CTAShape, typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
cutlass::WeightOnlyQuantOp QuantOp>
void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,
TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size)
@ -85,6 +87,9 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
static_assert(FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE,
"Unimplemented fusion provided to TMA WS Mixed MoE gemm launcher");
constexpr static bool IsFinalizeFusion = FUSION == EpilogueFusion::FINALIZE;
// A matrix configuration
using ElementA = typename TllmToCutlassTypeAdapter<T>::type;
@ -129,6 +134,9 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
using ElementFinalOutput = typename TllmToCutlassTypeAdapter<GemmOutputType>::type;
using ElementBias = ElementFinalOutput;
using ElementRouterScales = float;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
@ -136,6 +144,11 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = CTAShape; // Threadblock-level tile size
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using EpilogueFusionOp = cutlass::epilogue::fusion::ScaledAccPerRowBiasPerColScaleScatter<
typename cutlass::layout::LayoutTranspose<LayoutD>::type, ElementFinalOutput, ElementAccumulator, ElementBias,
ElementRouterScales>;
using KernelSchedule
= std::conditional_t<std::is_same_v<MainloopScheduleType, cutlass::gemm::KernelTmaWarpSpecializedPingpong>,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong,
@ -145,12 +158,21 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong,
cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative>; // Epilogue to launch
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<cutlass::arch::Sm90,
using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::CollectiveBuilder<cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator, ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type*,
AlignmentC, void, typename cutlass::layout::LayoutTranspose<LayoutD>::type*, AlignmentD, EpilogueSchedule,
EpilogueFusionOp>::CollectiveOp;
using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder<cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator, ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type*,
AlignmentC, ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type*, AlignmentD,
EpilogueSchedule>::CollectiveOp;
using CollectiveEpilogue
= std::conditional_t<IsFinalizeFusion, CollectiveEpilogueFinalize, CollectiveEpilogueDefault>;
// =========================================================== MIXED INPUT WITH SCALES
// =========================================================================== The Scale information must get paired
// with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the
@ -175,20 +197,56 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
Args arguments;
decltype(arguments.epilogue.thread) fusion_args;
fusion_args.alpha = use_wfp4a16 ? 1 : 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = use_wfp4a16 ? nullptr : inputs.alpha_scales;
fusion_args.beta_ptr_array = nullptr;
// One alpha and beta per each group
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1};
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count = sm_count_;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueScalars = decltype(EpilogueArguments{}.thread);
EpilogueScalars epilogue_scalars = [&]
{
if constexpr (IsFinalizeFusion)
{
auto epi_params = hopper_inputs.fused_finalize_epilogue;
return EpilogueScalars{ElementAccumulator(1), nullptr, hopper_inputs.alpha_scale_ptr_array,
Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, /* alpha */
reinterpret_cast<ElementBias const* const*>(epi_params.ptr_bias), Stride<_1, _0, int64_t>{}, /* bias */
epi_params.ptr_router_scales, Stride<_0, _1, int64_t>{}, /* scale */
reinterpret_cast<ElementFinalOutput*>(epi_params.ptr_final_output),
epi_params.stride_final_output_transposed, epi_params.ptr_source_token_index,
epi_params.num_rows_in_final_output, epi_params.shape_override, epi_params.use_reduction};
}
else
{
return EpilogueScalars{};
}
}();
EpilogueArguments epilogue_args = [&]
{
if constexpr (IsFinalizeFusion)
{
return EpilogueArguments{epilogue_scalars, nullptr, nullptr, nullptr, nullptr};
}
else
{
fusion_args.alpha = use_wfp4a16 ? 1 : 0;
fusion_args.beta = 0;
fusion_args.alpha_ptr = nullptr;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = use_wfp4a16 ? nullptr : inputs.alpha_scales;
fusion_args.beta_ptr_array = nullptr;
// One alpha and beta per each group
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1};
return EpilogueArguments{fusion_args, reinterpret_cast<ElementC const**>(hopper_inputs.ptr_c),
reinterpret_cast<StrideC*>(hopper_inputs.stride_c), reinterpret_cast<ElementD**>(hopper_inputs.ptr_d),
reinterpret_cast<StrideD*>(hopper_inputs.stride_d)};
}
}();
arguments = Args{cutlass::gemm::GemmUniversalMode::kGrouped,
{inputs.num_experts, hopper_inputs.int4_groupwise_params.shape.problem_shapes, nullptr},
{reinterpret_cast<ElementB const**>(hopper_inputs.ptr_weight),
@ -197,10 +255,7 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
reinterpret_cast<StrideA*>(hopper_inputs.stride_act),
reinterpret_cast<ElementScalePacked const**>(hopper_inputs.int4_groupwise_params.ptr_s_a),
reinterpret_cast<StrideS*>(hopper_inputs.int4_groupwise_params.stride_s_a), group_size},
{fusion_args, reinterpret_cast<ElementC const**>(hopper_inputs.ptr_c),
reinterpret_cast<StrideC*>(hopper_inputs.stride_c), reinterpret_cast<ElementD**>(hopper_inputs.ptr_d),
reinterpret_cast<StrideD*>(hopper_inputs.stride_d)},
hw_info};
epilogue_args, hw_info};
assert(group_size == int(inputs.groupwise_quant_group_size));
if (workspace_size != nullptr)

View File

@ -792,25 +792,37 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
TLLM_CHECK_WITH_INFO(
inputs.gemm_config.is_tma_warp_specialized, "w4afp8 is only supported for TMA warp specialization");
// EpilogueTag is ignored
#define SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE(SCALE_FACTOR) \
if (hopper_inputs.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE) \
{ \
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType, \
cutlass_extensions::EpilogueOpDefault, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE, \
SCALE_FACTOR>(inputs, hopper_inputs, multi_processor_count_, nullptr); \
} \
else \
{ \
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType, \
cutlass_extensions::EpilogueOpDefault, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, \
SCALE_FACTOR>(inputs, hopper_inputs, multi_processor_count_, nullptr); \
}
if (inputs.k % 512 == 0)
{
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
cutlass_extensions::EpilogueOpDefault, 4>(inputs, hopper_inputs, multi_processor_count_, nullptr);
SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE(4)
}
else if (inputs.k % 256 == 0)
{
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
cutlass_extensions::EpilogueOpDefault, 2>(inputs, hopper_inputs, multi_processor_count_, nullptr);
SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE(2)
}
else if (inputs.k % 128 == 0)
{
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
cutlass_extensions::EpilogueOpDefault, 1>(inputs, hopper_inputs, multi_processor_count_, nullptr);
SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE(1)
}
else
{
TLLM_THROW("Invalid GEMM K size %d", (int) inputs.k);
}
#undef SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE
return;
}
@ -820,7 +832,8 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
inputs.gemm_config.is_tma_warp_specialized, "wfp4a16 is only supported for TMA warp specialization");
// EpilogueTag is ignored
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
cutlass_extensions::EpilogueOpDefault, 1>(inputs, hopper_inputs, multi_processor_count_, nullptr);
cutlass_extensions::EpilogueOpDefault, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, 1>(
inputs, hopper_inputs, multi_processor_count_, nullptr);
return;
}
#endif

View File

@ -37,6 +37,7 @@
#include "cutlass/tensor_ref.h"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/detail/collective/mixed_input_utils.hpp"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
@ -67,11 +68,12 @@ using tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput;
using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput;
namespace tk = tensorrt_llm::common;
namespace tkc = tensorrt_llm::cutlass_extensions;
using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion;
using namespace cute;
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, typename CTAShape,
typename ClusterShape>
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, EpilogueFusion FUSION,
typename CTAShape, typename ClusterShape>
void sm90_dispatch_mainloop_schedules(GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,
TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size)
{
@ -88,7 +90,7 @@ void sm90_dispatch_mainloop_schedules(GroupedGemmInput<T, WeightType, GemmOutput
{
if constexpr ((get<0>(CTAShape{}) == 128) && get<1>(CTAShape{}) == 128)
{
sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, GemmOutputType, EpilogueTag, CTAShape,
sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, GemmOutputType, EpilogueTag, FUSION, CTAShape,
ClusterShape, cutlass::gemm::KernelTmaWarpSpecializedPingpong,
cutlass::epilogue::TmaWarpSpecializedCooperative,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>(
@ -96,7 +98,7 @@ void sm90_dispatch_mainloop_schedules(GroupedGemmInput<T, WeightType, GemmOutput
}
else
{
sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, GemmOutputType, EpilogueTag, CTAShape,
sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, GemmOutputType, EpilogueTag, FUSION, CTAShape,
ClusterShape, cutlass::gemm::KernelTmaWarpSpecializedCooperative,
cutlass::epilogue::TmaWarpSpecializedCooperative,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>(
@ -106,9 +108,10 @@ void sm90_dispatch_mainloop_schedules(GroupedGemmInput<T, WeightType, GemmOutput
break;
case tkc::MainloopScheduleType::PINGPONG:
sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, GemmOutputType, EpilogueTag, CTAShape, ClusterShape,
cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecializedCooperative,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>(inputs, hopper_inputs, sm_count_, workspace_size);
sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, GemmOutputType, EpilogueTag, FUSION, CTAShape,
ClusterShape, cutlass::gemm::KernelTmaWarpSpecializedPingpong,
cutlass::epilogue::TmaWarpSpecializedCooperative, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>(
inputs, hopper_inputs, sm_count_, workspace_size);
break;
default:
TLLM_THROW(
@ -122,7 +125,8 @@ void sm90_dispatch_mainloop_schedules(GroupedGemmInput<T, WeightType, GemmOutput
#endif
}
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, typename CTAShape>
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, EpilogueFusion FUSION,
typename CTAShape>
void sm90_dispatch_moe_mixed_dtype_gemm_config(GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,
TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size)
{
@ -130,26 +134,27 @@ void sm90_dispatch_moe_mixed_dtype_gemm_config(GroupedGemmInput<T, WeightType, G
switch (inputs.gemm_config.cluster_shape)
{
case tkc::ClusterShape::ClusterShape_1x1x1:
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, CTAShape, Shape<_1, _1, _1>>(
inputs, hopper_inputs, sm_count_, workspace_size);
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, FUSION, CTAShape,
Shape<_1, _1, _1>>(inputs, hopper_inputs, sm_count_, workspace_size);
break;
case tkc::ClusterShape::ClusterShape_2x1x1:
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, CTAShape, Shape<_2, _1, _1>>(
inputs, hopper_inputs, sm_count_, workspace_size);
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, FUSION, CTAShape,
Shape<_2, _1, _1>>(inputs, hopper_inputs, sm_count_, workspace_size);
break;
case tkc::ClusterShape::ClusterShape_1x2x1:
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, CTAShape, Shape<_1, _2, _1>>(
inputs, hopper_inputs, sm_count_, workspace_size);
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, FUSION, CTAShape,
Shape<_1, _2, _1>>(inputs, hopper_inputs, sm_count_, workspace_size);
break;
case tkc::ClusterShape::ClusterShape_2x2x1:
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, CTAShape, Shape<_2, _2, _1>>(
inputs, hopper_inputs, sm_count_, workspace_size);
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, FUSION, CTAShape,
Shape<_2, _2, _1>>(inputs, hopper_inputs, sm_count_, workspace_size);
break;
default: TLLM_THROW("[Mixed dtype MoE GEMM][dispatch_CGA_config] Config is invalid for mixed type GEMM."); break;
}
}
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, int PackedScalesNum>
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, EpilogueFusion FUSION,
int PackedScalesNum>
void sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass(
GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,
TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size)
@ -168,49 +173,49 @@ void sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass(
switch (inputs.gemm_config.tile_config_sm90)
{
case tkc::CutlassTileConfigSM90::CtaShape64x16x128B:
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_64, _16, _Ktile>>(
inputs, hopper_inputs, sm_count_, workspace_size);
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
Shape<_64, _16, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
break;
case tkc::CutlassTileConfigSM90::CtaShape64x32x128B:
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_64, _32, _Ktile>>(
inputs, hopper_inputs, sm_count_, workspace_size);
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
Shape<_64, _32, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
break;
case tkc::CutlassTileConfigSM90::CtaShape64x64x128B:
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_64, _64, _Ktile>>(
inputs, hopper_inputs, sm_count_, workspace_size);
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
Shape<_64, _64, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
break;
case tkc::CutlassTileConfigSM90::CtaShape64x128x128B:
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag,
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
Shape<_64, _Ntile, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
break;
// case tkc::CutlassTileConfigSM90::CtaShape64x256x128B:
// sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_64, _256,
// _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break;
case tkc::CutlassTileConfigSM90::CtaShape128x16x128B:
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_128, _16, _Ktile>>(
inputs, hopper_inputs, sm_count_, workspace_size);
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
Shape<_128, _16, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
break;
case tkc::CutlassTileConfigSM90::CtaShape128x32x128B:
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_128, _32, _Ktile>>(
inputs, hopper_inputs, sm_count_, workspace_size);
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
Shape<_128, _32, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
break;
case tkc::CutlassTileConfigSM90::CtaShape128x64x128B:
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_128, _64, _Ktile>>(
inputs, hopper_inputs, sm_count_, workspace_size);
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
Shape<_128, _64, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
break;
case tkc::CutlassTileConfigSM90::CtaShape128x128x128B:
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag,
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
Shape<_128, _128, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
break;
// case tkc::CutlassTileConfigSM90::CtaShape128x256x128B:
// sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_128, _256,
// _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break;
// sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION, Shape<_128,
// _256, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break;
// case tkc::CutlassTileConfigSM90::CtaShape256x128x128B:
// sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_128, _256,
// _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break;
// sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION, Shape<_128,
// _256, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break;
// case tkc::CutlassTileConfigSM90::CtaShape256x256x128B:
// sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_256, _256,
// _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break;
// sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION, Shape<_256,
// _256, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break;
case tkc::CutlassTileConfigSM90::Undefined:
TLLM_THROW("[Mixed dtype MoE GEMM][sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass] gemm config undefined.");
break;
@ -236,12 +241,15 @@ size_t calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput(int num_experts, int sm_
using _Ktile = Int<Ktile>;
#ifdef COMPILE_HOPPER_TMA_GROUPED_GEMMS
GroupedGemmInput<T, WeightType, OutputType, OutputType> inputs{};
constexpr bool use_wfp4a16 = std::is_same_v<WeightType, cutlass::float_e2m1_t>;
constexpr int group_size = use_wfp4a16 ? cutlass::gemm::collective::detail::mxfp4_group_size
: cutlass::gemm::collective::detail::int4_group_size;
GroupedGemmInput<T, WeightType, OutputType, OutputType> inputs{.groupwise_quant_group_size = group_size};
inputs.num_experts = num_experts;
sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, OutputType,
tensorrt_llm::cutlass_extensions::EpilogueOpDefault, Shape<_128, _64, _Ktile>, Shape<_1, _1, _1>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>(
tensorrt_llm::cutlass_extensions::EpilogueOpDefault, EpilogueFusion::NONE, Shape<_128, _64, _Ktile>,
Shape<_1, _1, _1>, cutlass::gemm::KernelTmaWarpSpecializedCooperative,
cutlass::epilogue::TmaWarpSpecializedCooperative, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>(
inputs, TmaWarpSpecializedGroupedGemmInput{}, sm_count_, &count);
#endif
return count;

View File

@ -37,6 +37,7 @@
// Order matters here, packed_stride.hpp is missing cute and convolution includes
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
@ -76,15 +77,23 @@ TRTLLM_NAMESPACE_BEGIN
namespace kernels::cutlass_kernels
{
/**
* Takes the input maps and prepares the expanded maps for min latency
* @param num_active_experts_per_node: Number of active experts on current node
* @param experts_to_token_scores: The score of each token for each activated expert. 0 if the expert is not chosen by
* the token. Only the first num_active_experts_per_ rows are valid
* @param active_expert_global_ids: The global expert id for each activated expert
* Only the first num_active_experts_per_ values are valid
* @param expert_first_token_offset: Store the first token offset for each expert
*/
// Forced vectorized load
template <typename T>
__device__ __forceinline__ T loadVec(T const* ptr)
{
T result;
cutlass::arch::global_load<T, sizeof(T)>(result, ptr, true);
return result;
}
// Forced vectorized store
template <typename T>
__device__ __forceinline__ void storeVec(T* ptr, T const& value)
{
cutlass::arch::global_store<T, sizeof(T)>(value, ptr, true);
}
template <typename T, int BLOCK_SIZE>
__device__ __forceinline__ void initTensor(T* value, int const tid, int const total_num, T const init_value)
{
@ -156,6 +165,15 @@ __device__ __forceinline__ void setActiveNum(int& num_active, int& num_active_of
num_active_offset_end = num_active_offset_start + num_active;
}
/**
* Takes the input maps and prepares the expanded maps for min latency
* @param num_active_experts_per_node: Number of active experts on current node
* @param experts_to_token_scores: The score of each token for each activated expert. 0 if the expert is not chosen by
* the token. Only the first num_active_experts_per_ rows are valid
* @param active_expert_global_ids: The global expert id for each activated expert
* Only the first num_active_experts_per_ values are valid
* @param expert_first_token_offset: Store the first token offset for each expert
*/
template <int BLOCK_SIZE>
__global__ void buildMinLatencyActiveExpertMapsKernel(int* num_active_experts_per_node, float* experts_to_token_scores,
int* active_expert_global_ids, int64_t* expert_first_token_offset, int const* token_selected_experts,
@ -1593,7 +1611,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
auto func = [&]()
{
#ifdef ENABLE_FP8
// Always MXFP8
// Always MXFP8 and W4A8_AWQ
if constexpr (std::is_same_v<ExpandedActivationsType, __nv_fp8_e4m3>
&& !std::is_same_v<InputActivationsType, __nv_fp8_e4m3>)
{
@ -1950,7 +1968,7 @@ constexpr static int ACTIVATION_THREADS_PER_BLOCK = 256;
template <class ActivationOutputType, class GemmOutputType, class ActFn>
__global__ void doGatedActivationKernel(ActivationOutputType* output, GemmOutputType const* gemm_result,
int64_t const* expert_first_token_offset, int64_t inter_size, int64_t num_experts_per_node,
ActivationParams activation_type)
ActivationParams activation_type, GemmOutputType const* prequant_scale, bool use_per_expert_prequant_scale)
{
int64_t const tid = threadIdx.x;
int64_t const token = blockIdx.x;
@ -1978,16 +1996,24 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output, GemmOutput
float gate_alpha = 1.0f;
float gate_bias = 0.0f;
float gate_limit = std::numeric_limits<float>::infinity();
int expert = 0;
if (use_per_expert_prequant_scale || activation_type.swiglu_alpha || activation_type.swiglu_beta
|| activation_type.swiglu_limit)
{
expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, (int64_t) token + 1) - 1;
}
if (activation_type.swiglu_alpha || activation_type.swiglu_beta || activation_type.swiglu_limit)
{
int expert
= findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, (int64_t) token + 1) - 1;
gate_alpha = activation_type.swiglu_alpha ? activation_type.swiglu_alpha[expert] : 1.0f;
gate_bias = activation_type.swiglu_beta ? activation_type.swiglu_beta[expert] : 0.0f;
gate_limit = activation_type.swiglu_limit ? activation_type.swiglu_limit[expert]
: std::numeric_limits<float>::infinity();
}
auto prequant_scale_vec = prequant_scale ? reinterpret_cast<GemmResultElem const*>(
prequant_scale + (use_per_expert_prequant_scale ? expert * inter_size : 0))
: nullptr;
ActFn fn{};
fn.alpha = gate_alpha;
fn.beta = gate_bias;
@ -1998,6 +2024,13 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output, GemmOutput
// BF16 isn't supported, use FP32 for activation function
auto gate_value = arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index + inter_size_vec]);
auto gate_act = fn(gate_value, linear_value);
// Apply prequant scale if provided
if (prequant_scale_vec)
{
gate_act = gate_act * arrayConvert<GemmResultElem, ComputeElem>(prequant_scale_vec[elem_index]);
}
output_vec[elem_index] = arrayConvert<ComputeElem, OutputElem>(gate_act);
}
}
@ -2005,7 +2038,8 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output, GemmOutput
template <typename ActivationOutputType, typename GemmOutputType>
void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_result,
int64_t const* expert_first_token_offset, int64_t inter_size, int64_t num_tokens, int64_t num_experts_per_node,
ActivationParams activation_type, cudaStream_t stream)
ActivationParams activation_type, cudaStream_t stream, bool use_per_expert_prequant_scale = false,
GemmOutputType const* prequant_scale = nullptr)
{
int64_t const blocks = num_tokens;
int64_t const threads = ACTIVATION_THREADS_PER_BLOCK;
@ -2018,18 +2052,19 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_
? &doGatedActivationKernel<ActivationOutputType, GemmOutputType, SwigluBiasAdaptor>
: nullptr;
TLLM_CHECK_WITH_INFO(fn != nullptr, "Invalid activation type");
fn<<<blocks, threads, 0, stream>>>(
output, gemm_result, expert_first_token_offset, inter_size, num_experts_per_node, activation_type);
fn<<<blocks, threads, 0, stream>>>(output, gemm_result, expert_first_token_offset, inter_size, num_experts_per_node,
activation_type, prequant_scale, use_per_expert_prequant_scale);
}
// ============================== Activation =================================
template <class T, class GemmOutputType, class ScaleBiasType, class ActFn,
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType>
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType, int kProcessRows>
__global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant,
ScaleBiasType const* bias_ptr, bool bias_is_broadcast, int64_t const* expert_first_token_offset,
int num_experts_per_node, int64_t inter_size, float const* fc2_act_global_scale, bool use_per_expert_act_scale,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, ActivationParams activation_params)
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, ActivationParams activation_params,
GemmOutputType const* prequant_scale, int64_t const num_valid_tokens)
{
#ifdef ENABLE_FP4
constexpr bool IsNVFP4 = std::is_same_v<T, __nv_fp4_e2m1>
@ -2041,7 +2076,6 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
constexpr bool IsMXFP8 = cute::dependent_false<T>;
#endif
int64_t const tid = threadIdx.x;
constexpr bool IsGated = ActFn::IS_GLU;
size_t gated_size_mul = IsGated ? 2 : 1;
size_t gated_off = IsGated ? inter_size : 0;
@ -2059,81 +2093,116 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
: TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX;
int64_t const padded_inter_size = ceilDiv(inter_size, min_k_dim_alignment) * min_k_dim_alignment;
int64_t const num_valid_tokens = expert_first_token_offset[num_experts_per_node];
// 2D grid: blockIdx.x for tokens in groups of kProcessRows, blockIdx.y for columns
int64_t const row_offset = blockIdx.x * kProcessRows;
int64_t const col_offset = blockIdx.y * blockDim.y + threadIdx.y;
assert(inter_size % ACTIVATION_ELEM_PER_THREAD == 0);
int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
assert(gated_off % ACTIVATION_ELEM_PER_THREAD == 0);
int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD;
// Early exit if this thread is out of bounds for columns
if (col_offset >= num_elems_in_col)
return;
// Precompute K-dimension padding range for merged SF padding
[[maybe_unused]] int64_t const k_padding_start = inter_size / VecSize;
[[maybe_unused]] int64_t const k_padding_end = padded_inter_size / VecSize;
[[maybe_unused]] bool const do_k_padding
= (IsNVFP4 || IsMXFP8) && col_offset >= k_padding_start && col_offset < k_padding_end;
auto main_loop = [&](auto has_prequant_scale, auto has_bias)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
cudaGridDependencySynchronize();
#endif
for (int64_t token = blockIdx.x; token < num_valid_tokens; token += gridDim.x)
{
size_t gemm_result_offset = token * inter_size * gated_size_mul;
size_t output_offset = token * inter_size;
constexpr bool has_prequant_scale_v = decltype(has_prequant_scale)::value;
constexpr bool has_bias_v = decltype(has_bias)::value;
int64_t expert = 0;
float gate_alpha = 1.0f;
float gate_beta = 0.0f;
float gate_limit = std::numeric_limits<float>::infinity();
if (bias_ptr || IsNVFP4 || IsMXFP8 || use_per_expert_act_scale || activation_params.swiglu_alpha
|| activation_params.swiglu_beta || activation_params.swiglu_limit)
// Process kProcessRows consecutive tokens, promoting per-expert data reuse
#pragma unroll
for (int i = 0; i < kProcessRows; ++i)
{
// TODO this is almost certainly faster as a linear scan
expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1;
int64_t const token = row_offset + i;
gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f;
gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f;
gate_limit = activation_params.swiglu_limit ? activation_params.swiglu_limit[expert]
: std::numeric_limits<float>::infinity();
}
// Early exit for this row if out of bounds
if (token >= num_valid_tokens)
break;
size_t act_scale_idx = use_per_expert_act_scale ? expert : 0;
float const quant_scale = fp8_quant ? fp8_quant[act_scale_idx] : 1.f;
size_t gemm_result_offset = token * inter_size * gated_size_mul;
size_t output_offset = token * inter_size;
// Some globals for FP4
float global_scale_val = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1.0f;
int64_t num_tokens_before_expert = (IsNVFP4 || IsMXFP8) ? expert_first_token_offset[expert] : 0;
int64_t expert = 0;
float gate_alpha = 1.0f;
float gate_beta = 0.0f;
float gate_limit = std::numeric_limits<float>::infinity();
if (bias_ptr || IsNVFP4 || IsMXFP8 || use_per_expert_act_scale || activation_params.swiglu_alpha
|| activation_params.swiglu_beta || activation_params.swiglu_limit)
{
// TODO this is almost certainly faster as a linear scan
expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1;
size_t bias_offset = 0;
if (bias_ptr)
{
bias_offset = (bias_is_broadcast ? expert * inter_size * gated_size_mul : gemm_result_offset);
}
gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f;
gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f;
gate_limit = activation_params.swiglu_limit ? activation_params.swiglu_limit[expert]
: std::numeric_limits<float>::infinity();
}
using BiasElem = cutlass::Array<ScaleBiasType, ACTIVATION_ELEM_PER_THREAD>;
using GemmResultElem = cutlass::Array<GemmOutputType, ACTIVATION_ELEM_PER_THREAD>;
using OutputElem = std::conditional_t<IsNVFP4, uint32_t,
std::conditional_t<IsMXFP8, uint64_t, cutlass::Array<T, ACTIVATION_ELEM_PER_THREAD>>>;
using ComputeElem = cutlass::Array<float, ACTIVATION_ELEM_PER_THREAD>;
// Aliases gemm_result for non-gated, non-fp8 cases
auto gemm_result_vec = reinterpret_cast<GemmResultElem const*>(gemm_result + gemm_result_offset);
auto output_vec = reinterpret_cast<OutputElem*>(safe_inc_ptr(output, output_offset));
auto bias_ptr_vec = reinterpret_cast<BiasElem const*>(bias_ptr + bias_offset);
int64_t const start_offset = tid;
int64_t const stride = ACTIVATION_THREADS_PER_BLOCK;
assert(inter_size % ACTIVATION_ELEM_PER_THREAD == 0);
int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
assert(gated_off % ACTIVATION_ELEM_PER_THREAD == 0);
int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD;
size_t act_scale_idx = use_per_expert_act_scale ? expert : 0;
float const quant_scale = fp8_quant ? fp8_quant[act_scale_idx] : 1.f;
ActFn fn{};
fn.alpha = gate_alpha;
fn.beta = gate_beta;
fn.limit = gate_limit;
for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
{
auto fc1_value = arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index + gated_off_vec]);
// Some globals for FP4
float global_scale_val = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1.0f;
int64_t num_tokens_before_expert = (IsNVFP4 || IsMXFP8) ? expert_first_token_offset[expert] : 0;
size_t bias_offset = 0;
if (bias_ptr)
{
fc1_value = fc1_value + arrayConvert<BiasElem, ComputeElem>(bias_ptr_vec[elem_index + gated_off_vec]);
bias_offset = (bias_is_broadcast ? expert * inter_size * gated_size_mul : gemm_result_offset);
}
using BiasElem = cutlass::Array<ScaleBiasType, ACTIVATION_ELEM_PER_THREAD>;
using GemmResultElem = cutlass::Array<GemmOutputType, ACTIVATION_ELEM_PER_THREAD>;
using OutputElem = std::conditional_t<IsNVFP4, uint32_t,
std::conditional_t<IsMXFP8, uint64_t, cutlass::Array<T, ACTIVATION_ELEM_PER_THREAD>>>;
using ComputeElem = cutlass::Array<float, ACTIVATION_ELEM_PER_THREAD>;
// Aliases gemm_result for non-gated, non-fp8 cases
auto gemm_result_vec = reinterpret_cast<GemmResultElem const*>(gemm_result + gemm_result_offset);
auto output_vec = reinterpret_cast<OutputElem*>(safe_inc_ptr(output, output_offset));
auto bias_ptr_vec = reinterpret_cast<BiasElem const*>(bias_ptr + bias_offset);
auto prequant_scale_vec = prequant_scale
? reinterpret_cast<GemmResultElem const*>(prequant_scale + expert * inter_size)
: nullptr;
ActFn fn{};
fn.alpha = gate_alpha;
fn.beta = gate_beta;
fn.limit = gate_limit;
// Each thread handles one vector at col_offset
int64_t const elem_index = col_offset;
// Use loadVec to force LDG.128 vectorized loads
auto fc1_value
= arrayConvert<GemmResultElem, ComputeElem>(loadVec(&gemm_result_vec[elem_index + gated_off_vec]));
if constexpr (has_bias_v)
{
fc1_value = fc1_value
+ arrayConvert<BiasElem, ComputeElem>(loadVec(&bias_ptr_vec[elem_index + gated_off_vec]));
}
auto gate_act = [&]()
{
if constexpr (IsGated)
{
auto linear_value = arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index]);
if (bias_ptr_vec)
auto linear_value
= arrayConvert<GemmResultElem, ComputeElem>(loadVec(&gemm_result_vec[elem_index]));
if constexpr (has_bias_v)
{
linear_value = linear_value + arrayConvert<BiasElem, ComputeElem>(bias_ptr_vec[elem_index]);
linear_value
= linear_value + arrayConvert<BiasElem, ComputeElem>(loadVec(&bias_ptr_vec[elem_index]));
}
return fn(fc1_value, linear_value);
}
@ -2145,6 +2214,13 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
auto post_act_val = gate_act * quant_scale;
// Apply prequant scale (shape [experts_per_rank, intermediate_size]) if provided
if constexpr (has_prequant_scale_v)
{
post_act_val = post_act_val
* arrayConvert<GemmResultElem, ComputeElem>(loadVec(&prequant_scale_vec[elem_index]));
}
if constexpr (IsNVFP4 || IsMXFP8)
{
// We use GemmOutputType as the intermediate compute type as that should always be unquantized
@ -2155,74 +2231,98 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
static_assert(
sizeof(res) == sizeof(*output_vec), "Quantized value must be the same size as the output");
output_vec[elem_index] = res;
// Pad zeros in the extra SFs along the K dimension, we do this to ensure there are no nan values in the
// padded SF atom. Only process padding for valid tokens in this block's row range.
if (do_k_padding)
{
writeSF<VecSize, VecSize>(num_tokens_before_expert, expert, /*source_row*/ -1, token, col_offset,
padded_inter_size, fc2_act_sf_flat,
/* input_sf */ nullptr); // Pass nullptr input_sf so we write 0
}
}
else
{
output_vec[elem_index] = arrayConvert<ComputeElem, OutputElem>(post_act_val);
// Use storeVec to force STG.128 vectorized store
storeVec(&output_vec[elem_index], arrayConvert<ComputeElem, OutputElem>(post_act_val));
}
}
// Pad zeros in the extra SFs along the K dimension, we do this to ensure there are no nan values in the padded
// SF atom
if constexpr (IsNVFP4 || IsMXFP8)
{
// Use VecSize per thread since we are just writing out zeros so every thread can process a whole vector
size_t padding_start_offset = inter_size / VecSize + start_offset;
size_t padding_elems_in_col = padded_inter_size / VecSize;
for (int64_t elem_index = padding_start_offset; elem_index < padding_elems_in_col; elem_index += stride)
{
writeSF<VecSize, VecSize>(num_tokens_before_expert, expert, /*source_row*/ -1, token, elem_index,
padded_inter_size, fc2_act_sf_flat, /* input_sf */ nullptr); // Pass nulltpr input_sf so we write 0
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values in the padded SF
// atom
if constexpr (IsNVFP4 || IsMXFP8)
{
int64_t const start_offset = threadIdx.x;
int64_t const stride = ACTIVATION_THREADS_PER_BLOCK;
// Use VecSize per thread since we are just writing out zeros so every thread can process a whole vector
int64_t const padded_num_elems_in_col = padded_inter_size / VecSize;
assert(padded_inter_size % VecSize == 0);
constexpr int64_t min_num_tokens_alignment = IsNVFP4
? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4
: TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX;
static_assert((min_num_tokens_alignment & (min_num_tokens_alignment - 1)) == 0,
"Min num tokens alignment must be a power of two");
// Since we don't know a priori how much padding is needed we assume the max per expert
// NOTE: we don't (min_num_tokens_alignment-1) to have power of two divisions
int64_t num_padding_tokens = min_num_tokens_alignment * num_experts_per_node;
for (int64_t padding_token = blockIdx.x; padding_token < num_padding_tokens; padding_token += gridDim.x)
// Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values in the padded
// SF atom. This is handled by dedicated blocks beyond num_valid_tokens range.
if constexpr (IsNVFP4 || IsMXFP8)
{
int64_t expert = padding_token / min_num_tokens_alignment;
int64_t num_tokens_before_expert = expert_first_token_offset[expert];
int64_t num_tokens_after_expert = expert_first_token_offset[expert + 1];
int64_t tokens_to_expert = num_tokens_after_expert - num_tokens_before_expert;
int64_t padding_to_expert
= TmaWarpSpecializedGroupedGemmInput::alignToSfDim(tokens_to_expert, min_num_tokens_alignment)
- tokens_to_expert;
int64_t expert_pad_idx = padding_token % min_num_tokens_alignment;
if (expert_pad_idx < padding_to_expert)
constexpr int64_t min_num_tokens_alignment = IsNVFP4
? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4
: TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX;
static_assert((min_num_tokens_alignment & (min_num_tokens_alignment - 1)) == 0,
"Min num tokens alignment must be a power of two");
int64_t const padded_num_elems_in_col = padded_inter_size / VecSize;
int64_t const sf_col_offset = blockIdx.y * blockDim.x + threadIdx.x;
if (sf_col_offset >= padded_num_elems_in_col)
return;
// Only blocks that handle padding tokens participate
int64_t const num_token_blocks = (num_valid_tokens + kProcessRows - 1) / kProcessRows;
if (blockIdx.x < num_token_blocks)
return;
// This block handles N-dimension padding
int64_t const padding_block_idx = blockIdx.x - num_token_blocks;
#pragma unroll
for (int i = 0; i < kProcessRows; ++i)
{
for (int64_t elem_index = start_offset; elem_index < padded_num_elems_in_col; elem_index += stride)
int64_t const padding_token = padding_block_idx * kProcessRows + i;
int64_t const num_padding_tokens = min_num_tokens_alignment * num_experts_per_node;
if (padding_token >= num_padding_tokens)
return;
int64_t expert = padding_token / min_num_tokens_alignment;
int64_t num_tokens_before_expert = expert_first_token_offset[expert];
int64_t num_tokens_after_expert = expert_first_token_offset[expert + 1];
int64_t tokens_to_expert = num_tokens_after_expert - num_tokens_before_expert;
int64_t padding_to_expert
= TmaWarpSpecializedGroupedGemmInput::alignToSfDim(tokens_to_expert, min_num_tokens_alignment)
- tokens_to_expert;
int64_t expert_pad_idx = padding_token % min_num_tokens_alignment;
if (expert_pad_idx < padding_to_expert)
{
// The SF buffer is padded to a multiple of MinNDimAlignment for each expert
// This means we can safely write to offset num_tokens_after_expert + padded_token, since the next
// expert will leave space for the padding
// This means we can safely write to offset num_tokens_after_expert + padded_token, since the
// next expert will leave space for the padding
writeSF<VecSize, VecSize>(num_tokens_before_expert, expert, /*source_row*/ -1,
num_tokens_after_expert + expert_pad_idx, elem_index, padded_inter_size, fc2_act_sf_flat,
/* input_sf */ nullptr); // Pass nulltpr input_sf so we write 0
num_tokens_after_expert + expert_pad_idx, sf_col_offset, padded_inter_size, fc2_act_sf_flat,
/* input_sf */ nullptr); // Pass nullptr input_sf so we write 0
}
}
}
}; // end of lambda main_loop
// Instantiate four code paths for the different combinations of prequant_scale and bias_ptr
// This ensures relevant codes will be executed in tight loop
if (prequant_scale && bias_ptr)
{
main_loop(/* has_prequant_scale */ std::true_type{}, /* has_bias */ std::true_type{});
}
else if (!prequant_scale && !bias_ptr)
{
main_loop(/* has_prequant_scale */ std::false_type{}, /* has_bias */ std::false_type{});
}
else if (!prequant_scale && bias_ptr)
{
main_loop(/* has_prequant_scale */ std::false_type{}, /* has_bias */ std::true_type{});
}
else // prequant_scale && !bias_ptr
{
main_loop(/* has_prequant_scale */ std::true_type{}, /* has_bias */ std::false_type{});
}
}
@ -2230,99 +2330,143 @@ template <class T, class GemmOutputType, class ScaleBiasType>
void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias,
bool bias_is_broadcast, int64_t const* expert_first_token_offset, int num_experts_per_node, int64_t inter_size,
int64_t expanded_num_tokens, ActivationParams activation_type, QuantParams const& quant_params,
bool use_per_expert_act_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, cudaStream_t stream)
bool use_per_expert_act_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, cudaStream_t stream,
GemmOutputType const* prequant_scale = nullptr)
{
#ifdef ENABLE_FP4
constexpr int64_t min_num_tokens_alignment = std::is_same_v<T, __nv_fp4_e2m1>
? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4
: TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX;
int64_t num_padding_tokens = min_num_tokens_alignment * num_experts_per_node;
constexpr bool IsNVFP4 = std::is_same_v<T, __nv_fp4_e2m1>;
constexpr bool IsMXFP8 = std::is_same_v<T, __nv_fp8_e4m3>;
#else
int64_t num_padding_tokens = 0;
constexpr bool IsNVFP4 = false;
constexpr bool IsMXFP8 = false;
#endif
// ACTIVATION_ELEM_PER_THREAD must match kernel's computation
constexpr int64_t ACTIVATION_ELEM_PER_THREAD = (IsNVFP4 || IsMXFP8)
? CVT_ELTS_PER_THREAD
: (128 / std::min(sizeof_bits<T>::value, sizeof_bits<GemmOutputType>::value));
auto fn = [&]()
int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
auto doActivationKernelLauncher = [&](auto num_rows_per_cta)
{
// IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
// common.h
auto fn
= [&](auto block_scaling_type) -> void (*)(T*, GemmOutputType const*, float const*, ScaleBiasType const*,
bool, int64_t const*, int, int64_t, float const*, bool,
TmaWarpSpecializedGroupedGemmInput::ElementSF*, ActivationParams)
constexpr int num_rows_per_cta_v = num_rows_per_cta.value;
// For NVFP4/MXFPX SFs
int64_t num_padding_tokens = 0;
auto fn = [&]()
{
switch (activation_type.activation_type)
// IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
// common.h
auto fn
= [&](auto block_scaling_type) -> void (*)(T*, GemmOutputType const*, float const*,
ScaleBiasType const*, bool, int64_t const*, int, int64_t,
float const*, bool, TmaWarpSpecializedGroupedGemmInput::ElementSF*,
ActivationParams, GemmOutputType const*, int64_t)
{
case ActivationType::Identity:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::Identity>, decltype(block_scaling_type)::value>;
case ActivationType::Gelu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value>;
case ActivationType::Relu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::ReLu>, decltype(block_scaling_type)::value>;
case ActivationType::Silu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value>;
case ActivationType::Swiglu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
GLUAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value>;
case ActivationType::Geglu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
GLUAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value>;
case ActivationType::SwigluBias:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
decltype(block_scaling_type)::value>;
case ActivationType::Relu2:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::Relu2>, decltype(block_scaling_type)::value>;
default: TLLM_CHECK_WITH_INFO(false, "Invalid activation type"); return nullptr;
}
};
auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4>{};
auto MXFPX = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX>{};
auto NONE = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE>{};
switch (activation_type.activation_type)
{
case ActivationType::Identity:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::Identity>, decltype(block_scaling_type)::value,
num_rows_per_cta_v>;
case ActivationType::Gelu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value,
num_rows_per_cta_v>;
case ActivationType::Relu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::ReLu>, decltype(block_scaling_type)::value,
num_rows_per_cta_v>;
case ActivationType::Silu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value,
num_rows_per_cta_v>;
case ActivationType::Swiglu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
GLUAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value,
num_rows_per_cta_v>;
case ActivationType::Geglu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
GLUAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value,
num_rows_per_cta_v>;
case ActivationType::SwigluBias:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
decltype(block_scaling_type)::value, num_rows_per_cta_v>;
case ActivationType::Relu2:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::Relu2>, decltype(block_scaling_type)::value,
num_rows_per_cta_v>;
default: TLLM_CHECK_WITH_INFO(false, "Invalid activation type"); return nullptr;
}
};
auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4>{};
auto MXFPX = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX>{};
auto NONE = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE>{};
#ifdef ENABLE_FP4
if constexpr (std::is_same_v<T, __nv_fp4_e2m1>)
{
TLLM_CHECK_WITH_INFO(
quant_params.fp4.fc2.weight_block_scale, "NVFP4 block scaling is expected for FP4xFP4");
return fn(NVFP4);
}
else if constexpr (std::is_same_v<T, __nv_fp8_e4m3>)
{
return quant_params.mxfp8_mxfp4.fc2.weight_block_scale ? fn(MXFPX) : fn(NONE);
}
else
if constexpr (std::is_same_v<T, __nv_fp4_e2m1>)
{
num_padding_tokens = TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4 * num_experts_per_node;
TLLM_CHECK_WITH_INFO(
quant_params.fp4.fc2.weight_block_scale, "NVFP4 block scaling is expected for FP4xFP4");
return fn(NVFP4);
}
else if constexpr (std::is_same_v<T, __nv_fp8_e4m3>)
{
num_padding_tokens = quant_params.mxfp8_mxfp4.fc2.weight_block_scale
? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX * num_experts_per_node
: 0;
return quant_params.mxfp8_mxfp4.fc2.weight_block_scale ? fn(MXFPX) : fn(NONE);
}
else
#endif
{
return fn(NONE);
}
}();
{
return fn(NONE);
}
}();
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(fn, ACTIVATION_THREADS_PER_BLOCK, 0);
int32_t const blocks
= std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(std::max(expanded_num_tokens, num_padding_tokens)));
int32_t const threads = ACTIVATION_THREADS_PER_BLOCK;
// X dimension for tokens in groups of num_rows_per_cta_v
// Y dimension for columns
int64_t const num_token_blocks = (expanded_num_tokens + num_rows_per_cta_v - 1) / num_rows_per_cta_v;
int64_t const num_padding_blocks = (num_padding_tokens + num_rows_per_cta_v - 1) / num_rows_per_cta_v;
// Add extra blocks in X dimension for token-wise padding in FP4/MXFP8 modes
int32_t const grid_x = static_cast<int32_t>(num_token_blocks + num_padding_blocks);
int32_t const grid_y = static_cast<int32_t>(
(num_elems_in_col + ACTIVATION_THREADS_PER_BLOCK - 1) / ACTIVATION_THREADS_PER_BLOCK);
int32_t const threads = ACTIVATION_THREADS_PER_BLOCK;
cudaLaunchConfig_t config;
config.gridDim = blocks;
config.blockDim = threads;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, fn, output, gemm_result, fp8_quant, bias, bias_is_broadcast, expert_first_token_offset,
num_experts_per_node, inter_size, quant_params.fp4.fc2.act_global_scale, use_per_expert_act_scale,
fc2_act_sf_flat, activation_type);
cudaLaunchConfig_t config;
config.gridDim = dim3(grid_x, grid_y, 1);
config.blockDim = dim3(1, threads, 1);
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, fn, output, gemm_result, fp8_quant, bias, bias_is_broadcast,
expert_first_token_offset, num_experts_per_node, inter_size, quant_params.fp4.fc2.act_global_scale,
use_per_expert_act_scale, fc2_act_sf_flat, activation_type, prequant_scale, expanded_num_tokens);
}; // end lambda doActivationKernelLauncher
// 256 threads per block * 256 blocks / 1 rows per block can be handled by 1-2 waves depending on SM arch
if (num_elems_in_col * expanded_num_tokens < 256 * 256)
{
doActivationKernelLauncher(std::integral_constant<int, 1>());
}
// 256 threads per block * 512 blocks / 2 rows per block can be handled by 1-2 waves depending on SM arch
else if (num_elems_in_col * expanded_num_tokens < 256 * 512)
{
doActivationKernelLauncher(std::integral_constant<int, 2>());
}
// Regular case
else
{
doActivationKernelLauncher(std::integral_constant<int, 4>());
}
}
// ============================== Lora Add Bias =================================
@ -2877,11 +3021,10 @@ template <class T, class WeightType, class OutputType, class InputType, class Sc
T const* CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Enable>::applyPrequantScale(
void* smoothed_act, void const* permuted_data, void const* prequant_scales, int64_t const* num_valid_tokens_ptr,
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream,
int64_t* expert_first_token_offset, int const num_experts_per_node)
QuantParams const& quant_params, int64_t* expert_first_token_offset, int const num_experts_per_node)
{
T const* gemm_input;
bool use_prequant_scale_kernel = use_awq && !std::is_same_v<T, WeightType>;
if (use_prequant_scale_kernel)
if (usePrequantScaleKernel(quant_params))
{
TLLM_CHECK_WITH_INFO(
(!std::is_same_v<T, WeightType>), "Prequant scales are only used for different weight/activation type!");
@ -2926,7 +3069,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
int64_t const inter_size, int const num_experts_per_node, ActivationParams fc1_activation_type,
float const** alpha_scale_ptr_array, bool bias_is_broadcast, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids)
int* active_expert_global_ids, void const* fc2_prequant_scale)
{
if (fp8_blockscale_gemm_runner)
@ -2988,16 +3131,30 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
sync_check_cuda_error(stream);
// TODO: when bias_is_broadcast is false, fuse bias to gemm
using GatedActOutputType = std::conditional_t<use_w4afp8, BackBoneType, T>;
bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4.fc2.use_per_expert_act_scale
: use_wfp4afp8 ? quant_params.fp8_mxfp4.fc2.use_per_expert_act_scale
: use_fp8 ? quant_params.fp8.fc2_use_per_expert_act_scale
: Self::useAwq(quant_params) ? quant_params.groupwise.fc2.use_per_expert_act_scale
: false;
doActivation<GatedActOutputType, UnfusedGemmOutputType>(reinterpret_cast<GatedActOutputType*>(output),
static_cast<UnfusedGemmOutputType const*>(gemm_output), fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast,
expert_first_token_offset, num_experts_per_node, inter_size, expanded_num_rows, fc1_activation_type,
quant_params, use_per_expert_act_scale, fc2_fp4_act_flat, stream);
// Activation -> (BackboneType) -> Prequant -> (T == ActType)
// When fusing activation and prequant, the output type is directly T = =ActType
// Else, the output type is BackboneType
if (fc2_prequant_scale)
{
doActivation<T, UnfusedGemmOutputType>(reinterpret_cast<T*>(output),
static_cast<UnfusedGemmOutputType const*>(gemm_output), fc2_fp8_quant, fc1_expert_biases,
bias_is_broadcast, expert_first_token_offset, num_experts_per_node, inter_size, expanded_num_rows,
fc1_activation_type, quant_params, use_per_expert_act_scale, fc2_fp4_act_flat, stream,
static_cast<UnfusedGemmOutputType const*>(fc2_prequant_scale));
}
else
{
using GatedActOutputType = std::conditional_t<use_w4afp8, BackBoneType, T>;
doActivation<GatedActOutputType, UnfusedGemmOutputType>(reinterpret_cast<GatedActOutputType*>(output),
static_cast<UnfusedGemmOutputType const*>(gemm_output), fc2_fp8_quant, fc1_expert_biases,
bias_is_broadcast, expert_first_token_offset, num_experts_per_node, inter_size, expanded_num_rows,
fc1_activation_type, quant_params, use_per_expert_act_scale, fc2_fp4_act_flat, stream);
}
sync_check_cuda_error(stream);
}
@ -3087,9 +3244,12 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
if (!use_ampere_activation_fusion)
{
using GatedActOutputType = std::conditional_t<use_w4afp8, BackBoneType, T>;
bool const use_per_expert_act_scale
= Self::useAwq(quant_params) ? quant_params.groupwise.fc2.use_per_expert_act_scale : false;
doGatedActivation<GatedActOutputType, UnfusedGemmOutputType>(reinterpret_cast<GatedActOutputType*>(output),
static_cast<UnfusedGemmOutputType const*>(intermediate_result), expert_first_token_offset, inter_size,
expanded_num_rows, num_experts_per_node, fc1_activation_type, stream);
expanded_num_rows, num_experts_per_node, fc1_activation_type, stream, use_per_expert_act_scale,
static_cast<UnfusedGemmOutputType const*>(fc2_prequant_scale));
sync_check_cuda_error(stream);
}
@ -3187,7 +3347,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
loraBiasApplyFunc(static_cast<UnfusedGemmOutputType*>(gemm_output),
static_cast<UnfusedGemmOutputType const*>(gemm_output), nullptr,
static_cast<ScaleBiasType const*>(fc2_lora), false, expert_first_token_offset, num_experts_per_node,
hidden_size, expanded_num_rows, ActivationParams(ActivationType::Identity), {}, false, nullptr, stream);
hidden_size, expanded_num_rows, ActivationParams(ActivationType::Identity), {}, false, nullptr, stream,
/*prequant_scale=*/nullptr);
sync_check_cuda_error(stream);
}
@ -3560,7 +3721,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
fc2_fp8_dequant == nullptr, "Scales are ignored for fp32/fp16/bf16 but received quant scale for FC2");
}
bool use_awq = quant_params.groupwise.fc1.act_scales && quant_params.groupwise.fc2.act_scales && !use_wfp4a16;
bool use_awq = useAwq(quant_params);
int const num_experts_per_node = full_num_experts / parallelism_config.ep_size;
configureWsPtrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token,
@ -3602,11 +3763,11 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
fc2_fp4_act_scale_, quant_params, num_rows, expanded_num_rows, expected_tokens_per_expert, hidden_size,
inter_size, num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array_fc1_, !use_lora, stream,
*gemm1_config_, true, min_latency_params.num_active_experts_per_node,
min_latency_params.active_expert_global_ids);
min_latency_params.active_expert_global_ids, /*fc2_prequant_scale=*/nullptr);
sync_check_cuda_error(stream);
auto gemm2_input = applyPrequantScale(smoothed_act_, fc1_result_, quant_params.groupwise.fc2.act_scales,
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream);
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream, quant_params);
Self::gemm2(moe_gemm_runner_, blockscale_gemm_runner, gemm2_input, final_output, nullptr,
expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales,
fc2_fp8_dequant, fc2_fp4_act_scale_, quant_params, token_topk_unpermuted_scales,
@ -3621,7 +3782,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
else
{
bool fused_prologue_result = false;
if (!use_w4_groupwise)
if (!use_wfp4a16)
{
// WAR: fusedBuildExpertMapsSortFirstToken kernel will lead to illegal memory access for W4AFP8
fused_prologue_result = fusedBuildExpertMapsSortFirstToken(token_selected_experts,
@ -3659,6 +3820,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
// Only NVFP4xNVFP4 supports FC1 per-expert act scale
bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4.fc1.use_per_expert_act_scale : false;
T* gemm1_input_expand = use_w4afp8 ? reinterpret_cast<T*>(smoothed_act_) : reinterpret_cast<T*>(permuted_data_);
// Expand input and maybe apply prequant scale for AWQ
expandInputRowsKernelLauncher(input_activations, gemm1_input_expand, token_topk_unpermuted_scales,
permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows, hidden_size, experts_per_token,
num_experts_per_node, quant_params, use_per_expert_act_scale, expert_first_token_offset_,
@ -3695,15 +3857,22 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
if constexpr (!use_w4afp8)
{
gemm1_input = applyPrequantScale(smoothed_act_, permuted_data_, quant_params.groupwise.fc1.act_scales,
num_valid_tokens_ptr, expanded_num_rows, hidden_size, use_awq, stream);
num_valid_tokens_ptr, expanded_num_rows, hidden_size, use_awq, stream, quant_params);
}
sync_check_cuda_error(stream);
Self::gemm1(moe_gemm_runner_, blockscale_gemm_runner, gemm1_input, fc1_result_, glu_inter_result_,
// Opportunistically apply FC2 prequant scaling in FC1 doActivation kernel if applicable
bool const fuse_fc2_prequant_scale = use_awq && is_gated_activation;
void const* fc2_prequant_scale_ptr = fuse_fc2_prequant_scale ? quant_params.groupwise.fc2.act_scales : nullptr;
// Match the FC2 act buffer bound to respective TMA desc defined in setupTmaWarpSpecializedInputs()
T* gemm1_output = fuse_fc2_prequant_scale ? reinterpret_cast<T*>(smoothed_act_) : fc1_result_;
Self::gemm1(moe_gemm_runner_, blockscale_gemm_runner, gemm1_input, gemm1_output, glu_inter_result_,
expert_first_token_offset_, gemm1_tma_ws_input, fc1_expert_weights, fc1_expert_biases, num_valid_tokens_ptr,
fc1_int_scales, fc1_fp8_dequant, use_wfp4afp8 ? fc2_wfp4afp8_quant_scale : fc2_fp8_quant,
fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, num_rows, expanded_num_rows,
expected_tokens_per_expert, hidden_size, inter_size, num_experts_per_node, fc1_activation_type,
alpha_scale_ptr_array_fc1_, !use_lora, stream, *gemm1_config_, false, nullptr, nullptr);
alpha_scale_ptr_array_fc1_, !use_lora, stream, *gemm1_config_, false, nullptr, nullptr,
fc2_prequant_scale_ptr);
sync_check_cuda_error(stream);
if (use_lora)
@ -3713,10 +3882,16 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
sync_check_cuda_error(stream);
}
auto gemm2_input = applyPrequantScale(smoothed_act_, fc1_result_, quant_params.groupwise.fc2.act_scales,
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream, expert_first_token_offset_,
num_experts_per_node);
sync_check_cuda_error(stream);
// When fusing, data is already in smoothed_act_; otherwise run applyPrequantScale to get it there
T const* gemm2_input{reinterpret_cast<T const*>(smoothed_act_)};
if (!fuse_fc2_prequant_scale)
{
// Outputs smoothed_act_
gemm2_input = applyPrequantScale(smoothed_act_, fc1_result_, quant_params.groupwise.fc2.act_scales,
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream, quant_params,
expert_first_token_offset_, num_experts_per_node);
sync_check_cuda_error(stream);
}
Self::gemm2(moe_gemm_runner_, blockscale_gemm_runner, gemm2_input, fc2_result_, final_output,
expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales,
fc2_fp8_dequant, fc2_fp4_act_scale_, quant_params, token_topk_unpermuted_scales,
@ -3755,10 +3930,12 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
}
auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale
: use_w4afp8 ? quant_params.groupwise.fc1.alpha
: use_wfp4afp8 ? quant_params.fp8_mxfp4.fc1.global_scale
: use_fp8 ? fp8_dequant1
: nullptr;
auto alpha_scale_flat2 = use_fp4 ? quant_params.fp4.fc2.global_scale
: use_w4afp8 ? quant_params.groupwise.fc2.alpha
: use_wfp4afp8 ? quant_params.fp8_mxfp4.fc2.global_scale
: use_fp8 ? fp8_dequant2
: nullptr;
@ -3834,7 +4011,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
return std::make_pair(gemm1_tma_ws_input, gemm2_tma_ws_input);
}
bool use_awq = quant_params.groupwise.fc1.act_scales && quant_params.groupwise.fc2.act_scales && !use_wfp4a16;
bool const use_awq = useAwq(quant_params);
bool is_gated_activation = isGatedActivation(fc1_activation_type);
int64_t const fc1_out_size = is_gated_activation ? inter_size * 2 : inter_size;
@ -3843,7 +4020,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
bool const has_intermediate = has_different_gemm_output_type || is_gated_activation;
auto* gemm1_output = has_intermediate ? glu_inter_result_ : static_cast<void*>(fc1_result_);
bool use_prequant_scale_kernel = use_awq && !std::is_same_v<T, WeightType>;
bool const use_prequant_scale_kernel = usePrequantScaleKernel(quant_params);
auto gemm2_input = use_prequant_scale_kernel ? smoothed_act_ : fc1_result_;
if (min_latency_mode)
@ -3880,8 +4057,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr;
bool gemm2_using_finalize_fusion = gemm2_config_->epilogue_fusion_type
== cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
bool using_fused_finalize
= use_fused_finalize_ && gemm2_using_finalize_fusion && !use_w4_groupwise && !use_lora;
bool using_fused_finalize = use_fused_finalize_ && gemm2_using_finalize_fusion && !use_wfp4a16 && !use_lora;
TLLM_CHECK_WITH_INFO(using_fused_finalize == gemm2_using_finalize_fusion,
"GEMM2 tactic requests finalize fusion, but the runner is not configured to use it");
if (using_fused_finalize)
@ -4120,7 +4296,7 @@ std::map<std::string, std::pair<size_t, size_t>> GemmProfilerBackend::getProfile
size_t output_size1 = inter_size * num_expanded_tokens * dtype_bytes;
size_t input_size2 = inter_size * num_expanded_tokens * dtype_bytes;
size_t output_size2 = hidden_size * output_bytes;
size_t output_size2 = hidden_size * num_expanded_tokens * output_bytes;
size_t input_size = mGemmToProfile == GemmToProfile::GEMM_1 ? input_size1 : input_size2;
size_t output_size = mGemmToProfile == GemmToProfile::GEMM_1 ? output_size1 : output_size2;
@ -4439,14 +4615,14 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr
&& mWType == nvinfer1::DataType::kUINT8);
bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;
bool const use_finalize_fusion = fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
bool const finalize_fusion_not_supported = !mInterface->use_fused_finalize_ || mMinLatencyMode || use_w4_groupwise
|| mGemmToProfile != GemmToProfile::GEMM_2;
bool const finalize_fusion_not_supported
= !mInterface->use_fused_finalize_ || mMinLatencyMode || use_wfp4a16 || mGemmToProfile != GemmToProfile::GEMM_2;
if (use_finalize_fusion && finalize_fusion_not_supported)
{
return;
}
if (use_w4_groupwise && !swap_ab)
if (use_wfp4a16 && !swap_ab)
{
return;
}
@ -4535,12 +4711,15 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr
}
else
{
auto fc1_alpha = use_w4afp8 ? mQuantParams.groupwise.fc1.alpha : mQuantParams.fp8.dequant_fc1;
auto fc2_alpha = use_w4afp8 ? mQuantParams.groupwise.fc2.alpha : mQuantParams.fp8.dequant_fc2;
std::tie(gemm1_tma_ws_input, gemm2_tma_ws_input) = mInterface->computeStridesTmaWarpSpecializedDispatch(
expert_first_token_offset, gemm1_tma_ws_input, gemm2_tma_ws_input, num_tokens, num_tokens * mK,
fc1_output_size, mExpertHiddenSize, mExpertHiddenSize, mExpertInterSize, mNumExpertsPerNode, input,
input, weights_sel, weights_sel, mQuantParams.fp8.dequant_fc1, mQuantParams.fp8.dequant_fc2,
fp4_act_scale_flat, fp4_act_scale_flat, mQuantParams, nullptr, nullptr, intermediate, intermediate,
token_topk_unpermuted_scales, permuted_row_to_unpermuted_row, stream);
input, weights_sel, weights_sel, fc1_alpha, fc2_alpha, fp4_act_scale_flat, fp4_act_scale_flat,
mQuantParams, nullptr, nullptr, intermediate, intermediate, token_topk_unpermuted_scales,
permuted_row_to_unpermuted_row, stream);
}
sync_check_cuda_error(stream);
}

View File

@ -41,9 +41,9 @@ EpiTag = {
EpiFusion = {
TrtLlm_EpilogueFusion.epilogue_fusion_none:
"tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE",
"tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE",
TrtLlm_EpilogueFusion.epilogue_fusion_finalize:
"tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE",
"tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE",
}
EpiFusionSuffixes = {
@ -229,9 +229,11 @@ const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zer
or operation.weight_type != e2m1):
# Mixed MoE GEMM
weight_tag = CudaTypeName[operation.weight_type]
assert operation.epi_fusion is not None
epi_fusion = EpiFusion[operation.epi_fusion]
instantiation = f"""
template void sm90_generic_mixed_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {out_tag},
{epi_tag}, {cute_cta_shape}, {cute_cga_shape}, {kernel_sched}, {epi_sched}, {quant_op}> (
{epi_tag}, {epi_fusion}, {cute_cta_shape}, {cute_cga_shape}, {kernel_sched}, {epi_sched}, {quant_op}> (
GroupedGemmInput<{act_tag}, {weight_tag}, {out_tag}, {out_tag}>inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size);
"""
else:
@ -590,6 +592,11 @@ def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled):
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
epi_fusions = [
TrtLlm_EpilogueFusion.epilogue_fusion_none,
TrtLlm_EpilogueFusion.epilogue_fusion_finalize
]
M_TILES = [64, 128] # Currently M tile must be 128 for Grouped GEMM
N_TILES = [16, 32, 64, 128]
K_TILES = [128, 256, 512]
@ -607,13 +614,13 @@ def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled):
cga_shapes = list(product([1, 2], [1, 2], [1]))
partial_args_int4 = product(supported_dtypes_int4, quant_ops, epi_tags,
cta_shapes_mnk_int4, cga_shapes)
epi_fusions, cta_shapes_mnk_int4, cga_shapes)
partial_args_fp4 = product(supported_dtypes_fp4, quant_ops, epi_tags,
cta_shapes_mnk_fp4, cga_shapes)
epi_fusions, cta_shapes_mnk_fp4, cga_shapes)
partial_args = chain(partial_args_int4, partial_args_fp4)
operations = list()
for dtype_combo, quant_op, epi_tag, cta_shape_mnk, cga_shape in partial_args:
for dtype_combo, quant_op, epi_tag, epi_fusion, cta_shape_mnk, cga_shape in partial_args:
use_coop = cta_shape_mnk[0] >= 128
mainloop_schedules = [
KernelScheduleType.TmaWarpSpecializedCooperative,
@ -626,7 +633,7 @@ def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled):
== KernelScheduleType.TmaWarpSpecializedCooperative):
continue
moe_gemm_operation = TrtLlm_GemmLauncher(GemmKind.Grouped, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \
warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule)
warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule, epi_fusion)
operations.append(moe_gemm_operation)
return operations

View File

@ -1124,6 +1124,9 @@ private:
auto& fc1_alpha = quant_scales.value()[6];
auto& fc2_alpha = quant_scales.value()[7];
int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size;
// Whether it is per-expert activation scale
bool fc1_use_per_expert_act_scale = fc1_act_scales.numel() > hidden_size;
bool fc2_use_per_expert_act_scale = fc2_act_scales.numel() > inter_size;
return kernels::QuantParams::GroupWise(group_size,
static_cast<void const*>(fc1_weight_scales.data_ptr()),
static_cast<void const*>(fc2_weight_scales.data_ptr()),
@ -1132,7 +1135,8 @@ private:
static_cast<void const*>(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros.data_ptr() : nullptr),
static_cast<void const*>(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros.data_ptr() : nullptr),
static_cast<float const*>(fc1_alpha.numel() > 0 ? fc1_alpha.data_ptr() : nullptr),
static_cast<float const*>(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr));
static_cast<float const*>(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr),
fc1_use_per_expert_act_scale, fc2_use_per_expert_act_scale);
}
else
{

View File

@ -168,7 +168,10 @@ protected:
constexpr static bool ANY_FP4 = WEIGHT_FP4 || ACT_FP4;
constexpr static bool ANY_FPX = ANY_FP4 || FP8;
constexpr static bool INT_QUANT = !std::is_same_v<GemmDataType, WeightType> && std::is_integral_v<WeightType>;
constexpr static bool W4A8_AWQ
= std::is_same_v<GemmDataType, SafeFP8> && std::is_same_v<WeightType, cutlass::uint4b_t>;
constexpr static bool INT_QUANT
= !std::is_same_v<GemmDataType, WeightType> && std::is_integral_v<WeightType> && !W4A8_AWQ;
constexpr static int64_t WEIGHT_ELEM_PER_BYTE = (INT4 || WEIGHT_FP4) ? 2 : 1;
using InputType = std::conditional_t<NVFP4 || MXFP8_MXFP4, OutputType, GemmDataType>;
using WeightStorage = std::conditional_t<WEIGHT_ELEM_PER_BYTE == 2, uint8_t, WeightType>;
@ -184,7 +187,7 @@ protected:
using DataType = std::conditional_t<NVFP4 || MXFP8_MXFP4, OutputType, GemmDataType>;
// FP8_MXFP4 quantizes just the weights on the fly
using WeightRawType = std::conditional_t<FP8_MXFP4, OutputType, DataType>;
using WeightRawType = std::conditional_t<FP8_MXFP4 || W4A8_AWQ, OutputType, DataType>;
static BufferManager::CudaStreamPtr mStream;
static std::unique_ptr<BufferManager> mBufferManager;
@ -221,7 +224,7 @@ protected:
static_assert(!FP8, "FP8 Tests enabled on unsupported CUDA version");
#endif
bool should_skip_no_device = mDeviceCount <= 0;
bool should_skip_unsupported_fp8 = getSMVersion() < 89 && FP8;
bool should_skip_unsupported_fp8 = getSMVersion() < 89 && (FP8 || W4A8_AWQ);
bool should_skip_unsupported_fp4 = (getSMVersion() < 100) && ANY_FP4;
return should_skip_no_device || should_skip_unsupported_fp8 || should_skip_unsupported_fp4;
}
@ -321,8 +324,9 @@ protected:
float* mSwigluBeta{};
float* mSwigluLimit{};
DataType* mExpertIntScale1{};
DataType* mExpertIntScale2{};
using scale_type = std::conditional_t<W4A8_AWQ, WeightScale, DataType>;
scale_type* mExpertIntScale1{};
scale_type* mExpertIntScale2{};
float mFP8WeightScalar1{1.f};
float mFP8WeightScalar2{1.f};
@ -375,7 +379,7 @@ protected:
bool mIsGated = false;
int64_t mGatedMultiplier = 1;
int64_t mGroupSize = -1;
int64_t mGroupSize = W4A8_AWQ ? 128 : -1;
ActivationType mActType = ActivationType::Relu;
@ -453,7 +457,7 @@ protected:
total_size += weight_size / 2;
}
// Quantized data types use a second scratch buffer for the weights before quantizing
if (ANY_FPX || INT_QUANT)
if (ANY_FPX || INT_QUANT || W4A8_AWQ)
{
total_size += weight_elems * sizeof(DataType);
}
@ -535,6 +539,14 @@ protected:
mExpertIntScale1 = allocBuffer<DataType>(mNumExperts * gated_inter);
mExpertIntScale2 = allocBuffer<DataType>(mNumExperts * mHiddenSize);
}
else if constexpr (W4A8_AWQ)
{
mExpertWeight1 = allocBuffer<WeightStorage>(expert_matrix_size * mGatedMultiplier / WEIGHT_ELEM_PER_BYTE);
mExpertWeight2 = allocBuffer<WeightStorage>(expert_matrix_size / WEIGHT_ELEM_PER_BYTE);
mExpertIntScale1 = allocBuffer<WeightScale>(expert_matrix_size * mGatedMultiplier / mGroupSize);
mExpertIntScale2 = allocBuffer<WeightScale>(expert_matrix_size / mGroupSize);
}
else if constexpr (ANY_FP4)
{
mExpertWeight1 = allocBuffer<WeightStorage>(
@ -638,12 +650,22 @@ protected:
doIntQuant(quant_type, shape1, mRawExpertWeight1, mExpertIntScale1, mExpertWeight1);
doIntQuant(quant_type, shape2, mRawExpertWeight2, mExpertIntScale2, mExpertWeight2);
}
else if constexpr (W4A8_AWQ)
{
cutlass_kernels::QuantType quant_type = cutlass_kernels::QuantType::W4_AFP8;
std::vector<size_t> shape1{(size_t) mNumExperts, (size_t) mHiddenSize, (size_t) gated_inter};
std::vector<size_t> shape2{(size_t) mNumExperts, (size_t) mInterSize, (size_t) mHiddenSize};
doIntQuant(quant_type, shape1, mRawExpertWeight1, mExpertIntScale1, mExpertWeight1);
doIntQuant(quant_type, shape2, mRawExpertWeight2, mExpertIntScale2, mExpertWeight2);
}
check_cuda_error(cudaStreamSynchronize(stream));
}
void doIntQuant(cutlass_kernels::QuantType quant_type, std::vector<size_t> shape, DataType* inputs,
DataType* scales, uint8_t* outputs)
void doIntQuant(cutlass_kernels::QuantType quant_type, std::vector<size_t> shape, WeightRawType* inputs,
scale_type* scales, uint8_t* outputs)
{
// Runs on the CPU, must be after stream sync
if constexpr (INT_QUANT)
@ -652,17 +674,119 @@ protected:
size_t elems = std::reduce(shape.begin(), shape.end(), 1, std::multiplies{});
std::vector<int8_t> h_out(elems);
std::vector<DataType> h_input(elems);
std::vector<DataType> h_scales(shape[0] * shape[2]);
std::vector<WeightRawType> h_input(elems);
std::vector<scale_type> h_scales(shape[0] * shape[2]);
check_cuda_error(cudaMemcpy(h_input.data(), inputs, elems * sizeof(DataType), cudaMemcpyDeviceToHost));
check_cuda_error(cudaMemcpy(h_input.data(), inputs, elems * sizeof(WeightRawType), cudaMemcpyDeviceToHost));
cutlass_kernels::symmetric_quantize(h_out.data(), h_scales.data(), h_input.data(), shape, quant_type, true);
check_cuda_error(cudaMemcpy(
outputs, h_out.data(), elems * sizeof(int8_t) / WEIGHT_ELEM_PER_BYTE, cudaMemcpyHostToDevice));
check_cuda_error(
cudaMemcpy(scales, h_scales.data(), h_scales.size() * sizeof(DataType), cudaMemcpyHostToDevice));
cudaMemcpy(scales, h_scales.data(), h_scales.size() * sizeof(scale_type), cudaMemcpyHostToDevice));
}
else if constexpr (W4A8_AWQ)
{
check_cuda_error(cudaStreamSynchronize(mStream->get()));
assert(shape[1] % mGroupSize == 0);
size_t elems = std::reduce(shape.begin(), shape.end(), 1, std::multiplies{});
std::vector<int8_t> h_out(elems * sizeof(int8_t) / WEIGHT_ELEM_PER_BYTE);
std::vector<WeightRawType> h_input(elems);
std::vector<scale_type> h_scales(elems / mGroupSize);
check_cuda_error(cudaMemcpy(h_input.data(), inputs, elems * sizeof(WeightRawType), cudaMemcpyDeviceToHost));
const size_t num_experts = shape[0];
int const input_mat_size = shape[1] * shape[2];
int const bits_per_weigtht_element = 4;
int const quantized_mat_size = input_mat_size * bits_per_weigtht_element / 8;
float const quant_range_scale = 1.f / float(1 << (bits_per_weigtht_element - 1));
for (int expert = 0; expert < num_experts; ++expert)
{
WeightRawType const* current_weight = h_input.data() + expert * input_mat_size;
int8_t* current_quantized_weight = h_out.data() + expert * quantized_mat_size;
scale_type* current_scales = h_scales.data() + expert * input_mat_size / mGroupSize;
for (int ii = 0; ii < input_mat_size / mGroupSize; ++ii)
{
float scale = 0.f;
WeightRawType const* current_weight_group = current_weight + ii * mGroupSize;
for (int jj = 0; jj < mGroupSize; ++jj)
{
scale = std::max(scale, std::abs(float(current_weight_group[jj])));
}
scale *= quant_range_scale;
current_scales[ii] = scale_type(scale);
}
for (int ii = 0; ii < input_mat_size / mGroupSize; ++ii)
{
WeightRawType const* current_weight_group = current_weight + ii * mGroupSize;
int8_t* current_quantized_weight_group
= current_quantized_weight + ii * mGroupSize * bits_per_weigtht_element / 8;
float const scale = float(current_scales[ii]);
for (int jj = 0; jj < mGroupSize / 2; ++jj)
{
// We will pack two int4 elements per iteration of the inner loop.
float const weight_elt0 = float(current_weight_group[jj * 2]);
float const weight_elt1 = float(current_weight_group[jj * 2 + 1]);
float const scaled_weight0 = (scale != 0.0f) ? round(weight_elt0 / scale) : 0.0f;
float const scaled_weight1 = (scale != 0.0f) ? round(weight_elt1 / scale) : 0.0f;
int int_weight0 = int(scaled_weight0);
int int_weight1 = int(scaled_weight1);
const int8_t clipped_weight0 = std::max(-8, std::min(7, int_weight0));
const int8_t clipped_weight1 = std::max(-8, std::min(7, int_weight1));
// Kill the sign extension bits (hence 0x0F mask) then shift to upper bits
// if packing the second int4 and or the bits into the final result.
current_quantized_weight_group[jj] = clipped_weight0 | (clipped_weight1 << 4);
}
}
// WAR: For a diagonal matrix in which each column has only one nonzero element, the scale value is
// calculated based on it being quantized to 8. However, after quantization, the nonzero value is
// clipped to 7. Adjust the scale value to fix the error.
for (int ii = 0; ii < input_mat_size / mGroupSize; ++ii)
{
current_scales[ii] = scale_type(float(current_scales[ii]) * 8 / 7);
}
int interleave = 1;
int const sm = getSMVersion();
if (sm == 90)
{
interleave = shape[1] % 512 == 0 ? 4 : shape[1] % 256 == 0 ? 2 : 1;
}
// Permute scales: from [N, K/mGroupSize/interleave, interleave] to [K/mGroupSize/interleave, N,
// interleave]
int const dim0 = shape[2]; // N
int const dim1 = shape[1] / mGroupSize / interleave; // K/mGroupSize/interleave
int const dim2 = interleave;
std::vector<scale_type> temp_scales(input_mat_size / mGroupSize);
for (int n = 0; n < dim0; ++n)
{
for (int k = 0; k < dim1; ++k)
{
for (int i = 0; i < dim2; ++i)
{
// src index: [n, k, i] in layout [N, K/mGroupSize/interleave, interleave]
int src_idx = n * (dim1 * dim2) + k * dim2 + i;
// dst index: [k, n, i] in layout [K/mGroupSize/interleave, N, interleave]
int dst_idx = k * (dim0 * dim2) + n * dim2 + i;
temp_scales[dst_idx] = current_scales[src_idx];
}
}
}
std::copy(temp_scales.begin(), temp_scales.end(), current_scales);
}
check_cuda_error(cudaMemcpy(
outputs, h_out.data(), elems * sizeof(int8_t) / WEIGHT_ELEM_PER_BYTE, cudaMemcpyHostToDevice));
check_cuda_error(
cudaMemcpy(scales, h_scales.data(), h_scales.size() * sizeof(scale_type), cudaMemcpyHostToDevice));
}
}
@ -1229,6 +1353,31 @@ protected:
ASSERT_TRUE(scale1_ptr && scale2_ptr);
quant_params = QuantParams::Int(scale1_ptr, scale2_ptr);
}
else if constexpr (W4A8_AWQ)
{
auto input_scale1 = allocBuffer<scale_type>(mNumExperts * mHiddenSize * mGatedMultiplier);
auto input_scale2 = allocBuffer<scale_type>(mNumExperts * mInterSize);
std::vector<scale_type> h_input_scale1(mNumExperts * mHiddenSize * mGatedMultiplier, 1.0f);
std::vector<scale_type> h_input_scale2(mNumExperts * mInterSize, 1.0f);
check_cuda_error(cudaMemcpy(input_scale1, h_input_scale1.data(),
mNumExperts * mHiddenSize * mGatedMultiplier * sizeof(scale_type), cudaMemcpyHostToDevice));
check_cuda_error(cudaMemcpy(input_scale2, h_input_scale2.data(),
mNumExperts * mInterSize * sizeof(scale_type), cudaMemcpyHostToDevice));
auto alpha1_ptrs = allocBuffer<float>(mNumExperts);
auto alpha2_ptrs = allocBuffer<float>(mNumExperts);
for (int i = 0; i < mNumExperts; i++)
{
float alpha1_value = 1.0f;
float alpha2_value = 1.0f;
check_cuda_error(cudaMemcpy(alpha1_ptrs + i, &alpha1_value, sizeof(float), cudaMemcpyHostToDevice));
check_cuda_error(cudaMemcpy(alpha2_ptrs + i, &alpha2_value, sizeof(float), cudaMemcpyHostToDevice));
}
ASSERT_TRUE(scale1_ptr && scale2_ptr);
quant_params = QuantParams::GroupWise(mGroupSize, scale1_ptr, scale2_ptr, input_scale1, input_scale2,
nullptr, nullptr, alpha1_ptrs, alpha2_ptrs);
}
else if (FP8)
{
ASSERT_TRUE(scale1_ptr && scale2_ptr && scale3_ptr);
@ -1628,6 +1777,11 @@ using Types = ::testing::Types<
#endif
#endif
#ifdef ENABLE_BF16
#ifdef ENABLE_FP8
WeightParams<SafeFP8, cutlass::uint4b_t, __nv_bfloat16, void, __nv_bfloat16>,
#endif
#endif
WeightParams<half>, WeightParams<float>
// , WeightParams<half, uint8_t>, WeightParams<half, cutlass::uint4b_t>
@ -1675,6 +1829,12 @@ void MixtureOfExpertsTest<TypeParam_>::BasicPermuteTest(
initLocals(hidden_size, num_experts, k, num_tokens);
if (mGroupSize > 0 && (mHiddenSize % mGroupSize != 0 || mInterSize % mGroupSize != 0))
{
GTEST_SKIP() << "Skipping due to unsupported groupwise configuration";
return;
}
auto test_archs = getAllTileConfigsToTest();
for (auto [gemm1, gemm2] : test_archs)
{
@ -1903,6 +2063,13 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteDeepSeekV3)
size_t inter_size = 2048;
this->mInterSizeFraction = float(inter_size) / hidden_size;
if (this->W4A8_AWQ)
{
// TODO: Implement W4A8_AWQ for PermuteDeepSeekV3
GTEST_SKIP() << "W4A8_AWQ is not implemented for PermuteDeepSeekV3";
return;
}
if (!this->checkSufficientTestMemory(100, hidden_size, 256, 8))
{
GTEST_SKIP() << "Insufficient free memory for test";
@ -1956,6 +2123,13 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest(
}
}
if (W4A8_AWQ)
{
// TODO: Implement W4A8_AWQ for ParallelismTest
GTEST_SKIP() << "W4A8_AWQ is not implemented for ParallelismTest";
return;
}
ASSERT_LE(ep_size, num_experts);
if (tp_size == 1)
{