[None][fix] Fix the performance issue of FP8 blockwise grouped GEMM when using attention DP (#8501)

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
This commit is contained in:
Jinyang Yuan 2025-10-27 10:18:19 +08:00 committed by GitHub
parent e0728ba8a7
commit 0a0f93d4a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 192 additions and 151 deletions

View File

@ -994,8 +994,8 @@ public:
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
mHiddenSize, mInterSize, mNumExperts, mK,
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mTotalTokens,
mHiddenSize, mHiddenSize, mInterSize, mNumExperts, mK,
mWorkspace + mWorkspaceSize * (mBufferIndex % mNumWorkspaceBuffers),
mFinalOutput + mFinalOutputSize * (mBufferIndex % mNumInputBuffers),
mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config,
@ -1007,8 +1007,8 @@ public:
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
mHiddenSize, mInterSize, mNumExperts, mK,
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mTotalTokens,
mHiddenSize, mHiddenSize, mInterSize, mNumExperts, mK,
mWorkspace + mWorkspaceSize * (mBufferIndex % mNumWorkspaceBuffers),
mFinalOutput + mFinalOutputSize * (mBufferIndex % mNumInputBuffers),
mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config,

View File

@ -88,8 +88,8 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::gemm(__nv_fp8
template <typename ElementA, typename ElementB, typename ElementD>
void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void* mat_d, void const* mat_a,
void const* mat_b, int64_t const* problem_m_offsets, size_t num_problems, size_t shape_n, size_t shape_k,
cudaStream_t stream, float const* scales_a, float const* scales_b)
void const* mat_b, int64_t const* problem_m_offsets, size_t num_problems, size_t expected_m, size_t shape_n,
size_t shape_k, cudaStream_t stream, float const* scales_a, float const* scales_b)
{
constexpr bool internal_quantize_a = !std::is_same_v<ElementA, __nv_fp8_e4m3>;
constexpr bool internal_quantize_b = !std::is_same_v<ElementB, __nv_fp8_e4m3>;
@ -138,21 +138,21 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void*
{
fp8_grouped_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, per_token_per_128c_scales,
reinterpret_cast<__nv_bfloat16 const*>(mat_b), fp8_mat_b, per_block_scales,
reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m_, max_shape_m_4_align_,
reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m, max_shape_m_4_align_,
max_shape_m_32_align_padded_, shape_n, shape_k, stream, internal_quantize_a, internal_quantize_b);
}
else if constexpr (std::is_same_v<ElementA, __nv_bfloat16> && std::is_same_v<ElementB, __nv_fp8_e4m3>)
{
fp8_grouped_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, per_token_per_128c_scales,
nullptr, fp8_mat_b, per_block_scales, reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets,
num_problems, expected_m_, max_shape_m_4_align_, max_shape_m_32_align_padded_, shape_n, shape_k, stream,
num_problems, expected_m, max_shape_m_4_align_, max_shape_m_32_align_padded_, shape_n, shape_k, stream,
internal_quantize_a, internal_quantize_b);
}
else if constexpr (std::is_same_v<ElementA, __nv_fp8_e4m3> && std::is_same_v<ElementB, __nv_fp8_e4m3>)
{
fp8_grouped_gemm_run(nullptr, fp8_mat_a, per_token_per_128c_scales,
reinterpret_cast<__nv_bfloat16 const*>(mat_b), fp8_mat_b, per_block_scales,
reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m_, max_shape_m_4_align_,
reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m, max_shape_m_4_align_,
max_shape_m_32_align_padded_, shape_n, shape_k, stream, internal_quantize_a, internal_quantize_b);
}
else
@ -164,6 +164,15 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void*
#endif
}
template <typename ElementA, typename ElementB, typename ElementD>
void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void* mat_d, void const* mat_a,
void const* mat_b, int64_t const* problem_m_offsets, size_t num_problems, size_t shape_n, size_t shape_k,
cudaStream_t stream, float const* scales_a, float const* scales_b)
{
moeGemm(mat_d, mat_a, mat_b, problem_m_offsets, num_problems, expected_m_, shape_n, shape_k, stream, scales_a,
scales_b);
}
template <typename ElementA, typename ElementB, typename ElementD>
void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::strideBatchGemm(__nv_bfloat16* mat_d, int ld_d,
int stride_d, __nv_fp8_e4m3* mat_a, int ld_a, int stride_a, __nv_fp8_e4m3* mat_b, int ld_b, int stride_b,

View File

@ -40,6 +40,11 @@ public:
cudaStream_t stream)
= 0;
virtual void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
size_t num_problems, size_t expected_m, size_t shape_n, size_t shape_k, cudaStream_t stream,
float const* scales_a = nullptr, float const* scales_b = nullptr)
= 0;
virtual void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
size_t num_problems, size_t shape_n, size_t shape_k, cudaStream_t stream, float const* scales_a = nullptr,
float const* scales_b = nullptr)
@ -95,6 +100,10 @@ public:
int ld_d, int shape_m, int shape_n, int shape_k, float const* scales_a, float const* scales_b,
cudaStream_t stream) override;
void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
size_t num_problems, size_t expected_m, size_t shape_n, size_t shape_k, cudaStream_t stream,
float const* scales_a = nullptr, float const* scales_b = nullptr) override;
void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
size_t num_problems, size_t shape_n, size_t shape_k, cudaStream_t stream, float const* scales_a = nullptr,
float const* scales_b = nullptr) override;

