mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: Expose bias and FP8_MXFP4 MOE CUTLASS backend features to pytorch (#5410)
Signed-off-by: Daniel Stokes <40156487+djns99@users.noreply.github.com>
This commit is contained in:
parent
ef43b95aa1
commit
83a1f60556
@ -52,7 +52,7 @@ using profiler_backend = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::GemmProfilerBackend
|
||||
class FusedMoeRunner : public torch::CustomClassHolder
|
||||
{
|
||||
public:
|
||||
template <typename Type, bool NeedQuant = false>
|
||||
template <typename TypeAct, typename TypeWeight, bool NeedQuant = false>
|
||||
std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> switch_output_type(c10::ScalarType output_type)
|
||||
{
|
||||
switch (output_type)
|
||||
@ -66,21 +66,22 @@ public:
|
||||
case c10::ScalarType::Half:
|
||||
if constexpr (NeedQuant)
|
||||
{
|
||||
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, half, half>>();
|
||||
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, half>>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, half, Type>>();
|
||||
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, TypeAct>>();
|
||||
}
|
||||
#ifdef ENABLE_BF16
|
||||
case c10::ScalarType::BFloat16:
|
||||
if constexpr (NeedQuant)
|
||||
{
|
||||
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, __nv_bfloat16, __nv_bfloat16>>();
|
||||
return std::make_unique<
|
||||
kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16, __nv_bfloat16>>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type, __nv_bfloat16, Type>>();
|
||||
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16, TypeAct>>();
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
@ -121,10 +122,16 @@ public:
|
||||
#ifdef ENABLE_FP8
|
||||
if (isFp8Quant())
|
||||
{
|
||||
mKernelRunner = switch_output_type<__nv_fp8_e4m3>(mOutputDtype);
|
||||
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3>(mOutputDtype);
|
||||
}
|
||||
#endif
|
||||
#ifdef ENABLE_FP4
|
||||
if (isWFp4AFp8Quant())
|
||||
{
|
||||
mInnerDimMultiplier = 16; // 16 FP4 -> 1 LONG
|
||||
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp4_e2m1>(mOutputDtype);
|
||||
}
|
||||
|
||||
if (isNvfp4Quant())
|
||||
{
|
||||
mInnerDimMultiplier = 16;
|
||||
@ -134,9 +141,9 @@ public:
|
||||
#ifdef ENABLE_BF16
|
||||
case c10::ScalarType::BFloat16:
|
||||
#endif
|
||||
mKernelRunner = switch_output_type<__nv_fp4_e2m1, true>(mOutputDtype);
|
||||
mKernelRunner = switch_output_type<__nv_fp4_e2m1, __nv_fp4_e2m1, true>(mOutputDtype);
|
||||
break;
|
||||
default: mKernelRunner = switch_output_type<__nv_fp4_e2m1, false>(mOutputDtype);
|
||||
default: mKernelRunner = switch_output_type<__nv_fp4_e2m1, __nv_fp4_e2m1, false>(mOutputDtype);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@ -204,11 +211,13 @@ public:
|
||||
void operator=(FusedMoeRunner const&) = delete;
|
||||
|
||||
torch::Tensor runMoe(torch::Tensor const& input, torch::Tensor const& token_selected_experts,
|
||||
torch::optional<torch::Tensor> token_final_scales, torch::Tensor const& fc1_expert_weights,
|
||||
torch::Tensor const& fc2_expert_weights, torch::optional<c10::ArrayRef<torch::Tensor>> quant_scales,
|
||||
torch::optional<torch::Tensor> input_sf, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size,
|
||||
int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall,
|
||||
bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> profile_ids)
|
||||
torch::optional<torch::Tensor> const& token_final_scales, torch::Tensor const& fc1_expert_weights,
|
||||
torch::optional<torch::Tensor> const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights,
|
||||
torch::optional<torch::Tensor> const& fc2_expert_biases,
|
||||
torch::optional<c10::ArrayRef<torch::Tensor>> const& quant_scales,
|
||||
torch::optional<torch::Tensor> const& input_sf, int64_t const tp_size, int64_t const tp_rank,
|
||||
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
|
||||
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
// Free the profile workspace to save memory
|
||||
@ -230,6 +239,23 @@ public:
|
||||
|
||||
TORCH_CHECK(fc1_expert_weights.dim() == 3, "fc1_expert_weights must be 3D.");
|
||||
TORCH_CHECK(fc2_expert_weights.dim() == 3, "fc2_expert_weights must be 3D.");
|
||||
|
||||
if (fc1_expert_biases.has_value() || fc2_expert_biases.has_value())
|
||||
{
|
||||
CHECK_INPUT(fc1_expert_biases.value(), mOutputDtype);
|
||||
CHECK_INPUT(fc2_expert_biases.value(), mOutputDtype);
|
||||
TORCH_CHECK(fc1_expert_biases.value().dim() == 2, "fc1_expert_biases must be 2D.");
|
||||
TORCH_CHECK(fc2_expert_biases.value().dim() == 2, "fc2_expert_biases must be 2D.");
|
||||
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc1_expert_biases.value().sizes()[0],
|
||||
"fc1_expert_weights and fc1_expert_biases must have the same number of experts.");
|
||||
TORCH_CHECK(fc2_expert_weights.sizes()[0] == fc2_expert_biases.value().sizes()[0],
|
||||
"fc2_expert_weights and fc2_expert_biases must have the same number of experts.");
|
||||
TORCH_CHECK(fc1_expert_biases.value().sizes()[1] == fc1_expert_weights.sizes()[1],
|
||||
"fc1_expert_biases should match fc1_expert_weights output shape.");
|
||||
TORCH_CHECK(fc2_expert_biases.value().sizes()[1] == fc2_expert_weights.sizes()[1],
|
||||
"fc2_expert_biases should match fc2_expert_weights output shape.");
|
||||
}
|
||||
|
||||
TORCH_CHECK(input.sizes()[0] == token_selected_experts.sizes()[0],
|
||||
"input and token_selected_experts must have the same num tokens.");
|
||||
if (token_final_scales)
|
||||
@ -275,8 +301,11 @@ public:
|
||||
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
||||
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
||||
: nullptr,
|
||||
fc1_expert_weights.const_data_ptr(), nullptr, activation_type, fc2_expert_weights.const_data_ptr(), nullptr,
|
||||
quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
||||
fc1_expert_weights.const_data_ptr(),
|
||||
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type,
|
||||
fc2_expert_weights.const_data_ptr(),
|
||||
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
|
||||
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
||||
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
|
||||
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
|
||||
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
|
||||
@ -286,8 +315,11 @@ public:
|
||||
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
||||
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
||||
: nullptr,
|
||||
fc1_expert_weights.const_data_ptr(), nullptr, activation_type, fc2_expert_weights.const_data_ptr(), nullptr,
|
||||
quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
||||
fc1_expert_weights.const_data_ptr(),
|
||||
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type,
|
||||
fc2_expert_weights.const_data_ptr(),
|
||||
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
|
||||
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
||||
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
|
||||
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
|
||||
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
|
||||
@ -297,12 +329,13 @@ public:
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> runMoeMinLantency(torch::Tensor const& input,
|
||||
torch::Tensor const& token_selected_experts, torch::optional<torch::Tensor> token_final_scales,
|
||||
torch::Tensor const& fc1_expert_weights, torch::Tensor const& fc2_expert_weights,
|
||||
torch::optional<c10::ArrayRef<torch::Tensor>> quant_scales, torch::optional<torch::Tensor> input_sf,
|
||||
int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank,
|
||||
int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode,
|
||||
torch::optional<c10::ArrayRef<int64_t>> profile_ids)
|
||||
torch::Tensor const& token_selected_experts, torch::optional<torch::Tensor> const& token_final_scales,
|
||||
torch::Tensor const& fc1_expert_weights, torch::optional<torch::Tensor> const& fc1_expert_biases,
|
||||
torch::Tensor const& fc2_expert_weights, torch::optional<torch::Tensor> const& fc2_expert_biases,
|
||||
torch::optional<c10::ArrayRef<torch::Tensor>> const& quant_scales,
|
||||
torch::optional<torch::Tensor> const& input_sf, int64_t const tp_size, int64_t const tp_rank,
|
||||
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
|
||||
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
|
||||
@ -323,6 +356,23 @@ public:
|
||||
|
||||
TORCH_CHECK(fc1_expert_weights.dim() == 3, "fc1_expert_weights must be 3D.");
|
||||
TORCH_CHECK(fc2_expert_weights.dim() == 3, "fc2_expert_weights must be 3D.");
|
||||
|
||||
if (fc1_expert_biases.has_value() || fc2_expert_biases.has_value())
|
||||
{
|
||||
CHECK_INPUT(fc1_expert_biases.value(), mOutputDtype);
|
||||
CHECK_INPUT(fc2_expert_biases.value(), mOutputDtype);
|
||||
TORCH_CHECK(fc1_expert_biases.value().dim() == 2, "fc1_expert_biases must be 2D.");
|
||||
TORCH_CHECK(fc2_expert_biases.value().dim() == 2, "fc2_expert_biases must be 2D.");
|
||||
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc1_expert_biases.value().sizes()[0],
|
||||
"fc1_expert_weights and fc1_expert_biases must have the same number of experts.");
|
||||
TORCH_CHECK(fc2_expert_weights.sizes()[0] == fc2_expert_biases.value().sizes()[0],
|
||||
"fc2_expert_weights and fc2_expert_biases must have the same number of experts.");
|
||||
TORCH_CHECK(fc1_expert_biases.value().sizes()[1] == fc1_expert_weights.sizes()[1],
|
||||
"fc1_expert_biases should match fc1_expert_weights output shape.");
|
||||
TORCH_CHECK(fc2_expert_biases.value().sizes()[1] == fc2_expert_weights.sizes()[1],
|
||||
"fc2_expert_biases should match fc2_expert_weights output shape.");
|
||||
}
|
||||
|
||||
TORCH_CHECK(input.sizes()[0] == token_selected_experts.sizes()[0],
|
||||
"input and token_selected_experts must have the same num tokens.");
|
||||
if (token_final_scales)
|
||||
@ -378,8 +428,11 @@ public:
|
||||
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
||||
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
||||
: nullptr,
|
||||
fc1_expert_weights.const_data_ptr(), nullptr, activation_type, fc2_expert_weights.const_data_ptr(), nullptr,
|
||||
quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
||||
fc1_expert_weights.const_data_ptr(),
|
||||
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type,
|
||||
fc2_expert_weights.const_data_ptr(),
|
||||
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
|
||||
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
||||
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
|
||||
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
|
||||
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
|
||||
@ -389,8 +442,11 @@ public:
|
||||
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
|
||||
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
|
||||
: nullptr,
|
||||
fc1_expert_weights.const_data_ptr(), nullptr, activation_type, fc2_expert_weights.const_data_ptr(), nullptr,
|
||||
quant_params, num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
||||
fc1_expert_weights.const_data_ptr(),
|
||||
fc1_expert_biases.has_value() ? fc1_expert_biases.value().const_data_ptr() : nullptr, activation_type,
|
||||
fc2_expert_weights.const_data_ptr(),
|
||||
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
|
||||
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
||||
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
|
||||
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
|
||||
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
|
||||
@ -406,10 +462,11 @@ public:
|
||||
}
|
||||
|
||||
void runGemmProfile(torch::Tensor const& input, torch::Tensor const& fc1_expert_weights,
|
||||
torch::Tensor const& fc2_expert_weights, int64_t const top_k, int64_t const tp_size, int64_t const tp_rank,
|
||||
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
|
||||
bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx, int64_t const profile_id,
|
||||
bool const do_preparation)
|
||||
torch::optional<torch::Tensor> const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights,
|
||||
torch::optional<torch::Tensor> const& fc2_expert_biases, int64_t const top_k, int64_t const tp_size,
|
||||
int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
|
||||
int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx,
|
||||
int64_t const profile_id, bool const do_preparation)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
|
||||
@ -447,7 +504,7 @@ public:
|
||||
static_cast<int>(tp_rank), static_cast<int>(ep_size), static_cast<int>(ep_rank),
|
||||
static_cast<int>(cluster_size), static_cast<int>(cluster_rank));
|
||||
|
||||
bool const USE_BIAS = false;
|
||||
bool const USE_BIAS = fc1_expert_biases.has_value() || fc2_expert_biases.has_value();
|
||||
bool const USE_LORA = false;
|
||||
auto activation_dtype = mUseW4A8GroupScaling ? at::ScalarType::Float8_e4m3fn : mActivationDtype;
|
||||
activation_dtype = isNvfp4Quant() ? at::ScalarType::Long : activation_dtype;
|
||||
@ -592,6 +649,52 @@ private:
|
||||
static_cast<float const*>(fc2_quant.data_ptr()), static_cast<float const*>(fc2_dequant.data_ptr()),
|
||||
/* fp8 output quant scale */ nullptr, static_cast<float const*>(fc1_input_dequant.data_ptr()));
|
||||
}
|
||||
|
||||
else if (isWFp4AFp8Quant())
|
||||
{
|
||||
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for WFP4AFP8 quantization");
|
||||
TORCH_CHECK(quant_scales.value().size() == 5, "Expecting 5 quant scales for WFP4AFP8 quantization");
|
||||
|
||||
auto const fc1_weight_block = quant_scales.value()[0];
|
||||
auto const fc1_global = quant_scales.value()[1];
|
||||
auto const fc2_act_global = quant_scales.value()[2];
|
||||
auto const fc2_weight_block = quant_scales.value()[3];
|
||||
auto const fc2_global = quant_scales.value()[4];
|
||||
|
||||
// The input for scale fc1_weight_block / fc2_weight_block is packed into INT32
|
||||
constexpr int FP8_PER_INT32 = 4;
|
||||
CHECK_INPUT(fc1_weight_block, c10::ScalarType::Int);
|
||||
CHECK_INPUT(fc1_global, c10::ScalarType::Float);
|
||||
CHECK_INPUT(fc2_act_global, c10::ScalarType::Float);
|
||||
CHECK_INPUT(fc2_weight_block, c10::ScalarType::Int);
|
||||
CHECK_INPUT(fc2_global, c10::ScalarType::Float);
|
||||
TORCH_CHECK(fc1_weight_block.dim() == 3, "fc1 weight block must be #D");
|
||||
TORCH_CHECK(fc1_global.dim() == 1, "fc1 global must be 1D");
|
||||
TORCH_CHECK(fc2_act_global.dim() == 0, "fc2 act global must be a scalar tensor");
|
||||
TORCH_CHECK(fc2_weight_block.dim() == 3, "fc2 weight block must be 3D");
|
||||
TORCH_CHECK(fc2_global.dim() == 1, "fc2 global must be 1D");
|
||||
TORCH_CHECK(fc1_weight_block.sizes()[0] == num_experts_on_rank
|
||||
&& fc1_weight_block.sizes()[1] == inter_size * 2
|
||||
&& fc1_weight_block.sizes()[2] * FP8_PER_INT32
|
||||
* TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize
|
||||
== hidden_size,
|
||||
"fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // "
|
||||
"block_scale_vector_size)");
|
||||
TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)");
|
||||
TORCH_CHECK(fc2_weight_block.sizes()[0] == num_experts_on_rank && fc2_weight_block.sizes()[1] == hidden_size
|
||||
&& fc2_weight_block.sizes()[2] * FP8_PER_INT32
|
||||
* TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize
|
||||
== inter_size,
|
||||
"fc2 weight block size must be (num_experts_on_rank, hidden_size, inter_size // 4 // "
|
||||
"block_scale_vector_size)");
|
||||
TORCH_CHECK(fc2_global.sizes()[0] == num_experts_on_rank, "fc2 global size must be (num_experts_on_rank,)");
|
||||
|
||||
return kernels::QuantParams::FP8MXFP4(nullptr,
|
||||
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc1_weight_block.data_ptr()),
|
||||
static_cast<float const*>(fc1_global.data_ptr()), static_cast<float const*>(fc2_act_global.data_ptr()),
|
||||
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc2_weight_block.data_ptr()),
|
||||
static_cast<float const*>(fc2_global.data_ptr()));
|
||||
}
|
||||
else if (isNvfp4Quant())
|
||||
{
|
||||
TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for nvfp4 quantization");
|
||||
@ -683,7 +786,8 @@ private:
|
||||
|
||||
bool isNvfp4Quant() const
|
||||
{
|
||||
return mWeightDtype == c10::ScalarType::Long;
|
||||
return mWeightDtype == c10::ScalarType::Long
|
||||
&& mActivationDtype != c10::ScalarType::Float8_e4m3fn; // FP8 activation does not use FP4
|
||||
}
|
||||
|
||||
bool isInt4Quant() const
|
||||
@ -695,6 +799,11 @@ private:
|
||||
{
|
||||
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant();
|
||||
}
|
||||
|
||||
bool isWFp4AFp8Quant() const
|
||||
{
|
||||
return mActivationDtype == c10::ScalarType::Float8_e4m3fn && mWeightDtype == c10::ScalarType::Long;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
@ -21,7 +21,9 @@ def trtllm_fused_moe(
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
w3_w1_stacked_weight,
|
||||
None, # w3_w1_stacked_bias
|
||||
w2_stacked_weight,
|
||||
None, # w2_stacked_bias
|
||||
x.dtype,
|
||||
quant_scales,
|
||||
tp_size=1,
|
||||
|
||||
@ -81,12 +81,13 @@ class MoERunner(TunableRunner):
|
||||
tactic: int = -1,
|
||||
do_preparation: bool = False,
|
||||
):
|
||||
x, fc1_expert_weights, fc2_expert_weights = inputs
|
||||
# determine if we should use min latency mode according to the profiled seq len
|
||||
x, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs
|
||||
self.fused_moe_runner.run_gemm_profile(
|
||||
x,
|
||||
fc1_expert_weights,
|
||||
fc1_expert_biases,
|
||||
fc2_expert_weights,
|
||||
fc2_expert_biases,
|
||||
self.top_k,
|
||||
self.tp_size,
|
||||
self.tp_rank,
|
||||
@ -117,7 +118,9 @@ def fused_moe(
|
||||
token_selected_experts: torch.Tensor,
|
||||
token_final_scales: torch.Tensor,
|
||||
fc1_expert_weights: torch.Tensor,
|
||||
fc1_expert_biases: Optional[torch.Tensor],
|
||||
fc2_expert_weights: torch.Tensor,
|
||||
fc2_expert_biases: Optional[torch.Tensor],
|
||||
output_dtype: torch.dtype,
|
||||
quant_scales: List[torch.Tensor],
|
||||
input_sf: Optional[torch.Tensor] = None,
|
||||
@ -159,7 +162,10 @@ def fused_moe(
|
||||
"trtllm::fused_moe::gemm1",
|
||||
[moe_runner],
|
||||
MoERunner.tuning_config,
|
||||
[input, fc1_expert_weights, fc2_expert_weights],
|
||||
[
|
||||
input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights,
|
||||
fc2_expert_biases
|
||||
],
|
||||
gemm_idx=1,
|
||||
)
|
||||
|
||||
@ -167,7 +173,10 @@ def fused_moe(
|
||||
"trtllm::fused_moe::gemm2",
|
||||
[moe_runner],
|
||||
MoERunner.tuning_config,
|
||||
[input, fc1_expert_weights, fc2_expert_weights],
|
||||
[
|
||||
input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights,
|
||||
fc2_expert_biases
|
||||
],
|
||||
gemm_idx=2,
|
||||
)
|
||||
|
||||
@ -177,7 +186,9 @@ def fused_moe(
|
||||
token_selected_experts,
|
||||
token_final_scales,
|
||||
fc1_expert_weights,
|
||||
fc1_expert_biases,
|
||||
fc2_expert_weights,
|
||||
fc2_expert_biases,
|
||||
quant_scales,
|
||||
input_sf,
|
||||
tp_size,
|
||||
@ -200,7 +211,9 @@ def _(
|
||||
token_selected_experts: torch.Tensor,
|
||||
token_final_scales: torch.Tensor,
|
||||
fc1_expert_weights: torch.Tensor,
|
||||
fc1_expert_biases: Optional[torch.Tensor],
|
||||
fc2_expert_weights: torch.Tensor,
|
||||
fc2_expert_biases: Optional[torch.Tensor],
|
||||
output_dtype: torch.dtype,
|
||||
quant_scales: List[torch.Tensor],
|
||||
input_sf: Optional[torch.Tensor] = None,
|
||||
|
||||
@ -269,7 +269,9 @@ class CutlassFusedMoE(MoE):
|
||||
token_selected_experts,
|
||||
token_final_scales,
|
||||
self.w3_w1_weight.view(weight_dtype),
|
||||
None, # fc1_expert_biases
|
||||
self.w2_weight.view(weight_dtype),
|
||||
None, # fc2_expert_biases
|
||||
output_dtype,
|
||||
quant_scales=self.quant_scales,
|
||||
input_sf=x_sf,
|
||||
|
||||
@ -592,7 +592,9 @@ class WideEPMoE(MoE):
|
||||
token_selected_slots,
|
||||
token_final_scales,
|
||||
w3_w1_weight.view(weight_dtype),
|
||||
None, # w3_w1_bias
|
||||
w2_weight.view(weight_dtype),
|
||||
None, # w2_bias
|
||||
output_dtype,
|
||||
quant_scales=quant_scales,
|
||||
input_sf=x_sf,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user