feat: Expose bias and FP8_MXFP4 MOE CUTLASS backend features to pytorch (#5410)

Signed-off-by: Daniel Stokes <40156487+djns99@users.noreply.github.com>
This commit is contained in:
Daniel Stokes 2025-06-27 16:29:34 +12:00 committed by GitHub
parent ef43b95aa1
commit 83a1f60556
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 165 additions and 37 deletions

View File

@ -52,7 +52,7 @@ using profiler_backend = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::GemmProfilerBackend
class FusedMoeRunner : public torch::CustomClassHolder
{
public:
template <typename Type, bool NeedQuant = false>
template <typename TypeAct, typename TypeWeight, bool NeedQuant = false>
std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> switch_output_type(c10::ScalarType output_type)
{
switch (output_type)
@ -66,21 +66,22 @@ public:
case c10::ScalarType::Half:
if constexpr (NeedQuant)
{
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, half, half>>();
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, half>>();
}
else
{
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, half, Type>>();
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, TypeAct>>();
}
#ifdef ENABLE_BF16
case c10::ScalarType::BFloat16:
if constexpr (NeedQuant)
{
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, __nv_bfloat16, __nv_bfloat16>>();
return std::make_unique<
kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16, __nv_bfloat16>>();
}
else
{
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, __nv_bfloat16, Type>>();
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16, TypeAct>>();
}
#endif
default:
@ -121,10 +122,16 @@ public:
#ifdef ENABLE_FP8
if (isFp8Quant())
{
mKernelRunner = switch_output_type<__nv_fp8_e4m3>(mOutputDtype);
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3>(mOutputDtype);
}
#endif
#ifdef ENABLE_FP4
if (isWFp4AFp8Quant())
{
mInnerDimMultiplier = 16; // 16 FP4 -> 1 LONG
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp4_e2m1>(mOutputDtype);
}
if (isNvfp4Quant())
{
mInnerDimMultiplier = 16;
@ -134,9 +141,9 @@ public:
#ifdef ENABLE_BF16
case c10::ScalarType::BFloat16:
#endif
mKernelRunner = switch_output_type<__nv_fp4_e2m1, true>(mOutputDtype);
mKernelRunner = switch_output_type<__nv_fp4_e2m1, __nv_fp4_e2m1, true>(mOutputDtype);
break;
default: mKernelRunner = switch_output_type<__nv_fp4_e2m1, false>(mOutputDtype);
default: mKernelRunner = switch_output_type<__nv_fp4_e2m1, __nv_fp4_e2m1, false>(mOutputDtype);
}
}
#endif
@ -204,11 +211,13 @@ public:
void operator=(FusedMoeRunner const&) = delete;
torch::Tensor runMoe(torch::Tensor const& input, torch::Tensor const& token_selected_experts,
torch::optional<torch::Tensor> token_final_scales, torch::Tensor const& fc1_expert_weights,
torch::Tensor const& fc2_expert_weights, torch::optional<c10::ArrayRef<torch::Tensor>> quant_scales,
torch::optional<torch::Tensor> input_sf, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size,
int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall,
bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> profile_ids)
torch::optional<torch::Tensor> const& token_final_scales, torch::Tensor const& fc1_expert_weights,
torch::optional<torch::Tensor> const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights,
torch::optional<torch::Tensor> const& fc2_expert_biases,
torch::optional<c10::ArrayRef<torch::Tensor>> const& quant_scales,
torch::optional<torch::Tensor> const& input_sf, int64_t const tp_size, int64_t const tp_rank,
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids)
{
std::lock_guard<std::mutex> lock(mMutex);
// Free the profile workspace to save memory
@ -230,6 +239,23 @@ public:
TORCH_CHECK(fc1_expert_weights.dim() == 3, "fc1_expert_weights must be 3D.");
TORCH_CHECK(fc2_expert_weights.dim() == 3, "fc2_expert_weights must be 3D.");
if (fc1_expert_biases.has_value() || fc2_expert_biases.has_value())
{
CHECK_INPUT(fc1_expert_biases.value(), mOutputDtype);
CHECK_INPUT(fc2_expert_biases.value(), mOutputDtype);
TORCH_CHECK(fc1_expert_biases.value().dim() == 2, "fc1_expert_biases must be 2D.");
TORCH_CHECK(fc2_expert_biases.value().dim() == 2, "fc2_expert_biases must be 2D.");
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc1_expert_biases.value().sizes()[0],
"fc1_expert_weights and fc1_expert_biases must have the same number of experts.");
TORCH_CHECK(fc2_expert_weights.sizes()[0] == fc2_expert_biases.value().sizes()[0],
"fc2_expert_weights and fc2_expert_biases must have the same number of experts.");
TORCH_CHECK(fc1_expert_biases.value().sizes()[1] == fc1_expert_weights.sizes()[1],
"fc1_expert_biases should match fc1_expert_weights output shape.");
TORCH_CHECK(fc2_expert_biases.value().sizes()[1] == fc2_expert_weights.sizes()[1],
"fc2_expert_biases should match fc2_expert_weights output shape.");
}
TORCH_CHECK(input.sizes()[0] == token_selected_experts.sizes()[0],
"input and token_selected_experts must have the same num tokens.");
if (token_final_scales)
@ -275,8 +301,11 @@ public:
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
: nullptr,
fc1_expert_weights.const_data_ptr(), nullptr, activation_type, fc2_expert_weights.const_data_ptr(), nullptr,
quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
fc1_expert_weights.const_data_ptr(),
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type,
fc2_expert_weights.const_data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
@ -286,8 +315,11 @@ public:
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
: nullptr,
fc1_expert_weights.const_data_ptr(), nullptr, activation_type, fc2_expert_weights.const_data_ptr(), nullptr,
quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
fc1_expert_weights.const_data_ptr(),
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type,
fc2_expert_weights.const_data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
@ -297,12 +329,13 @@ public:
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> runMoeMinLantency(torch::Tensor const& input,
torch::Tensor const& token_selected_experts, torch::optional<torch::Tensor> token_final_scales,
torch::Tensor const& fc1_expert_weights, torch::Tensor const& fc2_expert_weights,
torch::optional<c10::ArrayRef<torch::Tensor>> quant_scales, torch::optional<torch::Tensor> input_sf,
int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank,
int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode,
torch::optional<c10::ArrayRef<int64_t>> profile_ids)
torch::Tensor const& token_selected_experts, torch::optional<torch::Tensor> const& token_final_scales,
torch::Tensor const& fc1_expert_weights, torch::optional<torch::Tensor> const& fc1_expert_biases,
torch::Tensor const& fc2_expert_weights, torch::optional<torch::Tensor> const& fc2_expert_biases,
torch::optional<c10::ArrayRef<torch::Tensor>> const& quant_scales,
torch::optional<torch::Tensor> const& input_sf, int64_t const tp_size, int64_t const tp_rank,
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids)
{
std::lock_guard<std::mutex> lock(mMutex);
@ -323,6 +356,23 @@ public:
TORCH_CHECK(fc1_expert_weights.dim() == 3, "fc1_expert_weights must be 3D.");
TORCH_CHECK(fc2_expert_weights.dim() == 3, "fc2_expert_weights must be 3D.");
if (fc1_expert_biases.has_value() || fc2_expert_biases.has_value())
{
CHECK_INPUT(fc1_expert_biases.value(), mOutputDtype);
CHECK_INPUT(fc2_expert_biases.value(), mOutputDtype);
TORCH_CHECK(fc1_expert_biases.value().dim() == 2, "fc1_expert_biases must be 2D.");
TORCH_CHECK(fc2_expert_biases.value().dim() == 2, "fc2_expert_biases must be 2D.");
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc1_expert_biases.value().sizes()[0],
"fc1_expert_weights and fc1_expert_biases must have the same number of experts.");
TORCH_CHECK(fc2_expert_weights.sizes()[0] == fc2_expert_biases.value().sizes()[0],
"fc2_expert_weights and fc2_expert_biases must have the same number of experts.");
TORCH_CHECK(fc1_expert_biases.value().sizes()[1] == fc1_expert_weights.sizes()[1],
"fc1_expert_biases should match fc1_expert_weights output shape.");
TORCH_CHECK(fc2_expert_biases.value().sizes()[1] == fc2_expert_weights.sizes()[1],
"fc2_expert_biases should match fc2_expert_weights output shape.");
}
TORCH_CHECK(input.sizes()[0] == token_selected_experts.sizes()[0],
"input and token_selected_experts must have the same num tokens.");
if (token_final_scales)
@ -378,8 +428,11 @@ public:
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
: nullptr,
fc1_expert_weights.const_data_ptr(), nullptr, activation_type, fc2_expert_weights.const_data_ptr(), nullptr,
quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
fc1_expert_weights.const_data_ptr(),
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type,
fc2_expert_weights.const_data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
@ -389,8 +442,11 @@ public:
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
: nullptr,
fc1_expert_weights.const_data_ptr(), nullptr, activation_type, fc2_expert_weights.const_data_ptr(), nullptr,
quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
fc1_expert_weights.const_data_ptr(),
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type,
fc2_expert_weights.const_data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
@ -406,10 +462,11 @@ public:
}
void runGemmProfile(torch::Tensor const& input, torch::Tensor const& fc1_expert_weights,
torch::Tensor const& fc2_expert_weights, int64_t const top_k, int64_t const tp_size, int64_t const tp_rank,
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx, int64_t const profile_id,
bool const do_preparation)
torch::optional<torch::Tensor> const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights,
torch::optional<torch::Tensor> const& fc2_expert_biases, int64_t const top_k, int64_t const tp_size,
int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx,
int64_t const profile_id, bool const do_preparation)
{
std::lock_guard<std::mutex> lock(mMutex);
@ -447,7 +504,7 @@ public:
static_cast<int>(tp_rank), static_cast<int>(ep_size), static_cast<int>(ep_rank),
static_cast<int>(cluster_size), static_cast<int>(cluster_rank));
bool const USE_BIAS = false;
bool const USE_BIAS = fc1_expert_biases.has_value() || fc2_expert_biases.has_value();
bool const USE_LORA = false;
auto activation_dtype = mUseW4A8GroupScaling ? at::ScalarType::Float8_e4m3fn : mActivationDtype;
activation_dtype = isNvfp4Quant() ? at::ScalarType::Long : activation_dtype;
@ -592,6 +649,52 @@ private:
static_cast<float const*>(fc2_quant.data_ptr()), static_cast<float const*>(fc2_dequant.data_ptr()),
/* fp8 output quant scale */ nullptr, static_cast<float const*>(fc1_input_dequant.data_ptr()));
}
else if (isWFp4AFp8Quant())
{
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for WFP4AFP8 quantization");
TORCH_CHECK(quant_scales.value().size() == 5, "Expecting 5 quant scales for WFP4AFP8 quantization");
auto const fc1_weight_block = quant_scales.value()[0];
auto const fc1_global = quant_scales.value()[1];
auto const fc2_act_global = quant_scales.value()[2];
auto const fc2_weight_block = quant_scales.value()[3];
auto const fc2_global = quant_scales.value()[4];
// The input for scale fc1_weight_block / fc2_weight_block is packed into INT32
constexpr int FP8_PER_INT32 = 4;
CHECK_INPUT(fc1_weight_block, c10::ScalarType::Int);
CHECK_INPUT(fc1_global, c10::ScalarType::Float);
CHECK_INPUT(fc2_act_global, c10::ScalarType::Float);
CHECK_INPUT(fc2_weight_block, c10::ScalarType::Int);
CHECK_INPUT(fc2_global, c10::ScalarType::Float);
TORCH_CHECK(fc1_weight_block.dim() == 3, "fc1 weight block must be #D");
TORCH_CHECK(fc1_global.dim() == 1, "fc1 global must be 1D");
TORCH_CHECK(fc2_act_global.dim() == 0, "fc2 act global must be a scalar tensor");
TORCH_CHECK(fc2_weight_block.dim() == 3, "fc2 weight block must be 3D");
TORCH_CHECK(fc2_global.dim() == 1, "fc2 global must be 1D");
TORCH_CHECK(fc1_weight_block.sizes()[0] == num_experts_on_rank
&& fc1_weight_block.sizes()[1] == inter_size * 2
&& fc1_weight_block.sizes()[2] * FP8_PER_INT32
* TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize
== hidden_size,
"fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // "
"block_scale_vector_size)");
TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)");
TORCH_CHECK(fc2_weight_block.sizes()[0] == num_experts_on_rank && fc2_weight_block.sizes()[1] == hidden_size
&& fc2_weight_block.sizes()[2] * FP8_PER_INT32
* TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize
== inter_size,
"fc2 weight block size must be (num_experts_on_rank, hidden_size, inter_size // 4 // "
"block_scale_vector_size)");
TORCH_CHECK(fc2_global.sizes()[0] == num_experts_on_rank, "fc2 global size must be (num_experts_on_rank,)");
return kernels::QuantParams::FP8MXFP4(nullptr,
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc1_weight_block.data_ptr()),
static_cast<float const*>(fc1_global.data_ptr()), static_cast<float const*>(fc2_act_global.data_ptr()),
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc2_weight_block.data_ptr()),
static_cast<float const*>(fc2_global.data_ptr()));
}
else if (isNvfp4Quant())
{
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for nvfp4 quantization");
@ -683,7 +786,8 @@ private:
bool isNvfp4Quant() const
{
return mWeightDtype == c10::ScalarType::Long;
return mWeightDtype == c10::ScalarType::Long
&& mActivationDtype != c10::ScalarType::Float8_e4m3fn; // FP8 activation does not use FP4
}
bool isInt4Quant() const
@ -695,6 +799,11 @@ private:
{
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant();
}
bool isWFp4AFp8Quant() const
{
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && mWeightDtype == c10::ScalarType::Long;
}
};
} // namespace torch_ext

