diff --git a/setup.py b/setup.py index 7665978c6b0..f6276616a19 100644 --- a/setup.py +++ b/setup.py @@ -1085,7 +1085,9 @@ setup( install_requires=get_requirements(), extras_require={ # AMD Zen CPU optimizations via zentorch - "zen": ["zentorch"], + "zen": [ + "zentorch-weekly==5.2.1.dev20260408" + ], # Zentorch has weekly releases. This pulls the known-good version. "bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"], "tensorizer": ["tensorizer==2.10.1"], "fastsafetensors": ["fastsafetensors >= 0.2.2"], diff --git a/vllm/platforms/zen_cpu.py b/vllm/platforms/zen_cpu.py index 481eec1cb4e..2af64e5e9f5 100644 --- a/vllm/platforms/zen_cpu.py +++ b/vllm/platforms/zen_cpu.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + from vllm.logger import init_logger from vllm.platforms.cpu import CpuPlatform @@ -22,3 +24,9 @@ class ZenCpuPlatform(CpuPlatform): def is_zen_cpu(self) -> bool: # is_cpu() also returns True for this platform (inherited from CpuPlatform). return True + + # Currently, AMD CPUs do not support float16 compute. + # Hence explicitly return bfloat16 and float32. + @property + def supported_dtypes(self) -> list[torch.dtype]: + return [torch.bfloat16, torch.float32]