[None][fix] Fix the performance issue of FP8 blockwise grouped GEMM when using attention DP (#8501)

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
This commit is contained in:
Jinyang Yuan 2025-10-27 10:18:19 +08:00 committed by GitHub
parent e0728ba8a7
commit 0a0f93d4a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 192 additions and 151 deletions

View File

@ -994,8 +994,8 @@ public:
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr, mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex, mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex, ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize, mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mTotalTokens,
mHiddenSize, mInterSize, mNumExperts, mK, mHiddenSize, mHiddenSize, mInterSize, mNumExperts, mK,
mWorkspace + mWorkspaceSize * (mBufferIndex % mNumWorkspaceBuffers), mWorkspace + mWorkspaceSize * (mBufferIndex % mNumWorkspaceBuffers),
mFinalOutput + mFinalOutputSize * (mBufferIndex % mNumInputBuffers), mFinalOutput + mFinalOutputSize * (mBufferIndex % mNumInputBuffers),
mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config, mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config,
@ -1007,8 +1007,8 @@ public:
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr, mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex, mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex, ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize, mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mTotalTokens,
mHiddenSize, mInterSize, mNumExperts, mK, mHiddenSize, mHiddenSize, mInterSize, mNumExperts, mK,
mWorkspace + mWorkspaceSize * (mBufferIndex % mNumWorkspaceBuffers), mWorkspace + mWorkspaceSize * (mBufferIndex % mNumWorkspaceBuffers),
mFinalOutput + mFinalOutputSize * (mBufferIndex % mNumInputBuffers), mFinalOutput + mFinalOutputSize * (mBufferIndex % mNumInputBuffers),
mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config, mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config,

View File

@ -88,8 +88,8 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::gemm(__nv_fp8
template <typename ElementA, typename ElementB, typename ElementD> template <typename ElementA, typename ElementB, typename ElementD>
void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void* mat_d, void const* mat_a, void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void* mat_d, void const* mat_a,
void const* mat_b, int64_t const* problem_m_offsets, size_t num_problems, size_t shape_n, size_t shape_k, void const* mat_b, int64_t const* problem_m_offsets, size_t num_problems, size_t expected_m, size_t shape_n,
cudaStream_t stream, float const* scales_a, float const* scales_b) size_t shape_k, cudaStream_t stream, float const* scales_a, float const* scales_b)
{ {
constexpr bool internal_quantize_a = !std::is_same_v<ElementA, __nv_fp8_e4m3>; constexpr bool internal_quantize_a = !std::is_same_v<ElementA, __nv_fp8_e4m3>;
constexpr bool internal_quantize_b = !std::is_same_v<ElementB, __nv_fp8_e4m3>; constexpr bool internal_quantize_b = !std::is_same_v<ElementB, __nv_fp8_e4m3>;
@ -138,21 +138,21 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void*
{ {
fp8_grouped_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, per_token_per_128c_scales, fp8_grouped_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, per_token_per_128c_scales,
reinterpret_cast<__nv_bfloat16 const*>(mat_b), fp8_mat_b, per_block_scales, reinterpret_cast<__nv_bfloat16 const*>(mat_b), fp8_mat_b, per_block_scales,
reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m_, max_shape_m_4_align_, reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m, max_shape_m_4_align_,
max_shape_m_32_align_padded_, shape_n, shape_k, stream, internal_quantize_a, internal_quantize_b); max_shape_m_32_align_padded_, shape_n, shape_k, stream, internal_quantize_a, internal_quantize_b);
} }
else if constexpr (std::is_same_v<ElementA, __nv_bfloat16> && std::is_same_v<ElementB, __nv_fp8_e4m3>) else if constexpr (std::is_same_v<ElementA, __nv_bfloat16> && std::is_same_v<ElementB, __nv_fp8_e4m3>)
{ {
fp8_grouped_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, per_token_per_128c_scales, fp8_grouped_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, per_token_per_128c_scales,
nullptr, fp8_mat_b, per_block_scales, reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, nullptr, fp8_mat_b, per_block_scales, reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets,
num_problems, expected_m_, max_shape_m_4_align_, max_shape_m_32_align_padded_, shape_n, shape_k, stream, num_problems, expected_m, max_shape_m_4_align_, max_shape_m_32_align_padded_, shape_n, shape_k, stream,
internal_quantize_a, internal_quantize_b); internal_quantize_a, internal_quantize_b);
} }
else if constexpr (std::is_same_v<ElementA, __nv_fp8_e4m3> && std::is_same_v<ElementB, __nv_fp8_e4m3>) else if constexpr (std::is_same_v<ElementA, __nv_fp8_e4m3> && std::is_same_v<ElementB, __nv_fp8_e4m3>)
{ {
fp8_grouped_gemm_run(nullptr, fp8_mat_a, per_token_per_128c_scales, fp8_grouped_gemm_run(nullptr, fp8_mat_a, per_token_per_128c_scales,
reinterpret_cast<__nv_bfloat16 const*>(mat_b), fp8_mat_b, per_block_scales, reinterpret_cast<__nv_bfloat16 const*>(mat_b), fp8_mat_b, per_block_scales,
reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m_, max_shape_m_4_align_, reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m, max_shape_m_4_align_,
max_shape_m_32_align_padded_, shape_n, shape_k, stream, internal_quantize_a, internal_quantize_b); max_shape_m_32_align_padded_, shape_n, shape_k, stream, internal_quantize_a, internal_quantize_b);
} }
else else
@ -164,6 +164,15 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void*
#endif #endif
} }
template <typename ElementA, typename ElementB, typename ElementD>
void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void* mat_d, void const* mat_a,
void const* mat_b, int64_t const* problem_m_offsets, size_t num_problems, size_t shape_n, size_t shape_k,
cudaStream_t stream, float const* scales_a, float const* scales_b)
{
moeGemm(mat_d, mat_a, mat_b, problem_m_offsets, num_problems, expected_m_, shape_n, shape_k, stream, scales_a,
scales_b);
}
template <typename ElementA, typename ElementB, typename ElementD> template <typename ElementA, typename ElementB, typename ElementD>
void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::strideBatchGemm(__nv_bfloat16* mat_d, int ld_d, void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::strideBatchGemm(__nv_bfloat16* mat_d, int ld_d,
int stride_d, __nv_fp8_e4m3* mat_a, int ld_a, int stride_a, __nv_fp8_e4m3* mat_b, int ld_b, int stride_b, int stride_d, __nv_fp8_e4m3* mat_a, int ld_a, int stride_a, __nv_fp8_e4m3* mat_b, int ld_b, int stride_b,

View File

@ -40,6 +40,11 @@ public:
cudaStream_t stream) cudaStream_t stream)
= 0; = 0;
virtual void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
size_t num_problems, size_t expected_m, size_t shape_n, size_t shape_k, cudaStream_t stream,
float const* scales_a = nullptr, float const* scales_b = nullptr)
= 0;
virtual void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets, virtual void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
size_t num_problems, size_t shape_n, size_t shape_k, cudaStream_t stream, float const* scales_a = nullptr, size_t num_problems, size_t shape_n, size_t shape_k, cudaStream_t stream, float const* scales_a = nullptr,
float const* scales_b = nullptr) float const* scales_b = nullptr)
@ -95,6 +100,10 @@ public:
int ld_d, int shape_m, int shape_n, int shape_k, float const* scales_a, float const* scales_b, int ld_d, int shape_m, int shape_n, int shape_k, float const* scales_a, float const* scales_b,
cudaStream_t stream) override; cudaStream_t stream) override;
void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
size_t num_problems, size_t expected_m, size_t shape_n, size_t shape_k, cudaStream_t stream,
float const* scales_a = nullptr, float const* scales_b = nullptr) override;
void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets, void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
size_t num_problems, size_t shape_n, size_t shape_k, cudaStream_t stream, float const* scales_a = nullptr, size_t num_problems, size_t shape_n, size_t shape_k, cudaStream_t stream, float const* scales_a = nullptr,
float const* scales_b = nullptr) override; float const* scales_b = nullptr) override;

