[CPU][Zen] Route W8A8 and W4A16 linear inference through zentorch on AMD Zen CPUs (#41813)

Signed-off-by: R <Ganesh.R@amd.com>
Signed-off-by: Harshal Adhav <harshal.adhav@amd.com>
Signed-off-by: Aakar Dwivedi <aadwived@amd.com>
Co-authored-by: R <Ganesh.R@amd.com>
Co-authored-by: Harshal Adhav <harshal.adhav@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Aakar Dwivedi
2026-05-31 00:47:21 +05:30
committed by khluu
parent 1be7a57a18
commit 682ffebfef
7 changed files with 351 additions and 4 deletions
+1 -3
View File
@@ -1165,9 +1165,7 @@ setup(
install_requires=get_requirements(),
extras_require={
# AMD Zen CPU optimizations via zentorch
"zen": [
"zentorch-weekly==5.2.1.dev20260408"
], # Zentorch has weekly releases. This pulls the known-good version.
"zen": ["zentorch==2.11.0.0"],
"bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"],
"tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.2.2"],
+10 -1
View File
@@ -58,6 +58,9 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
XPUW4A8IntLinearKernel,
XPUwNa16LinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.zentorch import (
ZentorchWNA16LinearKernel,
)
from vllm.model_executor.kernels.linear.mxfp4 import (
MxFp4LinearKernel,
MxFp4LinearLayerConfig,
@@ -157,6 +160,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.triton import (
from vllm.model_executor.kernels.linear.scaled_mm.xpu import (
XPUFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.zentorch import (
ZentorchInt8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms import PlatformEnum, current_platform
@@ -254,7 +260,7 @@ def _filter_kernels_by_backend(
# in priority/performance order (when available)
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUInt8ScaledMMLinearKernel],
PlatformEnum.CPU: [ZentorchInt8ScaledMMLinearKernel, CPUInt8ScaledMMLinearKernel],
PlatformEnum.CUDA: [
CutlassInt8ScaledMMLinearKernel,
TritonInt8ScaledMMLinearKernel,
@@ -348,6 +354,7 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
],
PlatformEnum.CPU: [
Dynamic4bitLinearKernel,
ZentorchWNA16LinearKernel,
CPUWNA16LinearKernel,
],
}
@@ -1018,6 +1025,8 @@ __all__ = [
"RowWiseTorchFP8ScaledMMLinearKernel",
"ROCmFP8ScaledMMLinearKernel",
"TritonInt8ScaledMMLinearKernel",
"ZentorchInt8ScaledMMLinearKernel",
"ZentorchWNA16LinearKernel",
"MPLinearKernel",
"MPLinearLayerConfig",
"AllSparkLinearKernel",
@@ -36,6 +36,9 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
XPUW4A8IntLinearKernel,
XPUwNa16LinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.zentorch import (
ZentorchWNA16LinearKernel,
)
__all__ = [
"MPLinearKernel",
@@ -51,4 +54,5 @@ __all__ = [
"TritonW4A16LinearKernel",
"XPUW4A8IntLinearKernel",
"XPUwNa16LinearKernel",
"ZentorchWNA16LinearKernel",
]
@@ -0,0 +1,211 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Zentorch W4A16 GPTQ weight-only-quantized linear kernel for AMD Zen CPUs.
Selected by ``choose_mp_linear_kernel`` ahead of the generic oneDNN-backed
``CPUWNA16LinearKernel``. When ``can_implement`` rejects a layer, the selector
falls through to the next kernel in ``_POSSIBLE_KERNELS[PlatformEnum.CPU]``.
"""
import torch
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear.zentorch_utils import has_zentorch_op
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .cpu import CPUWNA16LinearKernel
from .MPLinearKernel import MPLinearLayerConfig
logger = init_logger(__name__)
def _import_unpack_from_int32():
"""Import compressed-tensors' ``unpack_from_int32`` across versions."""
try:
from compressed_tensors.compressors.pack_quantized.helpers import (
unpack_from_int32,
)
except ImportError:
from compressed_tensors.compressors.quantized_compressors.pack_quantized import ( # type: ignore[import-not-found] # noqa: E501
unpack_from_int32,
)
return unpack_from_int32
class ZentorchWNA16LinearKernel(CPUWNA16LinearKernel):
"""W4A16 GPTQ kernel backed by ``torch.ops.zentorch.zentorch_woq_linear``."""
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
ok, reason = super().can_implement(c)
if not ok:
return ok, reason
if not current_platform.is_zen_cpu():
return False, "ZentorchWNA16 requires an AMD Zen CPU."
if not has_zentorch_op(["zentorch_woq_repack_weight", "zentorch_woq_linear"]):
return (
False,
"torch.ops.zentorch.{zentorch_woq_repack_weight, "
"zentorch_woq_linear} are not registered.",
)
if c.has_g_idx:
return False, "ZentorchWNA16 does not support activation re-ordering."
return True, None
def _zentorch_woq_eligible(self, layer: torch.nn.Module) -> bool:
"""Eligibility predicate for the zentorch W4A16 GPTQ fast path.
Constraints (any failure -> ``cpu_gemm_wna16`` path via ``super()``
with ``layer`` untouched).
"""
if (
self.w_gidx_name is not None
and getattr(layer, self.w_gidx_name, None) is not None
) or (getattr(self.config, "has_g_idx", False)):
return False
weight_packed = getattr(layer, self.w_q_name, None)
weight_scale = getattr(layer, self.w_s_name, None)
if weight_packed is None or weight_scale is None:
return False
bits = self.config.weight_type.mantissa
pack_factor = torch.iinfo(weight_packed.dtype).bits // bits
# 4-bit -> 8 values per int32;
if pack_factor != 8:
return False
# GPTQ-only. AWQ packs along the output dim instead.
in_dim = getattr(weight_packed, "input_dim", None)
pk_dim = getattr(weight_packed, "packed_dim", None)
if in_dim is None or pk_dim is None or in_dim != pk_dim:
return False
is_ct_format = in_dim == pk_dim == 1
if not is_ct_format:
return False
if weight_packed.dim() != 2 or weight_scale.dim() != 2:
return False
# 4-bit -> 8 values per int32; in_features must be divisible by num_groups.
in_features = weight_packed.shape[1] * 8
num_groups = weight_scale.shape[1]
return num_groups > 0 and in_features % num_groups == 0
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Repack CT GPTQ weights into the zentorch WOQ layout.
Falls back to ``CPUWNA16LinearKernel.process_weights_after_loading``
via ``super()`` when the layer doesn't satisfy
``_zentorch_woq_eligible``.
On success, ``layer._zentorch_processed_weights`` is set to ``True``
"""
if getattr(layer, "_zentorch_processed_weights", False):
return
if not self._zentorch_woq_eligible(layer):
logger.info_once(
"[zen_cpu] ZentorchWNA16 fast path not eligible for this "
"layer (AWQ pack layout, g_idx, or non-int32 storage); "
"falling back to CPUWNA16LinearKernel (cpu_gemm_wna16)."
)
super().process_weights_after_loading(layer)
return
if (not self.config.zero_points) and (self.w_zp_name is not None):
setattr(layer, self.w_zp_name, None)
if (not self.config.has_g_idx) and (self.w_gidx_name is not None):
setattr(layer, self.w_gidx_name, None)
weight_q = getattr(layer, self.w_q_name)
weight_s = getattr(layer, self.w_s_name)
weight_packed = weight_q.data if hasattr(weight_q, "data") else weight_q
weight_scale = weight_s.data if hasattr(weight_s, "data") else weight_s
bits = self.config.weight_type.mantissa
pack_factor = torch.iinfo(weight_packed.dtype).bits // bits
out_features, num_groups = weight_scale.shape[0], weight_scale.shape[1]
in_features = weight_packed.shape[1] * pack_factor
original_shape = torch.Size([out_features, in_features])
unpack_from_int32 = _import_unpack_from_int32()
repack_op = torch.ops.zentorch.zentorch_woq_repack_weight.default
weight_unpacked = unpack_from_int32(
weight_packed,
bits,
original_shape,
packed_dim=weight_q.packed_dim,
)
zp_param = (
getattr(layer, self.w_zp_name, None) if self.w_zp_name is not None else None
)
needs_unsigned_offset = self.config.weight_type == scalar_types.uint4
if needs_unsigned_offset:
weight_unpacked = (weight_unpacked.to(torch.int32) + 8).clamp(0, 15)
repacked = repack_op(weight_unpacked.to(torch.int8).contiguous())
if zp_param is None:
zp_tc = None
else:
zp_tensor = zp_param.data if hasattr(zp_param, "data") else zp_param
zp = unpack_from_int32(
zp_tensor,
bits,
(out_features, num_groups),
packed_dim=zp_param.packed_dim,
)
if needs_unsigned_offset:
zp = (zp.to(torch.int32) + 8).clamp(0, 15)
zp_tc = zp.to(torch.int8).t().contiguous()
layer._zentorch_woq_packed = repacked.t()
layer._zentorch_woq_scale = weight_scale.t().contiguous()
layer._zentorch_woq_zero_point = zp_tc
for param_name in (self.w_q_name, self.w_s_name, self.w_zp_name):
if param_name is None:
continue
param = getattr(layer, param_name, None)
if param is None:
continue
if hasattr(param, "data"):
param.data = torch.empty(0)
else:
setattr(layer, param_name, torch.empty(0))
layer._zentorch_kind = "compressed_tensors_w4a16_gptq"
layer._zentorch_processed_weights = True
logger.info_once(
"[zen_cpu] Using zentorch_woq_linear for W4A16 GPTQ "
"(weight_type=%s, has_zp=%s)",
self.config.weight_type,
zp_tc is not None,
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if getattr(layer, "_zentorch_processed_weights", False):
return torch.ops.zentorch.zentorch_woq_linear.default(
x,
layer._zentorch_woq_packed,
layer._zentorch_woq_scale,
layer._zentorch_woq_zero_point,
bias,
)
return super().apply_weights(layer, x, bias)
__all__ = ["ZentorchWNA16LinearKernel"]
@@ -39,6 +39,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
TritonInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.zentorch import (
ZentorchInt8ScaledMMLinearKernel,
)
__all__ = [
"FP8ScaledMMLinearKernel",
@@ -58,6 +61,7 @@ __all__ = [
"RowWiseTorchFP8ScaledMMLinearKernel",
"ROCmFP8ScaledMMLinearKernel",
"TritonInt8ScaledMMLinearKernel",
"ZentorchInt8ScaledMMLinearKernel",
"Fp8BlockScaledMMLinearKernel",
"CPUFp8BlockScaledMMKernel",
]
@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Zentorch dynamic-symmetric W8A8 int8 linear kernel for AMD Zen CPUs.
Selected by ``choose_scaled_mm_linear_kernel`` ahead of the generic
oneDNN-backed ``CPUInt8ScaledMMLinearKernel``. When ``is_supported`` or
``can_implement`` rejects a layer, the selector falls through to the next
kernel in ``_POSSIBLE_INT8_KERNELS[PlatformEnum.CPU]``.
"""
import torch
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear.zentorch_utils import has_zentorch_op
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
)
logger = init_logger(__name__)
class ZentorchInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "requires CPU."
if not current_platform.is_zen_cpu():
return False, "requires AMD Zen CPU."
if not has_zentorch_op(["zentorch_dynamic_qlinear"]):
return (
False,
"torch.ops.zentorch.zentorch_dynamic_qlinear is not registered.",
)
return True, None
@classmethod
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if c.is_static_input_scheme:
return False, "requires dynamic activation quantization."
if not c.input_symmetric:
return False, "requires symmetric activation quantization."
if not c.is_channelwise:
return False, "requires per-channel weight quantization."
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Prepare weights for ``zentorch_dynamic_qlinear``.
Keeps weight in [N, K] layout (int8, contiguous) and converts the
per-channel weight scale to bf16 with shape ``(N,)``.
"""
w_q_name, w_s_name, _, _, _ = self.layer_param_names
weight = getattr(layer, w_q_name)
n = weight.shape[0]
replace_parameter(
layer,
w_q_name,
torch.nn.Parameter(weight.data.contiguous(), requires_grad=False),
)
weight_scale = getattr(layer, w_s_name)
ws = weight_scale.data
if ws.dim() == 2 and ws.shape[-1] == 1:
ws = ws.squeeze(-1)
ws = ws.to(torch.bfloat16).contiguous()
assert ws.shape == (n,), (
f"[zen_cpu] expected weight scale shape ({n},), got {tuple(ws.shape)}"
)
replace_parameter(
layer,
w_s_name,
torch.nn.Parameter(ws, requires_grad=False),
)
logger.info_once(
"[zen_cpu] Using zentorch_dynamic_qlinear for W8A8 (dynamic-symmetric)"
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q_name, w_s_name, _, _, _ = self.layer_param_names
return torch.ops.zentorch.zentorch_dynamic_qlinear(
x,
getattr(layer, w_q_name),
getattr(layer, w_s_name),
bias,
zentorch_op_name="zentorch::zentorch_dynamic_qlinear",
)
@@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Gates zentorch CPU linear dispatch on platform/op availability."""
from __future__ import annotations
import torch
from vllm.platforms import current_platform
__all__ = ["has_zentorch_op"]
def has_zentorch_op(op_names: list[str]) -> bool:
"""Return ``True`` when running on Zen CPU with all named ops registered."""
if not op_names:
raise ValueError("has_zentorch_op requires at least one op name")
if not current_platform.is_zen_cpu():
return False
ns = getattr(torch.ops, "zentorch", None)
if ns is None:
return False
return all(hasattr(ns, op_name) for op_name in op_names)