View File

@ -459,8 +459,8 @@ public:
virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights,
void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights,
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts,
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const num_valid_rows,
int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts,
int const experts_per_token, char* workspace_ptr, void* final_output, int* unpermuted_row_to_permuted_row,
MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, LoraParams& lora_params,
bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params,
@ -474,11 +474,11 @@ public:
int64_t const* const num_valid_tokens_ptr, void const* const fc1_int_scales, float const* const fc1_fp8_dequant,
float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params,
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array,
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids)
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids)
= 0;
virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output,
@ -489,10 +489,10 @@ public:
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size,
int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array,
bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
MOEParallelismConfig parallelism_config, bool const enable_alltoall,
int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const unpadded_hidden_size,
int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token,
float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale,
cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids)
= 0;
@ -618,8 +618,8 @@ public:
void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights,
void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights,
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts,
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const num_valid_rows,
int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts,
int const experts_per_token, char* workspace_ptr, void* final_output, int* unpermuted_row_to_permuted_row,
MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, LoraParams& lora_params,
bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params,
@ -639,10 +639,11 @@ public:
ScaleBiasType const* const fc1_int_scales, float const* const fc1_fp8_dequant, float const* const fc2_fp8_quant,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params,
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array,
bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids);
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
int* num_active_experts_per, int* active_expert_global_ids);
static void gemm2(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output,
@ -654,11 +655,12 @@ public:
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size,
int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array,
bool use_lora, void* fc2_lora, cudaStream_t stream, MOEParallelismConfig parallelism_config,
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
int* num_active_experts_per, int* active_expert_global_ids);
int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const unpadded_hidden_size,
int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token,
float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, cudaStream_t stream,
MOEParallelismConfig parallelism_config, bool const enable_alltoall,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids);
// Overrides to allow us to forward on to the internal functions with the pointers using the correct type
void gemm1(void const* const input, void* const output, void* const intermediate_result,
@ -667,20 +669,20 @@ public:
int64_t const* const num_valid_tokens_ptr, void const* const fc1_int_scales, float const* const fc1_fp8_dequant,
float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params,
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array,
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids) override
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids) override
{
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
return Self::gemm1(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input),
static_cast<T*>(output), intermediate_result, expert_first_token_offset, tma_ws_input_template,
static_cast<WeightType const*>(fc1_expert_weights), static_cast<ScaleBiasType const*>(fc1_expert_biases),
num_valid_tokens_ptr, static_cast<ScaleBiasType const*>(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant,
fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size,
num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, bias_is_broadcast, stream, config,
min_latency_mode, num_active_experts_per, active_expert_global_ids);
fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, expected_tokens_per_expert,
hidden_size, inter_size, num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array,
bias_is_broadcast, stream, config, min_latency_mode, num_active_experts_per, active_expert_global_ids);
}
void gemm2(void const* const input, void* const gemm_output, void* const final_output,
@ -691,10 +693,10 @@ public:
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size,
int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array,
bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
MOEParallelismConfig parallelism_config, bool const enable_alltoall,
int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const unpadded_hidden_size,
int64_t const inter_size, int const num_experts_per_node, int64_t const experts_per_token,
float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale,
cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids) override
{
@ -705,9 +707,9 @@ public:
static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row,
permuted_row_to_unpermuted_row, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows,
hidden_size, unpadded_hidden_size, inter_size, num_experts_per_node, experts_per_token,
alpha_scale_ptr_array, use_lora, fc2_lora, stream, parallelism_config, enable_alltoall, config,
min_latency_mode, num_active_experts_per, active_expert_global_ids);
expected_tokens_per_expert, hidden_size, unpadded_hidden_size, inter_size, num_experts_per_node,
experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora, stream, parallelism_config, enable_alltoall,
config, min_latency_mode, num_active_experts_per, active_expert_global_ids);
}
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override
@ -844,8 +846,9 @@ private:
void* const intermediate_result, int64_t const* const expert_first_token_offset,
WeightType const* const fc1_expert_weights, ScaleBiasType const* const fc1_expert_biases,
float const* const fc2_fp8_quant, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
ActivationParams fc1_activation_type, QuantParams& quant_params, cudaStream_t stream);
int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, ActivationParams fc1_activation_type, QuantParams& quant_params,
cudaStream_t stream);
static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output,
OutputType* const final_output, int64_t const* const expert_first_token_offset,
@ -853,9 +856,10 @@ private:
float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size,
int64_t const num_experts_per_node, int64_t const k, MOEParallelismConfig parallelism_config,
bool const enable_alltoall, QuantParams& quant_params, cudaStream_t stream);
int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const unpadded_hidden_size,
int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k,
MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params,
cudaStream_t stream);
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,

