mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[CPU][Zen] Route W8A8 and W4A16 linear inference through zentorch on AMD Zen CPUs (#41813)
Signed-off-by: R <Ganesh.R@amd.com> Signed-off-by: Harshal Adhav <harshal.adhav@amd.com> Signed-off-by: Aakar Dwivedi <aadwived@amd.com> Co-authored-by: R <Ganesh.R@amd.com> Co-authored-by: Harshal Adhav <harshal.adhav@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -1165,9 +1165,7 @@ setup(
|
||||
install_requires=get_requirements(),
|
||||
extras_require={
|
||||
# AMD Zen CPU optimizations via zentorch
|
||||
"zen": [
|
||||
"zentorch-weekly==5.2.1.dev20260408"
|
||||
], # Zentorch has weekly releases. This pulls the known-good version.
|
||||
"zen": ["zentorch==2.11.0.0"],
|
||||
"bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"],
|
||||
"tensorizer": ["tensorizer==2.10.1"],
|
||||
"fastsafetensors": ["fastsafetensors >= 0.2.2"],
|
||||
|
||||
@@ -58,6 +58,9 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
|
||||
XPUW4A8IntLinearKernel,
|
||||
XPUwNa16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.zentorch import (
|
||||
ZentorchWNA16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mxfp4 import (
|
||||
MxFp4LinearKernel,
|
||||
MxFp4LinearLayerConfig,
|
||||
@@ -157,6 +160,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.triton import (
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.xpu import (
|
||||
XPUFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.zentorch import (
|
||||
ZentorchInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
@@ -254,7 +260,7 @@ def _filter_kernels_by_backend(
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CPU: [CPUInt8ScaledMMLinearKernel],
|
||||
PlatformEnum.CPU: [ZentorchInt8ScaledMMLinearKernel, CPUInt8ScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [
|
||||
CutlassInt8ScaledMMLinearKernel,
|
||||
TritonInt8ScaledMMLinearKernel,
|
||||
@@ -348,6 +354,7 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
||||
],
|
||||
PlatformEnum.CPU: [
|
||||
Dynamic4bitLinearKernel,
|
||||
ZentorchWNA16LinearKernel,
|
||||
CPUWNA16LinearKernel,
|
||||
],
|
||||
}
|
||||
@@ -1018,6 +1025,8 @@ __all__ = [
|
||||
"RowWiseTorchFP8ScaledMMLinearKernel",
|
||||
"ROCmFP8ScaledMMLinearKernel",
|
||||
"TritonInt8ScaledMMLinearKernel",
|
||||
"ZentorchInt8ScaledMMLinearKernel",
|
||||
"ZentorchWNA16LinearKernel",
|
||||
"MPLinearKernel",
|
||||
"MPLinearLayerConfig",
|
||||
"AllSparkLinearKernel",
|
||||
|
||||
@@ -36,6 +36,9 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
|
||||
XPUW4A8IntLinearKernel,
|
||||
XPUwNa16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.zentorch import (
|
||||
ZentorchWNA16LinearKernel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MPLinearKernel",
|
||||
@@ -51,4 +54,5 @@ __all__ = [
|
||||
"TritonW4A16LinearKernel",
|
||||
"XPUW4A8IntLinearKernel",
|
||||
"XPUwNa16LinearKernel",
|
||||
"ZentorchWNA16LinearKernel",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,211 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Zentorch W4A16 GPTQ weight-only-quantized linear kernel for AMD Zen CPUs.
|
||||
|
||||
Selected by ``choose_mp_linear_kernel`` ahead of the generic oneDNN-backed
|
||||
``CPUWNA16LinearKernel``. When ``can_implement`` rejects a layer, the selector
|
||||
falls through to the next kernel in ``_POSSIBLE_KERNELS[PlatformEnum.CPU]``.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear.zentorch_utils import has_zentorch_op
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .cpu import CPUWNA16LinearKernel
|
||||
from .MPLinearKernel import MPLinearLayerConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _import_unpack_from_int32():
|
||||
"""Import compressed-tensors' ``unpack_from_int32`` across versions."""
|
||||
try:
|
||||
from compressed_tensors.compressors.pack_quantized.helpers import (
|
||||
unpack_from_int32,
|
||||
)
|
||||
except ImportError:
|
||||
from compressed_tensors.compressors.quantized_compressors.pack_quantized import ( # type: ignore[import-not-found] # noqa: E501
|
||||
unpack_from_int32,
|
||||
)
|
||||
return unpack_from_int32
|
||||
|
||||
|
||||
class ZentorchWNA16LinearKernel(CPUWNA16LinearKernel):
|
||||
"""W4A16 GPTQ kernel backed by ``torch.ops.zentorch.zentorch_woq_linear``."""
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
ok, reason = super().can_implement(c)
|
||||
if not ok:
|
||||
return ok, reason
|
||||
|
||||
if not current_platform.is_zen_cpu():
|
||||
return False, "ZentorchWNA16 requires an AMD Zen CPU."
|
||||
|
||||
if not has_zentorch_op(["zentorch_woq_repack_weight", "zentorch_woq_linear"]):
|
||||
return (
|
||||
False,
|
||||
"torch.ops.zentorch.{zentorch_woq_repack_weight, "
|
||||
"zentorch_woq_linear} are not registered.",
|
||||
)
|
||||
|
||||
if c.has_g_idx:
|
||||
return False, "ZentorchWNA16 does not support activation re-ordering."
|
||||
return True, None
|
||||
|
||||
def _zentorch_woq_eligible(self, layer: torch.nn.Module) -> bool:
|
||||
"""Eligibility predicate for the zentorch W4A16 GPTQ fast path.
|
||||
|
||||
Constraints (any failure -> ``cpu_gemm_wna16`` path via ``super()``
|
||||
with ``layer`` untouched).
|
||||
"""
|
||||
if (
|
||||
self.w_gidx_name is not None
|
||||
and getattr(layer, self.w_gidx_name, None) is not None
|
||||
) or (getattr(self.config, "has_g_idx", False)):
|
||||
return False
|
||||
|
||||
weight_packed = getattr(layer, self.w_q_name, None)
|
||||
weight_scale = getattr(layer, self.w_s_name, None)
|
||||
if weight_packed is None or weight_scale is None:
|
||||
return False
|
||||
|
||||
bits = self.config.weight_type.mantissa
|
||||
pack_factor = torch.iinfo(weight_packed.dtype).bits // bits
|
||||
# 4-bit -> 8 values per int32;
|
||||
if pack_factor != 8:
|
||||
return False
|
||||
|
||||
# GPTQ-only. AWQ packs along the output dim instead.
|
||||
in_dim = getattr(weight_packed, "input_dim", None)
|
||||
pk_dim = getattr(weight_packed, "packed_dim", None)
|
||||
if in_dim is None or pk_dim is None or in_dim != pk_dim:
|
||||
return False
|
||||
|
||||
is_ct_format = in_dim == pk_dim == 1
|
||||
if not is_ct_format:
|
||||
return False
|
||||
|
||||
if weight_packed.dim() != 2 or weight_scale.dim() != 2:
|
||||
return False
|
||||
|
||||
# 4-bit -> 8 values per int32; in_features must be divisible by num_groups.
|
||||
in_features = weight_packed.shape[1] * 8
|
||||
num_groups = weight_scale.shape[1]
|
||||
return num_groups > 0 and in_features % num_groups == 0
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""Repack CT GPTQ weights into the zentorch WOQ layout.
|
||||
|
||||
Falls back to ``CPUWNA16LinearKernel.process_weights_after_loading``
|
||||
via ``super()`` when the layer doesn't satisfy
|
||||
``_zentorch_woq_eligible``.
|
||||
|
||||
On success, ``layer._zentorch_processed_weights`` is set to ``True``
|
||||
"""
|
||||
if getattr(layer, "_zentorch_processed_weights", False):
|
||||
return
|
||||
|
||||
if not self._zentorch_woq_eligible(layer):
|
||||
logger.info_once(
|
||||
"[zen_cpu] ZentorchWNA16 fast path not eligible for this "
|
||||
"layer (AWQ pack layout, g_idx, or non-int32 storage); "
|
||||
"falling back to CPUWNA16LinearKernel (cpu_gemm_wna16)."
|
||||
)
|
||||
super().process_weights_after_loading(layer)
|
||||
return
|
||||
|
||||
if (not self.config.zero_points) and (self.w_zp_name is not None):
|
||||
setattr(layer, self.w_zp_name, None)
|
||||
|
||||
if (not self.config.has_g_idx) and (self.w_gidx_name is not None):
|
||||
setattr(layer, self.w_gidx_name, None)
|
||||
|
||||
weight_q = getattr(layer, self.w_q_name)
|
||||
weight_s = getattr(layer, self.w_s_name)
|
||||
weight_packed = weight_q.data if hasattr(weight_q, "data") else weight_q
|
||||
weight_scale = weight_s.data if hasattr(weight_s, "data") else weight_s
|
||||
|
||||
bits = self.config.weight_type.mantissa
|
||||
pack_factor = torch.iinfo(weight_packed.dtype).bits // bits
|
||||
out_features, num_groups = weight_scale.shape[0], weight_scale.shape[1]
|
||||
in_features = weight_packed.shape[1] * pack_factor
|
||||
original_shape = torch.Size([out_features, in_features])
|
||||
unpack_from_int32 = _import_unpack_from_int32()
|
||||
repack_op = torch.ops.zentorch.zentorch_woq_repack_weight.default
|
||||
|
||||
weight_unpacked = unpack_from_int32(
|
||||
weight_packed,
|
||||
bits,
|
||||
original_shape,
|
||||
packed_dim=weight_q.packed_dim,
|
||||
)
|
||||
|
||||
zp_param = (
|
||||
getattr(layer, self.w_zp_name, None) if self.w_zp_name is not None else None
|
||||
)
|
||||
needs_unsigned_offset = self.config.weight_type == scalar_types.uint4
|
||||
|
||||
if needs_unsigned_offset:
|
||||
weight_unpacked = (weight_unpacked.to(torch.int32) + 8).clamp(0, 15)
|
||||
repacked = repack_op(weight_unpacked.to(torch.int8).contiguous())
|
||||
|
||||
if zp_param is None:
|
||||
zp_tc = None
|
||||
else:
|
||||
zp_tensor = zp_param.data if hasattr(zp_param, "data") else zp_param
|
||||
zp = unpack_from_int32(
|
||||
zp_tensor,
|
||||
bits,
|
||||
(out_features, num_groups),
|
||||
packed_dim=zp_param.packed_dim,
|
||||
)
|
||||
if needs_unsigned_offset:
|
||||
zp = (zp.to(torch.int32) + 8).clamp(0, 15)
|
||||
zp_tc = zp.to(torch.int8).t().contiguous()
|
||||
|
||||
layer._zentorch_woq_packed = repacked.t()
|
||||
layer._zentorch_woq_scale = weight_scale.t().contiguous()
|
||||
layer._zentorch_woq_zero_point = zp_tc
|
||||
|
||||
for param_name in (self.w_q_name, self.w_s_name, self.w_zp_name):
|
||||
if param_name is None:
|
||||
continue
|
||||
param = getattr(layer, param_name, None)
|
||||
if param is None:
|
||||
continue
|
||||
if hasattr(param, "data"):
|
||||
param.data = torch.empty(0)
|
||||
else:
|
||||
setattr(layer, param_name, torch.empty(0))
|
||||
|
||||
layer._zentorch_kind = "compressed_tensors_w4a16_gptq"
|
||||
layer._zentorch_processed_weights = True
|
||||
logger.info_once(
|
||||
"[zen_cpu] Using zentorch_woq_linear for W4A16 GPTQ "
|
||||
"(weight_type=%s, has_zp=%s)",
|
||||
self.config.weight_type,
|
||||
zp_tc is not None,
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if getattr(layer, "_zentorch_processed_weights", False):
|
||||
return torch.ops.zentorch.zentorch_woq_linear.default(
|
||||
x,
|
||||
layer._zentorch_woq_packed,
|
||||
layer._zentorch_woq_scale,
|
||||
layer._zentorch_woq_zero_point,
|
||||
bias,
|
||||
)
|
||||
return super().apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
__all__ = ["ZentorchWNA16LinearKernel"]
|
||||
@@ -39,6 +39,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
|
||||
TritonInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.zentorch import (
|
||||
ZentorchInt8ScaledMMLinearKernel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FP8ScaledMMLinearKernel",
|
||||
@@ -58,6 +61,7 @@ __all__ = [
|
||||
"RowWiseTorchFP8ScaledMMLinearKernel",
|
||||
"ROCmFP8ScaledMMLinearKernel",
|
||||
"TritonInt8ScaledMMLinearKernel",
|
||||
"ZentorchInt8ScaledMMLinearKernel",
|
||||
"Fp8BlockScaledMMLinearKernel",
|
||||
"CPUFp8BlockScaledMMKernel",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Zentorch dynamic-symmetric W8A8 int8 linear kernel for AMD Zen CPUs.
|
||||
|
||||
Selected by ``choose_scaled_mm_linear_kernel`` ahead of the generic
|
||||
oneDNN-backed ``CPUInt8ScaledMMLinearKernel``. When ``is_supported`` or
|
||||
``can_implement`` rejects a layer, the selector falls through to the next
|
||||
kernel in ``_POSSIBLE_INT8_KERNELS[PlatformEnum.CPU]``.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear.zentorch_utils import has_zentorch_op
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ZentorchInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "requires CPU."
|
||||
if not current_platform.is_zen_cpu():
|
||||
return False, "requires AMD Zen CPU."
|
||||
if not has_zentorch_op(["zentorch_dynamic_qlinear"]):
|
||||
return (
|
||||
False,
|
||||
"torch.ops.zentorch.zentorch_dynamic_qlinear is not registered.",
|
||||
)
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.is_static_input_scheme:
|
||||
return False, "requires dynamic activation quantization."
|
||||
if not c.input_symmetric:
|
||||
return False, "requires symmetric activation quantization."
|
||||
if not c.is_channelwise:
|
||||
return False, "requires per-channel weight quantization."
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""Prepare weights for ``zentorch_dynamic_qlinear``.
|
||||
|
||||
Keeps weight in [N, K] layout (int8, contiguous) and converts the
|
||||
per-channel weight scale to bf16 with shape ``(N,)``.
|
||||
"""
|
||||
w_q_name, w_s_name, _, _, _ = self.layer_param_names
|
||||
weight = getattr(layer, w_q_name)
|
||||
n = weight.shape[0]
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_q_name,
|
||||
torch.nn.Parameter(weight.data.contiguous(), requires_grad=False),
|
||||
)
|
||||
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
ws = weight_scale.data
|
||||
if ws.dim() == 2 and ws.shape[-1] == 1:
|
||||
ws = ws.squeeze(-1)
|
||||
ws = ws.to(torch.bfloat16).contiguous()
|
||||
assert ws.shape == (n,), (
|
||||
f"[zen_cpu] expected weight scale shape ({n},), got {tuple(ws.shape)}"
|
||||
)
|
||||
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(ws, requires_grad=False),
|
||||
)
|
||||
logger.info_once(
|
||||
"[zen_cpu] Using zentorch_dynamic_qlinear for W8A8 (dynamic-symmetric)"
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q_name, w_s_name, _, _, _ = self.layer_param_names
|
||||
return torch.ops.zentorch.zentorch_dynamic_qlinear(
|
||||
x,
|
||||
getattr(layer, w_q_name),
|
||||
getattr(layer, w_s_name),
|
||||
bias,
|
||||
zentorch_op_name="zentorch::zentorch_dynamic_qlinear",
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Gates zentorch CPU linear dispatch on platform/op availability."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["has_zentorch_op"]
|
||||
|
||||
|
||||
def has_zentorch_op(op_names: list[str]) -> bool:
|
||||
"""Return ``True`` when running on Zen CPU with all named ops registered."""
|
||||
if not op_names:
|
||||
raise ValueError("has_zentorch_op requires at least one op name")
|
||||
if not current_platform.is_zen_cpu():
|
||||
return False
|
||||
ns = getattr(torch.ops, "zentorch", None)
|
||||
if ns is None:
|
||||
return False
|
||||
return all(hasattr(ns, op_name) for op_name in op_names)
|
||||
Reference in New Issue
Block a user