View File

@ -459,8 +459,8 @@ public:
virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights,
void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights,
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const num_valid_rows,
int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts, int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts,
int const experts_per_token, char* workspace_ptr, void* final_output, int* unpermuted_row_to_permuted_row, int const experts_per_token, char* workspace_ptr, void* final_output, int* unpermuted_row_to_permuted_row,
MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, LoraParams& lora_params, MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, LoraParams& lora_params,
bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params,
@ -474,11 +474,11 @@ public:
int64_t const* const num_valid_tokens_ptr, void const* const fc1_int_scales, float const* const fc1_fp8_dequant, int64_t const* const num_valid_tokens_ptr, void const* const fc1_int_scales, float const* const fc1_fp8_dequant,
float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params,
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert,
int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
int* active_expert_global_ids) bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids)
= 0; = 0;
virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output, virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output,
@ -489,10 +489,10 @@ public:
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row, float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts, int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const unpadded_hidden_size,
int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array, int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token,
bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale,
MOEParallelismConfig parallelism_config, bool const enable_alltoall, cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids) int* active_expert_global_ids)
= 0; = 0;
@ -618,8 +618,8 @@ public:
void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights,
void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights,
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const num_valid_rows,
int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts, int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts,
int const experts_per_token, char* workspace_ptr, void* final_output, int* unpermuted_row_to_permuted_row, int const experts_per_token, char* workspace_ptr, void* final_output, int* unpermuted_row_to_permuted_row,
MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, LoraParams& lora_params, MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, LoraParams& lora_params,
bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params,
@ -639,10 +639,11 @@ public:
ScaleBiasType const* const fc1_int_scales, float const* const fc1_fp8_dequant, float const* const fc2_fp8_quant, ScaleBiasType const* const fc1_int_scales, float const* const fc1_fp8_dequant, float const* const fc2_fp8_quant,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params,
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert,
int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids); cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
int* num_active_experts_per, int* active_expert_global_ids);
static void gemm2(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner, static void gemm2(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output, DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output,
@ -654,11 +655,12 @@ public:
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row, float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts, int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const unpadded_hidden_size,
int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array, int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token,
bool use_lora, void* fc2_lora, cudaStream_t stream, MOEParallelismConfig parallelism_config, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, cudaStream_t stream,
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
int* num_active_experts_per, int* active_expert_global_ids); cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids);
// Overrides to allow us to forward on to the internal functions with the pointers using the correct type // Overrides to allow us to forward on to the internal functions with the pointers using the correct type
void gemm1(void const* const input, void* const output, void* const intermediate_result, void gemm1(void const* const input, void* const output, void* const intermediate_result,
@ -667,20 +669,20 @@ public:
int64_t const* const num_valid_tokens_ptr, void const* const fc1_int_scales, float const* const fc1_fp8_dequant, int64_t const* const num_valid_tokens_ptr, void const* const fc1_int_scales, float const* const fc1_fp8_dequant,
float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params,
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert,
int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
int* active_expert_global_ids) override bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids) override
{ {
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr; auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
return Self::gemm1(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input), return Self::gemm1(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input),
static_cast<T*>(output), intermediate_result, expert_first_token_offset, tma_ws_input_template, static_cast<T*>(output), intermediate_result, expert_first_token_offset, tma_ws_input_template,
static_cast<WeightType const*>(fc1_expert_weights), static_cast<ScaleBiasType const*>(fc1_expert_biases), static_cast<WeightType const*>(fc1_expert_weights), static_cast<ScaleBiasType const*>(fc1_expert_biases),
num_valid_tokens_ptr, static_cast<ScaleBiasType const*>(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant, num_valid_tokens_ptr, static_cast<ScaleBiasType const*>(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant,
fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size, fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, expected_tokens_per_expert,
num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, bias_is_broadcast, stream, config, hidden_size, inter_size, num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array,
min_latency_mode, num_active_experts_per, active_expert_global_ids); bias_is_broadcast, stream, config, min_latency_mode, num_active_experts_per, active_expert_global_ids);
} }
void gemm2(void const* const input, void* const gemm_output, void* const final_output, void gemm2(void const* const input, void* const gemm_output, void* const final_output,
@ -691,10 +693,10 @@ public:
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row, float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts, int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const unpadded_hidden_size,
int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array, int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token,
bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale,
MOEParallelismConfig parallelism_config, bool const enable_alltoall, cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids) override int* active_expert_global_ids) override
{ {
@ -705,9 +707,9 @@ public:
static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params, static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row, token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row,
permuted_row_to_unpermuted_row, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows, permuted_row_to_unpermuted_row, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows,
hidden_size, unpadded_hidden_size, inter_size, num_experts_per_node, experts_per_token, expected_tokens_per_expert, hidden_size, unpadded_hidden_size, inter_size, num_experts_per_node,
alpha_scale_ptr_array, use_lora, fc2_lora, stream, parallelism_config, enable_alltoall, config, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora, stream, parallelism_config, enable_alltoall,
min_latency_mode, num_active_experts_per, active_expert_global_ids); config, min_latency_mode, num_active_experts_per, active_expert_global_ids);
} }
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override
@ -844,8 +846,9 @@ private:
void* const intermediate_result, int64_t const* const expert_first_token_offset, void* const intermediate_result, int64_t const* const expert_first_token_offset,
WeightType const* const fc1_expert_weights, ScaleBiasType const* const fc1_expert_biases, WeightType const* const fc1_expert_weights, ScaleBiasType const* const fc1_expert_biases,
float const* const fc2_fp8_quant, int64_t const num_rows, int64_t const expanded_num_rows, float const* const fc2_fp8_quant, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const inter_size,
ActivationParams fc1_activation_type, QuantParams& quant_params, cudaStream_t stream); int const num_experts_per_node, ActivationParams fc1_activation_type, QuantParams& quant_params,
cudaStream_t stream);
static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output, static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output,
OutputType* const final_output, int64_t const* const expert_first_token_offset, OutputType* const final_output, int64_t const* const expert_first_token_offset,
@ -853,9 +856,10 @@ private:
float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row, float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts, int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const unpadded_hidden_size,
int64_t const num_experts_per_node, int64_t const k, MOEParallelismConfig parallelism_config, int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k,
bool const enable_alltoall, QuantParams& quant_params, cudaStream_t stream); MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params,
cudaStream_t stream);
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales, T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,