View File

@ -2848,9 +2848,9 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena
DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, T* const output, void* const gemm_output,
int64_t const* const expert_first_token_offset, WeightType const* const fc1_expert_weights,
ScaleBiasType const* const fc1_expert_biases, float const* const fc2_fp8_quant, int64_t const num_rows,
int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, ActivationParams fc1_activation_type, QuantParams& quant_params,
cudaStream_t stream)
int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert, int64_t const hidden_size,
int64_t const inter_size, int const num_experts_per_node, ActivationParams fc1_activation_type,
QuantParams& quant_params, cudaStream_t stream)
{
bool const is_gated_activation = isGatedActivation(fc1_activation_type);
@ -2859,7 +2859,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena
// NOTE: we assume gemm_runner.configureWorkspace has already been called.
gemm_runner.moeGemm(gemm_output, input, fc1_expert_weights, expert_first_token_offset, num_experts_per_node,
shape_n, shape_k, stream, nullptr, quant_params.fp8_block_scaling.fc1_scales_ptrs);
expected_tokens_per_expert, shape_n, shape_k, stream, nullptr, quant_params.fp8_block_scaling.fc1_scales_ptrs);
sync_check_cuda_error(stream);
constexpr bool bias_is_broadcast = true;
@ -2879,16 +2879,16 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, ScaleBiasType, Ena
float const* const unpermuted_final_scales, int const* const unpermuted_row_to_permuted_row,
int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size,
int64_t const num_experts_per_node, int64_t const k, MOEParallelismConfig parallelism_config,
bool const enable_alltoall, QuantParams& quant_params, cudaStream_t stream)
int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const unpadded_hidden_size,
int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k,
MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params, cudaStream_t stream)
{
int shape_n = hidden_size;
int shape_k = inter_size;
// NOTE: we assume gemm_runner.configureWorkspace has already been called.
gemm_runner.moeGemm(gemm_output, input, fc2_expert_weights, expert_first_token_offset, num_experts_per_node,
shape_n, shape_k, stream, nullptr, quant_params.fp8_block_scaling.fc2_scales_ptrs);
expected_tokens_per_expert, shape_n, shape_k, stream, nullptr, quant_params.fp8_block_scaling.fc2_scales_ptrs);
sync_check_cuda_error(stream);
@ -2948,18 +2948,20 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
ScaleBiasType const* const fc1_int_scales, float const* const fc1_fp8_dequant, float const* const fc2_fp8_quant,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows,
int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array,
bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
int* num_active_experts_per, int* active_expert_global_ids)
int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert, int64_t const hidden_size,
int64_t const inter_size, int const num_experts_per_node, ActivationParams fc1_activation_type,
float const** alpha_scale_ptr_array, bool bias_is_broadcast, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids)
{
if (fp8_blockscale_gemm_runner)
{
TLLM_CHECK(!min_latency_mode);
Self::BlockScaleFC1(*fp8_blockscale_gemm_runner, input, output, intermediate_result, expert_first_token_offset,
fc1_expert_weights, fc1_expert_biases, fc2_fp8_quant, num_rows, expanded_num_rows, hidden_size, inter_size,
num_experts_per_node, fc1_activation_type, quant_params, stream);
fc1_expert_weights, fc1_expert_biases, fc2_fp8_quant, num_rows, expanded_num_rows,
expected_tokens_per_expert, hidden_size, inter_size, num_experts_per_node, fc1_activation_type,
quant_params, stream);
return;
}
@ -3131,11 +3133,11 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
QuantParams quant_params, float const* const unpermuted_final_scales, float const* const permuted_final_scales,
int const* const unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row,
int const* const token_selected_experts, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows,
int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const unpadded_hidden_size,
int64_t const inter_size, int const num_experts_per_node, int64_t const k, float const** alpha_scale_ptr_array,
bool use_lora, void* fc2_lora, cudaStream_t stream, MOEParallelismConfig parallelism_config,
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
int* num_active_experts_per, int* active_expert_global_ids)
int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert, int64_t const hidden_size,
int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const k,
float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, cudaStream_t stream,
MOEParallelismConfig parallelism_config, bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config,
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids)
{
int64_t const* total_tokens_including_expert = expert_first_token_offset + 1;
@ -3152,8 +3154,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
Self::BlockScaleFC2(*fp8_blockscale_gemm_runner, input, gemm_output, final_output, expert_first_token_offset,
fc2_expert_weights, fc2_expert_biases, unpermuted_final_scales, unpermuted_row_to_permuted_row,
permuted_row_to_unpermuted_row, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows,
hidden_size, unpadded_hidden_size, inter_size, num_experts_per_node, k, parallelism_config, enable_alltoall,
quant_params, stream);
expected_tokens_per_expert, hidden_size, unpadded_hidden_size, inter_size, num_experts_per_node, k,
parallelism_config, enable_alltoall, quant_params, stream);
return;
}
@ -3445,8 +3447,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
void const* input_activations_void, void const* input_sf_void, bool const swizzled_input_sf,
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights_void,
void const* fc1_expert_biases_void, ActivationParams fc1_activation_type, void const* fc2_expert_weights_void,
void const* fc2_expert_biases_void, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
int64_t const unpadded_hidden_size, int64_t const inter_size, int const full_num_experts,
void const* fc2_expert_biases_void, QuantParams quant_params, int64_t const num_rows, int64_t const num_valid_rows,
int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, int const full_num_experts,
int const experts_per_token, char* workspace_ptr, void* final_output_void, int* unpermuted_row_to_permuted_row,
MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, LoraParams& lora_params,
bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params,
@ -3597,6 +3599,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
int64_t const* num_valid_tokens_ptr = needs_num_valid ? expert_first_token_offset_ + num_experts_per_node : nullptr;
auto expanded_num_rows = num_rows * experts_per_token;
auto expected_tokens_per_expert = (num_valid_rows * experts_per_token + full_num_experts - 1) / full_num_experts;
if (min_latency_mode)
{
@ -3622,9 +3625,10 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
fc1_result_, glu_inter_result_, expert_first_token_offset_, gemm1_tma_ws_input, fc1_expert_weights,
fc1_expert_biases, num_valid_tokens_ptr, fc1_int_scales, fc1_fp8_dequant,
use_wfp4afp8 ? fc2_wfp4afp8_quant_scale : fc2_fp8_quant, input_sf /*input fp4 scale or expanded fp4 scale*/,
fc2_fp4_act_scale_, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size,
num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array_fc1_, !use_lora, stream, *gemm1_config_,
true, min_latency_params.num_active_experts_per_node, min_latency_params.active_expert_global_ids);
fc2_fp4_act_scale_, quant_params, num_rows, expanded_num_rows, expected_tokens_per_expert, hidden_size,
inter_size, num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array_fc1_, !use_lora, stream,
*gemm1_config_, true, min_latency_params.num_active_experts_per_node,
min_latency_params.active_expert_global_ids);
sync_check_cuda_error(stream);
auto gemm2_input = applyPrequantScale(smoothed_act_, fc1_result_, quant_params.groupwise.fc2.act_scales,
@ -3633,10 +3637,11 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales,
fc2_fp8_dequant, fc2_fp4_act_scale_, quant_params, token_topk_unpermuted_scales,
permuted_token_final_scales_, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row_,
token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows, hidden_size,
unpadded_hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array_fc2_,
use_lora, lora_fc2_result_, stream, parallelism_config, enable_alltoall, *gemm2_config_, true,
min_latency_params.num_active_experts_per_node, min_latency_params.active_expert_global_ids);
token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows, expected_tokens_per_expert,
hidden_size, unpadded_hidden_size, inter_size, num_experts_per_node, experts_per_token,
alpha_scale_ptr_array_fc2_, use_lora, lora_fc2_result_, stream, parallelism_config, enable_alltoall,
*gemm2_config_, true, min_latency_params.num_active_experts_per_node,
min_latency_params.active_expert_global_ids);
sync_check_cuda_error(stream);
}
else
@ -3722,9 +3727,9 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
Self::gemm1(moe_gemm_runner_, blockscale_gemm_runner, gemm1_input, fc1_result_, glu_inter_result_,
expert_first_token_offset_, gemm1_tma_ws_input, fc1_expert_weights, fc1_expert_biases, num_valid_tokens_ptr,
fc1_int_scales, fc1_fp8_dequant, use_wfp4afp8 ? fc2_wfp4afp8_quant_scale : fc2_fp8_quant,
fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size,
num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array_fc1_, !use_lora, stream, *gemm1_config_,
false, nullptr, nullptr);
fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, num_rows, expanded_num_rows,
expected_tokens_per_expert, hidden_size, inter_size, num_experts_per_node, fc1_activation_type,
alpha_scale_ptr_array_fc1_, !use_lora, stream, *gemm1_config_, false, nullptr, nullptr);
sync_check_cuda_error(stream);
if (use_lora)
@ -3742,10 +3747,10 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
expert_first_token_offset_, gemm2_tma_ws_input, fc2_expert_weights, fc2_expert_biases, fc2_int_scales,
fc2_fp8_dequant, fc2_fp4_act_scale_, quant_params, token_topk_unpermuted_scales,
permuted_token_final_scales_, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row_,
token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows, hidden_size,
unpadded_hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array_fc2_,
use_lora, lora_fc2_result_, stream, parallelism_config, enable_alltoall, *gemm2_config_, false, nullptr,
nullptr);
token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows, expected_tokens_per_expert,
hidden_size, unpadded_hidden_size, inter_size, num_experts_per_node, experts_per_token,
alpha_scale_ptr_array_fc2_, use_lora, lora_fc2_result_, stream, parallelism_config, enable_alltoall,
*gemm2_config_, false, nullptr, nullptr);
sync_check_cuda_error(stream);
}
}
@ -4673,6 +4678,7 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac
fp4_act_scale_flat, //
mQuantParams, //
original_num_tokens, //
original_num_tokens, //
expanded_num_tokens, //
mExpertHiddenSize, //
mExpertInterSize, //
@ -4708,6 +4714,7 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac
token_selected_experts, //
expert_first_token_offset + mNumExpertsPerNode, //
original_num_tokens, //
original_num_tokens, //
expanded_num_tokens, //
mExpertHiddenSize, //
mExpertUnpaddedHiddenSize, //

