mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Kernel][MoE] Add GELU_TANH to CPU, CUTLASS, and WNA16 MoE backends (#42027)
Signed-off-by: lesj0610 <lesj0610@users.noreply.github.com> Co-authored-by: lesj0610 <lesj0610@users.noreply.github.com>
This commit is contained in:
@@ -30,7 +30,12 @@
|
||||
}()
|
||||
|
||||
namespace {
|
||||
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul, GeluAndMul };
|
||||
enum class FusedMOEAct {
|
||||
SiluAndMul,
|
||||
SwigluOAIAndMul,
|
||||
GeluAndMul,
|
||||
GeluTanhAndMul,
|
||||
};
|
||||
|
||||
FusedMOEAct get_act_type(const std::string& act) {
|
||||
if (act == "silu") {
|
||||
@@ -39,6 +44,8 @@ FusedMOEAct get_act_type(const std::string& act) {
|
||||
return FusedMOEAct::SwigluOAIAndMul;
|
||||
} else if (act == "gelu") {
|
||||
return FusedMOEAct::GeluAndMul;
|
||||
} else if (act == "gelu_tanh") {
|
||||
return FusedMOEAct::GeluTanhAndMul;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid act type: " + act);
|
||||
}
|
||||
@@ -143,6 +150,44 @@ void gelu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void gelu_tanh_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
|
||||
const int32_t m_size, const int32_t n_size,
|
||||
const int32_t input_stride,
|
||||
const int32_t output_stride) {
|
||||
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
const int32_t dim = n_size / 2;
|
||||
float* __restrict__ gate = input;
|
||||
float* __restrict__ up = input + dim;
|
||||
vec_op::FP32Vec16 one_vec(1.0);
|
||||
vec_op::FP32Vec16 w1_vec(0.7978845608028654);
|
||||
vec_op::FP32Vec16 w2_vec(0.5);
|
||||
vec_op::FP32Vec16 w3_vec(0.044715);
|
||||
alignas(64) float temp[16];
|
||||
|
||||
for (int32_t m = 0; m < m_size; ++m) {
|
||||
for (int32_t n = 0; n < dim; n += 16) {
|
||||
vec_op::FP32Vec16 gate_vec(gate + n);
|
||||
vec_op::FP32Vec16 up_vec(up + n);
|
||||
auto gate_pow3_vec = gate_vec * gate_vec * gate_vec;
|
||||
auto inner_vec = w1_vec * (gate_vec + w3_vec * gate_pow3_vec);
|
||||
|
||||
inner_vec.save(temp);
|
||||
for (int32_t i = 0; i < 16; ++i) {
|
||||
temp[i] = std::tanh(temp[i]);
|
||||
}
|
||||
vec_op::FP32Vec16 tanh_vec(temp);
|
||||
auto gelu_tanh = gate_vec * w2_vec * (one_vec + tanh_vec);
|
||||
auto gated_output_fp32 = up_vec * gelu_tanh;
|
||||
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
|
||||
gated_output.save(output + n);
|
||||
}
|
||||
gate += input_stride;
|
||||
up += input_stride;
|
||||
output += output_stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
|
||||
float* __restrict__ input,
|
||||
@@ -160,6 +205,9 @@ FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
|
||||
case FusedMOEAct::GeluAndMul:
|
||||
gelu_and_mul(input, output, m, n, input_stride, output_stride);
|
||||
return;
|
||||
case FusedMOEAct::GeluTanhAndMul:
|
||||
gelu_tanh_and_mul(input, output, m, n, input_stride, output_stride);
|
||||
return;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported act type.");
|
||||
}
|
||||
|
||||
@@ -20,7 +20,12 @@ EXPERT_NUM = [
|
||||
HIDDEN_DIM = [128, 2880]
|
||||
INTERMEDIATE_DIM = [128, 2880]
|
||||
BATCH_SIZE = [1, 64, 256]
|
||||
ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI, MoEActivation.GELU]
|
||||
ACT = [
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
MoEActivation.GELU,
|
||||
MoEActivation.GELU_TANH,
|
||||
]
|
||||
USE_BIAS = [True, False]
|
||||
ISA = ["amx", "vec"] if torch.cpu._is_amx_tile_supported() else ["vec"]
|
||||
DTYPE = [torch.bfloat16]
|
||||
|
||||
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import (
|
||||
CutlassExpertsFp4,
|
||||
CutlassExpertsFp8,
|
||||
run_cutlass_moe_fp8,
|
||||
)
|
||||
@@ -52,6 +53,12 @@ MNK_FACTORS = [
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
|
||||
|
||||
def test_cutlass_moe_supports_gelu_tanh_activation_metadata():
|
||||
assert CutlassExpertsFp8._supports_activation(MoEActivation.GELU_TANH)
|
||||
assert CutlassExpertsFp4._supports_activation(MoEActivation.GELU_TANH)
|
||||
assert CutlassExpertsFp4._supports_activation(MoEActivation.GELU_TANH_NO_MUL)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MOETensors:
|
||||
a: torch.Tensor
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test on CUDA")
|
||||
def test_moe_wna16_apply_passes_layer_activation(monkeypatch):
|
||||
captured_kwargs = {}
|
||||
|
||||
def fake_fused_experts(*args, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return torch.empty(1, 2)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"vllm.model_executor.layers.fused_moe.fused_experts",
|
||||
fake_fused_experts,
|
||||
)
|
||||
|
||||
method = object.__new__(MoeWNA16Method)
|
||||
method.moe = SimpleNamespace(disable_inplace=False)
|
||||
method.moe_quant_config = object()
|
||||
layer = SimpleNamespace(
|
||||
w13_qweight=torch.empty(1, 2),
|
||||
w2_qweight=torch.empty(1, 2),
|
||||
activation=MoEActivation.GELU_TANH,
|
||||
apply_router_weight_on_input=False,
|
||||
global_num_experts=1,
|
||||
expert_map=None,
|
||||
)
|
||||
|
||||
output = method.apply(
|
||||
layer,
|
||||
x=torch.empty(1, 2),
|
||||
topk_weights=torch.empty(1, 1),
|
||||
topk_ids=torch.empty(1, 1, dtype=torch.int32),
|
||||
shared_experts=None,
|
||||
shared_experts_input=None,
|
||||
)
|
||||
|
||||
assert output.shape == (1, 2)
|
||||
assert captured_kwargs["activation"] is MoEActivation.GELU_TANH
|
||||
@@ -53,6 +53,10 @@ _CPU_MOE_ACT_FN: dict[MoEActivation, Callable[[torch.Tensor], torch.Tensor]] = {
|
||||
MoEActivation.SILU: lambda x: SiluAndMul(compile_native=False).forward_native(x),
|
||||
MoEActivation.SWIGLUOAI: _swigluoai_forward_native,
|
||||
MoEActivation.GELU: _gelu_and_mul,
|
||||
MoEActivation.GELU_TANH: (
|
||||
lambda x: F.gelu(x[..., : x.shape[-1] // 2], approximate="tanh")
|
||||
* x[..., x.shape[-1] // 2 :]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -322,6 +322,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
|
||||
return activation in [
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.GELU,
|
||||
MoEActivation.GELU_TANH,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
]
|
||||
|
||||
@@ -719,10 +720,12 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
|
||||
return activation in [
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.GELU,
|
||||
MoEActivation.GELU_TANH,
|
||||
MoEActivation.SWIGLUOAI,
|
||||
MoEActivation.SWIGLUSTEP,
|
||||
MoEActivation.SILU_NO_MUL,
|
||||
MoEActivation.GELU_NO_MUL,
|
||||
MoEActivation.GELU_TANH_NO_MUL,
|
||||
MoEActivation.RELU2_NO_MUL,
|
||||
]
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
RoutedExperts,
|
||||
SharedExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
int4_w4a16_moe_quant_config,
|
||||
@@ -367,16 +366,13 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
assert layer.activation == MoEActivation.SILU, (
|
||||
f"Only SiLU activation is supported, not {layer.activation}."
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
|
||||
Reference in New Issue
Block a user