View File

@ -2848,9 +2848,9 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena
DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, T* const output, void* const gemm_output, DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, T* const output, void* const gemm_output,
int64_t const* const expert_first_token_offset, WeightType const* const fc1_expert_weights, int64_t const* const expert_first_token_offset, WeightType const* const fc1_expert_weights,
ScaleBiasType const* const fc1_expert_biases, float const* const fc2_fp8_quant, int64_t const num_rows, ScaleBiasType const* const fc1_expert_biases, float const* const fc2_fp8_quant, int64_t const num_rows,
int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert, int64_t const hidden_size,
int const num_experts_per_node, ActivationParams fc1_activation_type, QuantParams& quant_params, int64_t const inter_size, int const num_experts_per_node, ActivationParams fc1_activation_type,
cudaStream_t stream) QuantParams& quant_params, cudaStream_t stream)
{ {
bool const is_gated_activation = isGatedActivation(fc1_activation_type); bool const is_gated_activation = isGatedActivation(fc1_activation_type);
@ -2859,7 +2859,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena
// NOTE: we assume gemm_runner.configureWorkspace has already been called. // NOTE: we assume gemm_runner.configureWorkspace has already been called.
gemm_runner.moeGemm(gemm_output, input, fc1_expert_weights, expert_first_token_offset, num_experts_per_node, gemm_runner.moeGemm(gemm_output, input, fc1_expert_weights, expert_first_token_offset, num_experts_per_node,
shape_n, shape_k, stream, nullptr, quant_params.fp8_block_scaling.fc1_scales_ptrs); expected_tokens_per_expert, shape_n, shape_k, stream, nullptr, quant_params.fp8_block_scaling.fc1_scales_ptrs);
sync_check_cuda_error(stream); sync_check_cuda_error(stream);
constexpr bool bias_is_broadcast = true; constexpr bool bias_is_broadcast = true;
@ -2879,16 +2879,16 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena
float const* const unpermuted_final_scales, int const* const unpermuted_row_to_permuted_row, float const* const unpermuted_final_scales, int const* const unpermuted_row_to_permuted_row,
int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts, int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const unpadded_hidden_size,
int64_t const num_experts_per_node, int64_t const k, MOEParallelismConfig parallelism_config, int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k,
bool const enable_alltoall, QuantParams& quant_params, cudaStream_t stream) MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params, cudaStream_t stream)
{ {
int shape_n = hidden_size; int shape_n = hidden_size;
int shape_k = inter_size; int shape_k = inter_size;
// NOTE: we assume gemm_runner.configureWorkspace has already been called. // NOTE: we assume gemm_runner.configureWorkspace has already been called.
gemm_runner.moeGemm(gemm_output, input, fc2_expert_weights, expert_first_token_offset, num_experts_per_node, gemm_runner.moeGemm(gemm_output, input, fc2_expert_weights, expert_first_token_offset, num_experts_per_node,
shape_n, shape_k, stream, nullptr, quant_params.fp8_block_scaling.fc2_scales_ptrs); expected_tokens_per_expert, shape_n, shape_k, stream, nullptr, quant_params.fp8_block_scaling.fc2_scales_ptrs);
sync_check_cuda_error(stream); sync_check_cuda_error(stream);
@ -2948,18 +2948,20 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
ScaleBiasType const* const fc1_int_scales, float const* const fc1_fp8_dequant, float const* const fc2_fp8_quant, ScaleBiasType const* const fc1_int_scales, float const* const fc1_fp8_dequant, float const* const fc2_fp8_quant,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows,
int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert, int64_t const hidden_size,
int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, int64_t const inter_size, int const num_experts_per_node, ActivationParams fc1_activation_type,
bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, float const** alpha_scale_ptr_array, bool bias_is_broadcast, cudaStream_t stream,
int* num_active_experts_per, int* active_expert_global_ids) cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids)
{ {
if (fp8_blockscale_gemm_runner) if (fp8_blockscale_gemm_runner)
{ {
TLLM_CHECK(!min_latency_mode); TLLM_CHECK(!min_latency_mode);
Self::BlockScaleFC1(*fp8_blockscale_gemm_runner, input, output, intermediate_result, expert_first_token_offset, Self::BlockScaleFC1(*fp8_blockscale_gemm_runner, input, output, intermediate_result, expert_first_token_offset,
fc1_expert_weights, fc1_expert_biases, fc2_fp8_quant, num_rows, expanded_num_rows, hidden_size, inter_size, fc1_expert_weights, fc1_expert_biases, fc2_fp8_quant, num_rows, expanded_num_rows,
num_experts_per_node, fc1_activation_type, quant_params, stream); expected_tokens_per_expert, hidden_size, inter_size, num_experts_per_node, fc1_activation_type,
quant_params, stream);
return; return;
} }
@ -3131,11 +3133,11 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
QuantParams quant_params, float const* const unpermuted_final_scales, float const* const permuted_final_scales, QuantParams quant_params, float const* const unpermuted_final_scales, float const* const permuted_final_scales,
int const* const unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row, int const* const unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row,
int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows,
int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert, int64_t const hidden_size,
int64_t const inter_size, int const num_experts_per_node, int64_t const k, float const** alpha_scale_ptr_array, int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const k,
bool use_lora, void* fc2_lora, cudaStream_t stream, MOEParallelismConfig parallelism_config, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, cudaStream_t stream,
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, MOEParallelismConfig parallelism_config, bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config,
int* num_active_experts_per, int* active_expert_global_ids) bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids)
{ {
int64_t const* total_tokens_including_expert = expert_first_token_offset + 1; int64_t const* total_tokens_including_expert = expert_first_token_offset + 1;
@ -3152,8 +3154,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
Self::BlockScaleFC2(*fp8_blockscale_gemm_runner, input, gemm_output, final_output, expert_first_token_offset, Self::BlockScaleFC2(*fp8_blockscale_gemm_runner, input, gemm_output, final_output, expert_first_token_offset,
fc2_expert_weights, fc2_expert_biases, unpermuted_final_scales, unpermuted_row_to_permuted_row, fc2_expert_weights, fc2_expert_biases, unpermuted_final_scales, unpermuted_row_to_permuted_row,
permuted_row_to_unpermuted_row, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows, permuted_row_to_unpermuted_row, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows,
hidden_size, unpadded_hidden_size, inter_size, num_experts_per_node, k, parallelism_config, enable_alltoall, expected_tokens_per_expert, hidden_size, unpadded_hidden_size, inter_size, num_experts_per_node, k,
quant_params, stream); parallelism_config, enable_alltoall, quant_params, stream);
return; return;
} }
@ -3445,8 +3447,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
void const* input_activations_void, void const* input_sf_void, bool const swizzled_input_sf, void const* input_activations_void, void const* input_sf_void, bool const swizzled_input_sf,
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights_void, int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights_void,
void const* fc1_expert_biases_void, ActivationParams fc1_activation_type, void const* fc2_expert_weights_void, void const* fc1_expert_biases_void, ActivationParams fc1_activation_type, void const* fc2_expert_weights_void,
void const* fc2_expert_biases_void, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, void const* fc2_expert_biases_void, QuantParams quant_params, int64_t const num_rows, int64_t const num_valid_rows,
int64_t const unpadded_hidden_size, int64_t const inter_size, int const full_num_experts, int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, int const full_num_experts,
int const experts_per_token, char* workspace_ptr, void* final_output_void, int* unpermuted_row_to_permuted_row, int const experts_per_token, char* workspace_ptr, void* final_output_void, int* unpermuted_row_to_permuted_row,
MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, LoraParams& lora_params, MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, LoraParams& lora_params,
bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params,
@ -3597,6 +3599,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
int64_t const* num_valid_tokens_ptr = needs_num_valid ? expert_first_token_offset_ + num_experts_per_node : nullptr; int64_t const* num_valid_tokens_ptr = needs_num_valid ? expert_first_token_offset_ + num_experts_per_node : nullptr;
auto expanded_num_rows = num_rows * experts_per_token; auto expanded_num_rows = num_rows * experts_per_token;
auto expected_tokens_per_expert = (num_valid_rows * experts_per_token + full_num_experts - 1) / full_num_experts;
if (min_latency_mode) if (min_latency_mode)
{ {
@ -3622,9 +3625,10 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
fc1_result_, glu_inter_result_, expert_first_token_offset_, gemm1_tma_ws_input, fc1_expert_weights, fc1_result_, glu_inter_result_, expert_first_token_offset_, gemm1_tma_ws_input, fc1_expert_weights,
fc1_expert_biases, num_valid_tokens_ptr, fc1_int_scales, fc1_fp8_dequant, fc1_expert_biases, num_valid_tokens_ptr, fc1_int_scales, fc1_fp8_dequant,
use_wfp4afp8 ? fc2_wfp4afp8_quant_scale : fc2_fp8_quant, input_sf /*input fp4 scale or expanded fp4 scale*/, use_wfp4afp8 ? fc2_wfp4afp8_quant_scale : fc2_fp8_quant, input_sf /*input fp4 scale or expanded fp4 scale*/,
fc2_fp4_act_scale_, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size, fc2_fp4_act_scale_, quant_params, num_rows, expanded_num_rows, expected_tokens_per_expert, hidden_size,
num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array_fc1_, !use_lora, stream, *gemm1_config_, inter_size, num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array_fc1_, !use_lora, stream,
true, min_latency_params.num_active_experts_per_node, min_latency_params.active_expert_global_ids); *gemm1_config_, true, min_latency_params.num_active_experts_per_node,
min_latency_params.active_expert_global_ids);
sync_check_cuda_error(stream); sync_check_cuda_error(stream);
auto gemm2_input = applyPrequantScale(smoothed_act_, fc1_result_, quant_params.groupwise.fc2.act_scales, auto gemm2_input = applyPrequantScale(smoothed_act_, fc1_result_, quant_params.groupwise.fc2.act_scales,
@ -3633,10 +3637,11 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales, expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales,
fc2_fp8_dequant, fc2_fp4_act_scale_, quant_params, token_topk_unpermuted_scales, fc2_fp8_dequant, fc2_fp4_act_scale_, quant_params, token_topk_unpermuted_scales,
permuted_token_final_scales_, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row_, permuted_token_final_scales_, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row_,
token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows, hidden_size, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows, expected_tokens_per_expert,
unpadded_hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array_fc2_, hidden_size, unpadded_hidden_size, inter_size, num_experts_per_node, experts_per_token,
use_lora, lora_fc2_result_, stream, parallelism_config, enable_alltoall, *gemm2_config_, true, alpha_scale_ptr_array_fc2_, use_lora, lora_fc2_result_, stream, parallelism_config, enable_alltoall,
min_latency_params.num_active_experts_per_node, min_latency_params.active_expert_global_ids); *gemm2_config_, true, min_latency_params.num_active_experts_per_node,
min_latency_params.active_expert_global_ids);
sync_check_cuda_error(stream); sync_check_cuda_error(stream);
} }
else else
@ -3722,9 +3727,9 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
Self::gemm1(moe_gemm_runner_, blockscale_gemm_runner, gemm1_input, fc1_result_, glu_inter_result_, Self::gemm1(moe_gemm_runner_, blockscale_gemm_runner, gemm1_input, fc1_result_, glu_inter_result_,
expert_first_token_offset_, gemm1_tma_ws_input, fc1_expert_weights, fc1_expert_biases, num_valid_tokens_ptr, expert_first_token_offset_, gemm1_tma_ws_input, fc1_expert_weights, fc1_expert_biases, num_valid_tokens_ptr,
fc1_int_scales, fc1_fp8_dequant, use_wfp4afp8 ? fc2_wfp4afp8_quant_scale : fc2_fp8_quant, fc1_int_scales, fc1_fp8_dequant, use_wfp4afp8 ? fc2_wfp4afp8_quant_scale : fc2_fp8_quant,
fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size, fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, num_rows, expanded_num_rows,
num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array_fc1_, !use_lora, stream, *gemm1_config_, expected_tokens_per_expert, hidden_size, inter_size, num_experts_per_node, fc1_activation_type,
false, nullptr, nullptr); alpha_scale_ptr_array_fc1_, !use_lora, stream, *gemm1_config_, false, nullptr, nullptr);
sync_check_cuda_error(stream); sync_check_cuda_error(stream);
if (use_lora) if (use_lora)
@ -3742,10 +3747,10 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales, expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales,
fc2_fp8_dequant, fc2_fp4_act_scale_, quant_params, token_topk_unpermuted_scales, fc2_fp8_dequant, fc2_fp4_act_scale_, quant_params, token_topk_unpermuted_scales,
permuted_token_final_scales_, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row_, permuted_token_final_scales_, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row_,
token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows, hidden_size, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows, expected_tokens_per_expert,
unpadded_hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array_fc2_, hidden_size, unpadded_hidden_size, inter_size, num_experts_per_node, experts_per_token,
use_lora, lora_fc2_result_, stream, parallelism_config, enable_alltoall, *gemm2_config_, false, nullptr, alpha_scale_ptr_array_fc2_, use_lora, lora_fc2_result_, stream, parallelism_config, enable_alltoall,
nullptr); *gemm2_config_, false, nullptr, nullptr);
sync_check_cuda_error(stream); sync_check_cuda_error(stream);
} }
} }
@ -4673,6 +4678,7 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac
fp4_act_scale_flat, // fp4_act_scale_flat, //
mQuantParams, // mQuantParams, //
original_num_tokens, // original_num_tokens, //
original_num_tokens, //
expanded_num_tokens, // expanded_num_tokens, //
mExpertHiddenSize, // mExpertHiddenSize, //
mExpertInterSize, // mExpertInterSize, //
@ -4708,6 +4714,7 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac
token_selected_experts, // token_selected_experts, //
expert_first_token_offset + mNumExpertsPerNode, // expert_first_token_offset + mNumExpertsPerNode, //
original_num_tokens, // original_num_tokens, //
original_num_tokens, //
expanded_num_tokens, // expanded_num_tokens, //
mExpertHiddenSize, // mExpertHiddenSize, //
mExpertUnpaddedHiddenSize, // mExpertUnpaddedHiddenSize, //