View File

@ -425,9 +425,9 @@ public:
virtual void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts,
float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases,
ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora,
QuantParams quant_params, int64_t const num_rows, int64_t const num_valid_rows, int64_t const hidden_size,
int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr,
void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora,
LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
= 0;
@ -439,11 +439,11 @@ public:
int64_t const* const num_valid_tokens_ptr, void const* const fc1_int_scales, float const* const fc1_fp8_dequant,
float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params,
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array,
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids, int start_expert)
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert)
= 0;
virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output,
@ -454,11 +454,11 @@ public:
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids, int start_expert)
int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array,
bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
MOEParallelismConfig parallelism_config, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
int* num_active_experts_per, int* active_expert_global_ids, int start_expert)
= 0;
virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@ -573,9 +573,9 @@ public:
void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts,
float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases,
ActivationParams fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora,
QuantParams quant_params, int64_t const num_rows, int64_t const num_valid_rows, int64_t const hidden_size,
int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr,
void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool use_lora,
LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;
@ -593,10 +593,11 @@ public:
ScaleBiasType const* const fc1_int_scales, float const* const fc1_fp8_dequant, float const* const fc2_fp8_quant,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params,
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array,
bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert);
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
int* num_active_experts_per, int* active_expert_global_ids, int start_expert);
static void gemm2(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output,
@ -608,10 +609,11 @@ public:
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
cudaStream_t stream, MOEParallelismConfig parallelism_config, cutlass_extensions::CutlassGemmConfig config,
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert);
int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array,
bool use_lora, void* fc2_lora, cudaStream_t stream, MOEParallelismConfig parallelism_config,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids, int start_expert);
// Overrides to allow us to forward on to the internal functions with the pointers using the correct type
void gemm1(void const* const input, void* const output, void* const intermediate_result,
@ -620,20 +622,21 @@ public:
int64_t const* const num_valid_tokens_ptr, void const* const fc1_int_scales, float const* const fc1_fp8_dequant,
float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params,
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array,
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids, int start_expert) override
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast,
bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert) override
{
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
return Self::gemm1(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input),
static_cast<T*>(output), intermediate_result, expert_first_token_offset, tma_ws_input_template,
static_cast<WeightType const*>(fc1_expert_weights), static_cast<ScaleBiasType const*>(fc1_expert_biases),
num_valid_tokens_ptr, static_cast<ScaleBiasType const*>(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant,
fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size,
num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, bias_is_broadcast, stream, config,
min_latency_mode, num_active_experts_per, active_expert_global_ids, start_expert);
fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, expected_tokens_per_expert,
hidden_size, inter_size, num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array,
bias_is_broadcast, stream, config, min_latency_mode, num_active_experts_per, active_expert_global_ids,
start_expert);
}
void gemm2(void const* const input, void* const gemm_output, void* const final_output,
@ -644,11 +647,11 @@ public:
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* permuted_row_to_unpermuted_row, int const* const expert_for_source_row,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
int* active_expert_global_ids, int start_expert) override
int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array,
bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
MOEParallelismConfig parallelism_config, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
int* num_active_experts_per, int* active_expert_global_ids, int start_expert) override
{
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
return Self::gemm2(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input), gemm_output,
@ -657,9 +660,9 @@ public:
static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row,
permuted_row_to_unpermuted_row, expert_for_source_row, num_valid_tokens_ptr, num_rows, expanded_num_rows,
hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora,
stream, parallelism_config, config, min_latency_mode, num_active_experts_per, active_expert_global_ids,
start_expert);
expected_tokens_per_expert, hidden_size, inter_size, num_experts_per_node, experts_per_token,
alpha_scale_ptr_array, use_lora, fc2_lora, stream, parallelism_config, config, min_latency_mode,
num_active_experts_per, active_expert_global_ids, start_expert);
}
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override
@ -790,16 +793,17 @@ private:
void* const intermediate_result, int64_t const* const expert_first_token_offset,
WeightType const* const fc1_expert_weights, ScaleBiasType const* const fc1_expert_biases,
float const* const fc2_fp8_quant, int64_t const num_rows, int64_t const expanded_num_rows,
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
ActivationParams fc1_activation_type, QuantParams& quant_params, cudaStream_t stream);
int64_t const expected_tokens_per_expert, int64_t const hidden_size, int64_t const inter_size,
int const num_experts_per_node, ActivationParams fc1_activation_type, QuantParams& quant_params,
cudaStream_t stream);
static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output,
OutputType* const final_output, int64_t const* const expert_first_token_offset,
WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases,
float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row,
int const* const expert_for_source_row, int64_t const* const num_valid_tokens_ptr, int64_t const num_rows,
int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const unpadded_hidden_size,
int64_t const inter_size, int const num_experts_per_node, int64_t const k,
int64_t const expanded_num_rows, int64_t const expected_tokens_per_expert, int64_t const hidden_size,
int64_t const unpadded_hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const k,
MOEParallelismConfig parallelism_config, QuantParams& quant_params, cudaStream_t stream);
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,

