mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[XPU] add scale transpose to prepare_fp8_moe_layer_for_xpu and bump up kernels (#43277)
Signed-off-by: mayuyuace <qiming1.zhang@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -17,4 +17,4 @@ torchaudio
|
||||
torchvision
|
||||
|
||||
auto_round_lib>=0.13.0
|
||||
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.8/vllm_xpu_kernels-0.1.8-cp38-abi3-manylinux_2_28_x86_64.whl
|
||||
vllm_xpu_kernels @ https://github.com/vllm-project/vllm-xpu-kernels/releases/download/v0.1.9/vllm_xpu_kernels-0.1.9-cp38-abi3-manylinux_2_28_x86_64.whl
|
||||
|
||||
@@ -29,9 +29,20 @@ if current_platform.is_xpu():
|
||||
|
||||
def prepare_fp8_moe_layer_for_xpu(
|
||||
w13: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return w13.transpose(-1, -2).contiguous(), w2.transpose(-1, -2).contiguous()
|
||||
w2_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if w13_scale is not None and w13_scale.ndim == 3:
|
||||
w13_scale = w13_scale.transpose(-1, -2).contiguous()
|
||||
if w2_scale is not None and w2_scale.ndim == 3:
|
||||
w2_scale = w2_scale.transpose(-1, -2).contiguous()
|
||||
return (
|
||||
w13.transpose(-1, -2).contiguous(),
|
||||
w13_scale,
|
||||
w2.transpose(-1, -2).contiguous(),
|
||||
w2_scale,
|
||||
)
|
||||
|
||||
|
||||
class XPUExperts(mk.FusedMoEExpertsModular):
|
||||
|
||||
@@ -483,7 +483,9 @@ def convert_to_fp8_moe_kernel_format(
|
||||
prepare_fp8_moe_layer_for_xpu,
|
||||
)
|
||||
|
||||
w13, w2 = prepare_fp8_moe_layer_for_xpu(w13, w2)
|
||||
w13, w13_scale, w2, w2_scale = prepare_fp8_moe_layer_for_xpu(
|
||||
w13, w13_scale, w2, w2_scale
|
||||
)
|
||||
elif fp8_backend == Fp8MoeBackend.CPU:
|
||||
from vllm.model_executor.layers.fused_moe.experts.cpu_moe import (
|
||||
prepare_fp8_moe_layer_for_cpu,
|
||||
|
||||
Reference in New Issue
Block a user