View File

@ -425,9 +425,9 @@ public:
virtual void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts, virtual void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts,
float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases, float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases,
ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, QuantParams quant_params, int64_t const num_rows, int64_t const num_valid_rows, int64_t const hidden_size,
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr,
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora, void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora,
LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
MoeMinLatencyParams& min_latency_params, cudaStream_t stream) MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
= 0; = 0;
@ -439,11 +439,11 @@ public:
int64_t const* const num_valid_tokens_ptr, void const* const fc1_int_scales, float const* const fc1_fp8_dequant, int64_t const* const num_valid_tokens_ptr, void const* const fc1_int_scales, float const* const fc1_fp8_dequant,
float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params,
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert,
int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
int* active_expert_global_ids, int start_expert) bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert)
= 0; = 0;
virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output, virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output,
@ -454,11 +454,11 @@ public:
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row, float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row, int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const inter_size,
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array,
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config, bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, MOEParallelismConfig parallelism_config, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
int* active_expert_global_ids, int start_expert) int* num_active_experts_per, int* active_expert_global_ids, int start_expert)
= 0; = 0;
virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput> virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@ -573,9 +573,9 @@ public:
void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts, void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts,
float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases, float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases,
ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, QuantParams quant_params, int64_t const num_rows, int64_t const num_valid_rows, int64_t const hidden_size,
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr,
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora, void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora,
LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override; MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;
@ -593,10 +593,11 @@ public:
ScaleBiasType const* const fc1_int_scales, float const* const fc1_fp8_dequant, float const* const fc2_fp8_quant, ScaleBiasType const* const fc1_int_scales, float const* const fc1_fp8_dequant, float const* const fc2_fp8_quant,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params,
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert,
int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert); cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
int* num_active_experts_per, int* active_expert_global_ids, int start_expert);
static void gemm2(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner, static void gemm2(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output, DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output,
@ -608,10 +609,11 @@ public:
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row, float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row, int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const inter_size,
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array,
cudaStream_t stream, MOEParallelismConfig parallelism_config, cutlass_extensions::CutlassGemmConfig config, bool use_lora, void* fc2_lora, cudaStream_t stream, MOEParallelismConfig parallelism_config,
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert); cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids, int start_expert);
// Overrides to allow us to forward on to the internal functions with the pointers using the correct type // Overrides to allow us to forward on to the internal functions with the pointers using the correct type
void gemm1(void const* const input, void* const output, void* const intermediate_result, void gemm1(void const* const input, void* const output, void* const intermediate_result,
@ -620,20 +622,21 @@ public:
int64_t const* const num_valid_tokens_ptr, void const* const fc1_int_scales, float const* const fc1_fp8_dequant, int64_t const* const num_valid_tokens_ptr, void const* const fc1_int_scales, float const* const fc1_fp8_dequant,
float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params,
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert,
int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
int* active_expert_global_ids, int start_expert) override bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert) override
{ {
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr; auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
return Self::gemm1(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input), return Self::gemm1(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input),
static_cast<T*>(output), intermediate_result, expert_first_token_offset, tma_ws_input_template, static_cast<T*>(output), intermediate_result, expert_first_token_offset, tma_ws_input_template,
static_cast<WeightType const*>(fc1_expert_weights), static_cast<ScaleBiasType const*>(fc1_expert_biases), static_cast<WeightType const*>(fc1_expert_weights), static_cast<ScaleBiasType const*>(fc1_expert_biases),
num_valid_tokens_ptr, static_cast<ScaleBiasType const*>(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant, num_valid_tokens_ptr, static_cast<ScaleBiasType const*>(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant,
fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size, fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, expected_tokens_per_expert,
num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, bias_is_broadcast, stream, config, hidden_size, inter_size, num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array,
min_latency_mode, num_active_experts_per, active_expert_global_ids, start_expert); bias_is_broadcast, stream, config, min_latency_mode, num_active_experts_per, active_expert_global_ids,
start_expert);
} }
void gemm2(void const* const input, void* const gemm_output, void* const final_output, void gemm2(void const* const input, void* const gemm_output, void* const final_output,
@ -644,11 +647,11 @@ public:
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row, float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row, int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const inter_size,
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array,
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config, bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, MOEParallelismConfig parallelism_config, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
int* active_expert_global_ids, int start_expert) override int* num_active_experts_per, int* active_expert_global_ids, int start_expert) override
{ {
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr; auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
return Self::gemm2(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input), gemm_output, return Self::gemm2(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input), gemm_output,
@ -657,9 +660,9 @@ public:
static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params, static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row, token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row,
permuted_row_to_unpermuted_row, expert_for_source_row, num_valid_tokens_ptr, num_rows, expanded_num_rows, permuted_row_to_unpermuted_row, expert_for_source_row, num_valid_tokens_ptr, num_rows, expanded_num_rows,
hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora, expected_tokens_per_expert, hidden_size, inter_size, num_experts_per_node, experts_per_token,
stream, parallelism_config, config, min_latency_mode, num_active_experts_per, active_expert_global_ids, alpha_scale_ptr_array, use_lora, fc2_lora, stream, parallelism_config, config, min_latency_mode,
start_expert); num_active_experts_per, active_expert_global_ids, start_expert);
} }
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override
@ -790,16 +793,17 @@ private:
void* const intermediate_result, int64_t const* const expert_first_token_offset, void* const intermediate_result, int64_t const* const expert_first_token_offset,
WeightType const* const fc1_expert_weights, ScaleBiasType const* const fc1_expert_biases, WeightType const* const fc1_expert_weights, ScaleBiasType const* const fc1_expert_biases,
float const* const fc2_fp8_quant, int64_t const num_rows, int64_t const expanded_num_rows, float const* const fc2_fp8_quant, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const inter_size,
ActivationParams fc1_activation_type, QuantParams& quant_params, cudaStream_t stream); int const num_experts_per_node, ActivationParams fc1_activation_type, QuantParams& quant_params,
cudaStream_t stream);
static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output, static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output,
OutputType* const final_output, int64_t const* const expert_first_token_offset, OutputType* const final_output, int64_t const* const expert_first_token_offset,
WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases, WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases,
float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row, float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* const expert_for_source_row, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int const* const expert_for_source_row, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows,
int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert, int64_t const hidden_size,
int64_t const inter_size, int const num_experts_per_node, int64_t const k, int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const k,
MOEParallelismConfig parallelism_config, QuantParams& quant_params, cudaStream_t stream); MOEParallelismConfig parallelism_config, QuantParams& quant_params, cudaStream_t stream);
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales, T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,