View File

@ -964,7 +964,7 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
hasFinalScales() ? static_cast<float const*>(inputs[getTokenFinalScalesIndex()]) : nullptr,
inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr,
ActivationParams(mActivationType), inputs[getExpertWeights2Index()],
hasBias() ? inputs[getExpertBias2Index()] : nullptr, quant_params, num_tokens, mExpertHiddenSize,
hasBias() ? inputs[getExpertBias2Index()] : nullptr, quant_params, num_tokens, num_tokens, mExpertHiddenSize,
mExpertHiddenSize /*TRT does not support padding, safe to assume padded/unpadded hidden sizes are the same*/,
mExpertInterSize, mNumExperts, mExpertsPerToken, static_cast<char*>(workspace.workspace),
// Outputs
@ -977,7 +977,7 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
hasFinalScales() ? static_cast<float const*>(inputs[getTokenFinalScalesIndex()]) : nullptr,
inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr,
ActivationParams(mActivationType), inputs[getExpertWeights2Index()],
hasBias() ? inputs[getExpertBias2Index()] : nullptr, quant_params, num_tokens, mExpertHiddenSize,
hasBias() ? inputs[getExpertBias2Index()] : nullptr, quant_params, num_tokens, num_tokens, mExpertHiddenSize,
mExpertInterSize, mNumExperts, mExpertsPerToken, static_cast<char*>(workspace.workspace),
// Outputs
outputs[getOutputTensorIndex()], static_cast<int*>(workspace.src_to_dest_map), mParallelismConfig, hasLora(),

