mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5726962][feat] Apply fusion for W4AFP8_AWQ MoE (#9838)
Signed-off-by: Min Yu <171526537+yumin066@users.noreply.github.com> Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com> Co-authored-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
This commit is contained in:
parent
6b8ae6fa81
commit
9cae7277ea
@ -315,7 +315,8 @@ struct QuantParams
|
||||
{
|
||||
struct GroupwiseGemmInputs
|
||||
{
|
||||
void const* act_scales = nullptr;
|
||||
bool use_per_expert_act_scale = false;
|
||||
void const* act_scales = nullptr; // (1 or num_experts_per_node, hidden_size or intermediate_size)
|
||||
void const* weight_scales = nullptr;
|
||||
void const* weight_zeros = nullptr;
|
||||
float const* alpha = nullptr;
|
||||
@ -401,12 +402,15 @@ struct QuantParams
|
||||
static QuantParams GroupWise(int group_size, void const* fc1_weight_scales, void const* fc2_weight_scales,
|
||||
void const* fc1_activation_scales = nullptr, void const* fc2_activation_scales = nullptr,
|
||||
void const* fc1_weight_zeros = nullptr, void const* fc2_weight_zeros = nullptr,
|
||||
float const* fc1_alpha = nullptr, float const* fc2_alpha = nullptr)
|
||||
float const* fc1_alpha = nullptr, float const* fc2_alpha = nullptr, bool fc1_use_per_expert_act_scale = false,
|
||||
bool fc2_use_per_expert_act_scale = false)
|
||||
{
|
||||
QuantParams qp;
|
||||
qp.groupwise.group_size = group_size;
|
||||
qp.groupwise.fc1 = {fc1_activation_scales, fc1_weight_scales, fc1_weight_zeros, fc1_alpha};
|
||||
qp.groupwise.fc2 = {fc2_activation_scales, fc2_weight_scales, fc2_weight_zeros, fc2_alpha};
|
||||
qp.groupwise.fc1
|
||||
= {fc1_use_per_expert_act_scale, fc1_activation_scales, fc1_weight_scales, fc1_weight_zeros, fc1_alpha};
|
||||
qp.groupwise.fc2
|
||||
= {fc2_use_per_expert_act_scale, fc2_activation_scales, fc2_weight_scales, fc2_weight_zeros, fc2_alpha};
|
||||
return qp;
|
||||
}
|
||||
|
||||
@ -646,7 +650,7 @@ public:
|
||||
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
|
||||
ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
|
||||
cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
|
||||
int* num_active_experts_per, int* active_expert_global_ids);
|
||||
int* num_active_experts_per, int* active_expert_global_ids, void const* fc2_prequant_scale = nullptr);
|
||||
|
||||
static void gemm2(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
|
||||
DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output,
|
||||
@ -803,6 +807,16 @@ private:
|
||||
bool min_latency_mode, bool use_awq);
|
||||
|
||||
private:
|
||||
static bool useAwq(cutlass_kernels::QuantParams const& quant_params)
|
||||
{
|
||||
return quant_params.groupwise.fc1.act_scales && quant_params.groupwise.fc2.act_scales && !use_wfp4a16;
|
||||
}
|
||||
|
||||
static bool usePrequantScaleKernel(cutlass_kernels::QuantParams const& quant_params)
|
||||
{
|
||||
return useAwq(quant_params) && !std::is_same_v<T, WeightType>;
|
||||
}
|
||||
|
||||
bool mayHaveDifferentGEMMOutputType() const
|
||||
{
|
||||
// We just check if its supported because we need to know when calculating workspace size
|
||||
@ -813,13 +827,13 @@ private:
|
||||
bool mayHaveFinalizeFused() const
|
||||
{
|
||||
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && use_fused_finalize_
|
||||
&& !use_w4_groupwise;
|
||||
&& !use_wfp4a16;
|
||||
}
|
||||
|
||||
static bool mayHaveFinalizeFused(int sm)
|
||||
{
|
||||
using RunnerType = decltype(moe_gemm_runner_);
|
||||
return RunnerType::supportsTmaWarpSpecialized(sm) && sm >= 90 && !use_w4_groupwise;
|
||||
return RunnerType::supportsTmaWarpSpecialized(sm) && sm >= 90 && !use_wfp4a16;
|
||||
}
|
||||
|
||||
// TODO: This should eventually take the quant params to give more flexibility
|
||||
@ -866,7 +880,8 @@ private:
|
||||
|
||||
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
|
||||
cudaStream_t stream, int64_t* expert_first_token_offset = nullptr, int const num_experts_per_node = 0);
|
||||
cudaStream_t stream, QuantParams const& quant_params, int64_t* expert_first_token_offset = nullptr,
|
||||
int const num_experts_per_node = 0);
|
||||
|
||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
|
||||
std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;
|
||||
|
||||
@ -28,8 +28,9 @@ namespace cutlass_kernels_oss
|
||||
{
|
||||
using tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput;
|
||||
using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput;
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, typename CTAShape,
|
||||
typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
|
||||
using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion;
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, EpilogueFusion FUSION,
|
||||
typename CTAShape, typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp>
|
||||
void sm90_generic_mixed_moe_gemm_kernelLauncher(
|
||||
tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,
|
||||
|
||||
@ -45,6 +45,7 @@
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
@ -71,11 +72,12 @@ namespace cutlass_kernels_oss
|
||||
using namespace tensorrt_llm::kernels::cutlass_kernels;
|
||||
namespace tk = tensorrt_llm::common;
|
||||
namespace tkc = tensorrt_llm::cutlass_extensions;
|
||||
using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion;
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, typename CTAShape,
|
||||
typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, EpilogueFusion FUSION,
|
||||
typename CTAShape, typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp>
|
||||
void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,
|
||||
TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size)
|
||||
@ -85,6 +87,9 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// GEMM kernel configurations
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
static_assert(FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE,
|
||||
"Unimplemented fusion provided to TMA WS Mixed MoE gemm launcher");
|
||||
constexpr static bool IsFinalizeFusion = FUSION == EpilogueFusion::FINALIZE;
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = typename TllmToCutlassTypeAdapter<T>::type;
|
||||
@ -129,6 +134,9 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
using ElementFinalOutput = typename TllmToCutlassTypeAdapter<GemmOutputType>::type;
|
||||
using ElementBias = ElementFinalOutput;
|
||||
using ElementRouterScales = float;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
@ -136,6 +144,11 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = CTAShape; // Threadblock-level tile size
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
|
||||
|
||||
using EpilogueFusionOp = cutlass::epilogue::fusion::ScaledAccPerRowBiasPerColScaleScatter<
|
||||
typename cutlass::layout::LayoutTranspose<LayoutD>::type, ElementFinalOutput, ElementAccumulator, ElementBias,
|
||||
ElementRouterScales>;
|
||||
|
||||
using KernelSchedule
|
||||
= std::conditional_t<std::is_same_v<MainloopScheduleType, cutlass::gemm::KernelTmaWarpSpecializedPingpong>,
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong,
|
||||
@ -145,12 +158,21 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong,
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative>; // Epilogue to launch
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<cutlass::arch::Sm90,
|
||||
using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::CollectiveBuilder<cutlass::arch::Sm90,
|
||||
cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator, ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type*,
|
||||
AlignmentC, void, typename cutlass::layout::LayoutTranspose<LayoutD>::type*, AlignmentD, EpilogueSchedule,
|
||||
EpilogueFusionOp>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder<cutlass::arch::Sm90,
|
||||
cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator, ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type*,
|
||||
AlignmentC, ElementD, typename cutlass::layout::LayoutTranspose<LayoutD>::type*, AlignmentD,
|
||||
EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveEpilogue
|
||||
= std::conditional_t<IsFinalizeFusion, CollectiveEpilogueFinalize, CollectiveEpilogueDefault>;
|
||||
|
||||
// =========================================================== MIXED INPUT WITH SCALES
|
||||
// =========================================================================== The Scale information must get paired
|
||||
// with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the
|
||||
@ -175,20 +197,56 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
|
||||
Args arguments;
|
||||
|
||||
decltype(arguments.epilogue.thread) fusion_args;
|
||||
fusion_args.alpha = use_wfp4a16 ? 1 : 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = use_wfp4a16 ? nullptr : inputs.alpha_scales;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1};
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = sm_count_;
|
||||
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueScalars = decltype(EpilogueArguments{}.thread);
|
||||
EpilogueScalars epilogue_scalars = [&]
|
||||
{
|
||||
if constexpr (IsFinalizeFusion)
|
||||
{
|
||||
auto epi_params = hopper_inputs.fused_finalize_epilogue;
|
||||
return EpilogueScalars{ElementAccumulator(1), nullptr, hopper_inputs.alpha_scale_ptr_array,
|
||||
Stride<_0, _0, int64_t>{cute::_0{}, cute::_0{}, 1}, /* alpha */
|
||||
reinterpret_cast<ElementBias const* const*>(epi_params.ptr_bias), Stride<_1, _0, int64_t>{}, /* bias */
|
||||
epi_params.ptr_router_scales, Stride<_0, _1, int64_t>{}, /* scale */
|
||||
reinterpret_cast<ElementFinalOutput*>(epi_params.ptr_final_output),
|
||||
epi_params.stride_final_output_transposed, epi_params.ptr_source_token_index,
|
||||
epi_params.num_rows_in_final_output, epi_params.shape_override, epi_params.use_reduction};
|
||||
}
|
||||
else
|
||||
{
|
||||
return EpilogueScalars{};
|
||||
}
|
||||
}();
|
||||
|
||||
EpilogueArguments epilogue_args = [&]
|
||||
{
|
||||
if constexpr (IsFinalizeFusion)
|
||||
{
|
||||
return EpilogueArguments{epilogue_scalars, nullptr, nullptr, nullptr, nullptr};
|
||||
}
|
||||
else
|
||||
{
|
||||
fusion_args.alpha = use_wfp4a16 ? 1 : 0;
|
||||
fusion_args.beta = 0;
|
||||
fusion_args.alpha_ptr = nullptr;
|
||||
fusion_args.beta_ptr = nullptr;
|
||||
fusion_args.alpha_ptr_array = use_wfp4a16 ? nullptr : inputs.alpha_scales;
|
||||
fusion_args.beta_ptr_array = nullptr;
|
||||
// One alpha and beta per each group
|
||||
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1};
|
||||
fusion_args.dBeta = {cute::_0{}, cute::_0{}, use_wfp4a16 ? 0 : 1};
|
||||
|
||||
return EpilogueArguments{fusion_args, reinterpret_cast<ElementC const**>(hopper_inputs.ptr_c),
|
||||
reinterpret_cast<StrideC*>(hopper_inputs.stride_c), reinterpret_cast<ElementD**>(hopper_inputs.ptr_d),
|
||||
reinterpret_cast<StrideD*>(hopper_inputs.stride_d)};
|
||||
}
|
||||
}();
|
||||
|
||||
arguments = Args{cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{inputs.num_experts, hopper_inputs.int4_groupwise_params.shape.problem_shapes, nullptr},
|
||||
{reinterpret_cast<ElementB const**>(hopper_inputs.ptr_weight),
|
||||
@ -197,10 +255,7 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput<T, WeightType,
|
||||
reinterpret_cast<StrideA*>(hopper_inputs.stride_act),
|
||||
reinterpret_cast<ElementScalePacked const**>(hopper_inputs.int4_groupwise_params.ptr_s_a),
|
||||
reinterpret_cast<StrideS*>(hopper_inputs.int4_groupwise_params.stride_s_a), group_size},
|
||||
{fusion_args, reinterpret_cast<ElementC const**>(hopper_inputs.ptr_c),
|
||||
reinterpret_cast<StrideC*>(hopper_inputs.stride_c), reinterpret_cast<ElementD**>(hopper_inputs.ptr_d),
|
||||
reinterpret_cast<StrideD*>(hopper_inputs.stride_d)},
|
||||
hw_info};
|
||||
epilogue_args, hw_info};
|
||||
|
||||
assert(group_size == int(inputs.groupwise_quant_group_size));
|
||||
if (workspace_size != nullptr)
|
||||
|
||||
@ -792,25 +792,37 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
inputs.gemm_config.is_tma_warp_specialized, "w4afp8 is only supported for TMA warp specialization");
|
||||
// EpilogueTag is ignored
|
||||
#define SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE(SCALE_FACTOR) \
|
||||
if (hopper_inputs.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE) \
|
||||
{ \
|
||||
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType, \
|
||||
cutlass_extensions::EpilogueOpDefault, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE, \
|
||||
SCALE_FACTOR>(inputs, hopper_inputs, multi_processor_count_, nullptr); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType, \
|
||||
cutlass_extensions::EpilogueOpDefault, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, \
|
||||
SCALE_FACTOR>(inputs, hopper_inputs, multi_processor_count_, nullptr); \
|
||||
}
|
||||
|
||||
if (inputs.k % 512 == 0)
|
||||
{
|
||||
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
|
||||
cutlass_extensions::EpilogueOpDefault, 4>(inputs, hopper_inputs, multi_processor_count_, nullptr);
|
||||
SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE(4)
|
||||
}
|
||||
else if (inputs.k % 256 == 0)
|
||||
{
|
||||
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
|
||||
cutlass_extensions::EpilogueOpDefault, 2>(inputs, hopper_inputs, multi_processor_count_, nullptr);
|
||||
SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE(2)
|
||||
}
|
||||
else if (inputs.k % 128 == 0)
|
||||
{
|
||||
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
|
||||
cutlass_extensions::EpilogueOpDefault, 1>(inputs, hopper_inputs, multi_processor_count_, nullptr);
|
||||
SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE(1)
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Invalid GEMM K size %d", (int) inputs.k);
|
||||
}
|
||||
#undef SM90_DISPATCH_MOE_MIXED_GEMM_TO_CUTLASS_SELECT_FINALIZE
|
||||
return;
|
||||
}
|
||||
|
||||
@ -820,7 +832,8 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
|
||||
inputs.gemm_config.is_tma_warp_specialized, "wfp4a16 is only supported for TMA warp specialization");
|
||||
// EpilogueTag is ignored
|
||||
cutlass_kernels_oss::sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass<T, WeightType, ScaleBiasType,
|
||||
cutlass_extensions::EpilogueOpDefault, 1>(inputs, hopper_inputs, multi_processor_count_, nullptr);
|
||||
cutlass_extensions::EpilogueOpDefault, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, 1>(
|
||||
inputs, hopper_inputs, multi_processor_count_, nullptr);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -37,6 +37,7 @@
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/detail/collective/mixed_input_utils.hpp"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
||||
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
|
||||
@ -67,11 +68,12 @@ using tensorrt_llm::kernels::cutlass_kernels::GroupedGemmInput;
|
||||
using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput;
|
||||
namespace tk = tensorrt_llm::common;
|
||||
namespace tkc = tensorrt_llm::cutlass_extensions;
|
||||
using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion;
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, typename CTAShape,
|
||||
typename ClusterShape>
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, EpilogueFusion FUSION,
|
||||
typename CTAShape, typename ClusterShape>
|
||||
void sm90_dispatch_mainloop_schedules(GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,
|
||||
TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size)
|
||||
{
|
||||
@ -88,7 +90,7 @@ void sm90_dispatch_mainloop_schedules(GroupedGemmInput<T, WeightType, GemmOutput
|
||||
{
|
||||
if constexpr ((get<0>(CTAShape{}) == 128) && get<1>(CTAShape{}) == 128)
|
||||
{
|
||||
sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, GemmOutputType, EpilogueTag, CTAShape,
|
||||
sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, GemmOutputType, EpilogueTag, FUSION, CTAShape,
|
||||
ClusterShape, cutlass::gemm::KernelTmaWarpSpecializedPingpong,
|
||||
cutlass::epilogue::TmaWarpSpecializedCooperative,
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>(
|
||||
@ -96,7 +98,7 @@ void sm90_dispatch_mainloop_schedules(GroupedGemmInput<T, WeightType, GemmOutput
|
||||
}
|
||||
else
|
||||
{
|
||||
sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, GemmOutputType, EpilogueTag, CTAShape,
|
||||
sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, GemmOutputType, EpilogueTag, FUSION, CTAShape,
|
||||
ClusterShape, cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
cutlass::epilogue::TmaWarpSpecializedCooperative,
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>(
|
||||
@ -106,9 +108,10 @@ void sm90_dispatch_mainloop_schedules(GroupedGemmInput<T, WeightType, GemmOutput
|
||||
|
||||
break;
|
||||
case tkc::MainloopScheduleType::PINGPONG:
|
||||
sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, GemmOutputType, EpilogueTag, CTAShape, ClusterShape,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecializedCooperative,
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>(inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, GemmOutputType, EpilogueTag, FUSION, CTAShape,
|
||||
ClusterShape, cutlass::gemm::KernelTmaWarpSpecializedPingpong,
|
||||
cutlass::epilogue::TmaWarpSpecializedCooperative, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>(
|
||||
inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
break;
|
||||
default:
|
||||
TLLM_THROW(
|
||||
@ -122,7 +125,8 @@ void sm90_dispatch_mainloop_schedules(GroupedGemmInput<T, WeightType, GemmOutput
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, typename CTAShape>
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, EpilogueFusion FUSION,
|
||||
typename CTAShape>
|
||||
void sm90_dispatch_moe_mixed_dtype_gemm_config(GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,
|
||||
TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size)
|
||||
{
|
||||
@ -130,26 +134,27 @@ void sm90_dispatch_moe_mixed_dtype_gemm_config(GroupedGemmInput<T, WeightType, G
|
||||
switch (inputs.gemm_config.cluster_shape)
|
||||
{
|
||||
case tkc::ClusterShape::ClusterShape_1x1x1:
|
||||
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, CTAShape, Shape<_1, _1, _1>>(
|
||||
inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, FUSION, CTAShape,
|
||||
Shape<_1, _1, _1>>(inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
break;
|
||||
case tkc::ClusterShape::ClusterShape_2x1x1:
|
||||
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, CTAShape, Shape<_2, _1, _1>>(
|
||||
inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, FUSION, CTAShape,
|
||||
Shape<_2, _1, _1>>(inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
break;
|
||||
case tkc::ClusterShape::ClusterShape_1x2x1:
|
||||
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, CTAShape, Shape<_1, _2, _1>>(
|
||||
inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, FUSION, CTAShape,
|
||||
Shape<_1, _2, _1>>(inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
break;
|
||||
case tkc::ClusterShape::ClusterShape_2x2x1:
|
||||
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, CTAShape, Shape<_2, _2, _1>>(
|
||||
inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
sm90_dispatch_mainloop_schedules<T, WeightType, GemmOutputType, EpilogueTag, FUSION, CTAShape,
|
||||
Shape<_2, _2, _1>>(inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
break;
|
||||
default: TLLM_THROW("[Mixed dtype MoE GEMM][dispatch_CGA_config] Config is invalid for mixed type GEMM."); break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, int PackedScalesNum>
|
||||
template <typename T, typename WeightType, typename GemmOutputType, typename EpilogueTag, EpilogueFusion FUSION,
|
||||
int PackedScalesNum>
|
||||
void sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass(
|
||||
GroupedGemmInput<T, WeightType, GemmOutputType, GemmOutputType> inputs,
|
||||
TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size)
|
||||
@ -168,49 +173,49 @@ void sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass(
|
||||
switch (inputs.gemm_config.tile_config_sm90)
|
||||
{
|
||||
case tkc::CutlassTileConfigSM90::CtaShape64x16x128B:
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_64, _16, _Ktile>>(
|
||||
inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
|
||||
Shape<_64, _16, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape64x32x128B:
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_64, _32, _Ktile>>(
|
||||
inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
|
||||
Shape<_64, _32, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape64x64x128B:
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_64, _64, _Ktile>>(
|
||||
inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
|
||||
Shape<_64, _64, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape64x128x128B:
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag,
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
|
||||
Shape<_64, _Ntile, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
break;
|
||||
// case tkc::CutlassTileConfigSM90::CtaShape64x256x128B:
|
||||
// sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_64, _256,
|
||||
// _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape128x16x128B:
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_128, _16, _Ktile>>(
|
||||
inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
|
||||
Shape<_128, _16, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape128x32x128B:
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_128, _32, _Ktile>>(
|
||||
inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
|
||||
Shape<_128, _32, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape128x64x128B:
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_128, _64, _Ktile>>(
|
||||
inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
|
||||
Shape<_128, _64, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape128x128x128B:
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag,
|
||||
sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION,
|
||||
Shape<_128, _128, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size);
|
||||
break;
|
||||
// case tkc::CutlassTileConfigSM90::CtaShape128x256x128B:
|
||||
// sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_128, _256,
|
||||
// _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break;
|
||||
// sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION, Shape<_128,
|
||||
// _256, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break;
|
||||
// case tkc::CutlassTileConfigSM90::CtaShape256x128x128B:
|
||||
// sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_128, _256,
|
||||
// _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break;
|
||||
// sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION, Shape<_128,
|
||||
// _256, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break;
|
||||
// case tkc::CutlassTileConfigSM90::CtaShape256x256x128B:
|
||||
// sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, Shape<_256, _256,
|
||||
// _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break;
|
||||
// sm90_dispatch_moe_mixed_dtype_gemm_config<T, WeightType, GemmOutputType, EpilogueTag, FUSION, Shape<_256,
|
||||
// _256, _Ktile>>(inputs, hopper_inputs, sm_count_, workspace_size); break;
|
||||
case tkc::CutlassTileConfigSM90::Undefined:
|
||||
TLLM_THROW("[Mixed dtype MoE GEMM][sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass] gemm config undefined.");
|
||||
break;
|
||||
@ -236,12 +241,15 @@ size_t calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput(int num_experts, int sm_
|
||||
using _Ktile = Int<Ktile>;
|
||||
|
||||
#ifdef COMPILE_HOPPER_TMA_GROUPED_GEMMS
|
||||
GroupedGemmInput<T, WeightType, OutputType, OutputType> inputs{};
|
||||
constexpr bool use_wfp4a16 = std::is_same_v<WeightType, cutlass::float_e2m1_t>;
|
||||
constexpr int group_size = use_wfp4a16 ? cutlass::gemm::collective::detail::mxfp4_group_size
|
||||
: cutlass::gemm::collective::detail::int4_group_size;
|
||||
GroupedGemmInput<T, WeightType, OutputType, OutputType> inputs{.groupwise_quant_group_size = group_size};
|
||||
inputs.num_experts = num_experts;
|
||||
sm90_generic_mixed_moe_gemm_kernelLauncher<T, WeightType, OutputType,
|
||||
tensorrt_llm::cutlass_extensions::EpilogueOpDefault, Shape<_128, _64, _Ktile>, Shape<_1, _1, _1>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative,
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>(
|
||||
tensorrt_llm::cutlass_extensions::EpilogueOpDefault, EpilogueFusion::NONE, Shape<_128, _64, _Ktile>,
|
||||
Shape<_1, _1, _1>, cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
cutlass::epilogue::TmaWarpSpecializedCooperative, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>(
|
||||
inputs, TmaWarpSpecializedGroupedGemmInput{}, sm_count_, &count);
|
||||
#endif
|
||||
return count;
|
||||
|
||||
@ -37,6 +37,7 @@
|
||||
// Order matters here, packed_stride.hpp is missing cute and convolution includes
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
@ -76,15 +77,23 @@ TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace kernels::cutlass_kernels
|
||||
{
|
||||
/**
|
||||
* Takes the input maps and prepares the expanded maps for min latency
|
||||
* @param num_active_experts_per_node: Number of active experts on current node
|
||||
* @param experts_to_token_scores: The score of each token for each activated expert. 0 if the expert is not chosen by
|
||||
* the token. Only the first num_active_experts_per_ rows are valid
|
||||
* @param active_expert_global_ids: The global expert id for each activated expert
|
||||
* Only the first num_active_experts_per_ values are valid
|
||||
* @param expert_first_token_offset: Store the first token offset for each expert
|
||||
*/
|
||||
|
||||
// Forced vectorized load
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T loadVec(T const* ptr)
|
||||
{
|
||||
T result;
|
||||
cutlass::arch::global_load<T, sizeof(T)>(result, ptr, true);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Forced vectorized store
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void storeVec(T* ptr, T const& value)
|
||||
{
|
||||
cutlass::arch::global_store<T, sizeof(T)>(value, ptr, true);
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_SIZE>
|
||||
__device__ __forceinline__ void initTensor(T* value, int const tid, int const total_num, T const init_value)
|
||||
{
|
||||
@ -156,6 +165,15 @@ __device__ __forceinline__ void setActiveNum(int& num_active, int& num_active_of
|
||||
num_active_offset_end = num_active_offset_start + num_active;
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes the input maps and prepares the expanded maps for min latency
|
||||
* @param num_active_experts_per_node: Number of active experts on current node
|
||||
* @param experts_to_token_scores: The score of each token for each activated expert. 0 if the expert is not chosen by
|
||||
* the token. Only the first num_active_experts_per_ rows are valid
|
||||
* @param active_expert_global_ids: The global expert id for each activated expert
|
||||
* Only the first num_active_experts_per_ values are valid
|
||||
* @param expert_first_token_offset: Store the first token offset for each expert
|
||||
*/
|
||||
template <int BLOCK_SIZE>
|
||||
__global__ void buildMinLatencyActiveExpertMapsKernel(int* num_active_experts_per_node, float* experts_to_token_scores,
|
||||
int* active_expert_global_ids, int64_t* expert_first_token_offset, int const* token_selected_experts,
|
||||
@ -1593,7 +1611,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
|
||||
auto func = [&]()
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
// Always MXFP8
|
||||
// Always MXFP8 and W4A8_AWQ
|
||||
if constexpr (std::is_same_v<ExpandedActivationsType, __nv_fp8_e4m3>
|
||||
&& !std::is_same_v<InputActivationsType, __nv_fp8_e4m3>)
|
||||
{
|
||||
@ -1950,7 +1968,7 @@ constexpr static int ACTIVATION_THREADS_PER_BLOCK = 256;
|
||||
template <class ActivationOutputType, class GemmOutputType, class ActFn>
|
||||
__global__ void doGatedActivationKernel(ActivationOutputType* output, GemmOutputType const* gemm_result,
|
||||
int64_t const* expert_first_token_offset, int64_t inter_size, int64_t num_experts_per_node,
|
||||
ActivationParams activation_type)
|
||||
ActivationParams activation_type, GemmOutputType const* prequant_scale, bool use_per_expert_prequant_scale)
|
||||
{
|
||||
int64_t const tid = threadIdx.x;
|
||||
int64_t const token = blockIdx.x;
|
||||
@ -1978,16 +1996,24 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output, GemmOutput
|
||||
float gate_alpha = 1.0f;
|
||||
float gate_bias = 0.0f;
|
||||
float gate_limit = std::numeric_limits<float>::infinity();
|
||||
int expert = 0;
|
||||
if (use_per_expert_prequant_scale || activation_type.swiglu_alpha || activation_type.swiglu_beta
|
||||
|| activation_type.swiglu_limit)
|
||||
{
|
||||
expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, (int64_t) token + 1) - 1;
|
||||
}
|
||||
if (activation_type.swiglu_alpha || activation_type.swiglu_beta || activation_type.swiglu_limit)
|
||||
{
|
||||
int expert
|
||||
= findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, (int64_t) token + 1) - 1;
|
||||
gate_alpha = activation_type.swiglu_alpha ? activation_type.swiglu_alpha[expert] : 1.0f;
|
||||
gate_bias = activation_type.swiglu_beta ? activation_type.swiglu_beta[expert] : 0.0f;
|
||||
gate_limit = activation_type.swiglu_limit ? activation_type.swiglu_limit[expert]
|
||||
: std::numeric_limits<float>::infinity();
|
||||
}
|
||||
|
||||
auto prequant_scale_vec = prequant_scale ? reinterpret_cast<GemmResultElem const*>(
|
||||
prequant_scale + (use_per_expert_prequant_scale ? expert * inter_size : 0))
|
||||
: nullptr;
|
||||
|
||||
ActFn fn{};
|
||||
fn.alpha = gate_alpha;
|
||||
fn.beta = gate_bias;
|
||||
@ -1998,6 +2024,13 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output, GemmOutput
|
||||
// BF16 isn't supported, use FP32 for activation function
|
||||
auto gate_value = arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index + inter_size_vec]);
|
||||
auto gate_act = fn(gate_value, linear_value);
|
||||
|
||||
// Apply prequant scale if provided
|
||||
if (prequant_scale_vec)
|
||||
{
|
||||
gate_act = gate_act * arrayConvert<GemmResultElem, ComputeElem>(prequant_scale_vec[elem_index]);
|
||||
}
|
||||
|
||||
output_vec[elem_index] = arrayConvert<ComputeElem, OutputElem>(gate_act);
|
||||
}
|
||||
}
|
||||
@ -2005,7 +2038,8 @@ __global__ void doGatedActivationKernel(ActivationOutputType* output, GemmOutput
|
||||
template <typename ActivationOutputType, typename GemmOutputType>
|
||||
void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_result,
|
||||
int64_t const* expert_first_token_offset, int64_t inter_size, int64_t num_tokens, int64_t num_experts_per_node,
|
||||
ActivationParams activation_type, cudaStream_t stream)
|
||||
ActivationParams activation_type, cudaStream_t stream, bool use_per_expert_prequant_scale = false,
|
||||
GemmOutputType const* prequant_scale = nullptr)
|
||||
{
|
||||
int64_t const blocks = num_tokens;
|
||||
int64_t const threads = ACTIVATION_THREADS_PER_BLOCK;
|
||||
@ -2018,18 +2052,19 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_
|
||||
? &doGatedActivationKernel<ActivationOutputType, GemmOutputType, SwigluBiasAdaptor>
|
||||
: nullptr;
|
||||
TLLM_CHECK_WITH_INFO(fn != nullptr, "Invalid activation type");
|
||||
fn<<<blocks, threads, 0, stream>>>(
|
||||
output, gemm_result, expert_first_token_offset, inter_size, num_experts_per_node, activation_type);
|
||||
fn<<<blocks, threads, 0, stream>>>(output, gemm_result, expert_first_token_offset, inter_size, num_experts_per_node,
|
||||
activation_type, prequant_scale, use_per_expert_prequant_scale);
|
||||
}
|
||||
|
||||
// ============================== Activation =================================
|
||||
|
||||
template <class T, class GemmOutputType, class ScaleBiasType, class ActFn,
|
||||
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType>
|
||||
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType, int kProcessRows>
|
||||
__global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant,
|
||||
ScaleBiasType const* bias_ptr, bool bias_is_broadcast, int64_t const* expert_first_token_offset,
|
||||
int num_experts_per_node, int64_t inter_size, float const* fc2_act_global_scale, bool use_per_expert_act_scale,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, ActivationParams activation_params)
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, ActivationParams activation_params,
|
||||
GemmOutputType const* prequant_scale, int64_t const num_valid_tokens)
|
||||
{
|
||||
#ifdef ENABLE_FP4
|
||||
constexpr bool IsNVFP4 = std::is_same_v<T, __nv_fp4_e2m1>
|
||||
@ -2041,7 +2076,6 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
|
||||
constexpr bool IsMXFP8 = cute::dependent_false<T>;
|
||||
#endif
|
||||
|
||||
int64_t const tid = threadIdx.x;
|
||||
constexpr bool IsGated = ActFn::IS_GLU;
|
||||
size_t gated_size_mul = IsGated ? 2 : 1;
|
||||
size_t gated_off = IsGated ? inter_size : 0;
|
||||
@ -2059,81 +2093,116 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
|
||||
: TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX;
|
||||
int64_t const padded_inter_size = ceilDiv(inter_size, min_k_dim_alignment) * min_k_dim_alignment;
|
||||
|
||||
int64_t const num_valid_tokens = expert_first_token_offset[num_experts_per_node];
|
||||
// 2D grid: blockIdx.x for tokens in groups of kProcessRows, blockIdx.y for columns
|
||||
int64_t const row_offset = blockIdx.x * kProcessRows;
|
||||
int64_t const col_offset = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
assert(inter_size % ACTIVATION_ELEM_PER_THREAD == 0);
|
||||
int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
|
||||
assert(gated_off % ACTIVATION_ELEM_PER_THREAD == 0);
|
||||
int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD;
|
||||
|
||||
// Early exit if this thread is out of bounds for columns
|
||||
if (col_offset >= num_elems_in_col)
|
||||
return;
|
||||
|
||||
// Precompute K-dimension padding range for merged SF padding
|
||||
[[maybe_unused]] int64_t const k_padding_start = inter_size / VecSize;
|
||||
[[maybe_unused]] int64_t const k_padding_end = padded_inter_size / VecSize;
|
||||
[[maybe_unused]] bool const do_k_padding
|
||||
= (IsNVFP4 || IsMXFP8) && col_offset >= k_padding_start && col_offset < k_padding_end;
|
||||
|
||||
auto main_loop = [&](auto has_prequant_scale, auto has_bias)
|
||||
{
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
for (int64_t token = blockIdx.x; token < num_valid_tokens; token += gridDim.x)
|
||||
{
|
||||
size_t gemm_result_offset = token * inter_size * gated_size_mul;
|
||||
size_t output_offset = token * inter_size;
|
||||
constexpr bool has_prequant_scale_v = decltype(has_prequant_scale)::value;
|
||||
constexpr bool has_bias_v = decltype(has_bias)::value;
|
||||
|
||||
int64_t expert = 0;
|
||||
float gate_alpha = 1.0f;
|
||||
float gate_beta = 0.0f;
|
||||
float gate_limit = std::numeric_limits<float>::infinity();
|
||||
if (bias_ptr || IsNVFP4 || IsMXFP8 || use_per_expert_act_scale || activation_params.swiglu_alpha
|
||||
|| activation_params.swiglu_beta || activation_params.swiglu_limit)
|
||||
// Process kProcessRows consecutive tokens, promoting per-expert data reuse
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kProcessRows; ++i)
|
||||
{
|
||||
// TODO this is almost certainly faster as a linear scan
|
||||
expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1;
|
||||
int64_t const token = row_offset + i;
|
||||
|
||||
gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f;
|
||||
gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f;
|
||||
gate_limit = activation_params.swiglu_limit ? activation_params.swiglu_limit[expert]
|
||||
: std::numeric_limits<float>::infinity();
|
||||
}
|
||||
// Early exit for this row if out of bounds
|
||||
if (token >= num_valid_tokens)
|
||||
break;
|
||||
|
||||
size_t act_scale_idx = use_per_expert_act_scale ? expert : 0;
|
||||
float const quant_scale = fp8_quant ? fp8_quant[act_scale_idx] : 1.f;
|
||||
size_t gemm_result_offset = token * inter_size * gated_size_mul;
|
||||
size_t output_offset = token * inter_size;
|
||||
|
||||
// Some globals for FP4
|
||||
float global_scale_val = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1.0f;
|
||||
int64_t num_tokens_before_expert = (IsNVFP4 || IsMXFP8) ? expert_first_token_offset[expert] : 0;
|
||||
int64_t expert = 0;
|
||||
float gate_alpha = 1.0f;
|
||||
float gate_beta = 0.0f;
|
||||
float gate_limit = std::numeric_limits<float>::infinity();
|
||||
if (bias_ptr || IsNVFP4 || IsMXFP8 || use_per_expert_act_scale || activation_params.swiglu_alpha
|
||||
|| activation_params.swiglu_beta || activation_params.swiglu_limit)
|
||||
{
|
||||
// TODO this is almost certainly faster as a linear scan
|
||||
expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1;
|
||||
|
||||
size_t bias_offset = 0;
|
||||
if (bias_ptr)
|
||||
{
|
||||
bias_offset = (bias_is_broadcast ? expert * inter_size * gated_size_mul : gemm_result_offset);
|
||||
}
|
||||
gate_alpha = activation_params.swiglu_alpha ? activation_params.swiglu_alpha[expert] : 1.0f;
|
||||
gate_beta = activation_params.swiglu_beta ? activation_params.swiglu_beta[expert] : 0.0f;
|
||||
gate_limit = activation_params.swiglu_limit ? activation_params.swiglu_limit[expert]
|
||||
: std::numeric_limits<float>::infinity();
|
||||
}
|
||||
|
||||
using BiasElem = cutlass::Array<ScaleBiasType, ACTIVATION_ELEM_PER_THREAD>;
|
||||
using GemmResultElem = cutlass::Array<GemmOutputType, ACTIVATION_ELEM_PER_THREAD>;
|
||||
using OutputElem = std::conditional_t<IsNVFP4, uint32_t,
|
||||
std::conditional_t<IsMXFP8, uint64_t, cutlass::Array<T, ACTIVATION_ELEM_PER_THREAD>>>;
|
||||
using ComputeElem = cutlass::Array<float, ACTIVATION_ELEM_PER_THREAD>;
|
||||
// Aliases gemm_result for non-gated, non-fp8 cases
|
||||
auto gemm_result_vec = reinterpret_cast<GemmResultElem const*>(gemm_result + gemm_result_offset);
|
||||
auto output_vec = reinterpret_cast<OutputElem*>(safe_inc_ptr(output, output_offset));
|
||||
auto bias_ptr_vec = reinterpret_cast<BiasElem const*>(bias_ptr + bias_offset);
|
||||
int64_t const start_offset = tid;
|
||||
int64_t const stride = ACTIVATION_THREADS_PER_BLOCK;
|
||||
assert(inter_size % ACTIVATION_ELEM_PER_THREAD == 0);
|
||||
int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
|
||||
assert(gated_off % ACTIVATION_ELEM_PER_THREAD == 0);
|
||||
int64_t const gated_off_vec = gated_off / ACTIVATION_ELEM_PER_THREAD;
|
||||
size_t act_scale_idx = use_per_expert_act_scale ? expert : 0;
|
||||
float const quant_scale = fp8_quant ? fp8_quant[act_scale_idx] : 1.f;
|
||||
|
||||
ActFn fn{};
|
||||
fn.alpha = gate_alpha;
|
||||
fn.beta = gate_beta;
|
||||
fn.limit = gate_limit;
|
||||
for (int64_t elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
|
||||
{
|
||||
auto fc1_value = arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index + gated_off_vec]);
|
||||
// Some globals for FP4
|
||||
float global_scale_val = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1.0f;
|
||||
int64_t num_tokens_before_expert = (IsNVFP4 || IsMXFP8) ? expert_first_token_offset[expert] : 0;
|
||||
|
||||
size_t bias_offset = 0;
|
||||
if (bias_ptr)
|
||||
{
|
||||
fc1_value = fc1_value + arrayConvert<BiasElem, ComputeElem>(bias_ptr_vec[elem_index + gated_off_vec]);
|
||||
bias_offset = (bias_is_broadcast ? expert * inter_size * gated_size_mul : gemm_result_offset);
|
||||
}
|
||||
|
||||
using BiasElem = cutlass::Array<ScaleBiasType, ACTIVATION_ELEM_PER_THREAD>;
|
||||
using GemmResultElem = cutlass::Array<GemmOutputType, ACTIVATION_ELEM_PER_THREAD>;
|
||||
using OutputElem = std::conditional_t<IsNVFP4, uint32_t,
|
||||
std::conditional_t<IsMXFP8, uint64_t, cutlass::Array<T, ACTIVATION_ELEM_PER_THREAD>>>;
|
||||
using ComputeElem = cutlass::Array<float, ACTIVATION_ELEM_PER_THREAD>;
|
||||
// Aliases gemm_result for non-gated, non-fp8 cases
|
||||
auto gemm_result_vec = reinterpret_cast<GemmResultElem const*>(gemm_result + gemm_result_offset);
|
||||
auto output_vec = reinterpret_cast<OutputElem*>(safe_inc_ptr(output, output_offset));
|
||||
auto bias_ptr_vec = reinterpret_cast<BiasElem const*>(bias_ptr + bias_offset);
|
||||
auto prequant_scale_vec = prequant_scale
|
||||
? reinterpret_cast<GemmResultElem const*>(prequant_scale + expert * inter_size)
|
||||
: nullptr;
|
||||
|
||||
ActFn fn{};
|
||||
fn.alpha = gate_alpha;
|
||||
fn.beta = gate_beta;
|
||||
fn.limit = gate_limit;
|
||||
|
||||
// Each thread handles one vector at col_offset
|
||||
int64_t const elem_index = col_offset;
|
||||
|
||||
// Use loadVec to force LDG.128 vectorized loads
|
||||
auto fc1_value
|
||||
= arrayConvert<GemmResultElem, ComputeElem>(loadVec(&gemm_result_vec[elem_index + gated_off_vec]));
|
||||
if constexpr (has_bias_v)
|
||||
{
|
||||
fc1_value = fc1_value
|
||||
+ arrayConvert<BiasElem, ComputeElem>(loadVec(&bias_ptr_vec[elem_index + gated_off_vec]));
|
||||
}
|
||||
|
||||
auto gate_act = [&]()
|
||||
{
|
||||
if constexpr (IsGated)
|
||||
{
|
||||
auto linear_value = arrayConvert<GemmResultElem, ComputeElem>(gemm_result_vec[elem_index]);
|
||||
if (bias_ptr_vec)
|
||||
auto linear_value
|
||||
= arrayConvert<GemmResultElem, ComputeElem>(loadVec(&gemm_result_vec[elem_index]));
|
||||
if constexpr (has_bias_v)
|
||||
{
|
||||
linear_value = linear_value + arrayConvert<BiasElem, ComputeElem>(bias_ptr_vec[elem_index]);
|
||||
linear_value
|
||||
= linear_value + arrayConvert<BiasElem, ComputeElem>(loadVec(&bias_ptr_vec[elem_index]));
|
||||
}
|
||||
return fn(fc1_value, linear_value);
|
||||
}
|
||||
@ -2145,6 +2214,13 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
|
||||
|
||||
auto post_act_val = gate_act * quant_scale;
|
||||
|
||||
// Apply prequant scale (shape [experts_per_rank, intermediate_size]) if provided
|
||||
if constexpr (has_prequant_scale_v)
|
||||
{
|
||||
post_act_val = post_act_val
|
||||
* arrayConvert<GemmResultElem, ComputeElem>(loadVec(&prequant_scale_vec[elem_index]));
|
||||
}
|
||||
|
||||
if constexpr (IsNVFP4 || IsMXFP8)
|
||||
{
|
||||
// We use GemmOutputType as the intermediate compute type as that should always be unquantized
|
||||
@ -2155,74 +2231,98 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
|
||||
static_assert(
|
||||
sizeof(res) == sizeof(*output_vec), "Quantized value must be the same size as the output");
|
||||
output_vec[elem_index] = res;
|
||||
|
||||
// Pad zeros in the extra SFs along the K dimension, we do this to ensure there are no nan values in the
|
||||
// padded SF atom. Only process padding for valid tokens in this block's row range.
|
||||
if (do_k_padding)
|
||||
{
|
||||
writeSF<VecSize, VecSize>(num_tokens_before_expert, expert, /*source_row*/ -1, token, col_offset,
|
||||
padded_inter_size, fc2_act_sf_flat,
|
||||
/* input_sf */ nullptr); // Pass nullptr input_sf so we write 0
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
output_vec[elem_index] = arrayConvert<ComputeElem, OutputElem>(post_act_val);
|
||||
// Use storeVec to force STG.128 vectorized store
|
||||
storeVec(&output_vec[elem_index], arrayConvert<ComputeElem, OutputElem>(post_act_val));
|
||||
}
|
||||
}
|
||||
|
||||
// Pad zeros in the extra SFs along the K dimension, we do this to ensure there are no nan values in the padded
|
||||
// SF atom
|
||||
if constexpr (IsNVFP4 || IsMXFP8)
|
||||
{
|
||||
// Use VecSize per thread since we are just writing out zeros so every thread can process a whole vector
|
||||
size_t padding_start_offset = inter_size / VecSize + start_offset;
|
||||
size_t padding_elems_in_col = padded_inter_size / VecSize;
|
||||
for (int64_t elem_index = padding_start_offset; elem_index < padding_elems_in_col; elem_index += stride)
|
||||
{
|
||||
writeSF<VecSize, VecSize>(num_tokens_before_expert, expert, /*source_row*/ -1, token, elem_index,
|
||||
padded_inter_size, fc2_act_sf_flat, /* input_sf */ nullptr); // Pass nulltpr input_sf so we write 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
|
||||
// Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values in the padded SF
|
||||
// atom
|
||||
if constexpr (IsNVFP4 || IsMXFP8)
|
||||
{
|
||||
int64_t const start_offset = threadIdx.x;
|
||||
int64_t const stride = ACTIVATION_THREADS_PER_BLOCK;
|
||||
// Use VecSize per thread since we are just writing out zeros so every thread can process a whole vector
|
||||
int64_t const padded_num_elems_in_col = padded_inter_size / VecSize;
|
||||
assert(padded_inter_size % VecSize == 0);
|
||||
|
||||
constexpr int64_t min_num_tokens_alignment = IsNVFP4
|
||||
? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4
|
||||
: TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX;
|
||||
static_assert((min_num_tokens_alignment & (min_num_tokens_alignment - 1)) == 0,
|
||||
"Min num tokens alignment must be a power of two");
|
||||
// Since we don't know a priori how much padding is needed we assume the max per expert
|
||||
// NOTE: we don't (min_num_tokens_alignment-1) to have power of two divisions
|
||||
int64_t num_padding_tokens = min_num_tokens_alignment * num_experts_per_node;
|
||||
|
||||
for (int64_t padding_token = blockIdx.x; padding_token < num_padding_tokens; padding_token += gridDim.x)
|
||||
// Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values in the padded
|
||||
// SF atom. This is handled by dedicated blocks beyond num_valid_tokens range.
|
||||
if constexpr (IsNVFP4 || IsMXFP8)
|
||||
{
|
||||
int64_t expert = padding_token / min_num_tokens_alignment;
|
||||
int64_t num_tokens_before_expert = expert_first_token_offset[expert];
|
||||
int64_t num_tokens_after_expert = expert_first_token_offset[expert + 1];
|
||||
int64_t tokens_to_expert = num_tokens_after_expert - num_tokens_before_expert;
|
||||
int64_t padding_to_expert
|
||||
= TmaWarpSpecializedGroupedGemmInput::alignToSfDim(tokens_to_expert, min_num_tokens_alignment)
|
||||
- tokens_to_expert;
|
||||
int64_t expert_pad_idx = padding_token % min_num_tokens_alignment;
|
||||
if (expert_pad_idx < padding_to_expert)
|
||||
constexpr int64_t min_num_tokens_alignment = IsNVFP4
|
||||
? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4
|
||||
: TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX;
|
||||
static_assert((min_num_tokens_alignment & (min_num_tokens_alignment - 1)) == 0,
|
||||
"Min num tokens alignment must be a power of two");
|
||||
|
||||
int64_t const padded_num_elems_in_col = padded_inter_size / VecSize;
|
||||
int64_t const sf_col_offset = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
if (sf_col_offset >= padded_num_elems_in_col)
|
||||
return;
|
||||
|
||||
// Only blocks that handle padding tokens participate
|
||||
int64_t const num_token_blocks = (num_valid_tokens + kProcessRows - 1) / kProcessRows;
|
||||
if (blockIdx.x < num_token_blocks)
|
||||
return;
|
||||
|
||||
// This block handles N-dimension padding
|
||||
int64_t const padding_block_idx = blockIdx.x - num_token_blocks;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kProcessRows; ++i)
|
||||
{
|
||||
for (int64_t elem_index = start_offset; elem_index < padded_num_elems_in_col; elem_index += stride)
|
||||
int64_t const padding_token = padding_block_idx * kProcessRows + i;
|
||||
int64_t const num_padding_tokens = min_num_tokens_alignment * num_experts_per_node;
|
||||
|
||||
if (padding_token >= num_padding_tokens)
|
||||
return;
|
||||
|
||||
int64_t expert = padding_token / min_num_tokens_alignment;
|
||||
int64_t num_tokens_before_expert = expert_first_token_offset[expert];
|
||||
int64_t num_tokens_after_expert = expert_first_token_offset[expert + 1];
|
||||
int64_t tokens_to_expert = num_tokens_after_expert - num_tokens_before_expert;
|
||||
int64_t padding_to_expert
|
||||
= TmaWarpSpecializedGroupedGemmInput::alignToSfDim(tokens_to_expert, min_num_tokens_alignment)
|
||||
- tokens_to_expert;
|
||||
int64_t expert_pad_idx = padding_token % min_num_tokens_alignment;
|
||||
if (expert_pad_idx < padding_to_expert)
|
||||
{
|
||||
// The SF buffer is padded to a multiple of MinNDimAlignment for each expert
|
||||
// This means we can safely write to offset num_tokens_after_expert + padded_token, since the next
|
||||
// expert will leave space for the padding
|
||||
// This means we can safely write to offset num_tokens_after_expert + padded_token, since the
|
||||
// next expert will leave space for the padding
|
||||
writeSF<VecSize, VecSize>(num_tokens_before_expert, expert, /*source_row*/ -1,
|
||||
num_tokens_after_expert + expert_pad_idx, elem_index, padded_inter_size, fc2_act_sf_flat,
|
||||
/* input_sf */ nullptr); // Pass nulltpr input_sf so we write 0
|
||||
num_tokens_after_expert + expert_pad_idx, sf_col_offset, padded_inter_size, fc2_act_sf_flat,
|
||||
/* input_sf */ nullptr); // Pass nullptr input_sf so we write 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // end of lambda main_loop
|
||||
|
||||
// Instantiate four code paths for the different combinations of prequant_scale and bias_ptr
|
||||
// This ensures relevant codes will be executed in tight loop
|
||||
if (prequant_scale && bias_ptr)
|
||||
{
|
||||
main_loop(/* has_prequant_scale */ std::true_type{}, /* has_bias */ std::true_type{});
|
||||
}
|
||||
else if (!prequant_scale && !bias_ptr)
|
||||
{
|
||||
main_loop(/* has_prequant_scale */ std::false_type{}, /* has_bias */ std::false_type{});
|
||||
}
|
||||
else if (!prequant_scale && bias_ptr)
|
||||
{
|
||||
main_loop(/* has_prequant_scale */ std::false_type{}, /* has_bias */ std::true_type{});
|
||||
}
|
||||
else // prequant_scale && !bias_ptr
|
||||
{
|
||||
main_loop(/* has_prequant_scale */ std::true_type{}, /* has_bias */ std::false_type{});
|
||||
}
|
||||
}
|
||||
|
||||
@ -2230,99 +2330,143 @@ template <class T, class GemmOutputType, class ScaleBiasType>
|
||||
void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias,
|
||||
bool bias_is_broadcast, int64_t const* expert_first_token_offset, int num_experts_per_node, int64_t inter_size,
|
||||
int64_t expanded_num_tokens, ActivationParams activation_type, QuantParams const& quant_params,
|
||||
bool use_per_expert_act_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, cudaStream_t stream)
|
||||
bool use_per_expert_act_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, cudaStream_t stream,
|
||||
GemmOutputType const* prequant_scale = nullptr)
|
||||
{
|
||||
|
||||
#ifdef ENABLE_FP4
|
||||
constexpr int64_t min_num_tokens_alignment = std::is_same_v<T, __nv_fp4_e2m1>
|
||||
? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4
|
||||
: TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX;
|
||||
int64_t num_padding_tokens = min_num_tokens_alignment * num_experts_per_node;
|
||||
constexpr bool IsNVFP4 = std::is_same_v<T, __nv_fp4_e2m1>;
|
||||
constexpr bool IsMXFP8 = std::is_same_v<T, __nv_fp8_e4m3>;
|
||||
#else
|
||||
int64_t num_padding_tokens = 0;
|
||||
constexpr bool IsNVFP4 = false;
|
||||
constexpr bool IsMXFP8 = false;
|
||||
#endif
|
||||
// ACTIVATION_ELEM_PER_THREAD must match kernel's computation
|
||||
constexpr int64_t ACTIVATION_ELEM_PER_THREAD = (IsNVFP4 || IsMXFP8)
|
||||
? CVT_ELTS_PER_THREAD
|
||||
: (128 / std::min(sizeof_bits<T>::value, sizeof_bits<GemmOutputType>::value));
|
||||
|
||||
auto fn = [&]()
|
||||
int64_t const num_elems_in_col = inter_size / ACTIVATION_ELEM_PER_THREAD;
|
||||
|
||||
auto doActivationKernelLauncher = [&](auto num_rows_per_cta)
|
||||
{
|
||||
// IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
|
||||
// common.h
|
||||
auto fn
|
||||
= [&](auto block_scaling_type) -> void (*)(T*, GemmOutputType const*, float const*, ScaleBiasType const*,
|
||||
bool, int64_t const*, int, int64_t, float const*, bool,
|
||||
TmaWarpSpecializedGroupedGemmInput::ElementSF*, ActivationParams)
|
||||
constexpr int num_rows_per_cta_v = num_rows_per_cta.value;
|
||||
|
||||
// For NVFP4/MXFPX SFs
|
||||
int64_t num_padding_tokens = 0;
|
||||
auto fn = [&]()
|
||||
{
|
||||
switch (activation_type.activation_type)
|
||||
// IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
|
||||
// common.h
|
||||
auto fn
|
||||
= [&](auto block_scaling_type) -> void (*)(T*, GemmOutputType const*, float const*,
|
||||
ScaleBiasType const*, bool, int64_t const*, int, int64_t,
|
||||
float const*, bool, TmaWarpSpecializedGroupedGemmInput::ElementSF*,
|
||||
ActivationParams, GemmOutputType const*, int64_t)
|
||||
{
|
||||
case ActivationType::Identity:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::Identity>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Gelu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Relu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::ReLu>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Silu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Swiglu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
GLUAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Geglu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
GLUAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value>;
|
||||
case ActivationType::SwigluBias:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
|
||||
decltype(block_scaling_type)::value>;
|
||||
case ActivationType::Relu2:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::Relu2>, decltype(block_scaling_type)::value>;
|
||||
default: TLLM_CHECK_WITH_INFO(false, "Invalid activation type"); return nullptr;
|
||||
}
|
||||
};
|
||||
auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
|
||||
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4>{};
|
||||
auto MXFPX = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
|
||||
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX>{};
|
||||
auto NONE = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
|
||||
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE>{};
|
||||
switch (activation_type.activation_type)
|
||||
{
|
||||
case ActivationType::Identity:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::Identity>, decltype(block_scaling_type)::value,
|
||||
num_rows_per_cta_v>;
|
||||
case ActivationType::Gelu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value,
|
||||
num_rows_per_cta_v>;
|
||||
case ActivationType::Relu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::ReLu>, decltype(block_scaling_type)::value,
|
||||
num_rows_per_cta_v>;
|
||||
case ActivationType::Silu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value,
|
||||
num_rows_per_cta_v>;
|
||||
case ActivationType::Swiglu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
GLUAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value,
|
||||
num_rows_per_cta_v>;
|
||||
case ActivationType::Geglu:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
GLUAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value,
|
||||
num_rows_per_cta_v>;
|
||||
case ActivationType::SwigluBias:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
|
||||
decltype(block_scaling_type)::value, num_rows_per_cta_v>;
|
||||
case ActivationType::Relu2:
|
||||
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
|
||||
IdentityAdaptor<cutlass::epilogue::thread::Relu2>, decltype(block_scaling_type)::value,
|
||||
num_rows_per_cta_v>;
|
||||
default: TLLM_CHECK_WITH_INFO(false, "Invalid activation type"); return nullptr;
|
||||
}
|
||||
};
|
||||
auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
|
||||
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4>{};
|
||||
auto MXFPX = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
|
||||
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX>{};
|
||||
auto NONE = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
|
||||
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE>{};
|
||||
#ifdef ENABLE_FP4
|
||||
if constexpr (std::is_same_v<T, __nv_fp4_e2m1>)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
quant_params.fp4.fc2.weight_block_scale, "NVFP4 block scaling is expected for FP4xFP4");
|
||||
return fn(NVFP4);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, __nv_fp8_e4m3>)
|
||||
{
|
||||
return quant_params.mxfp8_mxfp4.fc2.weight_block_scale ? fn(MXFPX) : fn(NONE);
|
||||
}
|
||||
else
|
||||
if constexpr (std::is_same_v<T, __nv_fp4_e2m1>)
|
||||
{
|
||||
num_padding_tokens = TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4 * num_experts_per_node;
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
quant_params.fp4.fc2.weight_block_scale, "NVFP4 block scaling is expected for FP4xFP4");
|
||||
return fn(NVFP4);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, __nv_fp8_e4m3>)
|
||||
{
|
||||
num_padding_tokens = quant_params.mxfp8_mxfp4.fc2.weight_block_scale
|
||||
? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX * num_experts_per_node
|
||||
: 0;
|
||||
return quant_params.mxfp8_mxfp4.fc2.weight_block_scale ? fn(MXFPX) : fn(NONE);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
return fn(NONE);
|
||||
}
|
||||
}();
|
||||
{
|
||||
return fn(NONE);
|
||||
}
|
||||
}();
|
||||
|
||||
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(fn, ACTIVATION_THREADS_PER_BLOCK, 0);
|
||||
int32_t const blocks
|
||||
= std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(std::max(expanded_num_tokens, num_padding_tokens)));
|
||||
int32_t const threads = ACTIVATION_THREADS_PER_BLOCK;
|
||||
// X dimension for tokens in groups of num_rows_per_cta_v
|
||||
// Y dimension for columns
|
||||
int64_t const num_token_blocks = (expanded_num_tokens + num_rows_per_cta_v - 1) / num_rows_per_cta_v;
|
||||
int64_t const num_padding_blocks = (num_padding_tokens + num_rows_per_cta_v - 1) / num_rows_per_cta_v;
|
||||
// Add extra blocks in X dimension for token-wise padding in FP4/MXFP8 modes
|
||||
int32_t const grid_x = static_cast<int32_t>(num_token_blocks + num_padding_blocks);
|
||||
int32_t const grid_y = static_cast<int32_t>(
|
||||
(num_elems_in_col + ACTIVATION_THREADS_PER_BLOCK - 1) / ACTIVATION_THREADS_PER_BLOCK);
|
||||
int32_t const threads = ACTIVATION_THREADS_PER_BLOCK;
|
||||
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = blocks;
|
||||
config.blockDim = threads;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, fn, output, gemm_result, fp8_quant, bias, bias_is_broadcast, expert_first_token_offset,
|
||||
num_experts_per_node, inter_size, quant_params.fp4.fc2.act_global_scale, use_per_expert_act_scale,
|
||||
fc2_act_sf_flat, activation_type);
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = dim3(grid_x, grid_y, 1);
|
||||
config.blockDim = dim3(1, threads, 1);
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, fn, output, gemm_result, fp8_quant, bias, bias_is_broadcast,
|
||||
expert_first_token_offset, num_experts_per_node, inter_size, quant_params.fp4.fc2.act_global_scale,
|
||||
use_per_expert_act_scale, fc2_act_sf_flat, activation_type, prequant_scale, expanded_num_tokens);
|
||||
}; // end lambda doActivationKernelLauncher
|
||||
|
||||
// 256 threads per block * 256 blocks / 1 rows per block can be handled by 1-2 waves depending on SM arch
|
||||
if (num_elems_in_col * expanded_num_tokens < 256 * 256)
|
||||
{
|
||||
doActivationKernelLauncher(std::integral_constant<int, 1>());
|
||||
}
|
||||
// 256 threads per block * 512 blocks / 2 rows per block can be handled by 1-2 waves depending on SM arch
|
||||
else if (num_elems_in_col * expanded_num_tokens < 256 * 512)
|
||||
{
|
||||
doActivationKernelLauncher(std::integral_constant<int, 2>());
|
||||
}
|
||||
// Regular case
|
||||
else
|
||||
{
|
||||
doActivationKernelLauncher(std::integral_constant<int, 4>());
|
||||
}
|
||||
}
|
||||
|
||||
// ============================== Lora Add Bias =================================
|
||||
@ -2877,11 +3021,10 @@ template <class T, class WeightType, class OutputType, class InputType, class Sc
|
||||
T const* CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Enable>::applyPrequantScale(
|
||||
void* smoothed_act, void const* permuted_data, void const* prequant_scales, int64_t const* num_valid_tokens_ptr,
|
||||
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream,
|
||||
int64_t* expert_first_token_offset, int const num_experts_per_node)
|
||||
QuantParams const& quant_params, int64_t* expert_first_token_offset, int const num_experts_per_node)
|
||||
{
|
||||
T const* gemm_input;
|
||||
bool use_prequant_scale_kernel = use_awq && !std::is_same_v<T, WeightType>;
|
||||
if (use_prequant_scale_kernel)
|
||||
if (usePrequantScaleKernel(quant_params))
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
(!std::is_same_v<T, WeightType>), "Prequant scales are only used for different weight/activation type!");
|
||||
@ -2926,7 +3069,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
|
||||
int64_t const inter_size, int const num_experts_per_node, ActivationParams fc1_activation_type,
|
||||
float const** alpha_scale_ptr_array, bool bias_is_broadcast, cudaStream_t stream,
|
||||
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
|
||||
int* active_expert_global_ids)
|
||||
int* active_expert_global_ids, void const* fc2_prequant_scale)
|
||||
{
|
||||
|
||||
if (fp8_blockscale_gemm_runner)
|
||||
@ -2988,16 +3131,30 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
|
||||
sync_check_cuda_error(stream);
|
||||
|
||||
// TODO: when bias_is_broadcast is false, fuse bias to gemm
|
||||
using GatedActOutputType = std::conditional_t<use_w4afp8, BackBoneType, T>;
|
||||
bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4.fc2.use_per_expert_act_scale
|
||||
: use_wfp4afp8 ? quant_params.fp8_mxfp4.fc2.use_per_expert_act_scale
|
||||
: use_fp8 ? quant_params.fp8.fc2_use_per_expert_act_scale
|
||||
: Self::useAwq(quant_params) ? quant_params.groupwise.fc2.use_per_expert_act_scale
|
||||
: false;
|
||||
|
||||
doActivation<GatedActOutputType, UnfusedGemmOutputType>(reinterpret_cast<GatedActOutputType*>(output),
|
||||
static_cast<UnfusedGemmOutputType const*>(gemm_output), fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast,
|
||||
expert_first_token_offset, num_experts_per_node, inter_size, expanded_num_rows, fc1_activation_type,
|
||||
quant_params, use_per_expert_act_scale, fc2_fp4_act_flat, stream);
|
||||
// Activation -> (BackboneType) -> Prequant -> (T == ActType)
|
||||
// When fusing activation and prequant, the output type is directly T = =ActType
|
||||
// Else, the output type is BackboneType
|
||||
if (fc2_prequant_scale)
|
||||
{
|
||||
doActivation<T, UnfusedGemmOutputType>(reinterpret_cast<T*>(output),
|
||||
static_cast<UnfusedGemmOutputType const*>(gemm_output), fc2_fp8_quant, fc1_expert_biases,
|
||||
bias_is_broadcast, expert_first_token_offset, num_experts_per_node, inter_size, expanded_num_rows,
|
||||
fc1_activation_type, quant_params, use_per_expert_act_scale, fc2_fp4_act_flat, stream,
|
||||
static_cast<UnfusedGemmOutputType const*>(fc2_prequant_scale));
|
||||
}
|
||||
else
|
||||
{
|
||||
using GatedActOutputType = std::conditional_t<use_w4afp8, BackBoneType, T>;
|
||||
doActivation<GatedActOutputType, UnfusedGemmOutputType>(reinterpret_cast<GatedActOutputType*>(output),
|
||||
static_cast<UnfusedGemmOutputType const*>(gemm_output), fc2_fp8_quant, fc1_expert_biases,
|
||||
bias_is_broadcast, expert_first_token_offset, num_experts_per_node, inter_size, expanded_num_rows,
|
||||
fc1_activation_type, quant_params, use_per_expert_act_scale, fc2_fp4_act_flat, stream);
|
||||
}
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
@ -3087,9 +3244,12 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
|
||||
if (!use_ampere_activation_fusion)
|
||||
{
|
||||
using GatedActOutputType = std::conditional_t<use_w4afp8, BackBoneType, T>;
|
||||
bool const use_per_expert_act_scale
|
||||
= Self::useAwq(quant_params) ? quant_params.groupwise.fc2.use_per_expert_act_scale : false;
|
||||
doGatedActivation<GatedActOutputType, UnfusedGemmOutputType>(reinterpret_cast<GatedActOutputType*>(output),
|
||||
static_cast<UnfusedGemmOutputType const*>(intermediate_result), expert_first_token_offset, inter_size,
|
||||
expanded_num_rows, num_experts_per_node, fc1_activation_type, stream);
|
||||
expanded_num_rows, num_experts_per_node, fc1_activation_type, stream, use_per_expert_act_scale,
|
||||
static_cast<UnfusedGemmOutputType const*>(fc2_prequant_scale));
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
@ -3187,7 +3347,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
|
||||
loraBiasApplyFunc(static_cast<UnfusedGemmOutputType*>(gemm_output),
|
||||
static_cast<UnfusedGemmOutputType const*>(gemm_output), nullptr,
|
||||
static_cast<ScaleBiasType const*>(fc2_lora), false, expert_first_token_offset, num_experts_per_node,
|
||||
hidden_size, expanded_num_rows, ActivationParams(ActivationType::Identity), {}, false, nullptr, stream);
|
||||
hidden_size, expanded_num_rows, ActivationParams(ActivationType::Identity), {}, false, nullptr, stream,
|
||||
/*prequant_scale=*/nullptr);
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
|
||||
@ -3560,7 +3721,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
|
||||
fc2_fp8_dequant == nullptr, "Scales are ignored for fp32/fp16/bf16 but received quant scale for FC2");
|
||||
}
|
||||
|
||||
bool use_awq = quant_params.groupwise.fc1.act_scales && quant_params.groupwise.fc2.act_scales && !use_wfp4a16;
|
||||
bool use_awq = useAwq(quant_params);
|
||||
int const num_experts_per_node = full_num_experts / parallelism_config.ep_size;
|
||||
|
||||
configureWsPtrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token,
|
||||
@ -3602,11 +3763,11 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
|
||||
fc2_fp4_act_scale_, quant_params, num_rows, expanded_num_rows, expected_tokens_per_expert, hidden_size,
|
||||
inter_size, num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array_fc1_, !use_lora, stream,
|
||||
*gemm1_config_, true, min_latency_params.num_active_experts_per_node,
|
||||
min_latency_params.active_expert_global_ids);
|
||||
min_latency_params.active_expert_global_ids, /*fc2_prequant_scale=*/nullptr);
|
||||
sync_check_cuda_error(stream);
|
||||
|
||||
auto gemm2_input = applyPrequantScale(smoothed_act_, fc1_result_, quant_params.groupwise.fc2.act_scales,
|
||||
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream);
|
||||
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream, quant_params);
|
||||
Self::gemm2(moe_gemm_runner_, blockscale_gemm_runner, gemm2_input, final_output, nullptr,
|
||||
expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales,
|
||||
fc2_fp8_dequant, fc2_fp4_act_scale_, quant_params, token_topk_unpermuted_scales,
|
||||
@ -3621,7 +3782,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
|
||||
else
|
||||
{
|
||||
bool fused_prologue_result = false;
|
||||
if (!use_w4_groupwise)
|
||||
if (!use_wfp4a16)
|
||||
{
|
||||
// WAR: fusedBuildExpertMapsSortFirstToken kernel will lead to illegal memory access for W4AFP8
|
||||
fused_prologue_result = fusedBuildExpertMapsSortFirstToken(token_selected_experts,
|
||||
@ -3659,6 +3820,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
|
||||
// Only NVFP4xNVFP4 supports FC1 per-expert act scale
|
||||
bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4.fc1.use_per_expert_act_scale : false;
|
||||
T* gemm1_input_expand = use_w4afp8 ? reinterpret_cast<T*>(smoothed_act_) : reinterpret_cast<T*>(permuted_data_);
|
||||
// Expand input and maybe apply prequant scale for AWQ
|
||||
expandInputRowsKernelLauncher(input_activations, gemm1_input_expand, token_topk_unpermuted_scales,
|
||||
permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows, hidden_size, experts_per_token,
|
||||
num_experts_per_node, quant_params, use_per_expert_act_scale, expert_first_token_offset_,
|
||||
@ -3695,15 +3857,22 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
|
||||
if constexpr (!use_w4afp8)
|
||||
{
|
||||
gemm1_input = applyPrequantScale(smoothed_act_, permuted_data_, quant_params.groupwise.fc1.act_scales,
|
||||
num_valid_tokens_ptr, expanded_num_rows, hidden_size, use_awq, stream);
|
||||
num_valid_tokens_ptr, expanded_num_rows, hidden_size, use_awq, stream, quant_params);
|
||||
}
|
||||
sync_check_cuda_error(stream);
|
||||
Self::gemm1(moe_gemm_runner_, blockscale_gemm_runner, gemm1_input, fc1_result_, glu_inter_result_,
|
||||
|
||||
// Opportunistically apply FC2 prequant scaling in FC1 doActivation kernel if applicable
|
||||
bool const fuse_fc2_prequant_scale = use_awq && is_gated_activation;
|
||||
void const* fc2_prequant_scale_ptr = fuse_fc2_prequant_scale ? quant_params.groupwise.fc2.act_scales : nullptr;
|
||||
// Match the FC2 act buffer bound to respective TMA desc defined in setupTmaWarpSpecializedInputs()
|
||||
T* gemm1_output = fuse_fc2_prequant_scale ? reinterpret_cast<T*>(smoothed_act_) : fc1_result_;
|
||||
Self::gemm1(moe_gemm_runner_, blockscale_gemm_runner, gemm1_input, gemm1_output, glu_inter_result_,
|
||||
expert_first_token_offset_, gemm1_tma_ws_input, fc1_expert_weights, fc1_expert_biases, num_valid_tokens_ptr,
|
||||
fc1_int_scales, fc1_fp8_dequant, use_wfp4afp8 ? fc2_wfp4afp8_quant_scale : fc2_fp8_quant,
|
||||
fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, num_rows, expanded_num_rows,
|
||||
expected_tokens_per_expert, hidden_size, inter_size, num_experts_per_node, fc1_activation_type,
|
||||
alpha_scale_ptr_array_fc1_, !use_lora, stream, *gemm1_config_, false, nullptr, nullptr);
|
||||
alpha_scale_ptr_array_fc1_, !use_lora, stream, *gemm1_config_, false, nullptr, nullptr,
|
||||
fc2_prequant_scale_ptr);
|
||||
sync_check_cuda_error(stream);
|
||||
|
||||
if (use_lora)
|
||||
@ -3713,10 +3882,16 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
|
||||
auto gemm2_input = applyPrequantScale(smoothed_act_, fc1_result_, quant_params.groupwise.fc2.act_scales,
|
||||
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream, expert_first_token_offset_,
|
||||
num_experts_per_node);
|
||||
sync_check_cuda_error(stream);
|
||||
// When fusing, data is already in smoothed_act_; otherwise run applyPrequantScale to get it there
|
||||
T const* gemm2_input{reinterpret_cast<T const*>(smoothed_act_)};
|
||||
if (!fuse_fc2_prequant_scale)
|
||||
{
|
||||
// Outputs smoothed_act_
|
||||
gemm2_input = applyPrequantScale(smoothed_act_, fc1_result_, quant_params.groupwise.fc2.act_scales,
|
||||
num_valid_tokens_ptr, expanded_num_rows, inter_size, use_awq, stream, quant_params,
|
||||
expert_first_token_offset_, num_experts_per_node);
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
Self::gemm2(moe_gemm_runner_, blockscale_gemm_runner, gemm2_input, fc2_result_, final_output,
|
||||
expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales,
|
||||
fc2_fp8_dequant, fc2_fp4_act_scale_, quant_params, token_topk_unpermuted_scales,
|
||||
@ -3755,10 +3930,12 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
|
||||
}
|
||||
|
||||
auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale
|
||||
: use_w4afp8 ? quant_params.groupwise.fc1.alpha
|
||||
: use_wfp4afp8 ? quant_params.fp8_mxfp4.fc1.global_scale
|
||||
: use_fp8 ? fp8_dequant1
|
||||
: nullptr;
|
||||
auto alpha_scale_flat2 = use_fp4 ? quant_params.fp4.fc2.global_scale
|
||||
: use_w4afp8 ? quant_params.groupwise.fc2.alpha
|
||||
: use_wfp4afp8 ? quant_params.fp8_mxfp4.fc2.global_scale
|
||||
: use_fp8 ? fp8_dequant2
|
||||
: nullptr;
|
||||
@ -3834,7 +4011,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
|
||||
return std::make_pair(gemm1_tma_ws_input, gemm2_tma_ws_input);
|
||||
}
|
||||
|
||||
bool use_awq = quant_params.groupwise.fc1.act_scales && quant_params.groupwise.fc2.act_scales && !use_wfp4a16;
|
||||
bool const use_awq = useAwq(quant_params);
|
||||
|
||||
bool is_gated_activation = isGatedActivation(fc1_activation_type);
|
||||
int64_t const fc1_out_size = is_gated_activation ? inter_size * 2 : inter_size;
|
||||
@ -3843,7 +4020,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
|
||||
bool const has_intermediate = has_different_gemm_output_type || is_gated_activation;
|
||||
auto* gemm1_output = has_intermediate ? glu_inter_result_ : static_cast<void*>(fc1_result_);
|
||||
|
||||
bool use_prequant_scale_kernel = use_awq && !std::is_same_v<T, WeightType>;
|
||||
bool const use_prequant_scale_kernel = usePrequantScaleKernel(quant_params);
|
||||
auto gemm2_input = use_prequant_scale_kernel ? smoothed_act_ : fc1_result_;
|
||||
|
||||
if (min_latency_mode)
|
||||
@ -3880,8 +4057,7 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
|
||||
auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr;
|
||||
bool gemm2_using_finalize_fusion = gemm2_config_->epilogue_fusion_type
|
||||
== cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
|
||||
bool using_fused_finalize
|
||||
= use_fused_finalize_ && gemm2_using_finalize_fusion && !use_w4_groupwise && !use_lora;
|
||||
bool using_fused_finalize = use_fused_finalize_ && gemm2_using_finalize_fusion && !use_wfp4a16 && !use_lora;
|
||||
TLLM_CHECK_WITH_INFO(using_fused_finalize == gemm2_using_finalize_fusion,
|
||||
"GEMM2 tactic requests finalize fusion, but the runner is not configured to use it");
|
||||
if (using_fused_finalize)
|
||||
@ -4120,7 +4296,7 @@ std::map<std::string, std::pair<size_t, size_t>> GemmProfilerBackend::getProfile
|
||||
size_t output_size1 = inter_size * num_expanded_tokens * dtype_bytes;
|
||||
|
||||
size_t input_size2 = inter_size * num_expanded_tokens * dtype_bytes;
|
||||
size_t output_size2 = hidden_size * output_bytes;
|
||||
size_t output_size2 = hidden_size * num_expanded_tokens * output_bytes;
|
||||
|
||||
size_t input_size = mGemmToProfile == GemmToProfile::GEMM_1 ? input_size1 : input_size2;
|
||||
size_t output_size = mGemmToProfile == GemmToProfile::GEMM_1 ? output_size1 : output_size2;
|
||||
@ -4439,14 +4615,14 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr
|
||||
&& mWType == nvinfer1::DataType::kUINT8);
|
||||
bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;
|
||||
bool const use_finalize_fusion = fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
|
||||
bool const finalize_fusion_not_supported = !mInterface->use_fused_finalize_ || mMinLatencyMode || use_w4_groupwise
|
||||
|| mGemmToProfile != GemmToProfile::GEMM_2;
|
||||
bool const finalize_fusion_not_supported
|
||||
= !mInterface->use_fused_finalize_ || mMinLatencyMode || use_wfp4a16 || mGemmToProfile != GemmToProfile::GEMM_2;
|
||||
if (use_finalize_fusion && finalize_fusion_not_supported)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_w4_groupwise && !swap_ab)
|
||||
if (use_wfp4a16 && !swap_ab)
|
||||
{
|
||||
return;
|
||||
}
|
||||
@ -4535,12 +4711,15 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr
|
||||
}
|
||||
else
|
||||
{
|
||||
auto fc1_alpha = use_w4afp8 ? mQuantParams.groupwise.fc1.alpha : mQuantParams.fp8.dequant_fc1;
|
||||
auto fc2_alpha = use_w4afp8 ? mQuantParams.groupwise.fc2.alpha : mQuantParams.fp8.dequant_fc2;
|
||||
|
||||
std::tie(gemm1_tma_ws_input, gemm2_tma_ws_input) = mInterface->computeStridesTmaWarpSpecializedDispatch(
|
||||
expert_first_token_offset, gemm1_tma_ws_input, gemm2_tma_ws_input, num_tokens, num_tokens * mK,
|
||||
fc1_output_size, mExpertHiddenSize, mExpertHiddenSize, mExpertInterSize, mNumExpertsPerNode, input,
|
||||
input, weights_sel, weights_sel, mQuantParams.fp8.dequant_fc1, mQuantParams.fp8.dequant_fc2,
|
||||
fp4_act_scale_flat, fp4_act_scale_flat, mQuantParams, nullptr, nullptr, intermediate, intermediate,
|
||||
token_topk_unpermuted_scales, permuted_row_to_unpermuted_row, stream);
|
||||
input, weights_sel, weights_sel, fc1_alpha, fc2_alpha, fp4_act_scale_flat, fp4_act_scale_flat,
|
||||
mQuantParams, nullptr, nullptr, intermediate, intermediate, token_topk_unpermuted_scales,
|
||||
permuted_row_to_unpermuted_row, stream);
|
||||
}
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
|
||||
@ -41,9 +41,9 @@ EpiTag = {
|
||||
|
||||
EpiFusion = {
|
||||
TrtLlm_EpilogueFusion.epilogue_fusion_none:
|
||||
"tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE",
|
||||
"tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE",
|
||||
TrtLlm_EpilogueFusion.epilogue_fusion_finalize:
|
||||
"tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE",
|
||||
"tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE",
|
||||
}
|
||||
|
||||
EpiFusionSuffixes = {
|
||||
@ -229,9 +229,11 @@ const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zer
|
||||
or operation.weight_type != e2m1):
|
||||
# Mixed MoE GEMM
|
||||
weight_tag = CudaTypeName[operation.weight_type]
|
||||
assert operation.epi_fusion is not None
|
||||
epi_fusion = EpiFusion[operation.epi_fusion]
|
||||
instantiation = f"""
|
||||
template void sm90_generic_mixed_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {out_tag},
|
||||
{epi_tag}, {cute_cta_shape}, {cute_cga_shape}, {kernel_sched}, {epi_sched}, {quant_op}> (
|
||||
{epi_tag}, {epi_fusion}, {cute_cta_shape}, {cute_cga_shape}, {kernel_sched}, {epi_sched}, {quant_op}> (
|
||||
GroupedGemmInput<{act_tag}, {weight_tag}, {out_tag}, {out_tag}>inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size);
|
||||
"""
|
||||
else:
|
||||
@ -590,6 +592,11 @@ def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled):
|
||||
|
||||
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
|
||||
|
||||
epi_fusions = [
|
||||
TrtLlm_EpilogueFusion.epilogue_fusion_none,
|
||||
TrtLlm_EpilogueFusion.epilogue_fusion_finalize
|
||||
]
|
||||
|
||||
M_TILES = [64, 128] # Currently M tile must be 128 for Grouped GEMM
|
||||
N_TILES = [16, 32, 64, 128]
|
||||
K_TILES = [128, 256, 512]
|
||||
@ -607,13 +614,13 @@ def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled):
|
||||
cga_shapes = list(product([1, 2], [1, 2], [1]))
|
||||
|
||||
partial_args_int4 = product(supported_dtypes_int4, quant_ops, epi_tags,
|
||||
cta_shapes_mnk_int4, cga_shapes)
|
||||
epi_fusions, cta_shapes_mnk_int4, cga_shapes)
|
||||
partial_args_fp4 = product(supported_dtypes_fp4, quant_ops, epi_tags,
|
||||
cta_shapes_mnk_fp4, cga_shapes)
|
||||
epi_fusions, cta_shapes_mnk_fp4, cga_shapes)
|
||||
partial_args = chain(partial_args_int4, partial_args_fp4)
|
||||
|
||||
operations = list()
|
||||
for dtype_combo, quant_op, epi_tag, cta_shape_mnk, cga_shape in partial_args:
|
||||
for dtype_combo, quant_op, epi_tag, epi_fusion, cta_shape_mnk, cga_shape in partial_args:
|
||||
use_coop = cta_shape_mnk[0] >= 128
|
||||
mainloop_schedules = [
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
@ -626,7 +633,7 @@ def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled):
|
||||
== KernelScheduleType.TmaWarpSpecializedCooperative):
|
||||
continue
|
||||
moe_gemm_operation = TrtLlm_GemmLauncher(GemmKind.Grouped, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \
|
||||
warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule)
|
||||
warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule, epi_fusion)
|
||||
operations.append(moe_gemm_operation)
|
||||
return operations
|
||||
|
||||
|
||||
@ -1124,6 +1124,9 @@ private:
|
||||
auto& fc1_alpha = quant_scales.value()[6];
|
||||
auto& fc2_alpha = quant_scales.value()[7];
|
||||
int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size;
|
||||
// Whether it is per-expert activation scale
|
||||
bool fc1_use_per_expert_act_scale = fc1_act_scales.numel() > hidden_size;
|
||||
bool fc2_use_per_expert_act_scale = fc2_act_scales.numel() > inter_size;
|
||||
return kernels::QuantParams::GroupWise(group_size,
|
||||
static_cast<void const*>(fc1_weight_scales.data_ptr()),
|
||||
static_cast<void const*>(fc2_weight_scales.data_ptr()),
|
||||
@ -1132,7 +1135,8 @@ private:
|
||||
static_cast<void const*>(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros.data_ptr() : nullptr),
|
||||
static_cast<void const*>(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros.data_ptr() : nullptr),
|
||||
static_cast<float const*>(fc1_alpha.numel() > 0 ? fc1_alpha.data_ptr() : nullptr),
|
||||
static_cast<float const*>(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr));
|
||||
static_cast<float const*>(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr),
|
||||
fc1_use_per_expert_act_scale, fc2_use_per_expert_act_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@ -168,7 +168,10 @@ protected:
|
||||
constexpr static bool ANY_FP4 = WEIGHT_FP4 || ACT_FP4;
|
||||
constexpr static bool ANY_FPX = ANY_FP4 || FP8;
|
||||
|
||||
constexpr static bool INT_QUANT = !std::is_same_v<GemmDataType, WeightType> && std::is_integral_v<WeightType>;
|
||||
constexpr static bool W4A8_AWQ
|
||||
= std::is_same_v<GemmDataType, SafeFP8> && std::is_same_v<WeightType, cutlass::uint4b_t>;
|
||||
constexpr static bool INT_QUANT
|
||||
= !std::is_same_v<GemmDataType, WeightType> && std::is_integral_v<WeightType> && !W4A8_AWQ;
|
||||
constexpr static int64_t WEIGHT_ELEM_PER_BYTE = (INT4 || WEIGHT_FP4) ? 2 : 1;
|
||||
using InputType = std::conditional_t<NVFP4 || MXFP8_MXFP4, OutputType, GemmDataType>;
|
||||
using WeightStorage = std::conditional_t<WEIGHT_ELEM_PER_BYTE == 2, uint8_t, WeightType>;
|
||||
@ -184,7 +187,7 @@ protected:
|
||||
using DataType = std::conditional_t<NVFP4 || MXFP8_MXFP4, OutputType, GemmDataType>;
|
||||
|
||||
// FP8_MXFP4 quantizes just the weights on the fly
|
||||
using WeightRawType = std::conditional_t<FP8_MXFP4, OutputType, DataType>;
|
||||
using WeightRawType = std::conditional_t<FP8_MXFP4 || W4A8_AWQ, OutputType, DataType>;
|
||||
|
||||
static BufferManager::CudaStreamPtr mStream;
|
||||
static std::unique_ptr<BufferManager> mBufferManager;
|
||||
@ -221,7 +224,7 @@ protected:
|
||||
static_assert(!FP8, "FP8 Tests enabled on unsupported CUDA version");
|
||||
#endif
|
||||
bool should_skip_no_device = mDeviceCount <= 0;
|
||||
bool should_skip_unsupported_fp8 = getSMVersion() < 89 && FP8;
|
||||
bool should_skip_unsupported_fp8 = getSMVersion() < 89 && (FP8 || W4A8_AWQ);
|
||||
bool should_skip_unsupported_fp4 = (getSMVersion() < 100) && ANY_FP4;
|
||||
return should_skip_no_device || should_skip_unsupported_fp8 || should_skip_unsupported_fp4;
|
||||
}
|
||||
@ -321,8 +324,9 @@ protected:
|
||||
float* mSwigluBeta{};
|
||||
float* mSwigluLimit{};
|
||||
|
||||
DataType* mExpertIntScale1{};
|
||||
DataType* mExpertIntScale2{};
|
||||
using scale_type = std::conditional_t<W4A8_AWQ, WeightScale, DataType>;
|
||||
scale_type* mExpertIntScale1{};
|
||||
scale_type* mExpertIntScale2{};
|
||||
|
||||
float mFP8WeightScalar1{1.f};
|
||||
float mFP8WeightScalar2{1.f};
|
||||
@ -375,7 +379,7 @@ protected:
|
||||
|
||||
bool mIsGated = false;
|
||||
int64_t mGatedMultiplier = 1;
|
||||
int64_t mGroupSize = -1;
|
||||
int64_t mGroupSize = W4A8_AWQ ? 128 : -1;
|
||||
|
||||
ActivationType mActType = ActivationType::Relu;
|
||||
|
||||
@ -453,7 +457,7 @@ protected:
|
||||
total_size += weight_size / 2;
|
||||
}
|
||||
// Quantized data types use a second scratch buffer for the weights before quantizing
|
||||
if (ANY_FPX || INT_QUANT)
|
||||
if (ANY_FPX || INT_QUANT || W4A8_AWQ)
|
||||
{
|
||||
total_size += weight_elems * sizeof(DataType);
|
||||
}
|
||||
@ -535,6 +539,14 @@ protected:
|
||||
mExpertIntScale1 = allocBuffer<DataType>(mNumExperts * gated_inter);
|
||||
mExpertIntScale2 = allocBuffer<DataType>(mNumExperts * mHiddenSize);
|
||||
}
|
||||
else if constexpr (W4A8_AWQ)
|
||||
{
|
||||
mExpertWeight1 = allocBuffer<WeightStorage>(expert_matrix_size * mGatedMultiplier / WEIGHT_ELEM_PER_BYTE);
|
||||
mExpertWeight2 = allocBuffer<WeightStorage>(expert_matrix_size / WEIGHT_ELEM_PER_BYTE);
|
||||
|
||||
mExpertIntScale1 = allocBuffer<WeightScale>(expert_matrix_size * mGatedMultiplier / mGroupSize);
|
||||
mExpertIntScale2 = allocBuffer<WeightScale>(expert_matrix_size / mGroupSize);
|
||||
}
|
||||
else if constexpr (ANY_FP4)
|
||||
{
|
||||
mExpertWeight1 = allocBuffer<WeightStorage>(
|
||||
@ -638,12 +650,22 @@ protected:
|
||||
doIntQuant(quant_type, shape1, mRawExpertWeight1, mExpertIntScale1, mExpertWeight1);
|
||||
doIntQuant(quant_type, shape2, mRawExpertWeight2, mExpertIntScale2, mExpertWeight2);
|
||||
}
|
||||
else if constexpr (W4A8_AWQ)
|
||||
{
|
||||
cutlass_kernels::QuantType quant_type = cutlass_kernels::QuantType::W4_AFP8;
|
||||
|
||||
std::vector<size_t> shape1{(size_t) mNumExperts, (size_t) mHiddenSize, (size_t) gated_inter};
|
||||
std::vector<size_t> shape2{(size_t) mNumExperts, (size_t) mInterSize, (size_t) mHiddenSize};
|
||||
|
||||
doIntQuant(quant_type, shape1, mRawExpertWeight1, mExpertIntScale1, mExpertWeight1);
|
||||
doIntQuant(quant_type, shape2, mRawExpertWeight2, mExpertIntScale2, mExpertWeight2);
|
||||
}
|
||||
|
||||
check_cuda_error(cudaStreamSynchronize(stream));
|
||||
}
|
||||
|
||||
void doIntQuant(cutlass_kernels::QuantType quant_type, std::vector<size_t> shape, DataType* inputs,
|
||||
DataType* scales, uint8_t* outputs)
|
||||
void doIntQuant(cutlass_kernels::QuantType quant_type, std::vector<size_t> shape, WeightRawType* inputs,
|
||||
scale_type* scales, uint8_t* outputs)
|
||||
{
|
||||
// Runs on the CPU, must be after stream sync
|
||||
if constexpr (INT_QUANT)
|
||||
@ -652,17 +674,119 @@ protected:
|
||||
|
||||
size_t elems = std::reduce(shape.begin(), shape.end(), 1, std::multiplies{});
|
||||
std::vector<int8_t> h_out(elems);
|
||||
std::vector<DataType> h_input(elems);
|
||||
std::vector<DataType> h_scales(shape[0] * shape[2]);
|
||||
std::vector<WeightRawType> h_input(elems);
|
||||
std::vector<scale_type> h_scales(shape[0] * shape[2]);
|
||||
|
||||
check_cuda_error(cudaMemcpy(h_input.data(), inputs, elems * sizeof(DataType), cudaMemcpyDeviceToHost));
|
||||
check_cuda_error(cudaMemcpy(h_input.data(), inputs, elems * sizeof(WeightRawType), cudaMemcpyDeviceToHost));
|
||||
|
||||
cutlass_kernels::symmetric_quantize(h_out.data(), h_scales.data(), h_input.data(), shape, quant_type, true);
|
||||
|
||||
check_cuda_error(cudaMemcpy(
|
||||
outputs, h_out.data(), elems * sizeof(int8_t) / WEIGHT_ELEM_PER_BYTE, cudaMemcpyHostToDevice));
|
||||
check_cuda_error(
|
||||
cudaMemcpy(scales, h_scales.data(), h_scales.size() * sizeof(DataType), cudaMemcpyHostToDevice));
|
||||
cudaMemcpy(scales, h_scales.data(), h_scales.size() * sizeof(scale_type), cudaMemcpyHostToDevice));
|
||||
}
|
||||
else if constexpr (W4A8_AWQ)
|
||||
{
|
||||
check_cuda_error(cudaStreamSynchronize(mStream->get()));
|
||||
assert(shape[1] % mGroupSize == 0);
|
||||
|
||||
size_t elems = std::reduce(shape.begin(), shape.end(), 1, std::multiplies{});
|
||||
std::vector<int8_t> h_out(elems * sizeof(int8_t) / WEIGHT_ELEM_PER_BYTE);
|
||||
std::vector<WeightRawType> h_input(elems);
|
||||
std::vector<scale_type> h_scales(elems / mGroupSize);
|
||||
check_cuda_error(cudaMemcpy(h_input.data(), inputs, elems * sizeof(WeightRawType), cudaMemcpyDeviceToHost));
|
||||
|
||||
const size_t num_experts = shape[0];
|
||||
int const input_mat_size = shape[1] * shape[2];
|
||||
int const bits_per_weigtht_element = 4;
|
||||
|
||||
int const quantized_mat_size = input_mat_size * bits_per_weigtht_element / 8;
|
||||
float const quant_range_scale = 1.f / float(1 << (bits_per_weigtht_element - 1));
|
||||
|
||||
for (int expert = 0; expert < num_experts; ++expert)
|
||||
{
|
||||
WeightRawType const* current_weight = h_input.data() + expert * input_mat_size;
|
||||
int8_t* current_quantized_weight = h_out.data() + expert * quantized_mat_size;
|
||||
scale_type* current_scales = h_scales.data() + expert * input_mat_size / mGroupSize;
|
||||
|
||||
for (int ii = 0; ii < input_mat_size / mGroupSize; ++ii)
|
||||
{
|
||||
float scale = 0.f;
|
||||
WeightRawType const* current_weight_group = current_weight + ii * mGroupSize;
|
||||
for (int jj = 0; jj < mGroupSize; ++jj)
|
||||
{
|
||||
scale = std::max(scale, std::abs(float(current_weight_group[jj])));
|
||||
}
|
||||
scale *= quant_range_scale;
|
||||
current_scales[ii] = scale_type(scale);
|
||||
}
|
||||
|
||||
for (int ii = 0; ii < input_mat_size / mGroupSize; ++ii)
|
||||
{
|
||||
WeightRawType const* current_weight_group = current_weight + ii * mGroupSize;
|
||||
int8_t* current_quantized_weight_group
|
||||
= current_quantized_weight + ii * mGroupSize * bits_per_weigtht_element / 8;
|
||||
float const scale = float(current_scales[ii]);
|
||||
for (int jj = 0; jj < mGroupSize / 2; ++jj)
|
||||
{
|
||||
// We will pack two int4 elements per iteration of the inner loop.
|
||||
float const weight_elt0 = float(current_weight_group[jj * 2]);
|
||||
float const weight_elt1 = float(current_weight_group[jj * 2 + 1]);
|
||||
float const scaled_weight0 = (scale != 0.0f) ? round(weight_elt0 / scale) : 0.0f;
|
||||
float const scaled_weight1 = (scale != 0.0f) ? round(weight_elt1 / scale) : 0.0f;
|
||||
int int_weight0 = int(scaled_weight0);
|
||||
int int_weight1 = int(scaled_weight1);
|
||||
const int8_t clipped_weight0 = std::max(-8, std::min(7, int_weight0));
|
||||
const int8_t clipped_weight1 = std::max(-8, std::min(7, int_weight1));
|
||||
// Kill the sign extension bits (hence 0x0F mask) then shift to upper bits
|
||||
// if packing the second int4 and or the bits into the final result.
|
||||
current_quantized_weight_group[jj] = clipped_weight0 | (clipped_weight1 << 4);
|
||||
}
|
||||
}
|
||||
// WAR: For a diagonal matrix in which each column has only one nonzero element, the scale value is
|
||||
// calculated based on it being quantized to 8. However, after quantization, the nonzero value is
|
||||
// clipped to 7. Adjust the scale value to fix the error.
|
||||
for (int ii = 0; ii < input_mat_size / mGroupSize; ++ii)
|
||||
{
|
||||
current_scales[ii] = scale_type(float(current_scales[ii]) * 8 / 7);
|
||||
}
|
||||
|
||||
int interleave = 1;
|
||||
int const sm = getSMVersion();
|
||||
if (sm == 90)
|
||||
{
|
||||
interleave = shape[1] % 512 == 0 ? 4 : shape[1] % 256 == 0 ? 2 : 1;
|
||||
}
|
||||
|
||||
// Permute scales: from [N, K/mGroupSize/interleave, interleave] to [K/mGroupSize/interleave, N,
|
||||
// interleave]
|
||||
int const dim0 = shape[2]; // N
|
||||
int const dim1 = shape[1] / mGroupSize / interleave; // K/mGroupSize/interleave
|
||||
int const dim2 = interleave;
|
||||
|
||||
std::vector<scale_type> temp_scales(input_mat_size / mGroupSize);
|
||||
for (int n = 0; n < dim0; ++n)
|
||||
{
|
||||
for (int k = 0; k < dim1; ++k)
|
||||
{
|
||||
for (int i = 0; i < dim2; ++i)
|
||||
{
|
||||
// src index: [n, k, i] in layout [N, K/mGroupSize/interleave, interleave]
|
||||
int src_idx = n * (dim1 * dim2) + k * dim2 + i;
|
||||
// dst index: [k, n, i] in layout [K/mGroupSize/interleave, N, interleave]
|
||||
int dst_idx = k * (dim0 * dim2) + n * dim2 + i;
|
||||
temp_scales[dst_idx] = current_scales[src_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
std::copy(temp_scales.begin(), temp_scales.end(), current_scales);
|
||||
}
|
||||
|
||||
check_cuda_error(cudaMemcpy(
|
||||
outputs, h_out.data(), elems * sizeof(int8_t) / WEIGHT_ELEM_PER_BYTE, cudaMemcpyHostToDevice));
|
||||
check_cuda_error(
|
||||
cudaMemcpy(scales, h_scales.data(), h_scales.size() * sizeof(scale_type), cudaMemcpyHostToDevice));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1229,6 +1353,31 @@ protected:
|
||||
ASSERT_TRUE(scale1_ptr && scale2_ptr);
|
||||
quant_params = QuantParams::Int(scale1_ptr, scale2_ptr);
|
||||
}
|
||||
else if constexpr (W4A8_AWQ)
|
||||
{
|
||||
auto input_scale1 = allocBuffer<scale_type>(mNumExperts * mHiddenSize * mGatedMultiplier);
|
||||
auto input_scale2 = allocBuffer<scale_type>(mNumExperts * mInterSize);
|
||||
std::vector<scale_type> h_input_scale1(mNumExperts * mHiddenSize * mGatedMultiplier, 1.0f);
|
||||
std::vector<scale_type> h_input_scale2(mNumExperts * mInterSize, 1.0f);
|
||||
check_cuda_error(cudaMemcpy(input_scale1, h_input_scale1.data(),
|
||||
mNumExperts * mHiddenSize * mGatedMultiplier * sizeof(scale_type), cudaMemcpyHostToDevice));
|
||||
check_cuda_error(cudaMemcpy(input_scale2, h_input_scale2.data(),
|
||||
mNumExperts * mInterSize * sizeof(scale_type), cudaMemcpyHostToDevice));
|
||||
|
||||
auto alpha1_ptrs = allocBuffer<float>(mNumExperts);
|
||||
auto alpha2_ptrs = allocBuffer<float>(mNumExperts);
|
||||
for (int i = 0; i < mNumExperts; i++)
|
||||
{
|
||||
float alpha1_value = 1.0f;
|
||||
float alpha2_value = 1.0f;
|
||||
check_cuda_error(cudaMemcpy(alpha1_ptrs + i, &alpha1_value, sizeof(float), cudaMemcpyHostToDevice));
|
||||
check_cuda_error(cudaMemcpy(alpha2_ptrs + i, &alpha2_value, sizeof(float), cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
ASSERT_TRUE(scale1_ptr && scale2_ptr);
|
||||
quant_params = QuantParams::GroupWise(mGroupSize, scale1_ptr, scale2_ptr, input_scale1, input_scale2,
|
||||
nullptr, nullptr, alpha1_ptrs, alpha2_ptrs);
|
||||
}
|
||||
else if (FP8)
|
||||
{
|
||||
ASSERT_TRUE(scale1_ptr && scale2_ptr && scale3_ptr);
|
||||
@ -1628,6 +1777,11 @@ using Types = ::testing::Types<
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
#ifdef ENABLE_FP8
|
||||
WeightParams<SafeFP8, cutlass::uint4b_t, __nv_bfloat16, void, __nv_bfloat16>,
|
||||
#endif
|
||||
#endif
|
||||
WeightParams<half>, WeightParams<float>
|
||||
|
||||
// , WeightParams<half, uint8_t>, WeightParams<half, cutlass::uint4b_t>
|
||||
@ -1675,6 +1829,12 @@ void MixtureOfExpertsTest<TypeParam_>::BasicPermuteTest(
|
||||
|
||||
initLocals(hidden_size, num_experts, k, num_tokens);
|
||||
|
||||
if (mGroupSize > 0 && (mHiddenSize % mGroupSize != 0 || mInterSize % mGroupSize != 0))
|
||||
{
|
||||
GTEST_SKIP() << "Skipping due to unsupported groupwise configuration";
|
||||
return;
|
||||
}
|
||||
|
||||
auto test_archs = getAllTileConfigsToTest();
|
||||
for (auto [gemm1, gemm2] : test_archs)
|
||||
{
|
||||
@ -1903,6 +2063,13 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteDeepSeekV3)
|
||||
size_t inter_size = 2048;
|
||||
this->mInterSizeFraction = float(inter_size) / hidden_size;
|
||||
|
||||
if (this->W4A8_AWQ)
|
||||
{
|
||||
// TODO: Implement W4A8_AWQ for PermuteDeepSeekV3
|
||||
GTEST_SKIP() << "W4A8_AWQ is not implemented for PermuteDeepSeekV3";
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this->checkSufficientTestMemory(100, hidden_size, 256, 8))
|
||||
{
|
||||
GTEST_SKIP() << "Insufficient free memory for test";
|
||||
@ -1956,6 +2123,13 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest(
|
||||
}
|
||||
}
|
||||
|
||||
if (W4A8_AWQ)
|
||||
{
|
||||
// TODO: Implement W4A8_AWQ for ParallelismTest
|
||||
GTEST_SKIP() << "W4A8_AWQ is not implemented for ParallelismTest";
|
||||
return;
|
||||
}
|
||||
|
||||
ASSERT_LE(ep_size, num_experts);
|
||||
if (tp_size == 1)
|
||||
{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user