View File

@ -964,7 +964,7 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
hasFinalScales() ? static_cast<float const*>(inputs[getTokenFinalScalesIndex()]) : nullptr, hasFinalScales() ? static_cast<float const*>(inputs[getTokenFinalScalesIndex()]) : nullptr,
inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr, inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr,
ActivationParams(mActivationType), inputs[getExpertWeights2Index()], ActivationParams(mActivationType), inputs[getExpertWeights2Index()],
hasBias() ? inputs[getExpertBias2Index()] : nullptr, quant_params, num_tokens, mExpertHiddenSize, hasBias() ? inputs[getExpertBias2Index()] : nullptr, quant_params, num_tokens, num_tokens, mExpertHiddenSize,
mExpertHiddenSize /*TRT does not support padding, safe to assume padded/unpadded hidden sizes are the same*/, mExpertHiddenSize /*TRT does not support padding, safe to assume padded/unpadded hidden sizes are the same*/,
mExpertInterSize, mNumExperts, mExpertsPerToken, static_cast<char*>(workspace.workspace), mExpertInterSize, mNumExperts, mExpertsPerToken, static_cast<char*>(workspace.workspace),
// Outputs // Outputs
@ -977,7 +977,7 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
hasFinalScales() ? static_cast<float const*>(inputs[getTokenFinalScalesIndex()]) : nullptr, hasFinalScales() ? static_cast<float const*>(inputs[getTokenFinalScalesIndex()]) : nullptr,
inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr, inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr,
ActivationParams(mActivationType), inputs[getExpertWeights2Index()], ActivationParams(mActivationType), inputs[getExpertWeights2Index()],
hasBias() ? inputs[getExpertBias2Index()] : nullptr, quant_params, num_tokens, mExpertHiddenSize, hasBias() ? inputs[getExpertBias2Index()] : nullptr, quant_params, num_tokens, num_tokens, mExpertHiddenSize,
mExpertInterSize, mNumExperts, mExpertsPerToken, static_cast<char*>(workspace.workspace), mExpertInterSize, mNumExperts, mExpertsPerToken, static_cast<char*>(workspace.workspace),
// Outputs // Outputs
outputs[getOutputTensorIndex()], static_cast<int*>(workspace.src_to_dest_map), mParallelismConfig, hasLora(), outputs[getOutputTensorIndex()], static_cast<int*>(workspace.src_to_dest_map), mParallelismConfig, hasLora(),