View File

@ -230,6 +230,7 @@ torch::Tensor fp8_block_scaling_moe_gemm_hopper(torch::Tensor const& mat1, torch
auto const num_problems = mat2.sizes()[0];
auto const n = mat2.sizes()[1];
auto const k = mat2.sizes()[2];
auto const expected_m = (m_total + num_problems - 1) / num_problems;
TORCH_CHECK(k % 16 == 0, "K must be a multiple of 16, (K=", k, ")");
TORCH_CHECK(n % 16 == 0, "N must be a multiple of 16, (N=", n, ")");
@ -247,7 +248,8 @@ torch::Tensor fp8_block_scaling_moe_gemm_hopper(torch::Tensor const& mat1, torch
void* workspace_ptr = workspace.data_ptr();
gemm_runner->configureWorkspace(static_cast<char*>(workspace_ptr));
gemm_runner->moeGemm(out.data_ptr(), mat1.data_ptr(), mat2.data_ptr(),
static_cast<int64_t*>(token_offset.data_ptr()), num_problems, n, k, stream, mat1ScalePtr, mat2ScalePtr);
static_cast<int64_t*>(token_offset.data_ptr()), num_problems, expected_m, n, k, stream, mat1ScalePtr,
mat2ScalePtr);
return out;
}

View File

