mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Cherry-pick moe sort (and all its dependencies) (#6127)
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> Signed-off-by: Daniel Stokes <40156487+djns99@users.noreply.github.com> Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> Co-authored-by: Li Min <11663212+limin2021@users.noreply.github.com> Co-authored-by: Daniel Stokes <40156487+djns99@users.noreply.github.com> Co-authored-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
18c0333e96
commit
ab6fb9f05d
@ -581,7 +581,7 @@ public:
|
|||||||
|
|
||||||
auto func = NVFP4 ? QuantParams::FP4 : QuantParams::FP8MXFP4;
|
auto func = NVFP4 ? QuantParams::FP4 : QuantParams::FP8MXFP4;
|
||||||
mQuantParams = func(mExpertFP4ActScale1, mExpertFP4WeightSf1, mExpertFP4GlobalScale1, mExpertFP4ActScale2,
|
mQuantParams = func(mExpertFP4ActScale1, mExpertFP4WeightSf1, mExpertFP4GlobalScale1, mExpertFP4ActScale2,
|
||||||
mExpertFP4WeightSf2, mExpertFP4GlobalScale2);
|
mExpertFP4WeightSf2, mExpertFP4GlobalScale2, false, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
mSelectedExperts = allocBuffer<int>(mTotalTokens * mK);
|
mSelectedExperts = allocBuffer<int>(mTotalTokens * mK);
|
||||||
|
|||||||
@ -87,32 +87,6 @@ struct LoraParams
|
|||||||
|
|
||||||
namespace cutlass_kernels
|
namespace cutlass_kernels
|
||||||
{
|
{
|
||||||
static inline size_t pad_to_multiple_of_16(size_t const& input)
|
|
||||||
{
|
|
||||||
static constexpr int ALIGNMENT = 16;
|
|
||||||
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
|
|
||||||
}
|
|
||||||
|
|
||||||
class CubKeyValueSorter
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
CubKeyValueSorter();
|
|
||||||
|
|
||||||
CubKeyValueSorter(int const num_experts_per_node);
|
|
||||||
|
|
||||||
void updateNumExperts(int const num_experts_per_node);
|
|
||||||
|
|
||||||
static size_t getWorkspaceSize(size_t const num_key_value_pairs, int const num_experts_per_node);
|
|
||||||
|
|
||||||
void run(void* workspace, size_t const workspace_size, int const* keys_in, int* keys_out, int const* values_in,
|
|
||||||
int* values_out, size_t const num_key_value_pairs, cudaStream_t stream);
|
|
||||||
|
|
||||||
private:
|
|
||||||
static int expertsToBits(int experts);
|
|
||||||
int num_experts_;
|
|
||||||
int num_bits_;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Describes what parallelism mode the MoE is using
|
* \brief Describes what parallelism mode the MoE is using
|
||||||
*
|
*
|
||||||
@ -210,8 +184,9 @@ struct QuantParams
|
|||||||
// FP8 quantization params
|
// FP8 quantization params
|
||||||
struct
|
struct
|
||||||
{
|
{
|
||||||
|
bool fc2_use_per_expert_act_scale = false;
|
||||||
float const* dequant_fc1 = nullptr; // (num_experts_per_node, )
|
float const* dequant_fc1 = nullptr; // (num_experts_per_node, )
|
||||||
float const* quant_fc2 = nullptr; // (1, )
|
float const* quant_fc2 = nullptr; // (1, ) or (num_experts_per_node, ) based on fc2_use_per_expert_act_scale
|
||||||
float const* dequant_fc2 = nullptr; // (num_experts_per_node, )
|
float const* dequant_fc2 = nullptr; // (num_experts_per_node, )
|
||||||
float const* quant_final = nullptr; // (1, )
|
float const* quant_final = nullptr; // (1, )
|
||||||
float const* dequant_input = nullptr; // (1, )
|
float const* dequant_input = nullptr; // (1, )
|
||||||
@ -223,10 +198,12 @@ struct QuantParams
|
|||||||
{
|
{
|
||||||
struct GemmInputs
|
struct GemmInputs
|
||||||
{
|
{
|
||||||
float const* act_global_scale = nullptr; // (1, )
|
bool use_per_expert_act_scale = false;
|
||||||
|
float const* act_global_scale
|
||||||
|
= nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale
|
||||||
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale
|
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale
|
||||||
= nullptr; // (experts, n, k / 32)
|
= nullptr; // (experts, n, k / 32)
|
||||||
float const* global_scale = nullptr; // (num_experts_per_node, )
|
float const* global_scale = nullptr; // (num_experts_per_node, )
|
||||||
};
|
};
|
||||||
|
|
||||||
GemmInputs fc1;
|
GemmInputs fc1;
|
||||||
@ -238,10 +215,13 @@ struct QuantParams
|
|||||||
{
|
{
|
||||||
struct GemmInputs
|
struct GemmInputs
|
||||||
{
|
{
|
||||||
float const* act_global_scale = nullptr; // (1, )
|
bool use_per_expert_act_scale = false;
|
||||||
|
|
||||||
|
float const* act_global_scale
|
||||||
|
= nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale
|
||||||
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* weight_block_scale
|
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* weight_block_scale
|
||||||
= nullptr; // (experts, n, k / 16)
|
= nullptr; // (experts, n, k / 16)
|
||||||
float const* global_scale = nullptr; // (num_experts_per_node, )
|
float const* global_scale = nullptr; // (num_experts_per_node, )
|
||||||
};
|
};
|
||||||
|
|
||||||
GemmInputs fc1;
|
GemmInputs fc1;
|
||||||
@ -287,10 +267,11 @@ struct QuantParams
|
|||||||
}
|
}
|
||||||
|
|
||||||
static QuantParams FP8(float const* dequant_fc1, float const* quant_fc2, float const* dequant_fc2,
|
static QuantParams FP8(float const* dequant_fc1, float const* quant_fc2, float const* dequant_fc2,
|
||||||
float const* quant_final = nullptr, float const* dequant_input = nullptr)
|
float const* quant_final = nullptr, float const* dequant_input = nullptr,
|
||||||
|
bool fc2_use_per_expert_act_scale = false)
|
||||||
{
|
{
|
||||||
QuantParams qp;
|
QuantParams qp;
|
||||||
qp.fp8 = {dequant_fc1, quant_fc2, dequant_fc2, quant_final, dequant_input};
|
qp.fp8 = {fc2_use_per_expert_act_scale, dequant_fc1, quant_fc2, dequant_fc2, quant_final, dequant_input};
|
||||||
return qp;
|
return qp;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -299,12 +280,14 @@ struct QuantParams
|
|||||||
float const* fc1_global_scale, //
|
float const* fc1_global_scale, //
|
||||||
float const* fc2_act_global_scale,
|
float const* fc2_act_global_scale,
|
||||||
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale,
|
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale,
|
||||||
float const* fc2_global_scale //
|
float const* fc2_global_scale, //
|
||||||
)
|
bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false)
|
||||||
{
|
{
|
||||||
QuantParams qp;
|
QuantParams qp;
|
||||||
qp.fp8_mxfp4.fc1 = {fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
|
qp.fp8_mxfp4.fc1
|
||||||
qp.fp8_mxfp4.fc2 = {fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
|
= {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
|
||||||
|
qp.fp8_mxfp4.fc2
|
||||||
|
= {fc2_use_per_expert_act_scale, fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
|
||||||
return qp;
|
return qp;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -313,12 +296,12 @@ struct QuantParams
|
|||||||
float const* fc1_global_scale, //
|
float const* fc1_global_scale, //
|
||||||
float const* fc2_act_global_scale,
|
float const* fc2_act_global_scale,
|
||||||
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc2_weight_block_scale,
|
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc2_weight_block_scale,
|
||||||
float const* fc2_global_scale //
|
float const* fc2_global_scale, //
|
||||||
)
|
bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false)
|
||||||
{
|
{
|
||||||
QuantParams qp;
|
QuantParams qp;
|
||||||
qp.fp4.fc1 = {fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
|
qp.fp4.fc1 = {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
|
||||||
qp.fp4.fc2 = {fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
|
qp.fp4.fc2 = {fc2_use_per_expert_act_scale, fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
|
||||||
return qp;
|
return qp;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -388,9 +371,9 @@ public:
|
|||||||
ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
|
ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
|
||||||
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||||
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
|
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
|
||||||
int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config,
|
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
|
||||||
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
|
bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
|
||||||
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
|
MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
// Aliases for profiling the gemms
|
// Aliases for profiling the gemms
|
||||||
@ -404,7 +387,7 @@ public:
|
|||||||
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
|
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
|
||||||
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
|
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
|
||||||
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
|
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
|
||||||
int* active_expert_global_ids, int start_expert)
|
int* active_expert_global_ids)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output,
|
virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output,
|
||||||
@ -412,14 +395,14 @@ public:
|
|||||||
void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales,
|
void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales,
|
||||||
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
|
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
|
||||||
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
|
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
|
||||||
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
|
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
|
||||||
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
|
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
|
||||||
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
|
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
|
||||||
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
|
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
|
||||||
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
|
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
|
||||||
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
|
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
|
||||||
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
|
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
|
||||||
int* num_active_experts_per, int* active_expert_global_ids, int start_expert)
|
int* num_active_experts_per, int* active_expert_global_ids)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
|
virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
|
||||||
@ -535,9 +518,9 @@ public:
|
|||||||
ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
|
ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
|
||||||
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||||
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
|
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
|
||||||
int* expanded_source_row_to_expanded_dest_row, MOEParallelismConfig parallelism_config,
|
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
|
||||||
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
|
bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
|
||||||
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;
|
MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;
|
||||||
|
|
||||||
// We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work
|
// We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work
|
||||||
static void gemm1(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
|
static void gemm1(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
|
||||||
@ -556,7 +539,7 @@ public:
|
|||||||
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
|
int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||||
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
|
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
|
||||||
bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
|
bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config,
|
||||||
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids, int start_expert);
|
bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids);
|
||||||
|
|
||||||
static void gemm2(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
|
static void gemm2(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,
|
||||||
DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output,
|
DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output,
|
||||||
@ -565,14 +548,14 @@ public:
|
|||||||
ScaleBiasType const* const fc2_expert_biases, ScaleBiasType const* const fc2_int_scales,
|
ScaleBiasType const* const fc2_expert_biases, ScaleBiasType const* const fc2_int_scales,
|
||||||
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
|
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
|
||||||
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
|
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
|
||||||
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
|
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
|
||||||
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
|
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
|
||||||
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
|
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
|
||||||
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
|
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
|
||||||
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
|
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
|
||||||
cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
|
cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
|
||||||
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
|
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
|
||||||
int* active_expert_global_ids, int start_expert);
|
int* active_expert_global_ids);
|
||||||
|
|
||||||
// Overrides to allow us to forward on to the internal functions with the pointers using the correct type
|
// Overrides to allow us to forward on to the internal functions with the pointers using the correct type
|
||||||
void gemm1(void const* const input, void* const output, void* const intermediate_result,
|
void gemm1(void const* const input, void* const output, void* const intermediate_result,
|
||||||
@ -585,7 +568,7 @@ public:
|
|||||||
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
|
int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array,
|
||||||
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
|
bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
|
||||||
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
|
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
|
||||||
int* active_expert_global_ids, int start_expert) override
|
int* active_expert_global_ids) override
|
||||||
{
|
{
|
||||||
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
|
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
|
||||||
return Self::gemm1(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input),
|
return Self::gemm1(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input),
|
||||||
@ -594,7 +577,7 @@ public:
|
|||||||
num_valid_tokens_ptr, static_cast<ScaleBiasType const*>(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant,
|
num_valid_tokens_ptr, static_cast<ScaleBiasType const*>(fc1_int_scales), fc1_fp8_dequant, fc2_fp8_quant,
|
||||||
fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size,
|
fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size,
|
||||||
num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, bias_is_broadcast, stream, config,
|
num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, bias_is_broadcast, stream, config,
|
||||||
min_latency_mode, num_active_experts_per, active_expert_global_ids, start_expert);
|
min_latency_mode, num_active_experts_per, active_expert_global_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gemm2(void const* const input, void* const gemm_output, void* const final_output,
|
void gemm2(void const* const input, void* const gemm_output, void* const final_output,
|
||||||
@ -602,25 +585,25 @@ public:
|
|||||||
void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales,
|
void const* const fc2_expert_weights, void const* const fc2_expert_biases, void const* const fc2_int_scales,
|
||||||
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
|
float const* const fc2_fp8_dequant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat,
|
||||||
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
|
QuantParams quant_params, float const* const token_topk_unpermuted_scales,
|
||||||
float const* const token_topk_permuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
|
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
|
||||||
int const* expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
|
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
|
||||||
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
|
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
|
||||||
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
|
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
|
||||||
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
|
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
|
||||||
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
|
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
|
||||||
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
|
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
|
||||||
int* num_active_experts_per, int* active_expert_global_ids, int start_expert) override
|
int* num_active_experts_per, int* active_expert_global_ids) override
|
||||||
{
|
{
|
||||||
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
|
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
|
||||||
return Self::gemm2(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input), gemm_output,
|
return Self::gemm2(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input), gemm_output,
|
||||||
static_cast<OutputType*>(final_output), expert_first_token_offset, tma_ws_input_template,
|
static_cast<OutputType*>(final_output), expert_first_token_offset, tma_ws_input_template,
|
||||||
static_cast<WeightType const*>(fc2_expert_weights), static_cast<ScaleBiasType const*>(fc2_expert_biases),
|
static_cast<WeightType const*>(fc2_expert_weights), static_cast<ScaleBiasType const*>(fc2_expert_biases),
|
||||||
static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
|
static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
|
||||||
token_topk_unpermuted_scales, token_topk_permuted_scales, expanded_source_row_to_expanded_dest_row,
|
token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row,
|
||||||
expanded_dest_row_to_expanded_source_row, expert_for_source_row, num_valid_tokens_ptr, num_rows,
|
permuted_row_to_unpermuted_row, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows,
|
||||||
expanded_num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array,
|
hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora,
|
||||||
use_lora, fc2_lora, stream, parallelism_config, enable_alltoall, config, min_latency_mode,
|
stream, parallelism_config, enable_alltoall, config, min_latency_mode, num_active_experts_per,
|
||||||
num_active_experts_per, active_expert_global_ids, start_expert);
|
active_expert_global_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override
|
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override
|
||||||
@ -754,10 +737,10 @@ private:
|
|||||||
static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output,
|
static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output,
|
||||||
OutputType* const final_output, int64_t const* const expert_first_token_offset,
|
OutputType* const final_output, int64_t const* const expert_first_token_offset,
|
||||||
WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases,
|
WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases,
|
||||||
float const* const token_topk_unpermuted_scales, int const* const expanded_source_row_to_expanded_dest_row,
|
float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row,
|
||||||
int const* const expanded_dest_row_to_expanded_source_row, int const* const expert_for_source_row,
|
int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts,
|
||||||
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
|
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
|
||||||
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int64_t const k,
|
int64_t const hidden_size, int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k,
|
||||||
MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params,
|
MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params,
|
||||||
cudaStream_t stream);
|
cudaStream_t stream);
|
||||||
|
|
||||||
@ -765,7 +748,6 @@ private:
|
|||||||
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
|
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
|
||||||
cudaStream_t stream);
|
cudaStream_t stream);
|
||||||
|
|
||||||
CubKeyValueSorter sorter_;
|
|
||||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
|
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
|
||||||
std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;
|
std::unique_ptr<DeepSeekBlockScaleGemmRunner> blockscale_gemm_runner_;
|
||||||
|
|
||||||
@ -773,11 +755,11 @@ private:
|
|||||||
std::optional<cutlass_extensions::CutlassGemmConfig> gemm2_config_;
|
std::optional<cutlass_extensions::CutlassGemmConfig> gemm2_config_;
|
||||||
|
|
||||||
// Pointers
|
// Pointers
|
||||||
int* unpermuted_token_selected_experts_{};
|
int* permuted_row_to_unpermuted_row_{};
|
||||||
int* unpermuted_source_token_ids_{};
|
|
||||||
int* permuted_source_token_ids_{};
|
|
||||||
int* permuted_token_selected_experts_{};
|
int* permuted_token_selected_experts_{};
|
||||||
char* sorter_ws_{};
|
int* blocked_expert_counts_{};
|
||||||
|
int* blocked_expert_counts_cumsum_{};
|
||||||
|
int* blocked_row_to_unpermuted_row_{};
|
||||||
T* permuted_data_{};
|
T* permuted_data_{};
|
||||||
float* permuted_token_final_scales_{};
|
float* permuted_token_final_scales_{};
|
||||||
|
|
||||||
@ -850,7 +832,6 @@ public:
|
|||||||
mParallelismConfig = parallelism_config;
|
mParallelismConfig = parallelism_config;
|
||||||
mEnableAlltoall = enable_alltoall;
|
mEnableAlltoall = enable_alltoall;
|
||||||
mSM = common::getSMVersion();
|
mSM = common::getSMVersion();
|
||||||
mSorter.updateNumExperts(mNumExpertsPerNode);
|
|
||||||
|
|
||||||
mScalingType = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE;
|
mScalingType = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE;
|
||||||
if (dtype == nvinfer1::DataType::kFP8
|
if (dtype == nvinfer1::DataType::kFP8
|
||||||
@ -874,7 +855,6 @@ public:
|
|||||||
cudaStream_t const& stream);
|
cudaStream_t const& stream);
|
||||||
|
|
||||||
CutlassMoeFCRunnerInterface* mInterface;
|
CutlassMoeFCRunnerInterface* mInterface;
|
||||||
CubKeyValueSorter mSorter;
|
|
||||||
|
|
||||||
GemmToProfile mGemmToProfile = GemmToProfile::Undefined;
|
GemmToProfile mGemmToProfile = GemmToProfile::Undefined;
|
||||||
std::vector<Config> mAllTacticsSaved;
|
std::vector<Config> mAllTacticsSaved;
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:c01175cbc8e003e8288e30ad2dc88c2c819147f4d435a5121460533141b04719
|
oid sha256:6d12357919fe6c63749a81e124afd60453153489a3f50cb44b41671d9b55f947
|
||||||
size 64321452
|
size 64338696
|
||||||
|
|||||||
@ -1,2 +1,2 @@
|
|||||||
a1180829a0d8fe772ff37934b72573bb41671e7ed76dfa3bd5cd449348b9683a libtensorrt_llm_internal_cutlass_kernels_static.a
|
ad34c0f31247c880d60e2c8198093e8373cf0e1d3e8badee0424bfa607d6cd8e libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||||
commit c767347ff934578193ee4bad58ba3b9398046245
|
commit bac309ac608d35d7d0144e594bf3e5fa8cfca796
|
||||||
|
|||||||
@ -159,8 +159,9 @@ struct QuantParams
|
|||||||
// FP8 quantization params
|
// FP8 quantization params
|
||||||
struct
|
struct
|
||||||
{
|
{
|
||||||
|
bool fc2_use_per_expert_act_scale = false;
|
||||||
float const* dequant_fc1 = nullptr; // (num_experts_per_node, )
|
float const* dequant_fc1 = nullptr; // (num_experts_per_node, )
|
||||||
float const* quant_fc2 = nullptr; // (1, )
|
float const* quant_fc2 = nullptr; // (1, ) or (num_experts_per_node, ) based on fc2_use_per_expert_act_scale
|
||||||
float const* dequant_fc2 = nullptr; // (num_experts_per_node, )
|
float const* dequant_fc2 = nullptr; // (num_experts_per_node, )
|
||||||
float const* quant_final = nullptr; // (1, )
|
float const* quant_final = nullptr; // (1, )
|
||||||
float const* dequant_input = nullptr; // (1, )
|
float const* dequant_input = nullptr; // (1, )
|
||||||
@ -172,10 +173,12 @@ struct QuantParams
|
|||||||
{
|
{
|
||||||
struct GemmInputs
|
struct GemmInputs
|
||||||
{
|
{
|
||||||
float const* act_global_scale = nullptr; // (1, )
|
bool use_per_expert_act_scale = false;
|
||||||
|
float const* act_global_scale
|
||||||
|
= nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale
|
||||||
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale
|
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale
|
||||||
= nullptr; // (experts, n, k / 32)
|
= nullptr; // (experts, n, k / 32)
|
||||||
float const* global_scale = nullptr; // (num_experts_per_node, )
|
float const* global_scale = nullptr; // (num_experts_per_node, )
|
||||||
};
|
};
|
||||||
|
|
||||||
GemmInputs fc1;
|
GemmInputs fc1;
|
||||||
@ -187,10 +190,12 @@ struct QuantParams
|
|||||||
{
|
{
|
||||||
struct GemmInputs
|
struct GemmInputs
|
||||||
{
|
{
|
||||||
float const* act_global_scale = nullptr; // (1, )
|
bool use_per_expert_act_scale = false;
|
||||||
|
float const* act_global_scale
|
||||||
|
= nullptr; // (1, ) or (num_experts_per_node, ) based on use_per_expert_act_scale
|
||||||
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* weight_block_scale
|
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* weight_block_scale
|
||||||
= nullptr; // (experts, n, k / 16)
|
= nullptr; // (experts, n, k / 16)
|
||||||
float const* global_scale = nullptr; // (num_experts_per_node, )
|
float const* global_scale = nullptr; // (num_experts_per_node, )
|
||||||
};
|
};
|
||||||
|
|
||||||
GemmInputs fc1;
|
GemmInputs fc1;
|
||||||
@ -236,10 +241,11 @@ struct QuantParams
|
|||||||
}
|
}
|
||||||
|
|
||||||
static QuantParams FP8(float const* dequant_fc1, float const* quant_fc2, float const* dequant_fc2,
|
static QuantParams FP8(float const* dequant_fc1, float const* quant_fc2, float const* dequant_fc2,
|
||||||
float const* quant_final = nullptr, float const* dequant_input = nullptr)
|
float const* quant_final = nullptr, float const* dequant_input = nullptr,
|
||||||
|
bool fc2_use_per_expert_act_scale = false)
|
||||||
{
|
{
|
||||||
QuantParams qp;
|
QuantParams qp;
|
||||||
qp.fp8 = {dequant_fc1, quant_fc2, dequant_fc2, quant_final, dequant_input};
|
qp.fp8 = {fc2_use_per_expert_act_scale, dequant_fc1, quant_fc2, dequant_fc2, quant_final, dequant_input};
|
||||||
return qp;
|
return qp;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,12 +254,14 @@ struct QuantParams
|
|||||||
float const* fc1_global_scale, //
|
float const* fc1_global_scale, //
|
||||||
float const* fc2_act_global_scale,
|
float const* fc2_act_global_scale,
|
||||||
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale,
|
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale,
|
||||||
float const* fc2_global_scale //
|
float const* fc2_global_scale, //
|
||||||
)
|
bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false)
|
||||||
{
|
{
|
||||||
QuantParams qp;
|
QuantParams qp;
|
||||||
qp.fp8_mxfp4.fc1 = {fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
|
qp.fp8_mxfp4.fc1
|
||||||
qp.fp8_mxfp4.fc2 = {fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
|
= {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
|
||||||
|
qp.fp8_mxfp4.fc2
|
||||||
|
= {fc2_use_per_expert_act_scale, fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
|
||||||
return qp;
|
return qp;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,12 +270,12 @@ struct QuantParams
|
|||||||
float const* fc1_global_scale, //
|
float const* fc1_global_scale, //
|
||||||
float const* fc2_act_global_scale,
|
float const* fc2_act_global_scale,
|
||||||
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc2_weight_block_scale,
|
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc2_weight_block_scale,
|
||||||
float const* fc2_global_scale //
|
float const* fc2_global_scale, //
|
||||||
)
|
bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false)
|
||||||
{
|
{
|
||||||
QuantParams qp;
|
QuantParams qp;
|
||||||
qp.fp4.fc1 = {fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
|
qp.fp4.fc1 = {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
|
||||||
qp.fp4.fc2 = {fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
|
qp.fp4.fc2 = {fc2_use_per_expert_act_scale, fc2_act_global_scale, fc2_weight_block_scale, fc2_global_scale};
|
||||||
return qp;
|
return qp;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -760,8 +768,8 @@ private:
|
|||||||
QuantParams& quant_params, cudaStream_t stream);
|
QuantParams& quant_params, cudaStream_t stream);
|
||||||
|
|
||||||
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
|
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
|
||||||
int const* permuted_token_selected_experts, int64_t const* num_valid_tokens_ptr,
|
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
|
||||||
int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream);
|
cudaStream_t stream);
|
||||||
|
|
||||||
CubKeyValueSorter sorter_;
|
CubKeyValueSorter sorter_;
|
||||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
|
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType> moe_gemm_runner_;
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:a139d8316f640da7c2d10cf5461cb0d0d9462d97f00748467ca4202c896a6187
|
oid sha256:53b6f54a21bd547c0da17e3723b7822d4ee16b66b66a545948c0cbee5760bf65
|
||||||
size 63833516
|
size 63835444
|
||||||
|
|||||||
@ -1,2 +1,2 @@
|
|||||||
e7130e36217c1df0d281788fc87764945d9c308bef11ad61b3b1a49c7d41c8af libtensorrt_llm_internal_cutlass_kernels_static.a
|
21c59ede16aa448b6135327bd0f95e72a6e614f219935b8f67fe635b3cb4b38b libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||||
commit c767347ff934578193ee4bad58ba3b9398046245
|
commit bac309ac608d35d7d0144e594bf3e5fa8cfca796
|
||||||
|
|||||||
949
cpp/tensorrt_llm/kernels/moeUtilOp.cu
Normal file
949
cpp/tensorrt_llm/kernels/moeUtilOp.cu
Normal file
@ -0,0 +1,949 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass_kernels/include/moe_kernels.h"
|
||||||
|
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||||
|
#include "tensorrt_llm/common/envUtils.h"
|
||||||
|
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||||
|
#include "tensorrt_llm/kernels/moeUtilOp.h"
|
||||||
|
#include "tensorrt_llm/kernels/quantization.cuh"
|
||||||
|
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <float.h>
|
||||||
|
|
||||||
|
#include <climits> // For INT_MAX
|
||||||
|
#include <cooperative_groups.h>
|
||||||
|
#include <cooperative_groups/reduce.h>
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
#include <cuda/std/limits> // For numeric_limits
|
||||||
|
#include <math.h>
|
||||||
|
|
||||||
|
#include <cutlass/array.h>
|
||||||
|
#include <cutlass/half.h>
|
||||||
|
#include <cutlass/numeric_types.h>
|
||||||
|
|
||||||
|
#ifndef CUDART_VERSION
|
||||||
|
#error CUDART_VERSION Undefined!
|
||||||
|
#elif (CUDART_VERSION >= 11050)
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
#include <cub/device/device_radix_sort.cuh>
|
||||||
|
#include <cub/util_type.cuh>
|
||||||
|
#include <curand_kernel.h>
|
||||||
|
#include <curand_philox4x32_x.h>
|
||||||
|
#else
|
||||||
|
#include "3rdparty/cub/cub.cuh"
|
||||||
|
#include "3rdparty/cub/device/device_radix_sort.cuh"
|
||||||
|
#include "3rdparty/cub/util_type.cuh"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace cg = cooperative_groups;
|
||||||
|
using namespace tensorrt_llm::common;
|
||||||
|
|
||||||
|
namespace tensorrt_llm::kernels
|
||||||
|
{
|
||||||
|
|
||||||
|
// ========================== CUB Sorting things ====================================
|
||||||
|
CubKeyValueSorter::CubKeyValueSorter()
|
||||||
|
: num_experts_(0)
|
||||||
|
, num_bits_(sizeof(int) * 8)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
int CubKeyValueSorter::expertsToBits(int num_experts)
|
||||||
|
{
|
||||||
|
// Max value we represent is V = num_experts + (num_experts - 1) = 2 * num_experts - 1
|
||||||
|
// The maximum number of bits is therefore floor(log2(V)) + 1
|
||||||
|
return static_cast<int>(log2(2 * num_experts - 1)) + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
CubKeyValueSorter::CubKeyValueSorter(int const num_experts)
|
||||||
|
: num_experts_(num_experts)
|
||||||
|
, num_bits_(expertsToBits(num_experts))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
void CubKeyValueSorter::updateNumExperts(int const num_experts)
|
||||||
|
{
|
||||||
|
num_experts_ = num_experts;
|
||||||
|
num_bits_ = expertsToBits(num_experts);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs, int const num_experts)
|
||||||
|
{
|
||||||
|
int num_bits = expertsToBits(num_experts);
|
||||||
|
size_t required_storage = 0;
|
||||||
|
int* null_int = nullptr;
|
||||||
|
cub::DeviceRadixSort::SortPairs(
|
||||||
|
nullptr, required_storage, null_int, null_int, null_int, null_int, num_key_value_pairs, 0, num_bits);
|
||||||
|
|
||||||
|
// TODO: fix DeviceRadixSort
|
||||||
|
// when num_key_value_pairs, num_experts, num_bits, required_storage = 64, 4, 3, 0
|
||||||
|
// The required_storage seems to vary between 0 and 1 for the same inputs
|
||||||
|
if (required_storage == 0)
|
||||||
|
{
|
||||||
|
required_storage = 1;
|
||||||
|
}
|
||||||
|
return required_storage;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CubKeyValueSorter::run(void* workspace, size_t const workspace_size, int const* keys_in, int* keys_out,
|
||||||
|
int const* values_in, int* values_out, size_t const num_key_value_pairs, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_);
|
||||||
|
size_t actual_ws_size = workspace_size;
|
||||||
|
|
||||||
|
TLLM_CHECK_WITH_INFO(expected_ws_size <= workspace_size,
|
||||||
|
"[CubKeyValueSorter::run] The allocated workspace is too small to run this problem.");
|
||||||
|
cub::DeviceRadixSort::SortPairs(
|
||||||
|
workspace, actual_ws_size, keys_in, keys_out, values_in, values_out, num_key_value_pairs, 0, num_bits_, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: These kernel implementations are duplicated in moe_kernels.cu. They will be refactored later (tracked by
|
||||||
|
// https://jirasw.nvidia.com/browse/TRTLLM-708)
|
||||||
|
template <int BLOCK_SIZE, int EXPERTS_PER_TOKEN, int LOG2_NUM_EXPERTS>
|
||||||
|
__global__ void fusedBuildExpertMapsSortFirstTokenKernel(int const* const token_selected_experts,
|
||||||
|
int* const unpermuted_token_selected_experts, int* const permuted_source_token_ids,
|
||||||
|
int64_t* const expert_first_token_offset, int64_t const num_tokens, int const experts_per_token,
|
||||||
|
int const start_expert, int const end_expert, int const num_experts_per_node)
|
||||||
|
{
|
||||||
|
// Only using block wise collective so we can only have one block
|
||||||
|
assert(gridDim.x == 1);
|
||||||
|
|
||||||
|
assert(start_expert <= end_expert);
|
||||||
|
assert(num_experts_per_node == (end_expert - start_expert));
|
||||||
|
assert(end_expert <= num_experts_per_node);
|
||||||
|
assert(num_experts_per_node <= (1 << LOG2_NUM_EXPERTS));
|
||||||
|
|
||||||
|
int const token = blockIdx.x * BLOCK_SIZE + threadIdx.x;
|
||||||
|
|
||||||
|
bool is_valid_token = token < num_tokens;
|
||||||
|
|
||||||
|
// This is the masked expert id for this token
|
||||||
|
int local_token_selected_experts[EXPERTS_PER_TOKEN];
|
||||||
|
// This is the final permuted rank of this token (ranked by selected expert)
|
||||||
|
int local_token_permuted_indices[EXPERTS_PER_TOKEN];
|
||||||
|
|
||||||
|
// Wait PDL before reading token_selected_experts
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
asm volatile("griddepcontrol.wait;");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// build expert map
|
||||||
|
// we need to populate expert ids for all threads, even if there are
|
||||||
|
// fewer tokens
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < EXPERTS_PER_TOKEN; i++)
|
||||||
|
{
|
||||||
|
int const expert
|
||||||
|
= is_valid_token ? token_selected_experts[token * EXPERTS_PER_TOKEN + i] : num_experts_per_node;
|
||||||
|
|
||||||
|
// If the token is not valid, set the expert id to num_experts_per_node + 1
|
||||||
|
// If expert is not in the current node, set it to num_experts_per_node
|
||||||
|
// If expert is in the current node, subtract start_expert to shift the range to [0, num_experts_per_node)
|
||||||
|
bool is_valid_expert = expert >= start_expert && expert < end_expert;
|
||||||
|
local_token_selected_experts[i] = !is_valid_token ? num_experts_per_node + 1
|
||||||
|
: is_valid_expert ? (expert - start_expert)
|
||||||
|
: num_experts_per_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: decompose cub's sort to expose the bucket starts, and just return
|
||||||
|
// that to elide the binary search
|
||||||
|
|
||||||
|
// sort the expert map
|
||||||
|
using BlockRadixRank = cub::BlockRadixRank<BLOCK_SIZE, LOG2_NUM_EXPERTS, false>;
|
||||||
|
extern __shared__ unsigned char temp_storage[];
|
||||||
|
auto& sort_temp = *reinterpret_cast<typename BlockRadixRank::TempStorage*>(temp_storage);
|
||||||
|
|
||||||
|
// Sanity check that the number of bins do correspond to the number of experts
|
||||||
|
static_assert(BlockRadixRank::BINS_TRACKED_PER_THREAD * BLOCK_SIZE >= (1 << LOG2_NUM_EXPERTS));
|
||||||
|
assert(BlockRadixRank::BINS_TRACKED_PER_THREAD * BLOCK_SIZE >= num_experts_per_node);
|
||||||
|
|
||||||
|
int local_expert_first_token_offset[BlockRadixRank::BINS_TRACKED_PER_THREAD];
|
||||||
|
|
||||||
|
cub::BFEDigitExtractor<int> extractor(0, LOG2_NUM_EXPERTS);
|
||||||
|
BlockRadixRank(sort_temp).RankKeys(
|
||||||
|
local_token_selected_experts, local_token_permuted_indices, extractor, local_expert_first_token_offset);
|
||||||
|
|
||||||
|
// We are done with compute, launch the dependent kernels while the stores are in flight
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
asm volatile("griddepcontrol.launch_dependents;");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// write to shared memory and global memory
|
||||||
|
if (is_valid_token)
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < EXPERTS_PER_TOKEN; i++)
|
||||||
|
{
|
||||||
|
unpermuted_token_selected_experts[token * EXPERTS_PER_TOKEN + i] = local_token_selected_experts[i];
|
||||||
|
permuted_source_token_ids[local_token_permuted_indices[i]] = i * num_tokens + token;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int expert_id = 0; expert_id < BlockRadixRank::BINS_TRACKED_PER_THREAD; expert_id++)
|
||||||
|
{
|
||||||
|
int out_expert_id = expert_id + token * BlockRadixRank::BINS_TRACKED_PER_THREAD;
|
||||||
|
if (out_expert_id < num_experts_per_node + 1)
|
||||||
|
{
|
||||||
|
expert_first_token_offset[out_expert_id] = local_expert_first_token_offset[expert_id];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int BLOCK_SIZE, int EXPERTS_PER_TOKEN, int LOG2_NUM_EXPERTS>
|
||||||
|
bool fusedBuildExpertMapsSortFirstTokenDispatch(int const* token_selected_experts,
|
||||||
|
int* unpermuted_token_selected_experts, int* permuted_source_token_ids, int64_t* expert_first_token_offset,
|
||||||
|
int64_t const num_tokens, int const num_experts_per_node, int const experts_per_token, int const start_expert,
|
||||||
|
int const end_expert, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
TLLM_CHECK_WITH_INFO(num_experts_per_node == (end_expert - start_expert),
|
||||||
|
"num_experts_per_node must be equal to end_expert - start_expert");
|
||||||
|
int const threads = BLOCK_SIZE;
|
||||||
|
int const blocks = (num_tokens + threads - 1) / threads;
|
||||||
|
TLLM_CHECK_WITH_INFO(blocks == 1, "Current implementation requires single block");
|
||||||
|
|
||||||
|
using BlockRadixRank = cub::BlockRadixRank<BLOCK_SIZE, LOG2_NUM_EXPERTS, false>;
|
||||||
|
size_t shared_size = sizeof(typename BlockRadixRank::TempStorage);
|
||||||
|
|
||||||
|
cudaLaunchConfig_t config;
|
||||||
|
config.gridDim = blocks;
|
||||||
|
config.blockDim = threads;
|
||||||
|
config.dynamicSmemBytes = shared_size;
|
||||||
|
config.stream = stream;
|
||||||
|
cudaLaunchAttribute attrs[1];
|
||||||
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||||
|
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||||
|
config.numAttrs = 1;
|
||||||
|
config.attrs = attrs;
|
||||||
|
|
||||||
|
auto kernel = &fusedBuildExpertMapsSortFirstTokenKernel<BLOCK_SIZE, EXPERTS_PER_TOKEN, LOG2_NUM_EXPERTS>;
|
||||||
|
|
||||||
|
int device = 0;
|
||||||
|
int max_smem_per_block = 0;
|
||||||
|
check_cuda_error(cudaGetDevice(&device));
|
||||||
|
check_cuda_error(cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
|
||||||
|
if (shared_size >= static_cast<size_t>(max_smem_per_block))
|
||||||
|
{
|
||||||
|
// This should mean that
|
||||||
|
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)
|
||||||
|
// wouldn't work.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
check_cuda_error(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_size));
|
||||||
|
check_cuda_error(cudaLaunchKernelEx(&config, kernel, token_selected_experts, unpermuted_token_selected_experts,
|
||||||
|
permuted_source_token_ids, expert_first_token_offset, num_tokens, experts_per_token, start_expert, end_expert,
|
||||||
|
num_experts_per_node));
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int EXPERTS_PER_TOKEN, int LOG2_NUM_EXPERTS>
|
||||||
|
bool fusedBuildExpertMapsSortFirstTokenBlockSize(int const* token_selected_experts,
|
||||||
|
int* unpermuted_token_selected_experts, int* permuted_source_token_ids, int64_t* expert_first_token_offset,
|
||||||
|
int64_t const num_tokens, int const num_experts_per_node, int const experts_per_token, int const start_expert,
|
||||||
|
int const end_expert, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
int const block_size = num_tokens;
|
||||||
|
if (num_tokens > 256)
|
||||||
|
{
|
||||||
|
TLLM_LOG_TRACE(
|
||||||
|
"Number of tokens %d is greater than 256, which is not supported for fused moe prologues", num_tokens);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto func = &fusedBuildExpertMapsSortFirstTokenDispatch<32, EXPERTS_PER_TOKEN, LOG2_NUM_EXPERTS>;
|
||||||
|
if (block_size > 32 && block_size <= 64)
|
||||||
|
{
|
||||||
|
func = &fusedBuildExpertMapsSortFirstTokenDispatch<64, EXPERTS_PER_TOKEN, LOG2_NUM_EXPERTS>;
|
||||||
|
}
|
||||||
|
else if (block_size > 64 && block_size <= 128)
|
||||||
|
{
|
||||||
|
func = &fusedBuildExpertMapsSortFirstTokenDispatch<128, EXPERTS_PER_TOKEN, LOG2_NUM_EXPERTS>;
|
||||||
|
}
|
||||||
|
else if (block_size > 128 && block_size <= 256)
|
||||||
|
{
|
||||||
|
func = &fusedBuildExpertMapsSortFirstTokenDispatch<256, EXPERTS_PER_TOKEN, LOG2_NUM_EXPERTS>;
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(token_selected_experts, unpermuted_token_selected_experts, permuted_source_token_ids,
|
||||||
|
expert_first_token_offset, num_tokens, num_experts_per_node, experts_per_token, start_expert, end_expert,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int LOG2_NUM_EXPERTS>
|
||||||
|
bool fusedBuildExpertMapsSortFirstTokenBlockSize(int const* token_selected_experts,
|
||||||
|
int* unpermuted_token_selected_experts, int* permuted_source_token_ids, int64_t* expert_first_token_offset,
|
||||||
|
int64_t const num_tokens, int const num_experts_per_node, int const experts_per_token, int const start_expert,
|
||||||
|
int const end_expert, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
auto func = &fusedBuildExpertMapsSortFirstTokenBlockSize<1, LOG2_NUM_EXPERTS>;
|
||||||
|
switch (experts_per_token)
|
||||||
|
{
|
||||||
|
case 1:
|
||||||
|
{
|
||||||
|
func = &fusedBuildExpertMapsSortFirstTokenBlockSize<1, LOG2_NUM_EXPERTS>;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 2:
|
||||||
|
{
|
||||||
|
func = &fusedBuildExpertMapsSortFirstTokenBlockSize<2, LOG2_NUM_EXPERTS>;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 4:
|
||||||
|
{
|
||||||
|
func = &fusedBuildExpertMapsSortFirstTokenBlockSize<4, LOG2_NUM_EXPERTS>;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 6:
|
||||||
|
{
|
||||||
|
func = &fusedBuildExpertMapsSortFirstTokenBlockSize<6, LOG2_NUM_EXPERTS>;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 8:
|
||||||
|
{
|
||||||
|
func = &fusedBuildExpertMapsSortFirstTokenBlockSize<8, LOG2_NUM_EXPERTS>;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
TLLM_LOG_TRACE("Top-K value %d does not have supported fused moe prologues", experts_per_token);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return func(token_selected_experts, unpermuted_token_selected_experts, permuted_source_token_ids,
|
||||||
|
expert_first_token_offset, num_tokens, num_experts_per_node, experts_per_token, start_expert, end_expert,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool fusedBuildExpertMapsSortFirstToken(int const* token_selected_experts, int* unpermuted_token_selected_experts,
|
||||||
|
int* permuted_source_token_ids, int64_t* expert_first_token_offset, int64_t const num_tokens,
|
||||||
|
int const num_experts_per_node, int const experts_per_token, int const start_expert, int const end_expert,
|
||||||
|
cudaStream_t stream)
|
||||||
|
{
|
||||||
|
// We need enough bits to represent [0, num_experts_per_node+1] (inclusive) i.e. num_experts_per_node + 2 values
|
||||||
|
// This is floor(log2(num_experts_per_node+1)) + 1
|
||||||
|
int expert_log = static_cast<int>(log2(num_experts_per_node + 1)) + 1;
|
||||||
|
if (expert_log <= 9)
|
||||||
|
{
|
||||||
|
auto funcs = std::array{&fusedBuildExpertMapsSortFirstTokenBlockSize<1>,
|
||||||
|
&fusedBuildExpertMapsSortFirstTokenBlockSize<2>, &fusedBuildExpertMapsSortFirstTokenBlockSize<3>,
|
||||||
|
&fusedBuildExpertMapsSortFirstTokenBlockSize<4>, &fusedBuildExpertMapsSortFirstTokenBlockSize<5>,
|
||||||
|
&fusedBuildExpertMapsSortFirstTokenBlockSize<6>, &fusedBuildExpertMapsSortFirstTokenBlockSize<7>,
|
||||||
|
&fusedBuildExpertMapsSortFirstTokenBlockSize<8>, &fusedBuildExpertMapsSortFirstTokenBlockSize<9>};
|
||||||
|
|
||||||
|
return funcs[expert_log - 1](token_selected_experts, unpermuted_token_selected_experts,
|
||||||
|
permuted_source_token_ids, expert_first_token_offset, num_tokens, num_experts_per_node, experts_per_token,
|
||||||
|
start_expert, end_expert, stream);
|
||||||
|
}
|
||||||
|
TLLM_LOG_TRACE("Experts per node %d does not have supported fused moe prologues", num_experts_per_node);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================== Infer GEMM sizes =================================
|
||||||
|
// TODO Could linear search be better for small # experts
|
||||||
|
template <class T>
|
||||||
|
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, int64_t const arr_length, T const target)
|
||||||
|
{
|
||||||
|
int64_t low = 0, high = arr_length - 1, target_location = -1;
|
||||||
|
while (low <= high)
|
||||||
|
{
|
||||||
|
int64_t mid = (low + high) / 2;
|
||||||
|
|
||||||
|
if (sorted_indices[mid] >= target)
|
||||||
|
{
|
||||||
|
high = mid - 1;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
low = mid + 1;
|
||||||
|
target_location = mid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return target_location + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculates the start offset of the tokens for a given expert. The last element is the total number of valid tokens
|
||||||
|
__global__ void computeExpertFirstTokenOffsetKernel(int const* sorted_experts, int64_t const sorted_experts_len,
|
||||||
|
int64_t const num_experts_per_node, int64_t* expert_first_token_offset)
|
||||||
|
{
|
||||||
|
// First, compute the global tid. We only need 1 thread per expert.
|
||||||
|
int const expert = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
// Note that expert goes [0, num_experts] (inclusive) because we want a count for the total number of active tokens
|
||||||
|
// at the end of the scan.
|
||||||
|
if (expert >= num_experts_per_node + 1)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
asm volatile("griddepcontrol.wait;");
|
||||||
|
#endif
|
||||||
|
expert_first_token_offset[expert] = findTotalEltsLessThanTarget(sorted_experts, sorted_experts_len, expert);
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
asm volatile("griddepcontrol.launch_dependents;");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void computeExpertFirstTokenOffset(int const* sorted_indices, int const total_indices, int const num_experts_per_node,
|
||||||
|
int64_t* expert_first_token_offset, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
int const num_entries = num_experts_per_node + 1;
|
||||||
|
int const threads = std::min(1024, num_entries);
|
||||||
|
int const blocks = (num_entries + threads - 1) / threads;
|
||||||
|
|
||||||
|
cudaLaunchConfig_t config;
|
||||||
|
config.gridDim = blocks;
|
||||||
|
config.blockDim = threads;
|
||||||
|
config.dynamicSmemBytes = 0;
|
||||||
|
config.stream = stream;
|
||||||
|
cudaLaunchAttribute attrs[1];
|
||||||
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||||
|
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||||
|
config.numAttrs = 1;
|
||||||
|
config.attrs = attrs;
|
||||||
|
cudaLaunchKernelEx(&config, computeExpertFirstTokenOffsetKernel, sorted_indices, total_indices,
|
||||||
|
num_experts_per_node, expert_first_token_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
using sizeof_bits = cutlass::sizeof_bits<typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t<T>>::type>;
|
||||||
|
|
||||||
|
// Function to safely offset an pointer that may contain sub-byte types (FP4/INT4)
|
||||||
|
template <class T>
|
||||||
|
__host__ __device__ constexpr T* safe_inc_ptr(T* ptr, size_t offset)
|
||||||
|
{
|
||||||
|
constexpr int adjustment = (sizeof_bits<T>::value < 8) ? (8 / sizeof_bits<T>::value) : 1;
|
||||||
|
assert(offset % adjustment == 0 && "Attempt to offset index to sub-byte");
|
||||||
|
return ptr + offset / adjustment;
|
||||||
|
}
|
||||||
|
|
||||||
|
__host__ __device__ constexpr int64_t getOffsetActivationSF(int64_t expert_id, int64_t token_offset, int64_t gemm_k,
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType scaling_type)
|
||||||
|
{
|
||||||
|
auto function = [=](int64_t min_alignment, int64_t block_size)
|
||||||
|
{
|
||||||
|
// This formulation ensures that sf_offset[i + 1] - sf_offset[i] >= token_offset[i + 1] - token_offset[i].
|
||||||
|
int64_t sf_offset = (token_offset + expert_id * (min_alignment - 1)) / min_alignment * min_alignment;
|
||||||
|
assert(gemm_k % block_size == 0);
|
||||||
|
return sf_offset * gemm_k / block_size;
|
||||||
|
};
|
||||||
|
switch (scaling_type)
|
||||||
|
{
|
||||||
|
case cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX:
|
||||||
|
return function(cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::MinNumRowsAlignmentMXFPX,
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize);
|
||||||
|
case cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4:
|
||||||
|
return function(cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::MinNumRowsAlignmentNVFP4,
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize);
|
||||||
|
case cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE:
|
||||||
|
return 0; // No scaling factors, no offset
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(false && "Unrecognized scaling type");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr static int NVFP4_VEC_SIZE = 16;
|
||||||
|
|
||||||
|
template <class GemmOutputType, class ComputeElem>
|
||||||
|
__device__ uint32_t quantizePackedFP4Value(ComputeElem& post_act_val, float global_scale_val,
|
||||||
|
int64_t num_tokens_before_expert, int64_t expert_id, int64_t token_id, int64_t elem_idx, int64_t num_cols,
|
||||||
|
int64_t max_tokens_per_expert, cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat,
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType scaling_type)
|
||||||
|
{
|
||||||
|
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = NVFP4_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD;
|
||||||
|
// Quantize the input to FP4
|
||||||
|
static_assert(std::is_same_v<GemmOutputType, __nv_bfloat16> || std::is_same_v<GemmOutputType, half>);
|
||||||
|
static_assert(ComputeElem::kElements == CVT_FP4_ELTS_PER_THREAD);
|
||||||
|
PackedVec<GemmOutputType> packed_vec{};
|
||||||
|
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++)
|
||||||
|
{
|
||||||
|
packed_vec.elts[i].x = static_cast<GemmOutputType>(post_act_val[i * 2 + 0]);
|
||||||
|
packed_vec.elts[i].y = static_cast<GemmOutputType>(post_act_val[i * 2 + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need to offset into the scaling factors for just this expert
|
||||||
|
auto act_sf_expert
|
||||||
|
= act_sf_flat + getOffsetActivationSF(expert_id, num_tokens_before_expert, num_cols, scaling_type);
|
||||||
|
|
||||||
|
// Use `token - num_tokens_before_expert` because we want this to be relative to the start of this expert
|
||||||
|
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF,
|
||||||
|
CVT_FP4_NUM_THREADS_PER_SF, NVFP4_VEC_SIZE>(std::nullopt /* batchIdx */, token_id - num_tokens_before_expert,
|
||||||
|
elem_idx, std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED);
|
||||||
|
|
||||||
|
// Do the conversion and set the output and scaling factor
|
||||||
|
auto func = (scaling_type == cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4)
|
||||||
|
? &cvt_warp_fp16_to_fp4<GemmOutputType, NVFP4_VEC_SIZE, false>
|
||||||
|
: &cvt_warp_fp16_to_fp4<GemmOutputType, NVFP4_VEC_SIZE, true>;
|
||||||
|
auto res = func(packed_vec, global_scale_val, sf_out);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int64_t source_token_id, int64_t token_id,
|
||||||
|
int64_t elem_idx, int64_t num_cols, int64_t max_tokens_per_expert,
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat,
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf)
|
||||||
|
{
|
||||||
|
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = NVFP4_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD;
|
||||||
|
|
||||||
|
// We need to offset into the scaling factors for just this expert
|
||||||
|
auto act_sf_expert = act_sf_flat
|
||||||
|
+ getOffsetActivationSF(expert_id, num_tokens_before_expert, num_cols,
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4);
|
||||||
|
|
||||||
|
// Use `token - num_tokens_before_expert` because we want this to be relative to the start of this expert
|
||||||
|
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF,
|
||||||
|
CVT_FP4_NUM_THREADS_PER_SF, NVFP4_VEC_SIZE>(std::nullopt /* batchIdx */, token_id - num_tokens_before_expert,
|
||||||
|
elem_idx, std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED);
|
||||||
|
if (sf_out)
|
||||||
|
{
|
||||||
|
auto const sf_in
|
||||||
|
= cvt_quant_to_fp4_get_sf_out_offset<cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF,
|
||||||
|
CVT_FP4_NUM_THREADS_PER_SF, NVFP4_VEC_SIZE>(std::nullopt /* batchIdx */, source_token_id, elem_idx,
|
||||||
|
std::nullopt /* numRows */, num_cols,
|
||||||
|
const_cast<cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
|
||||||
|
FP4QuantizationSFLayout::SWIZZLED);
|
||||||
|
*sf_out = *sf_in;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void generateTokenPermutation(int const* unpermuted_token_selected_experts, int const* unpermuted_source_token_ids,
|
||||||
|
int* permuted_token_selected_experts, int* permuted_source_token_ids, int64_t* expert_first_token_offset,
|
||||||
|
int64_t num_rows, int64_t num_experts_per_node, int64_t k, CubKeyValueSorter& sorter, void* sorter_ws,
|
||||||
|
cudaStream_t stream)
|
||||||
|
{
|
||||||
|
int64_t const expanded_num_rows = k * num_rows;
|
||||||
|
sorter.updateNumExperts(num_experts_per_node);
|
||||||
|
size_t const sorter_ws_size_bytes
|
||||||
|
= pad_to_multiple_of_16(sorter.getWorkspaceSize(expanded_num_rows, num_experts_per_node));
|
||||||
|
sorter.run((void*) sorter_ws, sorter_ws_size_bytes, unpermuted_token_selected_experts,
|
||||||
|
permuted_token_selected_experts, unpermuted_source_token_ids, permuted_source_token_ids, expanded_num_rows,
|
||||||
|
stream);
|
||||||
|
|
||||||
|
sync_check_cuda_error(stream);
|
||||||
|
|
||||||
|
// Upper bound on number of expanded rows
|
||||||
|
computeExpertFirstTokenOffset(
|
||||||
|
permuted_token_selected_experts, expanded_num_rows, num_experts_per_node, expert_first_token_offset, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Takes the input maps and prepares the expanded maps for the sort step
|
||||||
|
* @param unpermuted_token_selected_experts: Buffer of transformed expert ids masked for the current node, used as the
|
||||||
|
* keys for the sort
|
||||||
|
* @param unpermuted_source_token_ids: Buffer of unpermuted token ids that will be used to identify the source row for
|
||||||
|
* each expanded token, used as the values for the sort
|
||||||
|
*/
|
||||||
|
__global__ void buildExpertMapsKernel(int const* token_selected_experts, int* unpermuted_token_selected_experts,
|
||||||
|
int* unpermuted_source_token_ids, int64_t const num_tokens, int const experts_per_token, int const start_expert,
|
||||||
|
int const end_expert, int const num_experts_per_node)
|
||||||
|
{
|
||||||
|
int const token = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (token >= num_tokens)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
asm volatile("griddepcontrol.wait;");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
for (int i = 0; i < experts_per_token; i++)
|
||||||
|
{
|
||||||
|
int const expert = token_selected_experts[token * experts_per_token + i];
|
||||||
|
// If expert is not in the current node, set it to num_experts_per_node
|
||||||
|
// If expert is in the current node, subtract start_expert to shift the range to [0, num_experts_per_node)
|
||||||
|
bool is_valid_expert = expert >= start_expert && expert < end_expert;
|
||||||
|
unpermuted_token_selected_experts[token * experts_per_token + i]
|
||||||
|
= is_valid_expert ? (expert - start_expert) : num_experts_per_node;
|
||||||
|
unpermuted_source_token_ids[token * experts_per_token + i] = i * num_tokens + token;
|
||||||
|
}
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
asm volatile("griddepcontrol.launch_dependents;");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void buildExpertMaps(int const* token_selected_experts, int* unpermuted_token_selected_experts,
|
||||||
|
int* unpermuted_source_token_ids, int64_t const num_tokens, int const num_experts_per_node,
|
||||||
|
int const experts_per_token, int const start_expert, int const end_expert, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
TLLM_CHECK_WITH_INFO(num_experts_per_node == (end_expert - start_expert),
|
||||||
|
"num_experts_per_node must be equal to end_expert - start_expert");
|
||||||
|
int const threads = std::min(int64_t(1024), num_tokens);
|
||||||
|
int const blocks = (num_tokens + threads - 1) / threads;
|
||||||
|
|
||||||
|
cudaLaunchConfig_t config;
|
||||||
|
config.gridDim = blocks;
|
||||||
|
config.blockDim = threads;
|
||||||
|
config.dynamicSmemBytes = 0;
|
||||||
|
config.stream = stream;
|
||||||
|
cudaLaunchAttribute attrs[1];
|
||||||
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||||
|
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||||
|
config.numAttrs = 1;
|
||||||
|
config.attrs = attrs;
|
||||||
|
cudaLaunchKernelEx(&config, buildExpertMapsKernel, token_selected_experts, unpermuted_token_selected_experts,
|
||||||
|
unpermuted_source_token_ids, num_tokens, experts_per_token, start_expert, end_expert, num_experts_per_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========================== Permutation things =======================================
|
||||||
|
template <class T, class U>
|
||||||
|
__host__ __device__ constexpr static U arrayConvert(T const& input)
|
||||||
|
{
|
||||||
|
using Type = typename U::Element;
|
||||||
|
static_assert(T::kElements == U::kElements);
|
||||||
|
U u;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < U::kElements; i++)
|
||||||
|
{
|
||||||
|
u[i] = static_cast<Type>(input[i]);
|
||||||
|
}
|
||||||
|
return u;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing.
|
||||||
|
|
||||||
|
// "expanded_x_row" simply means that the number of values is num_rows x k. It is "expanded" since we will have to
|
||||||
|
// duplicate some rows in the input matrix to match the dimensions. Duplicates will always get routed to separate
|
||||||
|
// experts in the end.
|
||||||
|
|
||||||
|
// Note that the expanded_dest_row_to_expanded_source_row map referred to here has indices in the range (0,
|
||||||
|
// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input
|
||||||
|
// all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we simply take the modulus
|
||||||
|
// of the expanded index.
|
||||||
|
|
||||||
|
constexpr static int EXPAND_THREADS_PER_BLOCK = 256;
|
||||||
|
|
||||||
|
template <class InputActivationsType, class ExpandedActivationsType, bool CHECK_SKIPPED>
|
||||||
|
__global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_input,
|
||||||
|
ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales,
|
||||||
|
int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row,
|
||||||
|
int64_t const num_rows, int64_t const* num_dest_rows, int64_t const cols, int64_t k,
|
||||||
|
float const* fc1_act_global_scale, int64_t* expert_first_token_offset,
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t num_experts_per_node)
|
||||||
|
{
|
||||||
|
#ifdef ENABLE_FP4
|
||||||
|
constexpr bool is_fp4 = std::is_same_v<ExpandedActivationsType, __nv_fp4_e2m1>;
|
||||||
|
constexpr bool is_fp4_input = is_fp4 && std::is_same_v<InputActivationsType, __nv_fp4_e2m1>;
|
||||||
|
constexpr bool need_fp4_quant = is_fp4 && !std::is_same_v<InputActivationsType, __nv_fp4_e2m1>;
|
||||||
|
#else
|
||||||
|
constexpr bool is_fp4 = false;
|
||||||
|
constexpr bool is_fp4_input = false;
|
||||||
|
constexpr bool need_fp4_quant = false;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
static_assert(need_fp4_quant || std::is_same_v<InputActivationsType, ExpandedActivationsType>,
|
||||||
|
"Only FP4 quantization supports outputting a different format as part of the expansion");
|
||||||
|
|
||||||
|
// Reverse permutation map.
|
||||||
|
// I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the
|
||||||
|
// reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
|
||||||
|
// thread block will be responsible for all k summations.
|
||||||
|
int64_t const expanded_dest_row = blockIdx.x;
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
asm volatile("griddepcontrol.wait;");
|
||||||
|
#endif
|
||||||
|
int64_t const expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||||
|
if (threadIdx.x == 0)
|
||||||
|
{
|
||||||
|
assert(expanded_dest_row <= INT32_MAX);
|
||||||
|
expanded_source_row_to_expanded_dest_row[expanded_source_row] = static_cast<int>(expanded_dest_row);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows)
|
||||||
|
{
|
||||||
|
// Load 128-bits per thread
|
||||||
|
constexpr int64_t ELEM_PER_THREAD
|
||||||
|
= is_fp4 ? CVT_FP4_ELTS_PER_THREAD : (128 / sizeof_bits<InputActivationsType>::value);
|
||||||
|
constexpr int64_t ELEM_PER_BYTE = is_fp4_input ? 2 : 1;
|
||||||
|
using DataElem
|
||||||
|
= std::conditional_t<is_fp4_input, uint32_t, cutlass::Array<InputActivationsType, ELEM_PER_THREAD>>;
|
||||||
|
using OutputElem = std::conditional_t<is_fp4, uint32_t, DataElem>;
|
||||||
|
|
||||||
|
// Duplicate and permute rows
|
||||||
|
int64_t const source_k_rank = expanded_source_row / num_rows;
|
||||||
|
int64_t const source_row = expanded_source_row % num_rows;
|
||||||
|
|
||||||
|
auto const* source_row_ptr
|
||||||
|
= reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols / ELEM_PER_BYTE);
|
||||||
|
// Cast first to handle when this is FP4
|
||||||
|
auto* dest_row_ptr
|
||||||
|
= reinterpret_cast<OutputElem*>(permuted_output) + expanded_dest_row * cols / ELEM_PER_THREAD;
|
||||||
|
|
||||||
|
int64_t const start_offset = threadIdx.x;
|
||||||
|
int64_t const stride = EXPAND_THREADS_PER_BLOCK;
|
||||||
|
int64_t const num_elems_in_col = cols / ELEM_PER_THREAD;
|
||||||
|
assert(cols % ELEM_PER_THREAD == 0);
|
||||||
|
|
||||||
|
if constexpr (is_fp4)
|
||||||
|
{
|
||||||
|
int64_t expert = findTotalEltsLessThanTarget(
|
||||||
|
expert_first_token_offset, num_experts_per_node, (int64_t) expanded_dest_row + 1)
|
||||||
|
- 1;
|
||||||
|
float global_scale_val = fc1_act_global_scale ? *fc1_act_global_scale : 1.0f;
|
||||||
|
int64_t num_tokens_before_expert = expert_first_token_offset[expert];
|
||||||
|
|
||||||
|
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
|
||||||
|
{
|
||||||
|
auto in_vec = source_row_ptr[elem_index];
|
||||||
|
if constexpr (need_fp4_quant)
|
||||||
|
{
|
||||||
|
// auto res = quantizePackedFP4Value<InputActivationsType, DataElem>(in_vec, global_scale_val,
|
||||||
|
// num_tokens_before_expert, expert, expanded_dest_row, elem_index, cols, num_rows,
|
||||||
|
// fc1_act_sf_flat);
|
||||||
|
auto res = quantizePackedFP4Value<InputActivationsType, DataElem>(in_vec, global_scale_val,
|
||||||
|
num_tokens_before_expert, expert, expanded_dest_row, elem_index, cols, num_rows,
|
||||||
|
fc1_act_sf_flat,
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4);
|
||||||
|
dest_row_ptr[elem_index] = res;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
writeSF(num_tokens_before_expert, expert, source_row, expanded_dest_row, elem_index, cols, num_rows,
|
||||||
|
fc1_act_sf_flat, input_sf);
|
||||||
|
dest_row_ptr[elem_index] = in_vec;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
|
||||||
|
{
|
||||||
|
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (permuted_scales && threadIdx.x == 0)
|
||||||
|
{
|
||||||
|
int64_t const source_k_idx = source_row * k + source_k_rank;
|
||||||
|
permuted_scales[expanded_dest_row] = unpermuted_scales ? unpermuted_scales[source_k_idx] : 1.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
asm volatile("griddepcontrol.launch_dependents;");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class InputActivationsType, class ExpandedActivationsType>
|
||||||
|
void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
|
||||||
|
ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales,
|
||||||
|
int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row,
|
||||||
|
int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||||
|
int const num_experts_per_node, float const* fc1_act_global_scale, int64_t* expert_first_token_offset,
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
if (fc1_act_sf_flat)
|
||||||
|
{
|
||||||
|
assert(false && "Not supported, we need to keep the same as moe_kerenls.cu in the future (TODO).");
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t const blocks = num_rows * k;
|
||||||
|
int64_t const threads = EXPAND_THREADS_PER_BLOCK;
|
||||||
|
auto func = (num_valid_tokens_ptr != nullptr)
|
||||||
|
? expandInputRowsKernel<InputActivationsType, ExpandedActivationsType, true>
|
||||||
|
: expandInputRowsKernel<InputActivationsType, ExpandedActivationsType, false>;
|
||||||
|
|
||||||
|
cudaLaunchConfig_t config;
|
||||||
|
config.gridDim = blocks;
|
||||||
|
config.blockDim = threads;
|
||||||
|
config.dynamicSmemBytes = 0;
|
||||||
|
config.stream = stream;
|
||||||
|
cudaLaunchAttribute attrs[1];
|
||||||
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||||
|
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||||
|
config.numAttrs = 1;
|
||||||
|
config.attrs = attrs;
|
||||||
|
cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales,
|
||||||
|
expanded_dest_row_to_expanded_source_row, expanded_source_row_to_expanded_dest_row, num_rows,
|
||||||
|
num_valid_tokens_ptr, cols, k, fc1_act_global_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf,
|
||||||
|
num_experts_per_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define INSTANTIATE_EXPAND_INPUT_ROWS(InputActivationsType, ExpandedActivationsType) \
|
||||||
|
template void expandInputRowsKernelLauncher<InputActivationsType, ExpandedActivationsType>( \
|
||||||
|
InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, \
|
||||||
|
float const* unpermuted_scales, float* permuted_scales, int const* expanded_dest_row_to_expanded_source_row, \
|
||||||
|
int* expanded_source_row_to_expanded_dest_row, int64_t const num_rows, int64_t const* num_valid_tokens_ptr, \
|
||||||
|
int64_t const cols, int const k, int const num_experts_per_node, float const* fc1_act_global_scale, \
|
||||||
|
int64_t* expert_first_token_offset, \
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, \
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream);
|
||||||
|
|
||||||
|
INSTANTIATE_EXPAND_INPUT_ROWS(half, half);
|
||||||
|
INSTANTIATE_EXPAND_INPUT_ROWS(float, float);
|
||||||
|
#ifdef ENABLE_BF16
|
||||||
|
INSTANTIATE_EXPAND_INPUT_ROWS(__nv_bfloat16, __nv_bfloat16);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
enum class ScaleMode : int
|
||||||
|
{
|
||||||
|
NO_SCALE = 0,
|
||||||
|
DEFAULT = 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
constexpr static int FINALIZE_THREADS_PER_BLOCK = 256;
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
using sizeof_bits = cutlass::sizeof_bits<typename cutlass_kernels::TllmToCutlassTypeAdapter<std::remove_cv_t<T>>::type>;
|
||||||
|
|
||||||
|
// Final kernel to unpermute and scale
|
||||||
|
// This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection.
|
||||||
|
template <typename OutputType, class GemmOutputType, class ScaleBiasType, ScaleMode SCALE_MODE, bool CHECK_SKIPPED>
|
||||||
|
__global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted_rows,
|
||||||
|
OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* scales,
|
||||||
|
int const* expanded_source_row_to_expanded_dest_row, int const* expert_for_source_row, int64_t const orig_cols,
|
||||||
|
int64_t const experts_per_token, int64_t const* num_valid_ptr)
|
||||||
|
{
|
||||||
|
assert(orig_cols % 4 == 0);
|
||||||
|
int64_t const original_row = blockIdx.x;
|
||||||
|
int64_t const num_rows = gridDim.x;
|
||||||
|
auto const offset = original_row * orig_cols;
|
||||||
|
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
|
||||||
|
|
||||||
|
// Load 128-bits per thread, according to the smallest data type we read/write
|
||||||
|
constexpr int64_t FINALIZE_ELEM_PER_THREAD
|
||||||
|
= 128 / std::min(sizeof_bits<OutputType>::value, sizeof_bits<GemmOutputType>::value);
|
||||||
|
|
||||||
|
int64_t const start_offset = threadIdx.x;
|
||||||
|
int64_t const stride = FINALIZE_THREADS_PER_BLOCK;
|
||||||
|
int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD;
|
||||||
|
|
||||||
|
using BiasElem = cutlass::Array<ScaleBiasType, FINALIZE_ELEM_PER_THREAD>;
|
||||||
|
using InputElem = cutlass::Array<GemmOutputType, FINALIZE_ELEM_PER_THREAD>;
|
||||||
|
using OutputElem = cutlass::Array<OutputType, FINALIZE_ELEM_PER_THREAD>;
|
||||||
|
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
|
||||||
|
auto const* bias_v = reinterpret_cast<BiasElem const*>(bias);
|
||||||
|
auto const* expanded_permuted_rows_v = reinterpret_cast<InputElem const*>(expanded_permuted_rows);
|
||||||
|
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
|
||||||
|
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
asm volatile("griddepcontrol.wait;");
|
||||||
|
#endif
|
||||||
|
int64_t const num_valid = *num_valid_ptr;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride)
|
||||||
|
{
|
||||||
|
bool has_valid = false;
|
||||||
|
ComputeElem thread_output;
|
||||||
|
thread_output.fill(0);
|
||||||
|
for (int k_idx = 0; k_idx < experts_per_token; ++k_idx)
|
||||||
|
{
|
||||||
|
int64_t const expanded_original_row = original_row + k_idx * num_rows;
|
||||||
|
int64_t const expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row];
|
||||||
|
|
||||||
|
int64_t const k_offset = original_row * experts_per_token + k_idx;
|
||||||
|
float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset];
|
||||||
|
|
||||||
|
// Check after row_rescale has accumulated
|
||||||
|
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid)
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const* expanded_permuted_rows_row_ptr
|
||||||
|
= expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
|
||||||
|
|
||||||
|
int64_t const expert_idx = expert_for_source_row[k_offset];
|
||||||
|
|
||||||
|
auto const* bias_ptr = bias_v + expert_idx * num_elems_in_col;
|
||||||
|
ComputeElem bias_value;
|
||||||
|
if (bias)
|
||||||
|
{
|
||||||
|
bias_value = arrayConvert<BiasElem, ComputeElem>(bias_ptr[elem_index]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
bias_value.fill(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputeElem expert_result
|
||||||
|
= arrayConvert<InputElem, ComputeElem>(expanded_permuted_rows_row_ptr[elem_index]);
|
||||||
|
thread_output = thread_output + row_scale * (expert_result + bias_value);
|
||||||
|
has_valid = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputElem output_elem = arrayConvert<ComputeElem, OutputElem>(thread_output);
|
||||||
|
reduced_row_ptr_v[elem_index] = output_elem;
|
||||||
|
}
|
||||||
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
|
asm volatile("griddepcontrol.launch_dependents;");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class OutputType, class GemmOutputType, class ScaleBiasType>
|
||||||
|
void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_rows,
|
||||||
|
OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* final_scales,
|
||||||
|
int const* expanded_source_row_to_expanded_dest_row, int const* expert_for_source_row, int64_t const num_rows,
|
||||||
|
int64_t const cols, int64_t const experts_per_token, int64_t const* num_valid_ptr,
|
||||||
|
cutlass_kernels::MOEParallelismConfig parallelism_config, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
int64_t const blocks = num_rows;
|
||||||
|
int64_t const threads = FINALIZE_THREADS_PER_BLOCK;
|
||||||
|
|
||||||
|
// Only add bias on rank 0 for tensor parallelism
|
||||||
|
bool const is_rank_0 = parallelism_config.tp_rank == 0;
|
||||||
|
ScaleBiasType const* bias_ptr = is_rank_0 ? bias : nullptr;
|
||||||
|
|
||||||
|
bool const check_skipped = num_valid_ptr != nullptr;
|
||||||
|
|
||||||
|
ScaleMode scale_mode = final_scales ? ScaleMode::DEFAULT : ScaleMode::NO_SCALE;
|
||||||
|
|
||||||
|
using FuncPtr
|
||||||
|
= decltype(&finalizeMoeRoutingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::DEFAULT, false>);
|
||||||
|
FuncPtr func_map[2][3] = {
|
||||||
|
{
|
||||||
|
&finalizeMoeRoutingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::NO_SCALE, false>,
|
||||||
|
&finalizeMoeRoutingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::DEFAULT, false>,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
&finalizeMoeRoutingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::NO_SCALE, true>,
|
||||||
|
&finalizeMoeRoutingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::DEFAULT, true>,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
auto* const func = func_map[check_skipped][int(scale_mode)];
|
||||||
|
|
||||||
|
cudaLaunchConfig_t config;
|
||||||
|
config.gridDim = blocks;
|
||||||
|
config.blockDim = threads;
|
||||||
|
config.dynamicSmemBytes = 0;
|
||||||
|
config.stream = stream;
|
||||||
|
cudaLaunchAttribute attrs[1];
|
||||||
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||||
|
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||||
|
config.numAttrs = 1;
|
||||||
|
config.attrs = attrs;
|
||||||
|
cudaLaunchKernelEx(&config, func, expanded_permuted_rows, reduced_unpermuted_output, bias_ptr, final_scales,
|
||||||
|
expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, experts_per_token, num_valid_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define INSTANTIATE_FINALIZE_MOE_ROUTING(OutputT, GemmOutputT, ScaleBiasT) \
|
||||||
|
template void finalizeMoeRoutingKernelLauncher<OutputT, GemmOutputT, ScaleBiasT>( \
|
||||||
|
GemmOutputT const* expanded_permuted_rows, OutputT* reduced_unpermuted_output, ScaleBiasT const* bias, \
|
||||||
|
float const* final_scales, int const* expanded_source_row_to_expanded_dest_row, \
|
||||||
|
int const* expert_for_source_row, int64_t const num_rows, int64_t const cols, int64_t const experts_per_token, \
|
||||||
|
int64_t const* num_valid_ptr, cutlass_kernels::MOEParallelismConfig parallelism_config, cudaStream_t stream);
|
||||||
|
|
||||||
|
INSTANTIATE_FINALIZE_MOE_ROUTING(half, half, half);
|
||||||
|
INSTANTIATE_FINALIZE_MOE_ROUTING(float, float, float);
|
||||||
|
#ifdef ENABLE_BF16
|
||||||
|
INSTANTIATE_FINALIZE_MOE_ROUTING(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace tensorrt_llm::kernels
|
||||||
82
cpp/tensorrt_llm/kernels/moeUtilOp.h
Normal file
82
cpp/tensorrt_llm/kernels/moeUtilOp.h
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass_kernels/include/moe_kernels.h"
|
||||||
|
#include "tensorrt_llm/common/cudaUtils.h"
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
namespace tensorrt_llm::kernels
|
||||||
|
{
|
||||||
|
static inline size_t pad_to_multiple_of_16(size_t const& input)
|
||||||
|
{
|
||||||
|
static constexpr int ALIGNMENT = 16;
|
||||||
|
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
|
||||||
|
}
|
||||||
|
|
||||||
|
class CubKeyValueSorter
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CubKeyValueSorter();
|
||||||
|
|
||||||
|
CubKeyValueSorter(int const num_experts_per_node);
|
||||||
|
|
||||||
|
void updateNumExperts(int const num_experts_per_node);
|
||||||
|
|
||||||
|
static size_t getWorkspaceSize(size_t const num_key_value_pairs, int const num_experts_per_node);
|
||||||
|
|
||||||
|
void run(void* workspace, size_t const workspace_size, int const* keys_in, int* keys_out, int const* values_in,
|
||||||
|
int* values_out, size_t const num_key_value_pairs, cudaStream_t stream);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static int expertsToBits(int experts);
|
||||||
|
int num_experts_;
|
||||||
|
int num_bits_;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool fusedBuildExpertMapsSortFirstToken(int const* token_selected_experts, int* unpermuted_token_selected_experts,
|
||||||
|
int* permuted_source_token_ids, int64_t* expert_first_token_offset, int64_t const num_tokens,
|
||||||
|
int const num_experts_per_node, int const experts_per_token, int const start_expert, int const end_expert,
|
||||||
|
cudaStream_t stream);
|
||||||
|
|
||||||
|
void buildExpertMaps(int const* token_selected_experts, int* unpermuted_token_selected_experts,
|
||||||
|
int* unpermuted_source_token_ids, int64_t const num_tokens, int const num_experts_per_node,
|
||||||
|
int const experts_per_token, int const start_expert, int const end_expert, cudaStream_t stream);
|
||||||
|
|
||||||
|
void generateTokenPermutation(int const* unpermuted_token_selected_experts, int const* unpermuted_source_token_ids,
|
||||||
|
int* permuted_token_selected_experts, int* permuted_source_token_ids, int64_t* expert_first_token_offset,
|
||||||
|
int64_t num_rows, int64_t num_experts_per_node, int64_t k, CubKeyValueSorter& sorter, void* sorter_ws,
|
||||||
|
cudaStream_t stream);
|
||||||
|
|
||||||
|
template <class InputActivationsType, class ExpandedActivationsType>
|
||||||
|
void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
|
||||||
|
ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales,
|
||||||
|
int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row,
|
||||||
|
int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||||
|
int const num_experts_per_node, float const* fc1_act_global_scale, int64_t* expert_first_token_offset,
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
|
||||||
|
cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream);
|
||||||
|
|
||||||
|
template <class OutputType, class GemmOutputType, class ScaleBiasType>
|
||||||
|
void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_rows,
|
||||||
|
OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* final_scales,
|
||||||
|
int const* expanded_source_row_to_expanded_dest_row, int const* expert_for_source_row, int64_t const num_rows,
|
||||||
|
int64_t const cols, int64_t const experts_per_token, int64_t const* num_valid_ptr,
|
||||||
|
cutlass_kernels::MOEParallelismConfig parallelism_config, cudaStream_t stream);
|
||||||
|
|
||||||
|
} // namespace tensorrt_llm::kernels
|
||||||
@ -275,7 +275,7 @@ __global__ void perTokenQuantization(QuantT* dst, T const* src, int64_t const nu
|
|||||||
// FP4 Quantization
|
// FP4 Quantization
|
||||||
|
|
||||||
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
|
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
|
||||||
// constexpr int CVT_FP4_SF_VEC_SIZE = 16;
|
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
|
||||||
constexpr int CVT_FP4_THREADS_PER_WARP = 32;
|
constexpr int CVT_FP4_THREADS_PER_WARP = 32;
|
||||||
constexpr int CVT_FP8_TO_FP4_ELTS_PER_THREAD = 16;
|
constexpr int CVT_FP8_TO_FP4_ELTS_PER_THREAD = 16;
|
||||||
|
|
||||||
|
|||||||
@ -65,6 +65,7 @@ add_library(
|
|||||||
logitsBitmaskOp.cpp
|
logitsBitmaskOp.cpp
|
||||||
mambaConv1dOp.cpp
|
mambaConv1dOp.cpp
|
||||||
moeOp.cpp
|
moeOp.cpp
|
||||||
|
moeUtilOp.cpp
|
||||||
moeCommOp.cpp
|
moeCommOp.cpp
|
||||||
moeLoadBalanceOp.cpp
|
moeLoadBalanceOp.cpp
|
||||||
fp8BlockScaleMoe.cpp
|
fp8BlockScaleMoe.cpp
|
||||||
|
|||||||
@ -52,7 +52,7 @@ using profiler_backend = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::GemmProfilerBackend
|
|||||||
class FusedMoeRunner : public torch::CustomClassHolder
|
class FusedMoeRunner : public torch::CustomClassHolder
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
template <typename Type, bool NeedQuant = false>
|
template <typename TypeAct, typename TypeWeight, bool NeedQuant = false>
|
||||||
std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> switch_output_type(c10::ScalarType output_type)
|
std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> switch_output_type(c10::ScalarType output_type)
|
||||||
{
|
{
|
||||||
switch (output_type)
|
switch (output_type)
|
||||||
@ -66,21 +66,22 @@ public:
|
|||||||
case c10::ScalarType::Half:
|
case c10::ScalarType::Half:
|
||||||
if constexpr (NeedQuant)
|
if constexpr (NeedQuant)
|
||||||
{
|
{
|
||||||
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, half, half>>();
|
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, half>>();
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, half, Type>>();
|
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, TypeAct>>();
|
||||||
}
|
}
|
||||||
#ifdef ENABLE_BF16
|
#ifdef ENABLE_BF16
|
||||||
case c10::ScalarType::BFloat16:
|
case c10::ScalarType::BFloat16:
|
||||||
if constexpr (NeedQuant)
|
if constexpr (NeedQuant)
|
||||||
{
|
{
|
||||||
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, __nv_bfloat16, __nv_bfloat16>>();
|
return std::make_unique<
|
||||||
|
kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16, __nv_bfloat16>>();
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, __nv_bfloat16, Type>>();
|
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16, TypeAct>>();
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
default:
|
default:
|
||||||
@ -121,10 +122,16 @@ public:
|
|||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
if (isFp8Quant())
|
if (isFp8Quant())
|
||||||
{
|
{
|
||||||
mKernelRunner = switch_output_type<__nv_fp8_e4m3>(mOutputDtype);
|
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3>(mOutputDtype);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#ifdef ENABLE_FP4
|
#ifdef ENABLE_FP4
|
||||||
|
if (isWFp4AFp8Quant())
|
||||||
|
{
|
||||||
|
mInnerDimMultiplier = 16; // 16 FP4 -> 1 LONG
|
||||||
|
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp4_e2m1>(mOutputDtype);
|
||||||
|
}
|
||||||
|
|
||||||
if (isNvfp4Quant())
|
if (isNvfp4Quant())
|
||||||
{
|
{
|
||||||
mInnerDimMultiplier = 16;
|
mInnerDimMultiplier = 16;
|
||||||
@ -134,9 +141,9 @@ public:
|
|||||||
#ifdef ENABLE_BF16
|
#ifdef ENABLE_BF16
|
||||||
case c10::ScalarType::BFloat16:
|
case c10::ScalarType::BFloat16:
|
||||||
#endif
|
#endif
|
||||||
mKernelRunner = switch_output_type<__nv_fp4_e2m1, true>(mOutputDtype);
|
mKernelRunner = switch_output_type<__nv_fp4_e2m1, __nv_fp4_e2m1, true>(mOutputDtype);
|
||||||
break;
|
break;
|
||||||
default: mKernelRunner = switch_output_type<__nv_fp4_e2m1, false>(mOutputDtype);
|
default: mKernelRunner = switch_output_type<__nv_fp4_e2m1, __nv_fp4_e2m1, false>(mOutputDtype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -204,11 +211,13 @@ public:
|
|||||||
void operator=(FusedMoeRunner const&) = delete;
|
void operator=(FusedMoeRunner const&) = delete;
|
||||||
|
|
||||||
torch::Tensor runMoe(torch::Tensor const& input, torch::Tensor const& token_selected_experts,
|
torch::Tensor runMoe(torch::Tensor const& input, torch::Tensor const& token_selected_experts,
|
||||||
torch::optional<torch::Tensor> token_final_scales, torch::Tensor const& fc1_expert_weights,
|
torch::optional<torch::Tensor> const& token_final_scales, torch::Tensor const& fc1_expert_weights,
|
||||||
torch::Tensor const& fc2_expert_weights, torch::optional<c10::ArrayRef<torch::Tensor>> quant_scales,
|
torch::optional<torch::Tensor> const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights,
|
||||||
torch::optional<torch::Tensor> input_sf, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size,
|
torch::optional<torch::Tensor> const& fc2_expert_biases,
|
||||||
int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall,
|
torch::optional<c10::ArrayRef<torch::Tensor>> const& quant_scales,
|
||||||
bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> profile_ids)
|
torch::optional<torch::Tensor> const& input_sf, int64_t const tp_size, int64_t const tp_rank,
|
||||||
|
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
|
||||||
|
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids)
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mMutex);
|
std::lock_guard<std::mutex> lock(mMutex);
|
||||||
// Free the profile workspace to save memory
|
// Free the profile workspace to save memory
|
||||||
@ -230,6 +239,23 @@ public:
|
|||||||
|
|
||||||
TORCH_CHECK(fc1_expert_weights.dim() == 3, "fc1_expert_weights must be 3D.");
|
TORCH_CHECK(fc1_expert_weights.dim() == 3, "fc1_expert_weights must be 3D.");
|
||||||
TORCH_CHECK(fc2_expert_weights.dim() == 3, "fc2_expert_weights must be 3D.");
|
TORCH_CHECK(fc2_expert_weights.dim() == 3, "fc2_expert_weights must be 3D.");
|
||||||
|
|
||||||
|
if (fc1_expert_biases.has_value() || fc2_expert_biases.has_value())
|
||||||
|
{
|
||||||
|
CHECK_INPUT(fc1_expert_biases.value(), mOutputDtype);
|
||||||
|
CHECK_INPUT(fc2_expert_biases.value(), mOutputDtype);
|
||||||
|
TORCH_CHECK(fc1_expert_biases.value().dim() == 2, "fc1_expert_biases must be 2D.");
|
||||||
|
TORCH_CHECK(fc2_expert_biases.value().dim() == 2, "fc2_expert_biases must be 2D.");
|
||||||
|
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc1_expert_biases.value().sizes()[0],
|
||||||
|
"fc1_expert_weights and fc1_expert_biases must have the same number of experts.");
|
||||||
|
TORCH_CHECK(fc2_expert_weights.sizes()[0] == fc2_expert_biases.value().sizes()[0],
|
||||||
|
"fc2_expert_weights and fc2_expert_biases must have the same number of experts.");
|
||||||
|
TORCH_CHECK(fc1_expert_biases.value().sizes()[1] == fc1_expert_weights.sizes()[1],
|
||||||
|
"fc1_expert_biases should match fc1_expert_weights output shape.");
|
||||||
|
TORCH_CHECK(fc2_expert_biases.value().sizes()[1] == fc2_expert_weights.sizes()[1],
|
||||||
|
"fc2_expert_biases should match fc2_expert_weights output shape.");
|
||||||
|
}
|
||||||
|
|
||||||
TORCH_CHECK(input.sizes()[0] == token_selected_experts.sizes()[0],
|
TORCH_CHECK(input.sizes()[0] == token_selected_experts.sizes()[0],
|
||||||
"input and token_selected_experts must have the same num tokens.");
|
"input and token_selected_experts must have the same num tokens.");
|
||||||
if (token_final_scales)
|
if (token_final_scales)
|
||||||
@ -275,8 +301,11 @@ public:
|
|||||||
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
||||||
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
||||||
: nullptr,
|
: nullptr,
|
||||||
fc1_expert_weights.const_data_ptr(), nullptr, activation_type, fc2_expert_weights.const_data_ptr(), nullptr,
|
fc1_expert_weights.const_data_ptr(),
|
||||||
quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type,
|
||||||
|
fc2_expert_weights.const_data_ptr(),
|
||||||
|
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
|
||||||
|
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
||||||
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
|
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
|
||||||
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
|
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
|
||||||
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
|
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
|
||||||
@ -286,8 +315,11 @@ public:
|
|||||||
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
||||||
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
||||||
: nullptr,
|
: nullptr,
|
||||||
fc1_expert_weights.const_data_ptr(), nullptr, activation_type, fc2_expert_weights.const_data_ptr(), nullptr,
|
fc1_expert_weights.const_data_ptr(),
|
||||||
quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type,
|
||||||
|
fc2_expert_weights.const_data_ptr(),
|
||||||
|
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
|
||||||
|
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
||||||
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
|
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
|
||||||
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
|
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
|
||||||
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
|
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
|
||||||
@ -297,12 +329,13 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> runMoeMinLantency(torch::Tensor const& input,
|
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> runMoeMinLantency(torch::Tensor const& input,
|
||||||
torch::Tensor const& token_selected_experts, torch::optional<torch::Tensor> token_final_scales,
|
torch::Tensor const& token_selected_experts, torch::optional<torch::Tensor> const& token_final_scales,
|
||||||
torch::Tensor const& fc1_expert_weights, torch::Tensor const& fc2_expert_weights,
|
torch::Tensor const& fc1_expert_weights, torch::optional<torch::Tensor> const& fc1_expert_biases,
|
||||||
torch::optional<c10::ArrayRef<torch::Tensor>> quant_scales, torch::optional<torch::Tensor> input_sf,
|
torch::Tensor const& fc2_expert_weights, torch::optional<torch::Tensor> const& fc2_expert_biases,
|
||||||
int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank,
|
torch::optional<c10::ArrayRef<torch::Tensor>> const& quant_scales,
|
||||||
int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode,
|
torch::optional<torch::Tensor> const& input_sf, int64_t const tp_size, int64_t const tp_rank,
|
||||||
torch::optional<c10::ArrayRef<int64_t>> profile_ids)
|
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
|
||||||
|
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids)
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mMutex);
|
std::lock_guard<std::mutex> lock(mMutex);
|
||||||
|
|
||||||
@ -323,6 +356,23 @@ public:
|
|||||||
|
|
||||||
TORCH_CHECK(fc1_expert_weights.dim() == 3, "fc1_expert_weights must be 3D.");
|
TORCH_CHECK(fc1_expert_weights.dim() == 3, "fc1_expert_weights must be 3D.");
|
||||||
TORCH_CHECK(fc2_expert_weights.dim() == 3, "fc2_expert_weights must be 3D.");
|
TORCH_CHECK(fc2_expert_weights.dim() == 3, "fc2_expert_weights must be 3D.");
|
||||||
|
|
||||||
|
if (fc1_expert_biases.has_value() || fc2_expert_biases.has_value())
|
||||||
|
{
|
||||||
|
CHECK_INPUT(fc1_expert_biases.value(), mOutputDtype);
|
||||||
|
CHECK_INPUT(fc2_expert_biases.value(), mOutputDtype);
|
||||||
|
TORCH_CHECK(fc1_expert_biases.value().dim() == 2, "fc1_expert_biases must be 2D.");
|
||||||
|
TORCH_CHECK(fc2_expert_biases.value().dim() == 2, "fc2_expert_biases must be 2D.");
|
||||||
|
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc1_expert_biases.value().sizes()[0],
|
||||||
|
"fc1_expert_weights and fc1_expert_biases must have the same number of experts.");
|
||||||
|
TORCH_CHECK(fc2_expert_weights.sizes()[0] == fc2_expert_biases.value().sizes()[0],
|
||||||
|
"fc2_expert_weights and fc2_expert_biases must have the same number of experts.");
|
||||||
|
TORCH_CHECK(fc1_expert_biases.value().sizes()[1] == fc1_expert_weights.sizes()[1],
|
||||||
|
"fc1_expert_biases should match fc1_expert_weights output shape.");
|
||||||
|
TORCH_CHECK(fc2_expert_biases.value().sizes()[1] == fc2_expert_weights.sizes()[1],
|
||||||
|
"fc2_expert_biases should match fc2_expert_weights output shape.");
|
||||||
|
}
|
||||||
|
|
||||||
TORCH_CHECK(input.sizes()[0] == token_selected_experts.sizes()[0],
|
TORCH_CHECK(input.sizes()[0] == token_selected_experts.sizes()[0],
|
||||||
"input and token_selected_experts must have the same num tokens.");
|
"input and token_selected_experts must have the same num tokens.");
|
||||||
if (token_final_scales)
|
if (token_final_scales)
|
||||||
@ -378,8 +428,11 @@ public:
|
|||||||
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
||||||
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
||||||
: nullptr,
|
: nullptr,
|
||||||
fc1_expert_weights.const_data_ptr(), nullptr, activation_type, fc2_expert_weights.const_data_ptr(), nullptr,
|
fc1_expert_weights.const_data_ptr(),
|
||||||
quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type,
|
||||||
|
fc2_expert_weights.const_data_ptr(),
|
||||||
|
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
|
||||||
|
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
||||||
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
|
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
|
||||||
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
|
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
|
||||||
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
|
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
|
||||||
@ -389,8 +442,11 @@ public:
|
|||||||
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
||||||
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
||||||
: nullptr,
|
: nullptr,
|
||||||
fc1_expert_weights.const_data_ptr(), nullptr, activation_type, fc2_expert_weights.const_data_ptr(), nullptr,
|
fc1_expert_weights.const_data_ptr(),
|
||||||
quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type,
|
||||||
|
fc2_expert_weights.const_data_ptr(),
|
||||||
|
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
|
||||||
|
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
||||||
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
|
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
|
||||||
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
|
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
|
||||||
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
|
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
|
||||||
@ -406,10 +462,11 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void runGemmProfile(torch::Tensor const& input, torch::Tensor const& fc1_expert_weights,
|
void runGemmProfile(torch::Tensor const& input, torch::Tensor const& fc1_expert_weights,
|
||||||
torch::Tensor const& fc2_expert_weights, int64_t const top_k, int64_t const tp_size, int64_t const tp_rank,
|
torch::optional<torch::Tensor> const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights,
|
||||||
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
|
torch::optional<torch::Tensor> const& fc2_expert_biases, int64_t const top_k, int64_t const tp_size,
|
||||||
bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx, int64_t const profile_id,
|
int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
|
||||||
bool const do_preparation)
|
int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx,
|
||||||
|
int64_t const profile_id, bool const do_preparation)
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mMutex);
|
std::lock_guard<std::mutex> lock(mMutex);
|
||||||
|
|
||||||
@ -447,7 +504,7 @@ public:
|
|||||||
static_cast<int>(tp_rank), static_cast<int>(ep_size), static_cast<int>(ep_rank),
|
static_cast<int>(tp_rank), static_cast<int>(ep_size), static_cast<int>(ep_rank),
|
||||||
static_cast<int>(cluster_size), static_cast<int>(cluster_rank));
|
static_cast<int>(cluster_size), static_cast<int>(cluster_rank));
|
||||||
|
|
||||||
bool const USE_BIAS = false;
|
bool const USE_BIAS = fc1_expert_biases.has_value() || fc2_expert_biases.has_value();
|
||||||
bool const USE_LORA = false;
|
bool const USE_LORA = false;
|
||||||
auto activation_dtype = mUseW4A8GroupScaling ? at::ScalarType::Float8_e4m3fn : mActivationDtype;
|
auto activation_dtype = mUseW4A8GroupScaling ? at::ScalarType::Float8_e4m3fn : mActivationDtype;
|
||||||
activation_dtype = isNvfp4Quant() ? at::ScalarType::Long : activation_dtype;
|
activation_dtype = isNvfp4Quant() ? at::ScalarType::Long : activation_dtype;
|
||||||
@ -575,22 +632,80 @@ private:
|
|||||||
auto const fc2_dequant = quant_scales.value()[2];
|
auto const fc2_dequant = quant_scales.value()[2];
|
||||||
auto const fc1_input_dequant = quant_scales.value()[3];
|
auto const fc1_input_dequant = quant_scales.value()[3];
|
||||||
|
|
||||||
|
// Check types
|
||||||
CHECK_INPUT(fc1_dequant, c10::ScalarType::Float);
|
CHECK_INPUT(fc1_dequant, c10::ScalarType::Float);
|
||||||
CHECK_INPUT(fc2_quant, c10::ScalarType::Float);
|
CHECK_INPUT(fc2_quant, c10::ScalarType::Float);
|
||||||
CHECK_INPUT(fc2_dequant, c10::ScalarType::Float);
|
CHECK_INPUT(fc2_dequant, c10::ScalarType::Float);
|
||||||
CHECK_INPUT(fc1_input_dequant, c10::ScalarType::Float);
|
CHECK_INPUT(fc1_input_dequant, c10::ScalarType::Float);
|
||||||
|
// Check ranks
|
||||||
TORCH_CHECK(fc1_dequant.dim() == 1, "fc1 dequant must be 1D");
|
TORCH_CHECK(fc1_dequant.dim() == 1, "fc1 dequant must be 1D");
|
||||||
TORCH_CHECK(fc2_quant.dim() == 0, "fc2 quant must be a scalar tensor");
|
TORCH_CHECK(fc2_quant.dim() == 0 || fc2_quant.dim() == 1, "fc2 quant must be a scalar or 1-D tensor");
|
||||||
TORCH_CHECK(fc2_dequant.dim() == 1, "fc2 quant must be 1D");
|
TORCH_CHECK(fc2_dequant.dim() == 1, "fc2 quant must be 1D");
|
||||||
TORCH_CHECK(fc1_input_dequant.dim() == 0, "fc1 input dequant must be a scalar tensor");
|
TORCH_CHECK(fc1_input_dequant.dim() == 0, "fc1 input dequant must be a scalar tensor");
|
||||||
|
// Check shapes
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
fc1_dequant.sizes()[0] == num_experts_on_rank, "fc1 dequant size must be (num_experts_on_rank,)");
|
fc1_dequant.sizes()[0] == num_experts_on_rank, "fc1 dequant size must be (num_experts_on_rank,)");
|
||||||
|
TORCH_CHECK(fc2_quant.dim() == 0 || fc2_quant.sizes()[0] == num_experts_on_rank,
|
||||||
|
"fc2 quant must be scalar or (num_experts_on_rank,)");
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
fc2_dequant.sizes()[0] == num_experts_on_rank, "fc2 dequant size must be (num_experts_on_rank,)");
|
fc2_dequant.sizes()[0] == num_experts_on_rank, "fc2 dequant size must be (num_experts_on_rank,)");
|
||||||
|
|
||||||
return kernels::QuantParams::FP8(static_cast<float const*>(fc1_dequant.data_ptr()),
|
return kernels::QuantParams::FP8(static_cast<float const*>(fc1_dequant.data_ptr()),
|
||||||
static_cast<float const*>(fc2_quant.data_ptr()), static_cast<float const*>(fc2_dequant.data_ptr()),
|
static_cast<float const*>(fc2_quant.data_ptr()), static_cast<float const*>(fc2_dequant.data_ptr()),
|
||||||
/* fp8 output quant scale */ nullptr, static_cast<float const*>(fc1_input_dequant.data_ptr()));
|
/* fp8 output quant scale */ nullptr, static_cast<float const*>(fc1_input_dequant.data_ptr()),
|
||||||
|
fc2_quant.dim() == 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
else if (isWFp4AFp8Quant())
|
||||||
|
{
|
||||||
|
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for WFP4AFP8 quantization");
|
||||||
|
TORCH_CHECK(quant_scales.value().size() == 5, "Expecting 5 quant scales for WFP4AFP8 quantization");
|
||||||
|
|
||||||
|
auto const fc1_weight_block = quant_scales.value()[0];
|
||||||
|
auto const fc1_global = quant_scales.value()[1];
|
||||||
|
auto const fc2_act_global = quant_scales.value()[2];
|
||||||
|
auto const fc2_weight_block = quant_scales.value()[3];
|
||||||
|
auto const fc2_global = quant_scales.value()[4];
|
||||||
|
|
||||||
|
// The input for scale fc1_weight_block / fc2_weight_block is packed into INT32
|
||||||
|
constexpr int FP8_PER_INT32 = 4;
|
||||||
|
// Check types
|
||||||
|
CHECK_INPUT(fc1_weight_block, c10::ScalarType::Int);
|
||||||
|
CHECK_INPUT(fc1_global, c10::ScalarType::Float);
|
||||||
|
CHECK_INPUT(fc2_act_global, c10::ScalarType::Float);
|
||||||
|
CHECK_INPUT(fc2_weight_block, c10::ScalarType::Int);
|
||||||
|
CHECK_INPUT(fc2_global, c10::ScalarType::Float);
|
||||||
|
// Check ranks
|
||||||
|
TORCH_CHECK(fc1_weight_block.dim() == 3, "fc1 weight block must be #D");
|
||||||
|
TORCH_CHECK(fc1_global.dim() == 1, "fc1 global must be 1D");
|
||||||
|
TORCH_CHECK(fc2_act_global.dim() == 0 || fc2_act_global.dim() == 1,
|
||||||
|
"fc2 act global must be a scalar or 1-D tensor");
|
||||||
|
TORCH_CHECK(fc2_weight_block.dim() == 3, "fc2 weight block must be 3D");
|
||||||
|
TORCH_CHECK(fc2_global.dim() == 1, "fc2 global must be 1D");
|
||||||
|
// Check shapes
|
||||||
|
TORCH_CHECK(fc1_weight_block.sizes()[0] == num_experts_on_rank
|
||||||
|
&& fc1_weight_block.sizes()[1] == inter_size * 2
|
||||||
|
&& fc1_weight_block.sizes()[2] * FP8_PER_INT32
|
||||||
|
* TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize
|
||||||
|
== hidden_size,
|
||||||
|
"fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // "
|
||||||
|
"block_scale_vector_size)");
|
||||||
|
TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)");
|
||||||
|
TORCH_CHECK(fc2_act_global.dim() == 0 || fc2_act_global.sizes()[0] == num_experts_on_rank,
|
||||||
|
"fc2 act global must be scalar or (num_experts_on_rank,)");
|
||||||
|
TORCH_CHECK(fc2_weight_block.sizes()[0] == num_experts_on_rank && fc2_weight_block.sizes()[1] == hidden_size
|
||||||
|
&& fc2_weight_block.sizes()[2] * FP8_PER_INT32
|
||||||
|
* TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize
|
||||||
|
== inter_size,
|
||||||
|
"fc2 weight block size must be (num_experts_on_rank, hidden_size, inter_size // 4 // "
|
||||||
|
"block_scale_vector_size)");
|
||||||
|
TORCH_CHECK(fc2_global.sizes()[0] == num_experts_on_rank, "fc2 global size must be (num_experts_on_rank,)");
|
||||||
|
|
||||||
|
return kernels::QuantParams::FP8MXFP4(nullptr,
|
||||||
|
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc1_weight_block.data_ptr()),
|
||||||
|
static_cast<float const*>(fc1_global.data_ptr()), static_cast<float const*>(fc2_act_global.data_ptr()),
|
||||||
|
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc2_weight_block.data_ptr()),
|
||||||
|
static_cast<float const*>(fc2_global.data_ptr()), false, fc2_act_global.dim() == 1);
|
||||||
}
|
}
|
||||||
else if (isNvfp4Quant())
|
else if (isNvfp4Quant())
|
||||||
{
|
{
|
||||||
@ -606,18 +721,25 @@ private:
|
|||||||
|
|
||||||
// The input for scale fc1_weight_block / fc2_weight_block is packed into INT32
|
// The input for scale fc1_weight_block / fc2_weight_block is packed into INT32
|
||||||
constexpr int FP8_PER_INT32 = 4;
|
constexpr int FP8_PER_INT32 = 4;
|
||||||
|
// Check types
|
||||||
CHECK_INPUT(fc1_act_global, c10::ScalarType::Float);
|
CHECK_INPUT(fc1_act_global, c10::ScalarType::Float);
|
||||||
CHECK_INPUT(fc1_weight_block, c10::ScalarType::Int);
|
CHECK_INPUT(fc1_weight_block, c10::ScalarType::Int);
|
||||||
CHECK_INPUT(fc1_global, c10::ScalarType::Float);
|
CHECK_INPUT(fc1_global, c10::ScalarType::Float);
|
||||||
CHECK_INPUT(fc2_act_global, c10::ScalarType::Float);
|
CHECK_INPUT(fc2_act_global, c10::ScalarType::Float);
|
||||||
CHECK_INPUT(fc2_weight_block, c10::ScalarType::Int);
|
CHECK_INPUT(fc2_weight_block, c10::ScalarType::Int);
|
||||||
CHECK_INPUT(fc2_global, c10::ScalarType::Float);
|
CHECK_INPUT(fc2_global, c10::ScalarType::Float);
|
||||||
TORCH_CHECK(fc1_act_global.dim() == 0, "fc1 act global must be a scalar tensor");
|
// Check ranks
|
||||||
|
TORCH_CHECK(fc1_act_global.dim() == 0 || fc1_act_global.dim() == 1,
|
||||||
|
"fc1 act global must be a scalar or 1-D tensor");
|
||||||
TORCH_CHECK(fc1_weight_block.dim() == 3, "fc1 weight block must be #D");
|
TORCH_CHECK(fc1_weight_block.dim() == 3, "fc1 weight block must be #D");
|
||||||
TORCH_CHECK(fc1_global.dim() == 1, "fc1 global must be 1D");
|
TORCH_CHECK(fc1_global.dim() == 1, "fc1 global must be 1D");
|
||||||
TORCH_CHECK(fc2_act_global.dim() == 0, "fc2 act global must be a scalar tensor");
|
TORCH_CHECK(fc2_act_global.dim() == 0 || fc2_act_global.dim() == 1,
|
||||||
|
"fc2 act global must be a scalar or 1-D tensor");
|
||||||
TORCH_CHECK(fc2_weight_block.dim() == 3, "fc2 weight block must be 3D");
|
TORCH_CHECK(fc2_weight_block.dim() == 3, "fc2 weight block must be 3D");
|
||||||
TORCH_CHECK(fc2_global.dim() == 1, "fc2 global must be 1D");
|
TORCH_CHECK(fc2_global.dim() == 1, "fc2 global must be 1D");
|
||||||
|
// Check shapes
|
||||||
|
TORCH_CHECK(fc1_act_global.dim() == 0 || fc1_act_global.sizes()[0] == num_experts_on_rank,
|
||||||
|
"fc1 act global must be scalar or (num_experts_on_rank,)");
|
||||||
TORCH_CHECK(fc1_weight_block.sizes()[0] == num_experts_on_rank
|
TORCH_CHECK(fc1_weight_block.sizes()[0] == num_experts_on_rank
|
||||||
&& fc1_weight_block.sizes()[1] == inter_size * 2
|
&& fc1_weight_block.sizes()[1] == inter_size * 2
|
||||||
&& fc1_weight_block.sizes()[2] * FP8_PER_INT32
|
&& fc1_weight_block.sizes()[2] * FP8_PER_INT32
|
||||||
@ -626,6 +748,8 @@ private:
|
|||||||
"fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // "
|
"fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // "
|
||||||
"block_scale_vector_size)");
|
"block_scale_vector_size)");
|
||||||
TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)");
|
TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)");
|
||||||
|
TORCH_CHECK(fc2_act_global.dim() == 0 || fc2_act_global.sizes()[0] == num_experts_on_rank,
|
||||||
|
"fc2 act global must be scalar or (num_experts_on_rank,)");
|
||||||
TORCH_CHECK(fc2_weight_block.sizes()[0] == num_experts_on_rank && fc2_weight_block.sizes()[1] == hidden_size
|
TORCH_CHECK(fc2_weight_block.sizes()[0] == num_experts_on_rank && fc2_weight_block.sizes()[1] == hidden_size
|
||||||
&& fc2_weight_block.sizes()[2] * FP8_PER_INT32
|
&& fc2_weight_block.sizes()[2] * FP8_PER_INT32
|
||||||
* TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize
|
* TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize
|
||||||
@ -638,7 +762,7 @@ private:
|
|||||||
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc1_weight_block.data_ptr()),
|
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc1_weight_block.data_ptr()),
|
||||||
static_cast<float const*>(fc1_global.data_ptr()), static_cast<float const*>(fc2_act_global.data_ptr()),
|
static_cast<float const*>(fc1_global.data_ptr()), static_cast<float const*>(fc2_act_global.data_ptr()),
|
||||||
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc2_weight_block.data_ptr()),
|
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc2_weight_block.data_ptr()),
|
||||||
static_cast<float const*>(fc2_global.data_ptr()));
|
static_cast<float const*>(fc2_global.data_ptr()), fc1_act_global.dim() == 1, fc2_act_global.dim() == 1);
|
||||||
}
|
}
|
||||||
else if (mUseDeepSeekFP8BlockScaling)
|
else if (mUseDeepSeekFP8BlockScaling)
|
||||||
{
|
{
|
||||||
@ -683,7 +807,8 @@ private:
|
|||||||
|
|
||||||
bool isNvfp4Quant() const
|
bool isNvfp4Quant() const
|
||||||
{
|
{
|
||||||
return mWeightDtype == c10::ScalarType::Long;
|
return mWeightDtype == c10::ScalarType::Long
|
||||||
|
&& mActivationDtype != c10::ScalarType::Float8_e4m3fn; // FP8 activation does not use FP4
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isInt4Quant() const
|
bool isInt4Quant() const
|
||||||
@ -695,6 +820,11 @@ private:
|
|||||||
{
|
{
|
||||||
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant();
|
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isWFp4AFp8Quant() const
|
||||||
|
{
|
||||||
|
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && mWeightDtype == c10::ScalarType::Long;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace torch_ext
|
} // namespace torch_ext
|
||||||
|
|||||||
449
cpp/tensorrt_llm/thop/moeUtilOp.cpp
Normal file
449
cpp/tensorrt_llm/thop/moeUtilOp.cpp
Normal file
@ -0,0 +1,449 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "tensorrt_llm/kernels/moeUtilOp.h"
|
||||||
|
#include "moe_gemm_kernels.h"
|
||||||
|
#include "tensorrt_llm/common/workspace.h"
|
||||||
|
#include "tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h"
|
||||||
|
#include "tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h"
|
||||||
|
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||||
|
#include "tensorrt_llm/thop/thUtils.h"
|
||||||
|
|
||||||
|
#include <ATen/native/cuda/Resize.h>
|
||||||
|
|
||||||
|
namespace th = torch;
|
||||||
|
namespace tl = tensorrt_llm;
|
||||||
|
namespace tk = tensorrt_llm::kernels;
|
||||||
|
|
||||||
|
namespace common = tensorrt_llm::common;
|
||||||
|
namespace kernels = tensorrt_llm::kernels;
|
||||||
|
namespace cutlass_kernels = tensorrt_llm::kernels::cutlass_kernels;
|
||||||
|
|
||||||
|
namespace torch_ext
|
||||||
|
{
|
||||||
|
|
||||||
|
// input_activations: [num_tokens, hidden_size]
|
||||||
|
// input: token_topk_unpermuted_scales, [num_tokens, k]
|
||||||
|
// output: permuted_data_, [num_token * k, hidden_size]
|
||||||
|
// output: permuted_token_final_scales_, [num_tokens, k]
|
||||||
|
template <typename T>
|
||||||
|
void runPermute(void const* input_activations_void, void const* input_sf_void, int const* token_selected_experts,
|
||||||
|
float const* token_final_scales, void const* fc1_expert_weights_void, void const* fc1_expert_biases_void,
|
||||||
|
tensorrt_llm::ActivationType fc1_activation_type, void const* fc2_expert_weights_void,
|
||||||
|
void const* fc2_expert_biases_void, cutlass_kernels::QuantParams quant_params, int64_t const num_rows,
|
||||||
|
int64_t const hidden_size, int const full_num_experts, int const experts_per_token,
|
||||||
|
int* unpermuted_token_selected_experts_, int* unpermuted_source_token_ids_, int* permuted_source_token_ids_,
|
||||||
|
int* permuted_token_selected_experts_, T* permuted_data_, char* sorter_ws_, int64_t* expert_first_token_offset_,
|
||||||
|
float* permuted_token_final_scales_, int* expanded_source_row_to_expanded_dest_row,
|
||||||
|
cutlass_kernels::MOEParallelismConfig parallelism_config, kernels::CubKeyValueSorter sorter_, bool use_lora,
|
||||||
|
kernels::LoraParams& lora_params, bool use_fp8_block_scaling, bool min_latency_mode,
|
||||||
|
cutlass_kernels::MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
|
||||||
|
{
|
||||||
|
TLLM_CHECK_WITH_INFO(experts_per_token * full_num_experts <= std::numeric_limits<int>::max(),
|
||||||
|
"experts_per_token * num_experts is too large");
|
||||||
|
|
||||||
|
auto const* input_activations = static_cast<T const*>(input_activations_void);
|
||||||
|
auto const* input_sf = input_sf_void
|
||||||
|
? reinterpret_cast<tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::ElementSF const*>(input_sf_void)
|
||||||
|
: nullptr;
|
||||||
|
int const num_experts_per_node = full_num_experts / parallelism_config.ep_size;
|
||||||
|
int start_expert = num_experts_per_node * parallelism_config.ep_rank;
|
||||||
|
int end_expert = start_expert + num_experts_per_node;
|
||||||
|
|
||||||
|
bool const needs_num_valid = parallelism_config.ep_size > 1;
|
||||||
|
// Note: expert_first_token_offset_[num_experts_per_node] stores the total number of expanded tokens
|
||||||
|
int64_t const* num_valid_tokens_ptr = needs_num_valid ? expert_first_token_offset_ + num_experts_per_node : nullptr;
|
||||||
|
|
||||||
|
bool use_w4afp8 = false;
|
||||||
|
bool fused_prologue_result = false;
|
||||||
|
if (!use_w4afp8)
|
||||||
|
{
|
||||||
|
// WAR: fusedBuildExpertMapsSortFirstToken kernel will lead to illegal memory access for W4AFP8
|
||||||
|
// input: token_selected_experts, [num_tokens, k]
|
||||||
|
// output: unpermuted_token_selected_experts_, [num_tokens, k]
|
||||||
|
// output: permuted_source_token_ids_, [num_tokens, k]
|
||||||
|
// output: expert_first_token_offset_, [num_experts_per_node + 1]
|
||||||
|
fused_prologue_result = kernels::fusedBuildExpertMapsSortFirstToken(token_selected_experts,
|
||||||
|
unpermuted_token_selected_experts_, permuted_source_token_ids_, expert_first_token_offset_, num_rows,
|
||||||
|
num_experts_per_node, experts_per_token, start_expert, end_expert, stream);
|
||||||
|
}
|
||||||
|
if (!fused_prologue_result)
|
||||||
|
{
|
||||||
|
TLLM_LOG_TRACE("Falling back to unfused prologue");
|
||||||
|
kernels::buildExpertMaps(token_selected_experts, unpermuted_token_selected_experts_,
|
||||||
|
unpermuted_source_token_ids_, num_rows, num_experts_per_node, experts_per_token, start_expert, end_expert,
|
||||||
|
stream);
|
||||||
|
sync_check_cuda_error(stream);
|
||||||
|
|
||||||
|
kernels::generateTokenPermutation(unpermuted_token_selected_experts_, unpermuted_source_token_ids_,
|
||||||
|
permuted_token_selected_experts_, permuted_source_token_ids_, expert_first_token_offset_, num_rows,
|
||||||
|
num_experts_per_node, experts_per_token, sorter_, static_cast<void*>(sorter_ws_), stream);
|
||||||
|
}
|
||||||
|
sync_check_cuda_error(stream);
|
||||||
|
|
||||||
|
// using ExpandedActivationsType = std::conditional_t<use_w4afp8, BackBoneType, T>;
|
||||||
|
using ExpandedActivationsType = T;
|
||||||
|
// input_activations: [num_tokens, hidden_size]
|
||||||
|
// output: permuted_data_, [num_token * k, hidden_size]
|
||||||
|
// input: token_topk_unpermuted_scales, [num_tokens, k]
|
||||||
|
// output: permuted_token_final_scales_, [num_tokens * k]
|
||||||
|
// input: permuted_source_token_ids_, [num_tokens, k]
|
||||||
|
// output: expanded_source_row_to_expanded_dest_row, [num_tokens, k]
|
||||||
|
float const* token_topk_unpermuted_scales = token_final_scales;
|
||||||
|
kernels::expandInputRowsKernelLauncher(input_activations,
|
||||||
|
reinterpret_cast<ExpandedActivationsType*>(permuted_data_), token_topk_unpermuted_scales,
|
||||||
|
permuted_token_final_scales_, permuted_source_token_ids_, expanded_source_row_to_expanded_dest_row, num_rows,
|
||||||
|
num_valid_tokens_ptr, hidden_size, experts_per_token, num_experts_per_node,
|
||||||
|
quant_params.fp4.fc1.act_global_scale, expert_first_token_offset_,
|
||||||
|
/* fc1_fp4_act_scale_ */ nullptr, input_sf, stream);
|
||||||
|
sync_check_cuda_error(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
|
||||||
|
torch::Tensor>
|
||||||
|
moe_permute_op(torch::Tensor const& input, torch::Tensor const& token_selected_experts,
|
||||||
|
torch::optional<torch::Tensor> token_final_scales, torch::Tensor const& fc1_expert_weights,
|
||||||
|
torch::Tensor const& fc2_expert_weights, torch::optional<c10::ArrayRef<torch::Tensor>> quant_scales,
|
||||||
|
torch::optional<torch::Tensor> input_sf, int64_t const num_experts_on_rank, int64_t const tp_size,
|
||||||
|
int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
|
||||||
|
int64_t const cluster_rank, bool min_latency_mode, bool use_fp8_block_scaling)
|
||||||
|
{
|
||||||
|
kernels::CubKeyValueSorter sorter_;
|
||||||
|
|
||||||
|
TORCH_CHECK(cluster_size == 1 && cluster_rank == 0, "smart_router is supported in min_latency mode");
|
||||||
|
TORCH_CHECK(min_latency_mode == false, "min_latency_mode is not supported now");
|
||||||
|
|
||||||
|
CHECK_INPUT(token_selected_experts, at::ScalarType::Int)
|
||||||
|
if (token_final_scales)
|
||||||
|
{
|
||||||
|
CHECK_INPUT(token_final_scales.value(), at::ScalarType::Float)
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK(input.dim() == 2, "input must be 2D.");
|
||||||
|
TORCH_CHECK(token_selected_experts.dim() == 2, "token_selected_experts must be 2D.");
|
||||||
|
|
||||||
|
TORCH_CHECK(input.sizes()[0] == token_selected_experts.sizes()[0],
|
||||||
|
"input and token_selected_experts must have the same num tokens.");
|
||||||
|
if (token_final_scales)
|
||||||
|
{
|
||||||
|
TORCH_CHECK(token_final_scales.value().dim() == 2, "token_selected_experts_probs must be 2D.");
|
||||||
|
TORCH_CHECK(input.sizes()[0] == token_final_scales.value().sizes()[0],
|
||||||
|
"input and token_selected_experts_probs must have the same num tokens.");
|
||||||
|
TORCH_CHECK(token_selected_experts.sizes()[1] == token_final_scales.value().sizes()[1],
|
||||||
|
"token_selected_experts and token_final_scales must have the same number of experts per token.");
|
||||||
|
}
|
||||||
|
|
||||||
|
int experts_per_token = token_selected_experts.sizes()[1];
|
||||||
|
int64_t num_rows = input.sizes()[0];
|
||||||
|
int64_t hidden_size = input.sizes()[1];
|
||||||
|
auto const num_experts_total = static_cast<int>(num_experts_on_rank * ep_size);
|
||||||
|
auto parallelism_config = cutlass_kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank);
|
||||||
|
auto activation_type = tensorrt_llm::ActivationType::Swiglu;
|
||||||
|
|
||||||
|
int const num_experts_per_node = num_experts_on_rank;
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||||
|
int64_t num_moe_inputs = static_cast<int64_t>(experts_per_token * num_rows);
|
||||||
|
|
||||||
|
auto unpermuted_token_selected_experts_tensor
|
||||||
|
= torch::empty({num_moe_inputs}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
|
||||||
|
|
||||||
|
auto unpermuted_source_token_ids_tensor
|
||||||
|
= torch::empty({num_moe_inputs}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
|
||||||
|
|
||||||
|
auto permuted_source_token_ids_tensor
|
||||||
|
= torch::empty({num_moe_inputs}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
|
||||||
|
|
||||||
|
auto permuted_token_selected_experts_tensor
|
||||||
|
= torch::empty({num_moe_inputs}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
|
||||||
|
|
||||||
|
auto permuted_data_tensor = torch::empty({num_moe_inputs, hidden_size}, input.options().requires_grad(false));
|
||||||
|
|
||||||
|
auto permuted_token_final_scales_tensor
|
||||||
|
= torch::empty({num_moe_inputs}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
|
||||||
|
|
||||||
|
auto expert_first_token_offset_tensor = torch::empty(
|
||||||
|
{num_experts_per_node + 1}, torch::dtype(torch::kInt64).device(torch::kCUDA).requires_grad(false));
|
||||||
|
|
||||||
|
size_t const sorter_size = min_latency_mode
|
||||||
|
? 0
|
||||||
|
: kernels::CubKeyValueSorter::getWorkspaceSize(num_rows * experts_per_token, num_experts_per_node);
|
||||||
|
auto sorter_ws_tensor = torch::empty(
|
||||||
|
{static_cast<int64_t>(sorter_size)}, torch::dtype(torch::kChar).device(torch::kCUDA).requires_grad(false));
|
||||||
|
|
||||||
|
auto src_to_dest_map_tensor = torch::empty({static_cast<int64_t>(experts_per_token * num_rows)},
|
||||||
|
torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
|
||||||
|
|
||||||
|
cutlass_kernels::QuantParams quant_params{};
|
||||||
|
cutlass_kernels::MoeMinLatencyParams min_latency_params{};
|
||||||
|
|
||||||
|
kernels::LoraParams lora_params{};
|
||||||
|
|
||||||
|
auto data_type = input.scalar_type();
|
||||||
|
switch (data_type)
|
||||||
|
{
|
||||||
|
case torch::kFloat32:
|
||||||
|
runPermute<float>(input.const_data_ptr(), input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr,
|
||||||
|
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
||||||
|
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
||||||
|
: nullptr,
|
||||||
|
/*fc1_expert_weights.const_data_ptr()*/ nullptr, nullptr, activation_type,
|
||||||
|
/*fc2_expert_weights.const_data_ptr()*/ nullptr, nullptr, quant_params, num_rows, hidden_size,
|
||||||
|
num_experts_total, static_cast<int>(experts_per_token),
|
||||||
|
static_cast<int*>(unpermuted_token_selected_experts_tensor.data_ptr()),
|
||||||
|
static_cast<int*>(unpermuted_source_token_ids_tensor.data_ptr()),
|
||||||
|
static_cast<int*>(permuted_source_token_ids_tensor.data_ptr()),
|
||||||
|
static_cast<int*>(permuted_token_selected_experts_tensor.data_ptr()),
|
||||||
|
static_cast<float*>(permuted_data_tensor.data_ptr()), static_cast<char*>(sorter_ws_tensor.data_ptr()),
|
||||||
|
static_cast<int64_t*>(expert_first_token_offset_tensor.data_ptr()),
|
||||||
|
static_cast<float*>(permuted_token_final_scales_tensor.data_ptr()),
|
||||||
|
static_cast<int*>(src_to_dest_map_tensor.data_ptr()), parallelism_config, sorter_, false, lora_params,
|
||||||
|
use_fp8_block_scaling, min_latency_mode, min_latency_params, stream);
|
||||||
|
break;
|
||||||
|
case torch::kBFloat16:
|
||||||
|
runPermute<__nv_bfloat16>(input.const_data_ptr(),
|
||||||
|
input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr,
|
||||||
|
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
||||||
|
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
||||||
|
: nullptr,
|
||||||
|
/*fc1_expert_weights.const_data_ptr()*/ nullptr, nullptr, activation_type,
|
||||||
|
/*fc2_expert_weights.const_data_ptr()*/ nullptr, nullptr, quant_params, num_rows, hidden_size,
|
||||||
|
num_experts_total, static_cast<int>(experts_per_token),
|
||||||
|
static_cast<int*>(unpermuted_token_selected_experts_tensor.data_ptr()),
|
||||||
|
static_cast<int*>(unpermuted_source_token_ids_tensor.data_ptr()),
|
||||||
|
static_cast<int*>(permuted_source_token_ids_tensor.data_ptr()),
|
||||||
|
static_cast<int*>(permuted_token_selected_experts_tensor.data_ptr()),
|
||||||
|
static_cast<__nv_bfloat16*>(permuted_data_tensor.data_ptr()),
|
||||||
|
static_cast<char*>(sorter_ws_tensor.data_ptr()),
|
||||||
|
static_cast<int64_t*>(expert_first_token_offset_tensor.data_ptr()),
|
||||||
|
static_cast<float*>(permuted_token_final_scales_tensor.data_ptr()),
|
||||||
|
static_cast<int*>(src_to_dest_map_tensor.data_ptr()), parallelism_config, sorter_, false, lora_params,
|
||||||
|
use_fp8_block_scaling, min_latency_mode, min_latency_params, stream);
|
||||||
|
break;
|
||||||
|
case torch::kHalf:
|
||||||
|
runPermute<half>(input.const_data_ptr(), input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr,
|
||||||
|
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
||||||
|
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
||||||
|
: nullptr,
|
||||||
|
/*fc1_expert_weights.const_data_ptr()*/ nullptr, nullptr, activation_type,
|
||||||
|
/*fc2_expert_weights.const_data_ptr()*/ nullptr, nullptr, quant_params, num_rows, hidden_size,
|
||||||
|
num_experts_total, static_cast<int>(experts_per_token),
|
||||||
|
static_cast<int*>(unpermuted_token_selected_experts_tensor.data_ptr()),
|
||||||
|
static_cast<int*>(unpermuted_source_token_ids_tensor.data_ptr()),
|
||||||
|
static_cast<int*>(permuted_source_token_ids_tensor.data_ptr()),
|
||||||
|
static_cast<int*>(permuted_token_selected_experts_tensor.data_ptr()),
|
||||||
|
static_cast<half*>(permuted_data_tensor.data_ptr()), static_cast<char*>(sorter_ws_tensor.data_ptr()),
|
||||||
|
static_cast<int64_t*>(expert_first_token_offset_tensor.data_ptr()),
|
||||||
|
static_cast<float*>(permuted_token_final_scales_tensor.data_ptr()),
|
||||||
|
static_cast<int*>(src_to_dest_map_tensor.data_ptr()), parallelism_config, sorter_, false, lora_params,
|
||||||
|
use_fp8_block_scaling, min_latency_mode, min_latency_params, stream);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Invalid dtype, only supports input tensor with float32, float16 and bfloat16 dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return std::make_tuple(unpermuted_token_selected_experts_tensor, unpermuted_source_token_ids_tensor,
|
||||||
|
permuted_source_token_ids_tensor, permuted_token_selected_experts_tensor, permuted_data_tensor,
|
||||||
|
expert_first_token_offset_tensor, permuted_token_final_scales_tensor, src_to_dest_map_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> run_moe_expand_op(torch::Tensor const& input,
|
||||||
|
torch::optional<torch::Tensor> token_final_scales, torch::Tensor const& permuted_source_token_ids,
|
||||||
|
int64_t const num_rows, torch::Tensor& expert_first_token_offset_tensor, int64_t const hidden_size,
|
||||||
|
int64_t const experts_per_token, int64_t const num_experts_per_node, int64_t const tp_size, int64_t const tp_rank,
|
||||||
|
int64_t const ep_size, int64_t const ep_rank, bool use_fp8_block_scaling)
|
||||||
|
{
|
||||||
|
auto parallelism_config = cutlass_kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank);
|
||||||
|
|
||||||
|
bool const needs_num_valid = parallelism_config.ep_size > 1;
|
||||||
|
int64_t const* num_valid_tokens_ptr = needs_num_valid
|
||||||
|
? static_cast<int64_t*>(expert_first_token_offset_tensor.data_ptr()) + num_experts_per_node
|
||||||
|
: nullptr;
|
||||||
|
|
||||||
|
int64_t num_moe_inputs = static_cast<int64_t>(experts_per_token * num_rows);
|
||||||
|
auto permuted_data_tensor = torch::empty({num_moe_inputs, hidden_size}, input.options().requires_grad(false));
|
||||||
|
auto permuted_token_final_scales_tensor
|
||||||
|
= torch::empty({num_moe_inputs}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
|
||||||
|
auto expanded_source_row_to_expanded_dest_row
|
||||||
|
= torch::empty({num_moe_inputs}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||||
|
cutlass_kernels::QuantParams quant_params{};
|
||||||
|
|
||||||
|
float const* token_topk_unpermuted_scales = token_final_scales.has_value()
|
||||||
|
? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
||||||
|
: nullptr;
|
||||||
|
auto data_type = input.scalar_type();
|
||||||
|
switch (data_type)
|
||||||
|
{
|
||||||
|
case torch::kFloat32:
|
||||||
|
kernels::expandInputRowsKernelLauncher<float, float>(static_cast<float const*>(input.const_data_ptr()),
|
||||||
|
reinterpret_cast<float*>(permuted_data_tensor.data_ptr()), token_topk_unpermuted_scales,
|
||||||
|
static_cast<float*>(permuted_token_final_scales_tensor.data_ptr()),
|
||||||
|
static_cast<int const*>(permuted_source_token_ids.const_data_ptr()),
|
||||||
|
static_cast<int*>(expanded_source_row_to_expanded_dest_row.data_ptr()), num_rows, num_valid_tokens_ptr,
|
||||||
|
hidden_size, experts_per_token, num_experts_per_node, quant_params.fp4.fc1.act_global_scale,
|
||||||
|
static_cast<int64_t*>(expert_first_token_offset_tensor.data_ptr()),
|
||||||
|
/* fc1_fp4_act_scale_ */ nullptr, /*input_sf*/ nullptr, stream);
|
||||||
|
break;
|
||||||
|
case torch::kBFloat16:
|
||||||
|
kernels::expandInputRowsKernelLauncher<__nv_bfloat16, __nv_bfloat16>(
|
||||||
|
static_cast<__nv_bfloat16 const*>(input.const_data_ptr()),
|
||||||
|
reinterpret_cast<__nv_bfloat16*>(permuted_data_tensor.data_ptr()), token_topk_unpermuted_scales,
|
||||||
|
static_cast<float*>(permuted_token_final_scales_tensor.data_ptr()),
|
||||||
|
static_cast<int const*>(permuted_source_token_ids.const_data_ptr()),
|
||||||
|
static_cast<int*>(expanded_source_row_to_expanded_dest_row.data_ptr()), num_rows, num_valid_tokens_ptr,
|
||||||
|
hidden_size, experts_per_token, num_experts_per_node, quant_params.fp4.fc1.act_global_scale,
|
||||||
|
static_cast<int64_t*>(expert_first_token_offset_tensor.data_ptr()),
|
||||||
|
/* fc1_fp4_act_scale_ */ nullptr, /*input_sf*/ nullptr, stream);
|
||||||
|
break;
|
||||||
|
case torch::kHalf:
|
||||||
|
kernels::expandInputRowsKernelLauncher<half, half>(static_cast<half const*>(input.const_data_ptr()),
|
||||||
|
reinterpret_cast<half*>(permuted_data_tensor.data_ptr()), token_topk_unpermuted_scales,
|
||||||
|
static_cast<float*>(permuted_token_final_scales_tensor.data_ptr()),
|
||||||
|
static_cast<int const*>(permuted_source_token_ids.const_data_ptr()),
|
||||||
|
static_cast<int*>(expanded_source_row_to_expanded_dest_row.data_ptr()), num_rows, num_valid_tokens_ptr,
|
||||||
|
hidden_size, experts_per_token, num_experts_per_node, quant_params.fp4.fc1.act_global_scale,
|
||||||
|
static_cast<int64_t*>(expert_first_token_offset_tensor.data_ptr()),
|
||||||
|
/* fc1_fp4_act_scale_ */ nullptr, /*input_sf*/ nullptr, stream);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Invalid dtype, only supports input tensor with float32, float16 and bfloat16 dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return std::make_tuple(
|
||||||
|
permuted_data_tensor, permuted_token_final_scales_tensor, expanded_source_row_to_expanded_dest_row);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class UnfusedGemmOutputType, class ScaleBiasType, class OutputType>
|
||||||
|
void runMoEFinalizeScaleOp(UnfusedGemmOutputType const* const gemm2_output,
|
||||||
|
ScaleBiasType const* const fc2_expert_biases, float const* const unpermuted_final_scales,
|
||||||
|
int const* const expanded_source_row_to_expanded_dest_row, int const* const expert_for_source_row,
|
||||||
|
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, /*int64_t const expanded_num_rows,*/
|
||||||
|
int64_t const hidden_size, /*int64_t const inter_size, int const num_experts_per_node,*/
|
||||||
|
int64_t const experts_per_token, cutlass_kernels::MOEParallelismConfig parallelism_config, cudaStream_t stream,
|
||||||
|
OutputType* const final_output)
|
||||||
|
{
|
||||||
|
kernels::finalizeMoeRoutingKernelLauncher<OutputType, UnfusedGemmOutputType>(
|
||||||
|
static_cast<UnfusedGemmOutputType const*>(gemm2_output), final_output, fc2_expert_biases,
|
||||||
|
unpermuted_final_scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, num_rows, hidden_size,
|
||||||
|
experts_per_token, num_valid_tokens_ptr, parallelism_config, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor run_moe_finalize_scale_op(torch::Tensor const& gemm2_output, torch::Tensor const& fc2_expert_biases,
|
||||||
|
torch::Tensor const& unpermuted_final_scales, torch::Tensor const& expanded_source_row_to_expanded_dest_row,
|
||||||
|
torch::Tensor const& expert_for_source_row, torch::Tensor const& expert_first_token_offset_tensor,
|
||||||
|
c10::SymInt num_rows_param, c10::SymInt hidden_size_param, int64_t const experts_per_token,
|
||||||
|
int64_t const num_experts_per_node, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size,
|
||||||
|
int64_t const ep_rank)
|
||||||
|
{
|
||||||
|
int64_t num_rows = num_rows_param.guard_int(__FILE__, __LINE__);
|
||||||
|
int64_t hidden_size = hidden_size_param.guard_int(__FILE__, __LINE__);
|
||||||
|
|
||||||
|
TORCH_CHECK(gemm2_output.dim() == 2, "gemm2_output must be 2D.");
|
||||||
|
TORCH_CHECK(unpermuted_final_scales.dim() == 2, "unpermuted_final_scales must be 2D.");
|
||||||
|
TORCH_CHECK(
|
||||||
|
expanded_source_row_to_expanded_dest_row.dim() == 1, "expanded_source_row_to_expanded_dest_row must be 1D.");
|
||||||
|
TORCH_CHECK(expert_for_source_row.dim() == 1, "expert_for_source_row must be 1D.");
|
||||||
|
TORCH_CHECK(expert_first_token_offset_tensor.dim() == 1, "expert_first_token_offset_tensor must be 1D.");
|
||||||
|
|
||||||
|
TORCH_CHECK(gemm2_output.sizes()[0] == expert_for_source_row.sizes()[0],
|
||||||
|
"gemm2_output and expert_for_source_row must have the same expanded num tokens.");
|
||||||
|
TORCH_CHECK(unpermuted_final_scales.sizes()[0] == num_rows, "unpermuted_final_scales[0] should equal to num_rows.");
|
||||||
|
TORCH_CHECK(unpermuted_final_scales.sizes()[1] == experts_per_token,
|
||||||
|
"unpermuted_final_scales[1] should equal to experts_per_token.");
|
||||||
|
TORCH_CHECK(expert_for_source_row.sizes()[0] == gemm2_output.sizes()[0],
|
||||||
|
"expert_for_source_row and gemm2_output must have the same expanded num tokens.");
|
||||||
|
TORCH_CHECK(expert_first_token_offset_tensor.sizes()[0] == num_experts_per_node + 1,
|
||||||
|
"expert_first_token_offset_tensor[0] should equal to num_experts_per_node + 1.");
|
||||||
|
|
||||||
|
auto parallelism_config = cutlass_kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank);
|
||||||
|
|
||||||
|
bool const needs_num_valid = parallelism_config.ep_size > 1;
|
||||||
|
int64_t const* num_valid_tokens_ptr = needs_num_valid
|
||||||
|
? static_cast<int64_t const*>(expert_first_token_offset_tensor.const_data_ptr()) + num_experts_per_node
|
||||||
|
: nullptr;
|
||||||
|
|
||||||
|
auto final_output = torch::empty({num_rows, hidden_size}, gemm2_output.options());
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(gemm2_output.get_device());
|
||||||
|
auto data_type = gemm2_output.scalar_type();
|
||||||
|
switch (data_type)
|
||||||
|
{
|
||||||
|
case torch::kFloat32:
|
||||||
|
runMoEFinalizeScaleOp<float, float, float>(static_cast<float const*>(gemm2_output.const_data_ptr()),
|
||||||
|
// static_cast<float const*>(fc2_expert_biases.const_data_ptr()),
|
||||||
|
nullptr, static_cast<float const*>(unpermuted_final_scales.const_data_ptr()),
|
||||||
|
static_cast<int const*>(expanded_source_row_to_expanded_dest_row.const_data_ptr()),
|
||||||
|
static_cast<int const*>(expert_for_source_row.const_data_ptr()), num_valid_tokens_ptr, num_rows,
|
||||||
|
hidden_size, experts_per_token, parallelism_config, stream, static_cast<float*>(final_output.data_ptr()));
|
||||||
|
break;
|
||||||
|
case torch::kBFloat16:
|
||||||
|
runMoEFinalizeScaleOp<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>(
|
||||||
|
static_cast<__nv_bfloat16 const*>(gemm2_output.const_data_ptr()),
|
||||||
|
// static_cast<__nv_bfloat16 const*>(fc2_expert_biases.const_data_ptr()),
|
||||||
|
nullptr, static_cast<float const*>(unpermuted_final_scales.const_data_ptr()),
|
||||||
|
static_cast<int const*>(expanded_source_row_to_expanded_dest_row.const_data_ptr()),
|
||||||
|
static_cast<int const*>(expert_for_source_row.const_data_ptr()), num_valid_tokens_ptr, num_rows,
|
||||||
|
hidden_size, experts_per_token, parallelism_config, stream,
|
||||||
|
static_cast<__nv_bfloat16*>(final_output.data_ptr()));
|
||||||
|
break;
|
||||||
|
case torch::kHalf:
|
||||||
|
runMoEFinalizeScaleOp<half, half, half>(static_cast<half const*>(gemm2_output.const_data_ptr()),
|
||||||
|
// static_cast<half const*>(fc2_expert_biases.const_data_ptr()),
|
||||||
|
nullptr, static_cast<float const*>(unpermuted_final_scales.const_data_ptr()),
|
||||||
|
static_cast<int const*>(expanded_source_row_to_expanded_dest_row.const_data_ptr()),
|
||||||
|
static_cast<int const*>(expert_for_source_row.const_data_ptr()), num_valid_tokens_ptr, num_rows,
|
||||||
|
hidden_size, experts_per_token, parallelism_config, stream, static_cast<half*>(final_output.data_ptr()));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Invalid dtype, only supports input tensor with float32, float16 and bfloat16 dtype");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return final_output;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace torch_ext
|
||||||
|
|
||||||
|
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||||
|
{
|
||||||
|
m.def(
|
||||||
|
"moe_permute_op(Tensor input, Tensor token_selected_experts, Tensor? token_final_scales, Tensor "
|
||||||
|
"fc1_expert_weights, Tensor fc2_expert_weights, Tensor[]? quant_scales, Tensor? input_sf, int "
|
||||||
|
"num_experts_on_rank, int tp_size, int tp_rank, int ep_size, int ep_rank, int cluster_size, int cluster_rank, "
|
||||||
|
"bool min_latency_mode, bool use_fp8_block_scaling)"
|
||||||
|
"-> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
|
||||||
|
m.def(
|
||||||
|
"moe_finalize_scale_op(Tensor gemm2_output, Tensor fc2_expert_biases, Tensor unpermuted_final_scales, Tensor "
|
||||||
|
"expanded_source_row_to_expanded_dest_row, Tensor expert_for_source_row, Tensor "
|
||||||
|
"expert_first_token_offset_tensor, SymInt num_rows, SymInt hidden_size, int experts_per_token, int "
|
||||||
|
"num_experts_per_node, int tp_size, int tp_rank, int ep_size, int ep_rank)"
|
||||||
|
"-> (Tensor)");
|
||||||
|
m.def(
|
||||||
|
"moe_expand_op(Tensor input, Tensor? token_final_scales, Tensor permuted_source_token_ids, int num_rows, "
|
||||||
|
"Tensor expert_first_token_offset_tensor, int hidden_size, int experts_per_token, int num_experts_per_node, "
|
||||||
|
"int tp_size, int tp_rank, int ep_size, int ep_rank, bool use_fp8_block_scaling)"
|
||||||
|
"-> (Tensor, Tensor, Tensor)");
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||||
|
{
|
||||||
|
m.impl("moe_permute_op", &torch_ext::moe_permute_op);
|
||||||
|
m.impl("moe_finalize_scale_op", &torch_ext::run_moe_finalize_scale_op);
|
||||||
|
m.impl("moe_expand_op", &torch_ext::run_moe_expand_op);
|
||||||
|
}
|
||||||
@ -306,6 +306,9 @@ protected:
|
|||||||
bool mUseLora = false;
|
bool mUseLora = false;
|
||||||
bool mUsePrequantScale = false;
|
bool mUsePrequantScale = false;
|
||||||
|
|
||||||
|
// Run tests with per-expert act scale
|
||||||
|
bool mUsePerExpertActScale = true;
|
||||||
|
|
||||||
bool mIsGated = false;
|
bool mIsGated = false;
|
||||||
int64_t mGatedMultiplier = 1;
|
int64_t mGatedMultiplier = 1;
|
||||||
int64_t mGroupSize = -1;
|
int64_t mGroupSize = -1;
|
||||||
@ -480,12 +483,12 @@ protected:
|
|||||||
{
|
{
|
||||||
// FP4 uses the same logic as FP8 to generate the global scales
|
// FP4 uses the same logic as FP8 to generate the global scales
|
||||||
mExpertFPXScale1 = allocBuffer<float>(mNumExperts);
|
mExpertFPXScale1 = allocBuffer<float>(mNumExperts);
|
||||||
mExpertFPXScale2 = allocBuffer<float>(1);
|
mExpertFPXScale2 = allocBuffer<float>(mNumExperts); // mNumExperts or 1
|
||||||
mExpertFPXScale3 = allocBuffer<float>(mNumExperts);
|
mExpertFPXScale3 = allocBuffer<float>(mNumExperts);
|
||||||
|
|
||||||
if (ANY_FP4)
|
if (ANY_FP4)
|
||||||
{
|
{
|
||||||
mExpertFP4ActGlobalScale1 = allocBuffer<float>(1);
|
mExpertFP4ActGlobalScale1 = allocBuffer<float>(mNumExperts); // mNumExperts or 1
|
||||||
mExpertFP4WeightGlobalScale1 = allocBuffer<float>(mNumExperts);
|
mExpertFP4WeightGlobalScale1 = allocBuffer<float>(mNumExperts);
|
||||||
mExpertFP4WeightGlobalScale2 = allocBuffer<float>(mNumExperts);
|
mExpertFP4WeightGlobalScale2 = allocBuffer<float>(mNumExperts);
|
||||||
}
|
}
|
||||||
@ -665,23 +668,37 @@ protected:
|
|||||||
float scaleAct1 = getFPXActScalar(max_input);
|
float scaleAct1 = getFPXActScalar(max_input);
|
||||||
|
|
||||||
float maxFC1Output = calcMLPVal(max_input, maxIndex) / maxW2;
|
float maxFC1Output = calcMLPVal(max_input, maxIndex) / maxW2;
|
||||||
float scaleAct2 = getFPXActScalar(maxFC1Output);
|
|
||||||
|
std::vector<float> scales_1;
|
||||||
|
std::vector<float> scales_2;
|
||||||
|
std::vector<float> scales_3;
|
||||||
|
if (mUsePerExpertActScale)
|
||||||
|
{
|
||||||
|
scales_2 = std::vector<float>(mNumExperts);
|
||||||
|
for (int i = 0; i < mNumExperts; i++)
|
||||||
|
{
|
||||||
|
float maxExpertOutput = calcMLPVal(max_input, i) / applyExpertShift(mExpertWDiag2, i);
|
||||||
|
float scaleAct2 = getFPXActScalar(maxExpertOutput);
|
||||||
|
scales_2[i] = scaleAct2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
float scaleAct2 = getFPXActScalar(maxFC1Output);
|
||||||
|
scales_2 = std::vector<float>(mNumExperts, scaleAct2);
|
||||||
|
}
|
||||||
|
|
||||||
ASSERT_NE(mExpertFPXScale1, nullptr);
|
ASSERT_NE(mExpertFPXScale1, nullptr);
|
||||||
ASSERT_NE(mExpertFPXScale2, nullptr);
|
ASSERT_NE(mExpertFPXScale2, nullptr);
|
||||||
ASSERT_NE(mExpertFPXScale3, nullptr);
|
ASSERT_NE(mExpertFPXScale3, nullptr);
|
||||||
|
|
||||||
std::vector<float> scales_1;
|
|
||||||
std::vector<float> scales_2;
|
|
||||||
std::vector<float> scales_3;
|
|
||||||
if (ANY_FP4)
|
if (ANY_FP4)
|
||||||
{
|
{
|
||||||
std::vector<float> scale_global_w1(mNumExperts);
|
std::vector<float> scale_global_w1(mNumExperts);
|
||||||
std::vector<float> scale_global_w2(mNumExperts);
|
std::vector<float> scale_global_w2(mNumExperts);
|
||||||
|
|
||||||
std::vector<float> scales_0(1, scaleAct1);
|
std::vector<float> scales_0(mUsePerExpertActScale && NVFP4 ? mNumExperts : 1, scaleAct1);
|
||||||
scales_1 = std::vector<float>(mNumExperts);
|
scales_1 = std::vector<float>(mNumExperts);
|
||||||
scales_2 = std::vector<float>(1, scaleAct2);
|
|
||||||
scales_3 = std::vector<float>(mNumExperts);
|
scales_3 = std::vector<float>(mNumExperts);
|
||||||
|
|
||||||
for (int i = 0; i < mNumExperts; i++)
|
for (int i = 0; i < mNumExperts; i++)
|
||||||
@ -695,7 +712,7 @@ protected:
|
|||||||
|
|
||||||
// TODO Per expert scaling factors
|
// TODO Per expert scaling factors
|
||||||
scales_1[i] = 1.f / (scaleAct1 * scaleW1);
|
scales_1[i] = 1.f / (scaleAct1 * scaleW1);
|
||||||
scales_3[i] = 1.f / (scaleAct2 * scaleW2);
|
scales_3[i] = 1.f / (scales_2[i] * scaleW2);
|
||||||
}
|
}
|
||||||
|
|
||||||
ASSERT_NE(mExpertFP4ActGlobalScale1, nullptr);
|
ASSERT_NE(mExpertFP4ActGlobalScale1, nullptr);
|
||||||
@ -713,8 +730,17 @@ protected:
|
|||||||
mFP8WeightScalar1 = scaleW1;
|
mFP8WeightScalar1 = scaleW1;
|
||||||
mFP8WeightScalar2 = scaleW2;
|
mFP8WeightScalar2 = scaleW2;
|
||||||
scales_1 = std::vector<float>(mNumExperts, 1.f / (scaleW1 * scaleAct1));
|
scales_1 = std::vector<float>(mNumExperts, 1.f / (scaleW1 * scaleAct1));
|
||||||
scales_2 = std::vector<float>(1, scaleAct2);
|
scales_3 = std::vector<float>(mNumExperts);
|
||||||
scales_3 = std::vector<float>(mNumExperts, 1.f / (scaleW2 * scaleAct2));
|
|
||||||
|
for (int i = 0; i < mNumExperts; i++)
|
||||||
|
{
|
||||||
|
scales_3[i] = 1.f / (scaleW2 * scales_2[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!mUsePerExpertActScale)
|
||||||
|
{
|
||||||
|
scales_2.resize(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
check_cuda_error(cudaMemcpyAsync(mExpertFPXScale1, scales_1.data(), scales_1.size() * sizeof(float),
|
check_cuda_error(cudaMemcpyAsync(mExpertFPXScale1, scales_1.data(), scales_1.size() * sizeof(float),
|
||||||
@ -893,6 +919,10 @@ protected:
|
|||||||
ep_scale_1 = mExpertFPXScale1 + experts_per_node * parallelism_config.ep_rank;
|
ep_scale_1 = mExpertFPXScale1 + experts_per_node * parallelism_config.ep_rank;
|
||||||
ep_scale_3 = mExpertFPXScale3 + experts_per_node * parallelism_config.ep_rank;
|
ep_scale_3 = mExpertFPXScale3 + experts_per_node * parallelism_config.ep_rank;
|
||||||
}
|
}
|
||||||
|
if (mUsePerExpertActScale)
|
||||||
|
{
|
||||||
|
ep_scale_2 = mExpertFPXScale2 + experts_per_node * parallelism_config.ep_rank;
|
||||||
|
}
|
||||||
|
|
||||||
// Slice weights for TP
|
// Slice weights for TP
|
||||||
void* scale_1 = ep_scale_1;
|
void* scale_1 = ep_scale_1;
|
||||||
@ -1039,18 +1069,22 @@ protected:
|
|||||||
else if (FP8)
|
else if (FP8)
|
||||||
{
|
{
|
||||||
ASSERT_TRUE(scale1_ptr && scale2_ptr && scale3_ptr);
|
ASSERT_TRUE(scale1_ptr && scale2_ptr && scale3_ptr);
|
||||||
quant_params = QuantParams::FP8(static_cast<float const*>(scale1_ptr),
|
quant_params
|
||||||
static_cast<float const*>(scale2_ptr), static_cast<float const*>(scale3_ptr));
|
= QuantParams::FP8(static_cast<float const*>(scale1_ptr), static_cast<float const*>(scale2_ptr),
|
||||||
|
static_cast<float const*>(scale3_ptr), nullptr, nullptr, mUsePerExpertActScale);
|
||||||
}
|
}
|
||||||
else if (ANY_FP4)
|
else if (ANY_FP4)
|
||||||
{
|
{
|
||||||
ASSERT_TRUE(mExpertFP4ActGlobalScale1);
|
ASSERT_TRUE(mExpertFP4ActGlobalScale1);
|
||||||
ASSERT_TRUE(mFP4ScalingFactorsW1 && mFP4ScalingFactorsW2);
|
ASSERT_TRUE(mFP4ScalingFactorsW1 && mFP4ScalingFactorsW2);
|
||||||
ASSERT_TRUE(scale1_ptr && scale2_ptr && scale3_ptr);
|
ASSERT_TRUE(scale1_ptr && scale2_ptr && scale3_ptr);
|
||||||
|
auto fc1_sf_offset = mUsePerExpertActScale && NVFP4
|
||||||
|
? mNumExperts / parallelism_config.ep_size * parallelism_config.ep_rank
|
||||||
|
: 0;
|
||||||
auto constructor = NVFP4 ? &QuantParams::FP4 : &QuantParams::FP8MXFP4;
|
auto constructor = NVFP4 ? &QuantParams::FP4 : &QuantParams::FP8MXFP4;
|
||||||
quant_params
|
quant_params = constructor(mExpertFP4ActGlobalScale1 + fc1_sf_offset, mFP4ScalingFactorsW1,
|
||||||
= constructor(mExpertFP4ActGlobalScale1, mFP4ScalingFactorsW1, static_cast<float const*>(scale1_ptr),
|
static_cast<float const*>(scale1_ptr), static_cast<float const*>(scale2_ptr), mFP4ScalingFactorsW2,
|
||||||
static_cast<float const*>(scale2_ptr), mFP4ScalingFactorsW2, static_cast<float const*>(scale3_ptr));
|
static_cast<float const*>(scale3_ptr), mUsePerExpertActScale && NVFP4, mUsePerExpertActScale);
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (WEIGHT_FP4)
|
if constexpr (WEIGHT_FP4)
|
||||||
@ -1497,6 +1531,19 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteNoBias)
|
|||||||
this->BasicPermuteTest(3);
|
this->BasicPermuteTest(3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(MixtureOfExpertsTest, PermuteSingletonScale)
|
||||||
|
{
|
||||||
|
if (!this->ANY_FPX)
|
||||||
|
{
|
||||||
|
GTEST_SKIP() << "Only FPX cares about per-expert act scale";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this->mUsePerExpertActScale = false;
|
||||||
|
this->BasicPermuteTest(1);
|
||||||
|
this->BasicPermuteTest(2);
|
||||||
|
this->BasicPermuteTest(3);
|
||||||
|
}
|
||||||
|
|
||||||
TYPED_TEST(MixtureOfExpertsTest, PermuteGelu)
|
TYPED_TEST(MixtureOfExpertsTest, PermuteGelu)
|
||||||
{
|
{
|
||||||
this->mActType = ActivationType::Gelu;
|
this->mActType = ActivationType::Gelu;
|
||||||
@ -2071,7 +2118,7 @@ TEST_F(MixtureOfExpertsProfilerTest, TestGeneratedProfilerDistribution)
|
|||||||
#ifdef USING_OSS_CUTLASS_MOE_GEMM
|
#ifdef USING_OSS_CUTLASS_MOE_GEMM
|
||||||
backend.init(this->mMoERunner, GemmProfilerBackend::GemmToProfile::GEMM_1, nvinfer1::DataType::kHALF,
|
backend.init(this->mMoERunner, GemmProfilerBackend::GemmToProfile::GEMM_1, nvinfer1::DataType::kHALF,
|
||||||
nvinfer1::DataType::kHALF, nvinfer1::DataType::kHALF, num_experts, k, 1024, 4096, mGroupSize, {}, false,
|
nvinfer1::DataType::kHALF, nvinfer1::DataType::kHALF, num_experts, k, 1024, 4096, mGroupSize, {}, false,
|
||||||
mUseLora, /*min_latency_mode=*/false, /*need_weights=*/true, MOEParallelismConfig{1, 0, ep, ep - 1},
|
mUseLora, /*min_latency_mode=*/false, /*need_weights=*/true, MOEParallelismConfig{1, 0, ep, 0},
|
||||||
/*enable_alltoall=*/false);
|
/*enable_alltoall=*/false);
|
||||||
#else
|
#else
|
||||||
backend.init(this->mMoERunner, GemmProfilerBackend::GemmToProfile::GEMM_1, nvinfer1::DataType::kHALF,
|
backend.init(this->mMoERunner, GemmProfilerBackend::GemmToProfile::GEMM_1, nvinfer1::DataType::kHALF,
|
||||||
@ -2089,34 +2136,47 @@ TEST_F(MixtureOfExpertsProfilerTest, TestGeneratedProfilerDistribution)
|
|||||||
#define GET_WS_PTR(type, name) auto* name = reinterpret_cast<type>(workspace + workspaces.at(#name).second)
|
#define GET_WS_PTR(type, name) auto* name = reinterpret_cast<type>(workspace + workspaces.at(#name).second)
|
||||||
|
|
||||||
GET_WS_PTR(int64_t*, expert_first_token_offset);
|
GET_WS_PTR(int64_t*, expert_first_token_offset);
|
||||||
GET_WS_PTR(int*, source_to_dest);
|
GET_WS_PTR(int*, unpermuted_row_to_permuted_row);
|
||||||
GET_WS_PTR(int*, dest_to_source);
|
GET_WS_PTR(int*, permuted_row_to_unpermuted_row);
|
||||||
|
#ifdef USING_OSS_CUTLASS_MOE_GEMM
|
||||||
|
GET_WS_PTR(int*, token_selected_experts);
|
||||||
|
#else
|
||||||
GET_WS_PTR(int*, unpermuted_selected_experts);
|
GET_WS_PTR(int*, unpermuted_selected_experts);
|
||||||
|
#endif
|
||||||
#undef GET_WS_PTR
|
#undef GET_WS_PTR
|
||||||
|
|
||||||
for (int sample = 0; sample < backend.NUM_ROUTING_SAMPLES; sample++)
|
for (int sample = 0; sample < backend.NUM_ROUTING_SAMPLES; sample++)
|
||||||
{
|
{
|
||||||
auto host_expert_first_token_offset_size = getDataFromDevice(
|
auto host_expert_first_token_offset_size = getDataFromDevice(
|
||||||
expert_first_token_offset + sample * (num_experts_per_node + 1), num_experts_per_node + 1);
|
expert_first_token_offset + sample * (num_experts_per_node + 1), num_experts_per_node + 1);
|
||||||
auto host_source_to_dest_map
|
auto host_unpermuted_row_to_permuted_row_map = getDataFromDevice(
|
||||||
= getDataFromDevice(source_to_dest + sample * expanded_num_tokens, expanded_num_tokens);
|
unpermuted_row_to_permuted_row + sample * expanded_num_tokens, expanded_num_tokens);
|
||||||
auto host_dest_to_source_map
|
auto host_permuted_row_to_unpermuted_row_map = getDataFromDevice(
|
||||||
= getDataFromDevice(dest_to_source + sample * expanded_num_tokens, expanded_num_tokens);
|
permuted_row_to_unpermuted_row + sample * expanded_num_tokens, expanded_num_tokens);
|
||||||
|
#ifdef USING_OSS_CUTLASS_MOE_GEMM
|
||||||
|
auto host_token_selected_experts
|
||||||
|
= getDataFromDevice(token_selected_experts + sample * expanded_num_tokens, expanded_num_tokens);
|
||||||
|
#else
|
||||||
auto host_token_selected_experts = getDataFromDevice(
|
auto host_token_selected_experts = getDataFromDevice(
|
||||||
unpermuted_selected_experts + sample * expanded_num_tokens, expanded_num_tokens);
|
unpermuted_selected_experts + sample * expanded_num_tokens, expanded_num_tokens);
|
||||||
|
#endif
|
||||||
|
|
||||||
std::vector<int64_t> calculated_routing_values(num_experts_per_node + 1, 0);
|
std::vector<int64_t> calculated_routing_values(num_experts_per_node + 1, 0);
|
||||||
int skipped = 0;
|
int skipped = 0;
|
||||||
for (auto v : host_token_selected_experts)
|
for (auto v : host_token_selected_experts)
|
||||||
{
|
{
|
||||||
|
#ifndef USING_OSS_CUTLASS_MOE_GEMM
|
||||||
ASSERT_TRUE(v < num_experts_per_node || (v == num_experts_per_node && ep > 1))
|
ASSERT_TRUE(v < num_experts_per_node || (v == num_experts_per_node && ep > 1))
|
||||||
<< "v " << v << " num_experts_per_node " << num_experts_per_node << " ep " << ep;
|
<< "v " << v << " num_experts_per_node " << num_experts_per_node << " ep " << ep;
|
||||||
skipped += (v == num_experts_per_node);
|
#endif
|
||||||
if (v < num_experts_per_node)
|
if (v < num_experts_per_node)
|
||||||
{
|
{
|
||||||
calculated_routing_values[v]++;
|
calculated_routing_values[v]++;
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
skipped++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (num_tokens > 1)
|
if (num_tokens > 1)
|
||||||
@ -2159,14 +2219,18 @@ TEST_F(MixtureOfExpertsProfilerTest, TestGeneratedProfilerDistribution)
|
|||||||
int64_t idx = token_idx * k + k_idx;
|
int64_t idx = token_idx * k + k_idx;
|
||||||
int64_t expert_idx = host_token_selected_experts[idx];
|
int64_t expert_idx = host_token_selected_experts[idx];
|
||||||
|
|
||||||
|
#ifdef USING_OSS_CUTLASS_MOE_GEMM
|
||||||
|
if (expert_idx < num_experts_per_node)
|
||||||
|
#else
|
||||||
if (expert_idx < num_experts)
|
if (expert_idx < num_experts)
|
||||||
|
#endif
|
||||||
{
|
{
|
||||||
int64_t source_location = k_idx * num_tokens + token_idx;
|
int64_t unpermuted_row = k_idx * num_tokens + token_idx;
|
||||||
int64_t dest_location = host_expert_first_token_offset_size[expert_idx]
|
int64_t permuted_row = host_expert_first_token_offset_size[expert_idx]
|
||||||
+ calculated_routing_values[expert_idx];
|
+ calculated_routing_values[expert_idx];
|
||||||
|
|
||||||
ASSERT_EQ(host_source_to_dest_map[source_location], dest_location);
|
ASSERT_EQ(host_unpermuted_row_to_permuted_row_map[unpermuted_row], permuted_row);
|
||||||
ASSERT_EQ(host_dest_to_source_map[dest_location], source_location);
|
ASSERT_EQ(host_permuted_row_to_unpermuted_row_map[permuted_row], unpermuted_row);
|
||||||
|
|
||||||
calculated_routing_values[expert_idx]++;
|
calculated_routing_values[expert_idx]++;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -21,7 +21,9 @@ def trtllm_fused_moe(
|
|||||||
selected_experts,
|
selected_experts,
|
||||||
routing_weights,
|
routing_weights,
|
||||||
w3_w1_stacked_weight,
|
w3_w1_stacked_weight,
|
||||||
|
None, # w3_w1_stacked_bias
|
||||||
w2_stacked_weight,
|
w2_stacked_weight,
|
||||||
|
None, # w2_stacked_bias
|
||||||
x.dtype,
|
x.dtype,
|
||||||
quant_scales,
|
quant_scales,
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
|
|||||||
@ -397,3 +397,79 @@ def _register_fake():
|
|||||||
pad_slot_id: int,
|
pad_slot_id: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@torch.library.register_fake("trtllm::moe_permute_op")
|
||||||
|
def _(
|
||||||
|
input: torch.Tensor,
|
||||||
|
token_selected_experts: torch.Tensor,
|
||||||
|
token_final_scales: torch.Tensor,
|
||||||
|
fc1_expert_weights: torch.Tensor,
|
||||||
|
fc2_expert_weights: torch.Tensor,
|
||||||
|
quant_scales: List[torch.Tensor],
|
||||||
|
input_sf: Optional[torch.Tensor],
|
||||||
|
num_experts_per_node: int,
|
||||||
|
tp_size: int,
|
||||||
|
tp_rank: int,
|
||||||
|
ep_size: int,
|
||||||
|
ep_rank: int,
|
||||||
|
cluster_size: int,
|
||||||
|
cluster_rank: int,
|
||||||
|
min_latency_mode: bool,
|
||||||
|
use_fp8_block_scaling: bool,
|
||||||
|
):
|
||||||
|
|
||||||
|
experts_per_token = token_selected_experts.shape[1]
|
||||||
|
num_rows = input.shape[0]
|
||||||
|
hidden_size = input.shape[1]
|
||||||
|
|
||||||
|
num_moe_inputs = experts_per_token * num_rows
|
||||||
|
|
||||||
|
unpermuted_token_selected_experts_tensor = token_selected_experts.new_empty(
|
||||||
|
(num_moe_inputs, ), dtype=torch.int32)
|
||||||
|
unpermuted_source_token_ids_tensor = token_selected_experts.new_empty(
|
||||||
|
(num_moe_inputs, ), dtype=torch.int32)
|
||||||
|
permuted_source_token_ids_tensor = token_selected_experts.new_empty(
|
||||||
|
(num_moe_inputs, ), dtype=torch.int32)
|
||||||
|
permuted_token_selected_experts_tensor = token_selected_experts.new_empty(
|
||||||
|
(num_moe_inputs, ), dtype=torch.int32)
|
||||||
|
permuted_data_tensor = input.new_empty((num_moe_inputs, hidden_size),
|
||||||
|
dtype=torch.float32)
|
||||||
|
expert_first_token_offset_tensor = token_selected_experts.new_empty(
|
||||||
|
(num_experts_per_node + 1, ), dtype=torch.int64)
|
||||||
|
permuted_token_final_scales_tensor = token_selected_experts.new_empty(
|
||||||
|
(num_moe_inputs, ), dtype=torch.float32)
|
||||||
|
src_to_dest_map_tensor = token_selected_experts.new_empty(
|
||||||
|
(num_moe_inputs, ), dtype=torch.int32)
|
||||||
|
|
||||||
|
return (
|
||||||
|
unpermuted_token_selected_experts_tensor,
|
||||||
|
unpermuted_source_token_ids_tensor,
|
||||||
|
permuted_source_token_ids_tensor,
|
||||||
|
permuted_token_selected_experts_tensor,
|
||||||
|
permuted_data_tensor,
|
||||||
|
expert_first_token_offset_tensor,
|
||||||
|
permuted_token_final_scales_tensor,
|
||||||
|
src_to_dest_map_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.library.register_fake("trtllm::moe_finalize_scale_op")
|
||||||
|
def _(
|
||||||
|
gemm2_output: torch.Tensor,
|
||||||
|
fc2_expert_biases: torch.Tensor,
|
||||||
|
unpermuted_final_scales: torch.Tensor,
|
||||||
|
expanded_source_row_to_expanded_dest_row: torch.Tensor,
|
||||||
|
expert_for_source_row: torch.Tensor,
|
||||||
|
expert_first_token_offset_tensor: torch.Tensor,
|
||||||
|
num_rows: torch.SymInt,
|
||||||
|
hidden_size: torch.SymInt,
|
||||||
|
experts_per_token: int,
|
||||||
|
num_experts_per_node: int,
|
||||||
|
tp_size: int,
|
||||||
|
tp_rank: int,
|
||||||
|
ep_size: int,
|
||||||
|
ep_rank: int,
|
||||||
|
):
|
||||||
|
num_rows_val = int(num_rows)
|
||||||
|
hidden_size_val = int(hidden_size)
|
||||||
|
return gemm2_output.new_empty((num_rows_val, hidden_size_val),
|
||||||
|
dtype=gemm2_output.dtype)
|
||||||
|
|||||||
@ -81,12 +81,13 @@ class MoERunner(TunableRunner):
|
|||||||
tactic: int = -1,
|
tactic: int = -1,
|
||||||
do_preparation: bool = False,
|
do_preparation: bool = False,
|
||||||
):
|
):
|
||||||
x, fc1_expert_weights, fc2_expert_weights = inputs
|
x, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs
|
||||||
# determine if we should use min latency mode according to the profiled seq len
|
|
||||||
self.fused_moe_runner.run_gemm_profile(
|
self.fused_moe_runner.run_gemm_profile(
|
||||||
x,
|
x,
|
||||||
fc1_expert_weights,
|
fc1_expert_weights,
|
||||||
|
fc1_expert_biases,
|
||||||
fc2_expert_weights,
|
fc2_expert_weights,
|
||||||
|
fc2_expert_biases,
|
||||||
self.top_k,
|
self.top_k,
|
||||||
self.tp_size,
|
self.tp_size,
|
||||||
self.tp_rank,
|
self.tp_rank,
|
||||||
@ -117,7 +118,9 @@ def fused_moe(
|
|||||||
token_selected_experts: torch.Tensor,
|
token_selected_experts: torch.Tensor,
|
||||||
token_final_scales: torch.Tensor,
|
token_final_scales: torch.Tensor,
|
||||||
fc1_expert_weights: torch.Tensor,
|
fc1_expert_weights: torch.Tensor,
|
||||||
|
fc1_expert_biases: Optional[torch.Tensor],
|
||||||
fc2_expert_weights: torch.Tensor,
|
fc2_expert_weights: torch.Tensor,
|
||||||
|
fc2_expert_biases: Optional[torch.Tensor],
|
||||||
output_dtype: torch.dtype,
|
output_dtype: torch.dtype,
|
||||||
quant_scales: List[torch.Tensor],
|
quant_scales: List[torch.Tensor],
|
||||||
input_sf: Optional[torch.Tensor] = None,
|
input_sf: Optional[torch.Tensor] = None,
|
||||||
@ -159,7 +162,10 @@ def fused_moe(
|
|||||||
"trtllm::fused_moe::gemm1",
|
"trtllm::fused_moe::gemm1",
|
||||||
[moe_runner],
|
[moe_runner],
|
||||||
MoERunner.tuning_config,
|
MoERunner.tuning_config,
|
||||||
[input, fc1_expert_weights, fc2_expert_weights],
|
[
|
||||||
|
input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights,
|
||||||
|
fc2_expert_biases
|
||||||
|
],
|
||||||
gemm_idx=1,
|
gemm_idx=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -167,7 +173,10 @@ def fused_moe(
|
|||||||
"trtllm::fused_moe::gemm2",
|
"trtllm::fused_moe::gemm2",
|
||||||
[moe_runner],
|
[moe_runner],
|
||||||
MoERunner.tuning_config,
|
MoERunner.tuning_config,
|
||||||
[input, fc1_expert_weights, fc2_expert_weights],
|
[
|
||||||
|
input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights,
|
||||||
|
fc2_expert_biases
|
||||||
|
],
|
||||||
gemm_idx=2,
|
gemm_idx=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -177,7 +186,9 @@ def fused_moe(
|
|||||||
token_selected_experts,
|
token_selected_experts,
|
||||||
token_final_scales,
|
token_final_scales,
|
||||||
fc1_expert_weights,
|
fc1_expert_weights,
|
||||||
|
fc1_expert_biases,
|
||||||
fc2_expert_weights,
|
fc2_expert_weights,
|
||||||
|
fc2_expert_biases,
|
||||||
quant_scales,
|
quant_scales,
|
||||||
input_sf,
|
input_sf,
|
||||||
tp_size,
|
tp_size,
|
||||||
@ -200,7 +211,9 @@ def _(
|
|||||||
token_selected_experts: torch.Tensor,
|
token_selected_experts: torch.Tensor,
|
||||||
token_final_scales: torch.Tensor,
|
token_final_scales: torch.Tensor,
|
||||||
fc1_expert_weights: torch.Tensor,
|
fc1_expert_weights: torch.Tensor,
|
||||||
|
fc1_expert_biases: Optional[torch.Tensor],
|
||||||
fc2_expert_weights: torch.Tensor,
|
fc2_expert_weights: torch.Tensor,
|
||||||
|
fc2_expert_biases: Optional[torch.Tensor],
|
||||||
output_dtype: torch.dtype,
|
output_dtype: torch.dtype,
|
||||||
quant_scales: List[torch.Tensor],
|
quant_scales: List[torch.Tensor],
|
||||||
input_sf: Optional[torch.Tensor] = None,
|
input_sf: Optional[torch.Tensor] = None,
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from .create_moe import create_moe, get_moe_cls
|
from .create_moe import create_moe, get_moe_cls
|
||||||
|
from .fused_moe_cute_dsl import CuteDslFusedMoE
|
||||||
from .fused_moe_cutlass import CutlassFusedMoE
|
from .fused_moe_cutlass import CutlassFusedMoE
|
||||||
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
|
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
|
||||||
from .fused_moe_vanilla import VanillaMoE
|
from .fused_moe_vanilla import VanillaMoE
|
||||||
@ -17,6 +18,7 @@ from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod,
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseMoeRoutingMethod",
|
"BaseMoeRoutingMethod",
|
||||||
"create_moe",
|
"create_moe",
|
||||||
|
"CuteDslFusedMoE",
|
||||||
"CutlassFusedMoE",
|
"CutlassFusedMoE",
|
||||||
"DeepSeekV3MoeRoutingMethod",
|
"DeepSeekV3MoeRoutingMethod",
|
||||||
"DefaultMoeRoutingMethod",
|
"DefaultMoeRoutingMethod",
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from tensorrt_llm.logger import logger
|
|||||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||||
|
|
||||||
from ...model_config import ModelConfig
|
from ...model_config import ModelConfig
|
||||||
|
from .fused_moe_cute_dsl import CuteDslFusedMoE
|
||||||
from .fused_moe_cutlass import CutlassFusedMoE
|
from .fused_moe_cutlass import CutlassFusedMoE
|
||||||
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
|
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
|
||||||
from .fused_moe_vanilla import VanillaMoE
|
from .fused_moe_vanilla import VanillaMoE
|
||||||
@ -28,6 +29,8 @@ def get_moe_cls(
|
|||||||
return CutlassFusedMoE
|
return CutlassFusedMoE
|
||||||
elif moe_backend.upper() == "VANILLA":
|
elif moe_backend.upper() == "VANILLA":
|
||||||
return VanillaMoE
|
return VanillaMoE
|
||||||
|
elif moe_backend.upper() == "CUTEDSL":
|
||||||
|
return CuteDslFusedMoE
|
||||||
elif moe_backend.upper() == "TRTLLM":
|
elif moe_backend.upper() == "TRTLLM":
|
||||||
if quant_config is not None and (
|
if quant_config is not None and (
|
||||||
quant_config.quant_mode.has_fp8_block_scales()
|
quant_config.quant_mode.has_fp8_block_scales()
|
||||||
@ -122,5 +125,19 @@ def create_moe(
|
|||||||
weight_loading_mode=weight_loading_mode,
|
weight_loading_mode=weight_loading_mode,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
elif moe_cls == CuteDslFusedMoE:
|
||||||
|
return moe_cls(
|
||||||
|
routing_method=routing_method,
|
||||||
|
num_experts=num_experts,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
dtype=dtype,
|
||||||
|
reduce_results=reduce_results,
|
||||||
|
model_config=model_config,
|
||||||
|
aux_stream=aux_stream,
|
||||||
|
weight_loading_mode=weight_loading_mode,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported moe backend: {moe_cls}")
|
raise ValueError(f"Unsupported moe backend: {moe_cls}")
|
||||||
|
|||||||
262
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py
Normal file
262
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from tensorrt_llm._utils import get_sm_version
|
||||||
|
|
||||||
|
from ...distributed import allgather
|
||||||
|
from ...model_config import ModelConfig
|
||||||
|
from ...utils import Fp4QuantizedTensor, disable_fp4_allgather, reswizzle_sf
|
||||||
|
from .fused_moe_cutlass import CutlassFusedMoE
|
||||||
|
from .quantization import MoEWeightLoadingMode
|
||||||
|
from .routing import BaseMoeRoutingMethod
|
||||||
|
|
||||||
|
|
||||||
|
def swiglu_fused_moe(x):
|
||||||
|
x, gate = x.chunk(2, dim=-1)
|
||||||
|
return F.silu(gate) * x
|
||||||
|
|
||||||
|
|
||||||
|
def cute_dsl_fp8_group_blockwise_gemm_ref(
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
a_sf: torch.Tensor,
|
||||||
|
b_sf: torch.Tensor,
|
||||||
|
offset_array: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
m, k = a.shape[0], a.shape[1]
|
||||||
|
l, n, k = b.shape[0], b.shape[1], b.shape[2]
|
||||||
|
num_group, w_n, w_k = b_sf.shape[0], b_sf.shape[1], b_sf.shape[2]
|
||||||
|
|
||||||
|
# Note: view(int8) will cause error.
|
||||||
|
a_tmp = a.as_strided((m, k, 1), (k, 1, m * k))
|
||||||
|
b_tmp = b.permute(1, 2, 0)
|
||||||
|
|
||||||
|
# Note: we have different output scale shape for fp8_quantize_1x128, so we need to handle it differently for sm100 and other archs.
|
||||||
|
if get_sm_version() == 100:
|
||||||
|
input_scale_tmp = a_sf.permute(1, 0).as_strided((m, w_k, 1),
|
||||||
|
(1, m, m * w_k))
|
||||||
|
else:
|
||||||
|
m_padded = (m + 3) // 4 * 4
|
||||||
|
input_scale_tmp = a_sf[0:m_padded * w_k]
|
||||||
|
input_scale_tmp = input_scale_tmp.reshape(-1, m_padded)
|
||||||
|
input_scale_tmp = input_scale_tmp[:w_k, :m].contiguous().permute(1, 0)
|
||||||
|
input_scale_tmp = input_scale_tmp.as_strided((m, w_k, 1),
|
||||||
|
(1, m, m * w_k))
|
||||||
|
|
||||||
|
weight_scale_tmp = b_sf.permute(1, 2, 0)
|
||||||
|
|
||||||
|
def pad_and_multiply(scale, tensor):
|
||||||
|
cm, ck, _ = scale.shape
|
||||||
|
m, k, _ = tensor.shape
|
||||||
|
IsGroupWise = False
|
||||||
|
IsBlockWise = False
|
||||||
|
if ck == math.ceil(k / 128):
|
||||||
|
IsGroupWise = True
|
||||||
|
if cm == math.ceil(m / 128):
|
||||||
|
IsBlockWise = True
|
||||||
|
if not IsBlockWise and not IsGroupWise:
|
||||||
|
raise ValueError("Only support granularity = 128")
|
||||||
|
|
||||||
|
k_idx = torch.arange(k, device=scale.device)
|
||||||
|
if IsGroupWise:
|
||||||
|
k_idx = k_idx // 128
|
||||||
|
m_idx = torch.arange(m, device=scale.device)
|
||||||
|
if IsBlockWise:
|
||||||
|
m_idx = m_idx // 128
|
||||||
|
expanded_scale = scale[m_idx[:, None], k_idx, :]
|
||||||
|
|
||||||
|
result = expanded_scale * tensor
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
updated_a = pad_and_multiply(input_scale_tmp, a_tmp.to(torch.float32))
|
||||||
|
updated_b = pad_and_multiply(weight_scale_tmp, b_tmp.to(torch.float32))
|
||||||
|
|
||||||
|
ref = torch.zeros((m, n), device="cuda", dtype=torch.float32)
|
||||||
|
|
||||||
|
len_offset_array = offset_array.shape[0]
|
||||||
|
for i in range(len_offset_array - 1):
|
||||||
|
start = offset_array[i]
|
||||||
|
end = offset_array[i + 1]
|
||||||
|
# assert start <= end, f"Invalid group boundaries: start={start} > end={end}"
|
||||||
|
ref[start:end, :] = torch.einsum("mk,nk->mn", updated_a[start:end, :,
|
||||||
|
0],
|
||||||
|
updated_b[:, :, i])
|
||||||
|
ref = ref.to(torch.bfloat16)
|
||||||
|
return ref
|
||||||
|
|
||||||
|
|
||||||
|
class CuteDslFusedMoE(CutlassFusedMoE):
|
||||||
|
"""
|
||||||
|
Python Flow of Fused Mixture of Experts (MoE) Layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_experts (int): Number of experts in the MoE layer.
|
||||||
|
top_k (int): Number of top experts to select for each input token.
|
||||||
|
hidden_size (int): Size of the hidden state.
|
||||||
|
intermediate_size (int): Size of the intermediate state.
|
||||||
|
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
|
||||||
|
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||||
|
reduce_results (bool): Whether to reduce the results across devices.
|
||||||
|
model_config (ModelConfig): Configuration object for the model.
|
||||||
|
|
||||||
|
This backend is composed of multiple custom ops:
|
||||||
|
1. moe_permute_op: permute the input tensor and the expert selected tensor.
|
||||||
|
2. cute_dsl_fp8_group_blockwise_gemm_ref: a reference implementation of the cute_dsl_fp8_group_blockwise_gemm.
|
||||||
|
3. moe_finalize_scale_op: finalize the scale of the output tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
routing_method: BaseMoeRoutingMethod,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
reduce_results: bool = False,
|
||||||
|
model_config: ModelConfig = ModelConfig(),
|
||||||
|
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||||
|
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
|
||||||
|
VANILLA,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
layer_idx: Optional[int] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
routing_method=routing_method,
|
||||||
|
num_experts=num_experts,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
dtype=dtype,
|
||||||
|
reduce_results=reduce_results,
|
||||||
|
model_config=model_config,
|
||||||
|
aux_stream=aux_stream,
|
||||||
|
weight_loading_mode=weight_loading_mode,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_chunk(
|
||||||
|
self,
|
||||||
|
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
output_dtype: Optional[torch.dtype] = None,
|
||||||
|
all_rank_num_tokens: Optional[List[int]] = None,
|
||||||
|
use_dp_padding: Optional[bool] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if isinstance(x, Fp4QuantizedTensor):
|
||||||
|
assert output_dtype is not None
|
||||||
|
output_dtype = output_dtype
|
||||||
|
else:
|
||||||
|
output_dtype = x.dtype
|
||||||
|
|
||||||
|
# apply routing
|
||||||
|
token_selected_experts, token_final_scales = self.routing_method.apply(
|
||||||
|
router_logits)
|
||||||
|
assert token_selected_experts.shape[
|
||||||
|
1] == self.routing_method.experts_per_token
|
||||||
|
assert token_selected_experts.shape == token_final_scales.shape
|
||||||
|
assert token_selected_experts.shape[0] == router_logits.shape[0]
|
||||||
|
assert token_final_scales.dtype == torch.float32
|
||||||
|
assert token_selected_experts.dtype == torch.int32
|
||||||
|
|
||||||
|
if self.apply_router_weight_on_input:
|
||||||
|
assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing"
|
||||||
|
assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
|
||||||
|
x = x * token_final_scales.to(x.dtype)
|
||||||
|
# TODO: remove this once we have correct fusedmoe kernel ready
|
||||||
|
token_final_scales = None
|
||||||
|
|
||||||
|
# quantize inputs
|
||||||
|
use_deepseek_fp8_block_scale = False
|
||||||
|
weight_dtype = self.w3_w1_weight.dtype
|
||||||
|
x_sf = None
|
||||||
|
if self.has_any_quant:
|
||||||
|
if self.has_deepseek_fp8_block_scales:
|
||||||
|
use_deepseek_fp8_block_scale = True
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# gather inputs for attention dp
|
||||||
|
if self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather(
|
||||||
|
):
|
||||||
|
x, x_sf, token_selected_experts, token_final_scales = allgather(
|
||||||
|
[x, x_sf, token_selected_experts, token_final_scales],
|
||||||
|
self.mapping,
|
||||||
|
dim=0,
|
||||||
|
sizes=None if use_dp_padding else all_rank_num_tokens)
|
||||||
|
# Fp4 gemm has extra scaling factor
|
||||||
|
if x_sf is not None:
|
||||||
|
x_sf = reswizzle_sf(x_sf, x_row, x_col,
|
||||||
|
self.scaling_vector_size)
|
||||||
|
|
||||||
|
(
|
||||||
|
unpermuted_token_selected_experts_tensor,
|
||||||
|
unpermuted_source_token_ids_tensor,
|
||||||
|
permuted_source_token_ids_tensor,
|
||||||
|
permuted_token_selected_experts_tensor,
|
||||||
|
permuted_data_tensor,
|
||||||
|
expert_first_token_offset_tensor,
|
||||||
|
permuted_token_final_scales_tensor,
|
||||||
|
src_to_dest_map_tensor,
|
||||||
|
) = torch.ops.trtllm.moe_permute_op(
|
||||||
|
x,
|
||||||
|
token_selected_experts,
|
||||||
|
token_final_scales,
|
||||||
|
None, # w3_w1_weight.view(weight_dtype),
|
||||||
|
None, # w2_weight.view(weight_dtype),
|
||||||
|
None, # quant_scales,
|
||||||
|
input_sf=x_sf,
|
||||||
|
num_experts_on_rank=self.expert_size_per_partition,
|
||||||
|
tp_size=self.tp_size,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
ep_size=self.ep_size,
|
||||||
|
ep_rank=self.ep_rank,
|
||||||
|
cluster_size=self.cluster_size,
|
||||||
|
cluster_rank=self.cluster_rank,
|
||||||
|
min_latency_mode=False,
|
||||||
|
use_fp8_block_scaling=use_deepseek_fp8_block_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
|
||||||
|
permuted_data_tensor)
|
||||||
|
h1 = cute_dsl_fp8_group_blockwise_gemm_ref(
|
||||||
|
a=act_input_fp8,
|
||||||
|
b=self.w3_w1_weight.view(weight_dtype),
|
||||||
|
a_sf=act_input_sf,
|
||||||
|
b_sf=self.quant_scales[0],
|
||||||
|
offset_array=expert_first_token_offset_tensor,
|
||||||
|
)
|
||||||
|
h2 = swiglu_fused_moe(h1)
|
||||||
|
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(h2)
|
||||||
|
h3 = cute_dsl_fp8_group_blockwise_gemm_ref(
|
||||||
|
a=act_input_fp8,
|
||||||
|
b=self.w2_weight.view(weight_dtype),
|
||||||
|
a_sf=act_input_sf,
|
||||||
|
b_sf=self.quant_scales[1],
|
||||||
|
offset_array=expert_first_token_offset_tensor,
|
||||||
|
)
|
||||||
|
final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op(
|
||||||
|
h3,
|
||||||
|
None,
|
||||||
|
token_final_scales,
|
||||||
|
src_to_dest_map_tensor,
|
||||||
|
unpermuted_token_selected_experts_tensor,
|
||||||
|
expert_first_token_offset_tensor,
|
||||||
|
x.shape[0], # num_rows
|
||||||
|
x.shape[1], # hidden_size
|
||||||
|
self.routing_method.top_k,
|
||||||
|
self.expert_size_per_partition, # num_experts_per_node
|
||||||
|
self.tp_size,
|
||||||
|
self.tp_rank,
|
||||||
|
self.ep_size,
|
||||||
|
self.ep_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_hidden_states
|
||||||
@ -269,7 +269,9 @@ class CutlassFusedMoE(MoE):
|
|||||||
token_selected_experts,
|
token_selected_experts,
|
||||||
token_final_scales,
|
token_final_scales,
|
||||||
self.w3_w1_weight.view(weight_dtype),
|
self.w3_w1_weight.view(weight_dtype),
|
||||||
|
None, # fc1_expert_biases
|
||||||
self.w2_weight.view(weight_dtype),
|
self.w2_weight.view(weight_dtype),
|
||||||
|
None, # fc2_expert_biases
|
||||||
output_dtype,
|
output_dtype,
|
||||||
quant_scales=self.quant_scales,
|
quant_scales=self.quant_scales,
|
||||||
input_sf=x_sf,
|
input_sf=x_sf,
|
||||||
|
|||||||
@ -597,7 +597,9 @@ class WideEPMoE(MoE):
|
|||||||
token_selected_slots,
|
token_selected_slots,
|
||||||
token_final_scales,
|
token_final_scales,
|
||||||
w3_w1_weight.view(weight_dtype),
|
w3_w1_weight.view(weight_dtype),
|
||||||
|
None, # w3_w1_bias
|
||||||
w2_weight.view(weight_dtype),
|
w2_weight.view(weight_dtype),
|
||||||
|
None, # w2_bias
|
||||||
output_dtype,
|
output_dtype,
|
||||||
quant_scales=quant_scales,
|
quant_scales=quant_scales,
|
||||||
input_sf=x_sf,
|
input_sf=x_sf,
|
||||||
|
|||||||
@ -651,6 +651,68 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
|||||||
task = GSM8K(self.MODEL_NAME)
|
task = GSM8K(self.MODEL_NAME)
|
||||||
task.evaluate(llm)
|
task.evaluate(llm)
|
||||||
|
|
||||||
|
@skip_no_hopper
|
||||||
|
@parametrize_with_ids("torch_compile", [False])
|
||||||
|
@parametrize_with_ids(
|
||||||
|
"fp8kv,attention_dp,cuda_graph,overlap_scheduler",
|
||||||
|
[(False, False, False, False)],
|
||||||
|
)
|
||||||
|
@parametrize_with_ids("mtp_nextn", [0])
|
||||||
|
def test_cute_dsl_fp8_block_scales(
|
||||||
|
self,
|
||||||
|
mtp_nextn,
|
||||||
|
fp8kv,
|
||||||
|
attention_dp,
|
||||||
|
cuda_graph,
|
||||||
|
overlap_scheduler,
|
||||||
|
torch_compile,
|
||||||
|
):
|
||||||
|
if torch_compile and mtp_nextn > 0:
|
||||||
|
pytest.skip("https://nvbugs/5252313")
|
||||||
|
if torch_compile and attention_dp:
|
||||||
|
pytest.skip("https://nvbugs/5252559")
|
||||||
|
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||||
|
torch_compile_config = (TorchCompileConfig(
|
||||||
|
enable_fullgraph=True, enable_piecewise_cuda_graph=cuda_graph)
|
||||||
|
if torch_compile else None)
|
||||||
|
pytorch_config = dict(
|
||||||
|
disable_overlap_scheduler=not overlap_scheduler,
|
||||||
|
use_cuda_graph=cuda_graph,
|
||||||
|
torch_compile_config=torch_compile_config,
|
||||||
|
moe_backend="CUTEDSL",
|
||||||
|
)
|
||||||
|
|
||||||
|
quant_config = QuantConfig()
|
||||||
|
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||||
|
if fp8kv:
|
||||||
|
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||||
|
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||||
|
|
||||||
|
mtp_config = None
|
||||||
|
if mtp_nextn > 0:
|
||||||
|
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
**pytorch_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
enable_attention_dp=attention_dp,
|
||||||
|
speculative_config=mtp_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
|
||||||
|
if fp8kv:
|
||||||
|
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||||
|
|
||||||
|
with llm:
|
||||||
|
# No need to run MMLU for fp8kv
|
||||||
|
if not fp8kv:
|
||||||
|
task = MMLU(self.MODEL_NAME)
|
||||||
|
task.evaluate(llm)
|
||||||
|
task = GSM8K(self.MODEL_NAME)
|
||||||
|
task.evaluate(llm)
|
||||||
|
|
||||||
@pytest.mark.skip_device_not_contain(["H100"])
|
@pytest.mark.skip_device_not_contain(["H100"])
|
||||||
@parametrize_with_ids("mtp_nextn", [0, 2])
|
@parametrize_with_ids("mtp_nextn", [0, 2])
|
||||||
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
|
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
|
||||||
@ -775,6 +837,82 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
|||||||
task = GSM8K(self.MODEL_NAME)
|
task = GSM8K(self.MODEL_NAME)
|
||||||
task.evaluate(llm)
|
task.evaluate(llm)
|
||||||
|
|
||||||
|
@pytest.mark.skip_less_device(4)
|
||||||
|
@skip_no_hopper
|
||||||
|
@parametrize_with_ids("torch_compile", [False])
|
||||||
|
@parametrize_with_ids(
|
||||||
|
"fp8kv,attention_dp,cuda_graph,overlap_scheduler",
|
||||||
|
[(False, False, False, False)],
|
||||||
|
)
|
||||||
|
@parametrize_with_ids("mtp_nextn", [0])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tp_size,pp_size,ep_size",
|
||||||
|
[(4, 1, 1), (4, 1, 4), (2, 2, 1), (1, 4, 1)],
|
||||||
|
ids=["tp4", "ep4", "tp2pp2", "pp4"],
|
||||||
|
)
|
||||||
|
def test_cute_dsl_fp8_block_scales_4gpus(
|
||||||
|
self,
|
||||||
|
tp_size,
|
||||||
|
pp_size,
|
||||||
|
ep_size,
|
||||||
|
mtp_nextn,
|
||||||
|
fp8kv,
|
||||||
|
attention_dp,
|
||||||
|
cuda_graph,
|
||||||
|
overlap_scheduler,
|
||||||
|
torch_compile,
|
||||||
|
):
|
||||||
|
if torch_compile and mtp_nextn > 0:
|
||||||
|
pytest.skip("https://nvbugs/5252313")
|
||||||
|
if torch_compile and attention_dp:
|
||||||
|
pytest.skip("https://nvbugs/5252559")
|
||||||
|
if torch_compile and pp_size > 1:
|
||||||
|
pytest.skip("PP with torch.compile is not supported yet.")
|
||||||
|
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||||
|
torch_compile_config = (TorchCompileConfig(
|
||||||
|
enable_fullgraph=True, enable_piecewise_cuda_graph=cuda_graph)
|
||||||
|
if torch_compile else None)
|
||||||
|
pytorch_config = dict(
|
||||||
|
disable_overlap_scheduler=not overlap_scheduler,
|
||||||
|
use_cuda_graph=cuda_graph,
|
||||||
|
torch_compile_config=torch_compile_config,
|
||||||
|
moe_backend="CUTEDSL",
|
||||||
|
)
|
||||||
|
|
||||||
|
quant_config = QuantConfig()
|
||||||
|
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||||
|
if fp8kv:
|
||||||
|
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||||
|
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||||
|
|
||||||
|
mtp_config = None
|
||||||
|
if mtp_nextn > 0:
|
||||||
|
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
pipeline_parallel_size=pp_size,
|
||||||
|
moe_expert_parallel_size=ep_size,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
**pytorch_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
enable_attention_dp=attention_dp,
|
||||||
|
speculative_config=mtp_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
|
||||||
|
if fp8kv:
|
||||||
|
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||||
|
|
||||||
|
with llm:
|
||||||
|
# No need to run MMLU for fp8kv
|
||||||
|
if not fp8kv:
|
||||||
|
task = MMLU(self.MODEL_NAME)
|
||||||
|
task.evaluate(llm)
|
||||||
|
task = GSM8K(self.MODEL_NAME)
|
||||||
|
task.evaluate(llm)
|
||||||
|
|
||||||
@pytest.mark.skip_less_device(4)
|
@pytest.mark.skip_less_device(4)
|
||||||
@pytest.mark.skip_device_not_contain(["H100", "H200"])
|
@pytest.mark.skip_device_not_contain(["H100", "H200"])
|
||||||
def test_fp8_block_scales_4gpus_static_eplb(self):
|
def test_fp8_block_scales_4gpus_static_eplb(self):
|
||||||
|
|||||||
@ -4,14 +4,17 @@ from itertools import product
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
import _torch.helpers
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from _torch.helpers import per_block_cast_to_fp8
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
from mpi4py.futures import MPIPoolExecutor
|
from mpi4py.futures import MPIPoolExecutor
|
||||||
from utils.util import (skip_neither_ada_nor_hopper_unittest,
|
from utils.util import (skip_neither_ada_nor_hopper_unittest,
|
||||||
skip_pre_blackwell, skip_pre_hopper)
|
skip_non_hopper_unittest, skip_pre_blackwell,
|
||||||
|
skip_pre_hopper)
|
||||||
|
|
||||||
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
||||||
from tensorrt_llm._torch.model_config import ModelConfig
|
from tensorrt_llm._torch.model_config import ModelConfig
|
||||||
@ -20,6 +23,8 @@ from tensorrt_llm._torch.modules.fused_moe import (BaseMoeRoutingMethod,
|
|||||||
DefaultMoeRoutingMethod,
|
DefaultMoeRoutingMethod,
|
||||||
RenormalizeMoeRoutingMethod,
|
RenormalizeMoeRoutingMethod,
|
||||||
VanillaMoE, WideEPMoE)
|
VanillaMoE, WideEPMoE)
|
||||||
|
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import \
|
||||||
|
CuteDslFusedMoE
|
||||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import \
|
from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import \
|
||||||
AlltoallMethodType
|
AlltoallMethodType
|
||||||
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
|
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
|
||||||
@ -28,6 +33,7 @@ from tensorrt_llm.mapping import Mapping
|
|||||||
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
|
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
|
||||||
|
|
||||||
cloudpickle.register_pickle_by_value(sys.modules[__name__])
|
cloudpickle.register_pickle_by_value(sys.modules[__name__])
|
||||||
|
cloudpickle.register_pickle_by_value(_torch.helpers)
|
||||||
MPI.pickle.__init__(
|
MPI.pickle.__init__(
|
||||||
cloudpickle.dumps,
|
cloudpickle.dumps,
|
||||||
cloudpickle.loads,
|
cloudpickle.loads,
|
||||||
@ -314,6 +320,196 @@ def test_fused_moe_fp8(dtype):
|
|||||||
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
|
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def set_tensor_value_2(x, num_row, num_cols):
|
||||||
|
# Create 2x2 base pattern matrix
|
||||||
|
pattern = torch.tensor([[0.2, -0.5], [-0.3, 0.1]], device=x.device)
|
||||||
|
|
||||||
|
# Repeat pattern to cover entire matrix
|
||||||
|
repeated = pattern.repeat((num_row + 1) // 2,
|
||||||
|
(num_cols + 1) // 2)[:num_row, :num_cols]
|
||||||
|
|
||||||
|
x.copy_(repeated)
|
||||||
|
|
||||||
|
|
||||||
|
def set_tensor_value_3(x, num_row, num_cols):
|
||||||
|
# Create 3x3 base pattern matrix
|
||||||
|
pattern = torch.tensor(
|
||||||
|
[[0.1, 0.21, 0.31], [0.3, 0.6, 0.1], [0.11, 0.51, 0.62]],
|
||||||
|
device=x.device)
|
||||||
|
|
||||||
|
# Repeat pattern to cover entire matrix
|
||||||
|
repeated = pattern.repeat((num_row + 2) // 3,
|
||||||
|
(num_cols + 2) // 3)[:num_row, :num_cols]
|
||||||
|
|
||||||
|
x.copy_(repeated)
|
||||||
|
|
||||||
|
|
||||||
|
def set_tensor_value_4(x, num_row, num_cols):
|
||||||
|
# Create 4x4 base pattern matrix
|
||||||
|
pattern = torch.tensor(
|
||||||
|
[
|
||||||
|
[0.1, 0.21, 0.31, 0.41],
|
||||||
|
[0.3, 0.6, 0.1, 0.2],
|
||||||
|
[0.11, 0.51, 0.61, 0.71],
|
||||||
|
[0.11, 0.52, 0.62, 0.72],
|
||||||
|
],
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Repeat pattern to cover entire matrix
|
||||||
|
repeated = pattern.repeat((num_row + 3) // 4,
|
||||||
|
(num_cols + 3) // 4)[:num_row, :num_cols]
|
||||||
|
|
||||||
|
x.copy_(repeated)
|
||||||
|
|
||||||
|
|
||||||
|
@skip_non_hopper_unittest
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls",
|
||||||
|
product(
|
||||||
|
[torch.bfloat16],
|
||||||
|
[72],
|
||||||
|
[128, 256, 384, 512, 1024, 2048, 4096, 8192],
|
||||||
|
[2560],
|
||||||
|
[DefaultMoeRoutingMethod],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_fused_moe_fp8_blockwise(dtype,
|
||||||
|
num_experts,
|
||||||
|
seq_len,
|
||||||
|
hidden_size,
|
||||||
|
RoutingMethodCls,
|
||||||
|
mapping=None):
|
||||||
|
SEQ_LEN = seq_len
|
||||||
|
HIDDEN_SIZE = hidden_size
|
||||||
|
INTERMEDIATE_SIZE = 1536
|
||||||
|
NUM_EXPERTS = num_experts
|
||||||
|
TOP_K = 6
|
||||||
|
|
||||||
|
routing_method = RoutingMethodCls(top_k=TOP_K)
|
||||||
|
|
||||||
|
mapping = mapping or Mapping()
|
||||||
|
mapping.rank = mpi_rank()
|
||||||
|
torch.cuda.set_device(mapping.rank)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
torch.cuda.manual_seed(0)
|
||||||
|
|
||||||
|
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
|
||||||
|
# Note: we use some special values init x and weight, otherwise the test will false positive failed.
|
||||||
|
set_tensor_value_2(x, SEQ_LEN, HIDDEN_SIZE)
|
||||||
|
|
||||||
|
x = x.cuda()
|
||||||
|
router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda()
|
||||||
|
|
||||||
|
weights = {}
|
||||||
|
for expert_id in range(NUM_EXPERTS):
|
||||||
|
w1_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
|
||||||
|
dtype=dtype).cuda()
|
||||||
|
w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE),
|
||||||
|
dtype=dtype).cuda()
|
||||||
|
w3_weight = torch.randn((INTERMEDIATE_SIZE, HIDDEN_SIZE),
|
||||||
|
dtype=dtype).cuda()
|
||||||
|
set_tensor_value_3(w1_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE)
|
||||||
|
set_tensor_value_4(w2_weight, HIDDEN_SIZE, INTERMEDIATE_SIZE)
|
||||||
|
set_tensor_value_3(w3_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE)
|
||||||
|
|
||||||
|
w1_weight_fp8, w1_weight_scale = per_block_cast_to_fp8(w1_weight)
|
||||||
|
w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda()
|
||||||
|
|
||||||
|
w2_weight_fp8, w2_weight_scale = per_block_cast_to_fp8(w2_weight)
|
||||||
|
w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda()
|
||||||
|
|
||||||
|
w3_weight_fp8, w3_weight_scale = per_block_cast_to_fp8(w3_weight)
|
||||||
|
w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda()
|
||||||
|
|
||||||
|
weights[f"{expert_id}.w1.weight"] = w1_weight_fp8
|
||||||
|
weights[f"{expert_id}.w2.weight"] = w2_weight_fp8
|
||||||
|
weights[f"{expert_id}.w3.weight"] = w3_weight_fp8
|
||||||
|
weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale
|
||||||
|
weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale
|
||||||
|
weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale
|
||||||
|
weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale
|
||||||
|
weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale
|
||||||
|
weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale
|
||||||
|
|
||||||
|
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES)
|
||||||
|
|
||||||
|
fused_moe = CuteDslFusedMoE(
|
||||||
|
num_experts=NUM_EXPERTS,
|
||||||
|
routing_method=routing_method,
|
||||||
|
hidden_size=HIDDEN_SIZE,
|
||||||
|
intermediate_size=INTERMEDIATE_SIZE,
|
||||||
|
dtype=dtype,
|
||||||
|
reduce_results=True,
|
||||||
|
model_config=ModelConfig(quant_config=quant_config, mapping=mapping),
|
||||||
|
)
|
||||||
|
fused_moe.cuda()
|
||||||
|
fused_moe.load_weights([weights])
|
||||||
|
|
||||||
|
fused_moe_origin = CutlassFusedMoE(
|
||||||
|
num_experts=NUM_EXPERTS,
|
||||||
|
routing_method=routing_method,
|
||||||
|
hidden_size=HIDDEN_SIZE,
|
||||||
|
intermediate_size=INTERMEDIATE_SIZE,
|
||||||
|
dtype=dtype,
|
||||||
|
reduce_results=True,
|
||||||
|
model_config=ModelConfig(quant_config=quant_config, mapping=mapping),
|
||||||
|
)
|
||||||
|
fused_moe_origin.cuda()
|
||||||
|
fused_moe_origin.load_weights([weights])
|
||||||
|
|
||||||
|
ref_fused_moe = RefGatedMLPFusedMoE(
|
||||||
|
num_experts=NUM_EXPERTS,
|
||||||
|
routing_method=routing_method,
|
||||||
|
hidden_size=HIDDEN_SIZE,
|
||||||
|
intermediate_size=INTERMEDIATE_SIZE,
|
||||||
|
dtype=dtype,
|
||||||
|
model_config=ModelConfig(quant_config=quant_config),
|
||||||
|
)
|
||||||
|
ref_fused_moe.load_weights([weights])
|
||||||
|
ref_fused_moe.cuda()
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
output = fused_moe.forward(x, router_logits)
|
||||||
|
output_origin = fused_moe_origin.forward(x, router_logits)
|
||||||
|
ref_output = ref_fused_moe.forward(x, router_logits)
|
||||||
|
|
||||||
|
# compare
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.testing.assert_close(output_origin, output, rtol=1e-2, atol=0.1)
|
||||||
|
torch.testing.assert_close(output_origin, ref_output, rtol=1e-2, atol=0.1)
|
||||||
|
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@skip_non_hopper_unittest
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||||
|
reason="needs 4 GPUs to run this test")
|
||||||
|
@pytest.mark.parametrize("ep_size", [1, 2, 4])
|
||||||
|
@pytest.mark.parametrize("routing_method", [DefaultMoeRoutingMethod])
|
||||||
|
def test_fused_moe_fp8_blockwise_multi_gpu(ep_size, routing_method):
|
||||||
|
world_size = 4
|
||||||
|
with MPIPoolExecutor(max_workers=world_size) as executor:
|
||||||
|
results = executor.map(
|
||||||
|
test_fused_moe_fp8_blockwise,
|
||||||
|
*zip(*[(
|
||||||
|
torch.bfloat16,
|
||||||
|
72,
|
||||||
|
384,
|
||||||
|
384,
|
||||||
|
routing_method,
|
||||||
|
Mapping(
|
||||||
|
world_size=world_size,
|
||||||
|
tp_size=world_size,
|
||||||
|
moe_ep_size=ep_size,
|
||||||
|
moe_tp_size=world_size // ep_size,
|
||||||
|
),
|
||||||
|
)] * world_size),
|
||||||
|
)
|
||||||
|
for r in results:
|
||||||
|
assert r is True
|
||||||
|
|
||||||
|
|
||||||
@skip_pre_blackwell
|
@skip_pre_blackwell
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
def test_fused_moe_nvfp4(dtype):
|
def test_fused_moe_nvfp4(dtype):
|
||||||
@ -649,6 +845,14 @@ class RefGatedMLPFusedMoE(nn.Module):
|
|||||||
f"{expert}.w3.weight_scale_2"]
|
f"{expert}.w3.weight_scale_2"]
|
||||||
down_proj_weights[0]['weight_scale_2'] = weights[
|
down_proj_weights[0]['weight_scale_2'] = weights[
|
||||||
f"{expert}.w2.weight_scale_2"]
|
f"{expert}.w2.weight_scale_2"]
|
||||||
|
elif (self.quant_config and self.quant_config.quant_algo
|
||||||
|
== QuantAlgo.FP8_BLOCK_SCALES):
|
||||||
|
gate_up_proj_weights[0]["weight_scale"] = weights[
|
||||||
|
f"{expert}.w1.weight_scale"]
|
||||||
|
gate_up_proj_weights[1]["weight_scale"] = weights[
|
||||||
|
f"{expert}.w3.weight_scale"]
|
||||||
|
down_proj_weights[0]["weight_scale"] = weights[
|
||||||
|
f"{expert}.w2.weight_scale"]
|
||||||
|
|
||||||
self.experts[expert].gate_up_proj.load_weights(gate_up_proj_weights)
|
self.experts[expert].gate_up_proj.load_weights(gate_up_proj_weights)
|
||||||
self.experts[expert].down_proj.load_weights(down_proj_weights)
|
self.experts[expert].down_proj.load_weights(down_proj_weights)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user