View File

@ -230,6 +230,7 @@ torch::Tensor fp8_block_scaling_moe_gemm_hopper(torch::Tensor const& mat1, torch
auto const num_problems = mat2.sizes()[0]; auto const num_problems = mat2.sizes()[0];
auto const n = mat2.sizes()[1]; auto const n = mat2.sizes()[1];
auto const k = mat2.sizes()[2]; auto const k = mat2.sizes()[2];
auto const expected_m = (m_total + num_problems - 1) / num_problems;
TORCH_CHECK(k % 16 == 0, "K must be a multiple of 16, (K=", k, ")"); TORCH_CHECK(k % 16 == 0, "K must be a multiple of 16, (K=", k, ")");
TORCH_CHECK(n % 16 == 0, "N must be a multiple of 16, (N=", n, ")"); TORCH_CHECK(n % 16 == 0, "N must be a multiple of 16, (N=", n, ")");
@ -247,7 +248,8 @@ torch::Tensor fp8_block_scaling_moe_gemm_hopper(torch::Tensor const& mat1, torch
void* workspace_ptr = workspace.data_ptr(); void* workspace_ptr = workspace.data_ptr();
gemm_runner->configureWorkspace(static_cast<char*>(workspace_ptr)); gemm_runner->configureWorkspace(static_cast<char*>(workspace_ptr));
gemm_runner->moeGemm(out.data_ptr(), mat1.data_ptr(), mat2.data_ptr(), gemm_runner->moeGemm(out.data_ptr(), mat1.data_ptr(), mat2.data_ptr(),
static_cast<int64_t*>(token_offset.data_ptr()), num_problems, n, k, stream, mat1ScalePtr, mat2ScalePtr); static_cast<int64_t*>(token_offset.data_ptr()), num_problems, expected_m, n, k, stream, mat1ScalePtr,
mat2ScalePtr);
return out; return out;
} }