@ -259,7 +259,7 @@ public:
torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank,
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids,
torch::optional<int64_t> const& unpadded_hidden_size)
torch::optional<int64_t> const& unpadded_hidden_size, torch::optional<int64_t> const& num_valid_tokens)
{
std::lock_guard<std::mutex> lock(mMutex);
// Free the profile workspace to save memory
@ -428,10 +428,11 @@ public:
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params,
fc2_expert_weights.const_data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
num_rows, hidden_size, unpadded_hidden_size_val, inter_size, num_experts_total,
static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace.data_ptr()),
output.data_ptr(), static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall,
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
num_rows, num_valid_tokens.has_value() ? num_valid_tokens.value() : num_rows, hidden_size,
unpadded_hidden_size_val, inter_size, num_experts_total, static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
#else
mKernelRunner->runMoe(input.const_data_ptr(),
input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf,
@ -442,7 +443,8 @@ public:
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params,
fc2_expert_weights.const_data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
num_rows, num_valid_tokens.has_value() ? num_valid_tokens.value() : num_rows, hidden_size, inter_size,
num_experts_total, static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
@ -461,7 +463,7 @@ public:
torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank,
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids,
torch::optional<int64_t> const& unpadded_hidden_size)
torch::optional<int64_t> const& unpadded_hidden_size, torch::optional<int64_t> const& num_valid_tokens)
{
std::lock_guard<std::mutex> lock(mMutex);
@ -588,10 +590,11 @@ public:
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params,
fc2_expert_weights.const_data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
num_rows, hidden_size, unpadded_hidden_size_val, inter_size, num_experts_total,
static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace.data_ptr()),
output.data_ptr(), static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall,
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
num_rows, num_valid_tokens.has_value() ? num_valid_tokens.value() : num_rows, hidden_size,
unpadded_hidden_size_val, inter_size, num_experts_total, static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
#else
mKernelRunner->runMoe(input.const_data_ptr(),
input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf,
@ -602,7 +605,8 @@ public:
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_params,
fc2_expert_weights.const_data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
num_rows, num_valid_tokens.has_value() ? num_valid_tokens.value() : num_rows, hidden_size, inter_size,
num_experts_total, static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);