View File

@ -21,7 +21,9 @@ def trtllm_fused_moe(
selected_experts,
routing_weights,
w3_w1_stacked_weight,
None, # w3_w1_stacked_bias
w2_stacked_weight,
None, # w2_stacked_bias
x.dtype,
quant_scales,
tp_size=1,

View File

@ -81,12 +81,13 @@ class MoERunner(TunableRunner):
tactic: int = -1,
do_preparation: bool = False,
):
x, fc1_expert_weights, fc2_expert_weights = inputs
# determine if we should use min latency mode according to the profiled seq len
x, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs
self.fused_moe_runner.run_gemm_profile(
x,
fc1_expert_weights,
fc1_expert_biases,
fc2_expert_weights,
fc2_expert_biases,
self.top_k,
self.tp_size,
self.tp_rank,
@ -117,7 +118,9 @@ def fused_moe(
token_selected_experts: torch.Tensor,
token_final_scales: torch.Tensor,
fc1_expert_weights: torch.Tensor,
fc1_expert_biases: Optional[torch.Tensor],
fc2_expert_weights: torch.Tensor,
fc2_expert_biases: Optional[torch.Tensor],
output_dtype: torch.dtype,
quant_scales: List[torch.Tensor],
input_sf: Optional[torch.Tensor] = None,
@ -159,7 +162,10 @@ def fused_moe(
"trtllm::fused_moe::gemm1",
[moe_runner],
MoERunner.tuning_config,
[input, fc1_expert_weights, fc2_expert_weights],
[
input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights,
fc2_expert_biases
],
gemm_idx=1,
)
@ -167,7 +173,10 @@ def fused_moe(
"trtllm::fused_moe::gemm2",
[moe_runner],
MoERunner.tuning_config,
[input, fc1_expert_weights, fc2_expert_weights],
[
input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights,
fc2_expert_biases
],
gemm_idx=2,
)
@ -177,7 +186,9 @@ def fused_moe(
token_selected_experts,
token_final_scales,
fc1_expert_weights,
fc1_expert_biases,
fc2_expert_weights,
fc2_expert_biases,
quant_scales,
input_sf,
tp_size,
@ -200,7 +211,9 @@ def _(
token_selected_experts: torch.Tensor,
token_final_scales: torch.Tensor,
fc1_expert_weights: torch.Tensor,
fc1_expert_biases: Optional[torch.Tensor],
fc2_expert_weights: torch.Tensor,
fc2_expert_biases: Optional[torch.Tensor],
output_dtype: torch.dtype,
quant_scales: List[torch.Tensor],
input_sf: Optional[torch.Tensor] = None,

View File

@ -269,7 +269,9 @@ class CutlassFusedMoE(MoE):
token_selected_experts,
token_final_scales,
self.w3_w1_weight.view(weight_dtype),
None, # fc1_expert_biases
self.w2_weight.view(weight_dtype),
None, # fc2_expert_biases
output_dtype,
quant_scales=self.quant_scales,
input_sf=x_sf,

View File

@ -592,7 +592,9 @@ class WideEPMoE(MoE):
token_selected_slots,
token_final_scales,
w3_w1_weight.view(weight_dtype),
None, # w3_w1_bias
w2_weight.view(weight_dtype),
None, # w2_bias
output_dtype,
quant_scales=quant_scales,
input_sf=x_sf,