View File

@ -259,7 +259,7 @@ public:
torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank, torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank,
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids, bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids,
torch::optional<int64_t> const& unpadded_hidden_size) torch::optional<int64_t> const& unpadded_hidden_size, torch::optional<int64_t> const& num_valid_tokens)
{ {
std::lock_guard<std::mutex> lock(mMutex); std::lock_guard<std::mutex> lock(mMutex);
// Free the profile workspace to save memory // Free the profile workspace to save memory
@ -428,10 +428,11 @@ public:
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params, fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params,
fc2_expert_weights.const_data_ptr(), fc2_expert_weights.const_data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params, fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
num_rows, hidden_size, unpadded_hidden_size_val, inter_size, num_experts_total, num_rows, num_valid_tokens.has_value() ? num_valid_tokens.value() : num_rows, hidden_size,
static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace.data_ptr()), unpadded_hidden_size_val, inter_size, num_experts_total, static_cast<int>(experts_per_token),
output.data_ptr(), static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream); static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
#else #else
mKernelRunner->runMoe(input.const_data_ptr(), mKernelRunner->runMoe(input.const_data_ptr(),
input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf, input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf,
@ -442,7 +443,8 @@ public:
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params, fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params,
fc2_expert_weights.const_data_ptr(), fc2_expert_weights.const_data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params, fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token), num_rows, num_valid_tokens.has_value() ? num_valid_tokens.value() : num_rows, hidden_size, inter_size,
num_experts_total, static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream); mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
@ -461,7 +463,7 @@ public:
torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank, torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank,
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids, bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids,
torch::optional<int64_t> const& unpadded_hidden_size) torch::optional<int64_t> const& unpadded_hidden_size, torch::optional<int64_t> const& num_valid_tokens)
{ {
std::lock_guard<std::mutex> lock(mMutex); std::lock_guard<std::mutex> lock(mMutex);
@ -588,10 +590,11 @@ public:
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params, fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params,
fc2_expert_weights.const_data_ptr(), fc2_expert_weights.const_data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params, fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
num_rows, hidden_size, unpadded_hidden_size_val, inter_size, num_experts_total, num_rows, num_valid_tokens.has_value() ? num_valid_tokens.value() : num_rows, hidden_size,
static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace.data_ptr()), unpadded_hidden_size_val, inter_size, num_experts_total, static_cast<int>(experts_per_token),
output.data_ptr(), static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream); static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
#else #else
mKernelRunner->runMoe(input.const_data_ptr(), mKernelRunner->runMoe(input.const_data_ptr(),
input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf, input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf,
@ -602,7 +605,8 @@ public:
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params, fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params,
fc2_expert_weights.const_data_ptr(), fc2_expert_weights.const_data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params, fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token), num_rows, num_valid_tokens.has_value() ? num_valid_tokens.value() : num_rows, hidden_size, inter_size,
num_experts_total, static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream); mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);