View File

@ -1279,16 +1279,16 @@ protected:
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mMoERunner.runMoe(mInputTensor, nullptr, true, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr,
ActivationParams(mActType, mSwigluAlpha, mSwigluBeta, mSwigluLimit), weight2_ptr, bias2_ptr, quant_params,
mTotalTokens, mHiddenSize, mUnpaddedHiddenSize > 0 ? mUnpaddedHiddenSize : mHiddenSize,
mTotalTokens, mTotalTokens, mHiddenSize, mUnpaddedHiddenSize > 0 ? mUnpaddedHiddenSize : mHiddenSize,
mInterSize / parallelism_config.tp_size, mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap,
parallelism_config, enable_alltoall, mUseLora, lora_params, useFp8BlockScales, minLatencyMode,
min_latency_params, stream);
#else
mMoERunner.runMoe(mInputTensor, nullptr, true, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr,
ActivationParams(mActType, mSwigluAlpha, mSwigluBeta, mSwigluLimit), weight2_ptr, bias2_ptr, quant_params,
mTotalTokens, mHiddenSize, mInterSize / parallelism_config.tp_size, mNumExperts, mK, mWorkspace,
mFinalOutput, mSourceToExpandedMap, parallelism_config, mUseLora, lora_params, useFp8BlockScales,
minLatencyMode, min_latency_params, stream);
mTotalTokens, mTotalTokens, mHiddenSize, mInterSize / parallelism_config.tp_size, mNumExperts, mK,
mWorkspace, mFinalOutput, mSourceToExpandedMap, parallelism_config, mUseLora, lora_params,
useFp8BlockScales, minLatencyMode, min_latency_params, stream);
#endif
check_cuda_error(cudaStreamSynchronize(stream));

View File

@ -240,6 +240,7 @@ def fused_moe(
min_latency_mode,
[gemm_tactic_1, gemm_tactic_2],
unpadded_hidden_size,
tuner_num_tokens,
)
return output if min_latency_mode else [output]

View File

@ -217,6 +217,7 @@ class CutlassMoEOp(MoEOp):
min_latency_mode,
self.gemm_tactics,
unpadded_hidden_size,
tuner_num_tokens,
)
# Return output based on latency mode