[Kernel][MoE] Add GELU_TANH to CPU, CUTLASS, and WNA16 MoE backends (#42027)

Signed-off-by: lesj0610 <lesj0610@users.noreply.github.com>
Co-authored-by: lesj0610 <lesj0610@users.noreply.github.com>
This commit is contained in:
SeongJun Lee
2026-06-03 06:12:08 +09:00
committed by GitHub
parent e15f20258b
commit 3099de3617
7 changed files with 119 additions and 7 deletions
+49 -1
View File
@@ -30,7 +30,12 @@
}()
namespace {
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul, GeluAndMul };
enum class FusedMOEAct {
SiluAndMul,
SwigluOAIAndMul,
GeluAndMul,
GeluTanhAndMul,
};
FusedMOEAct get_act_type(const std::string& act) {
if (act == "silu") {
@@ -39,6 +44,8 @@ FusedMOEAct get_act_type(const std::string& act) {
return FusedMOEAct::SwigluOAIAndMul;
} else if (act == "gelu") {
return FusedMOEAct::GeluAndMul;
} else if (act == "gelu_tanh") {
return FusedMOEAct::GeluTanhAndMul;
} else {
TORCH_CHECK(false, "Invalid act type: " + act);
}
@@ -143,6 +150,44 @@ void gelu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
}
}
template <typename scalar_t>
void gelu_tanh_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
const int32_t m_size, const int32_t n_size,
const int32_t input_stride,
const int32_t output_stride) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
const int32_t dim = n_size / 2;
float* __restrict__ gate = input;
float* __restrict__ up = input + dim;
vec_op::FP32Vec16 one_vec(1.0);
vec_op::FP32Vec16 w1_vec(0.7978845608028654);
vec_op::FP32Vec16 w2_vec(0.5);
vec_op::FP32Vec16 w3_vec(0.044715);
alignas(64) float temp[16];
for (int32_t m = 0; m < m_size; ++m) {
for (int32_t n = 0; n < dim; n += 16) {
vec_op::FP32Vec16 gate_vec(gate + n);
vec_op::FP32Vec16 up_vec(up + n);
auto gate_pow3_vec = gate_vec * gate_vec * gate_vec;
auto inner_vec = w1_vec * (gate_vec + w3_vec * gate_pow3_vec);
inner_vec.save(temp);
for (int32_t i = 0; i < 16; ++i) {
temp[i] = std::tanh(temp[i]);
}
vec_op::FP32Vec16 tanh_vec(temp);
auto gelu_tanh = gate_vec * w2_vec * (one_vec + tanh_vec);
auto gated_output_fp32 = up_vec * gelu_tanh;
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
gated_output.save(output + n);
}
gate += input_stride;
up += input_stride;
output += output_stride;
}
}
template <typename scalar_t>
FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
float* __restrict__ input,
@@ -160,6 +205,9 @@ FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
case FusedMOEAct::GeluAndMul:
gelu_and_mul(input, output, m, n, input_stride, output_stride);
return;
case FusedMOEAct::GeluTanhAndMul:
gelu_tanh_and_mul(input, output, m, n, input_stride, output_stride);
return;
default:
TORCH_CHECK(false, "Unsupported act type.");
}
+6 -1
View File
@@ -20,7 +20,12 @@ EXPERT_NUM = [
HIDDEN_DIM = [128, 2880]
INTERMEDIATE_DIM = [128, 2880]
BATCH_SIZE = [1, 64, 256]
ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI, MoEActivation.GELU]
ACT = [
MoEActivation.SILU,
MoEActivation.SWIGLUOAI,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
]
USE_BIAS = [True, False]
ISA = ["amx", "vec"] if torch.cpu._is_amx_tile_supported() else ["vec"]
DTYPE = [torch.bfloat16]
+7
View File
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import (
CutlassExpertsFp4,
CutlassExpertsFp8,
run_cutlass_moe_fp8,
)
@@ -52,6 +53,12 @@ MNK_FACTORS = [
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
def test_cutlass_moe_supports_gelu_tanh_activation_metadata():
assert CutlassExpertsFp8._supports_activation(MoEActivation.GELU_TANH)
assert CutlassExpertsFp4._supports_activation(MoEActivation.GELU_TANH)
assert CutlassExpertsFp4._supports_activation(MoEActivation.GELU_TANH_NO_MUL)
@dataclasses.dataclass
class MOETensors:
a: torch.Tensor
+49
View File
@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import SimpleNamespace
import pytest
import torch
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
from vllm.platforms import current_platform
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test on CUDA")
def test_moe_wna16_apply_passes_layer_activation(monkeypatch):
captured_kwargs = {}
def fake_fused_experts(*args, **kwargs):
captured_kwargs.update(kwargs)
return torch.empty(1, 2)
monkeypatch.setattr(
"vllm.model_executor.layers.fused_moe.fused_experts",
fake_fused_experts,
)
method = object.__new__(MoeWNA16Method)
method.moe = SimpleNamespace(disable_inplace=False)
method.moe_quant_config = object()
layer = SimpleNamespace(
w13_qweight=torch.empty(1, 2),
w2_qweight=torch.empty(1, 2),
activation=MoEActivation.GELU_TANH,
apply_router_weight_on_input=False,
global_num_experts=1,
expert_map=None,
)
output = method.apply(
layer,
x=torch.empty(1, 2),
topk_weights=torch.empty(1, 1),
topk_ids=torch.empty(1, 1, dtype=torch.int32),
shared_experts=None,
shared_experts_input=None,
)
assert output.shape == (1, 2)
assert captured_kwargs["activation"] is MoEActivation.GELU_TANH
@@ -53,6 +53,10 @@ _CPU_MOE_ACT_FN: dict[MoEActivation, Callable[[torch.Tensor], torch.Tensor]] = {
MoEActivation.SILU: lambda x: SiluAndMul(compile_native=False).forward_native(x),
MoEActivation.SWIGLUOAI: _swigluoai_forward_native,
MoEActivation.GELU: _gelu_and_mul,
MoEActivation.GELU_TANH: (
lambda x: F.gelu(x[..., : x.shape[-1] // 2], approximate="tanh")
* x[..., x.shape[-1] // 2 :]
),
}
@@ -322,6 +322,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
MoEActivation.SWIGLUOAI,
]
@@ -719,10 +720,12 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.GELU_TANH_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe import (
RoutedExperts,
SharedExperts,
)
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
int4_w4a16_moe_quant_config,
@@ -367,16 +366,13 @@ class MoeWNA16Method(FusedMoEMethodBase):
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)
return fused_experts(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,