View File

@ -1279,16 +1279,16 @@ protected:
#ifdef USING_OSS_CUTLASS_MOE_GEMM #ifdef USING_OSS_CUTLASS_MOE_GEMM
mMoERunner.runMoe(mInputTensor, nullptr, true, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr, mMoERunner.runMoe(mInputTensor, nullptr, true, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr,
ActivationParams(mActType, mSwigluAlpha, mSwigluBeta, mSwigluLimit), weight2_ptr, bias2_ptr, quant_params, ActivationParams(mActType, mSwigluAlpha, mSwigluBeta, mSwigluLimit), weight2_ptr, bias2_ptr, quant_params,
mTotalTokens, mHiddenSize, mUnpaddedHiddenSize > 0 ? mUnpaddedHiddenSize : mHiddenSize, mTotalTokens, mTotalTokens, mHiddenSize, mUnpaddedHiddenSize > 0 ? mUnpaddedHiddenSize : mHiddenSize,
mInterSize / parallelism_config.tp_size, mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap, mInterSize / parallelism_config.tp_size, mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap,
parallelism_config, enable_alltoall, mUseLora, lora_params, useFp8BlockScales, minLatencyMode, parallelism_config, enable_alltoall, mUseLora, lora_params, useFp8BlockScales, minLatencyMode,
min_latency_params, stream); min_latency_params, stream);
#else #else
mMoERunner.runMoe(mInputTensor, nullptr, true, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr, mMoERunner.runMoe(mInputTensor, nullptr, true, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr,
ActivationParams(mActType, mSwigluAlpha, mSwigluBeta, mSwigluLimit), weight2_ptr, bias2_ptr, quant_params, ActivationParams(mActType, mSwigluAlpha, mSwigluBeta, mSwigluLimit), weight2_ptr, bias2_ptr, quant_params,
mTotalTokens, mHiddenSize, mInterSize / parallelism_config.tp_size, mNumExperts, mK, mWorkspace, mTotalTokens, mTotalTokens, mHiddenSize, mInterSize / parallelism_config.tp_size, mNumExperts, mK,
mFinalOutput, mSourceToExpandedMap, parallelism_config, mUseLora, lora_params, useFp8BlockScales, mWorkspace, mFinalOutput, mSourceToExpandedMap, parallelism_config, mUseLora, lora_params,
minLatencyMode, min_latency_params, stream); useFp8BlockScales, minLatencyMode, min_latency_params, stream);
#endif #endif
check_cuda_error(cudaStreamSynchronize(stream)); check_cuda_error(cudaStreamSynchronize(stream));

View File

@ -240,6 +240,7 @@ def fused_moe(
min_latency_mode, min_latency_mode,
[gemm_tactic_1, gemm_tactic_2], [gemm_tactic_1, gemm_tactic_2],
unpadded_hidden_size, unpadded_hidden_size,
tuner_num_tokens,
) )
return output if min_latency_mode else [output] return output if min_latency_mode else [output]

View File

@ -217,6 +217,7 @@ class CutlassMoEOp(MoEOp):
min_latency_mode, min_latency_mode,
self.gemm_tactics, self.gemm_tactics,
unpadded_hidden_size, unpadded_hidden_size,
tuner_num_tokens,
) )
# Return output based on latency mode # Return output based on latency mode