mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Hardware][Power]Add Power VSX Attention Backend and fix l2 Cache Crash (#40451)
Signed-off-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com> Signed-off-by: Akash Kaothalkar <akash.kaothalkar@ibm.com> Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com> Co-authored-by: Akash Kaothalkar <akashkaothalkar@akashs-mbp.bl1-in.ibm.com> Co-authored-by: Akash Kaothalkar <akash.kaothalkar@ibm.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com>
This commit is contained in:
@@ -29,6 +29,8 @@ torch::Tensor get_scheduler_metadata(
|
||||
isa = cpu_attention::ISA::NEON;
|
||||
} else if (isa_hint == "vxe") {
|
||||
isa = cpu_attention::ISA::VXE;
|
||||
} else if (isa_hint == "vsx") {
|
||||
isa = cpu_attention::ISA::VSX;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
|
||||
}
|
||||
@@ -129,6 +131,8 @@ void cpu_attn_reshape_and_cache(
|
||||
return cpu_attention::ISA::NEON;
|
||||
} else if (isa == "vxe") {
|
||||
return cpu_attention::ISA::VXE;
|
||||
} else if (isa == "vsx") {
|
||||
return cpu_attention::ISA::VSX;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid ISA type: " + isa);
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "cpu/utils.hpp"
|
||||
|
||||
namespace cpu_attention {
|
||||
enum class ISA { AMX, VEC, VEC16, NEON, VXE };
|
||||
enum class ISA { AMX, VEC, VEC16, NEON, VXE, VSX };
|
||||
|
||||
// Mirrors csrc/attention/dtype_fp8.cuh Fp8KVCacheDataType exactly.
|
||||
enum class Fp8KVCacheDataType {
|
||||
@@ -164,6 +164,9 @@ struct AttentionMetadata {
|
||||
case ISA::VXE:
|
||||
ss << "VXE, ";
|
||||
break;
|
||||
case ISA::VSX:
|
||||
ss << "VSX, ";
|
||||
break;
|
||||
}
|
||||
ss << "workitem_group_num: " << workitem_group_num
|
||||
<< ", reduction_item_num: " << reduction_item_num
|
||||
|
||||
@@ -27,8 +27,8 @@ FORCE_INLINE std::pair<vec_op::FP32Vec16, vec_op::FP32Vec16> load_b_pair_vec(
|
||||
return {vec_op::FP32Vec16(bf16_b_reg, 0), vec_op::FP32Vec16(bf16_b_reg, 1)};
|
||||
} else {
|
||||
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
|
||||
return {vec_op::FP32Vec16(load_vec_t(ptr)),
|
||||
vec_op::FP32Vec16(load_vec_t(ptr + 16))};
|
||||
return std::make_pair(vec_op::FP32Vec16(load_vec_t(ptr)),
|
||||
vec_op::FP32Vec16(load_vec_t(ptr + 16)));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,359 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#ifndef CPU_ATTN_VSX_HPP
|
||||
#define CPU_ATTN_VSX_HPP
|
||||
|
||||
#include "cpu_attn_impl.hpp"
|
||||
#include <altivec.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace cpu_attention {
|
||||
|
||||
namespace {
|
||||
|
||||
// ppc64le Vector = 16 bytes (128 bits)
|
||||
#define BLOCK_SIZE_ALIGNMENT 32
|
||||
#define HEAD_SIZE_ALIGNMENT 32
|
||||
#define MAX_Q_HEAD_NUM_PER_ITER 16
|
||||
|
||||
template <typename kv_cache_t>
|
||||
FORCE_INLINE void load_row8_B_as_f32(const kv_cache_t* p, __vector float& b0,
|
||||
__vector float& b1);
|
||||
|
||||
// [1] Float Specialization
|
||||
template <>
|
||||
FORCE_INLINE void load_row8_B_as_f32<float>(const float* p, __vector float& b0,
|
||||
__vector float& b1) {
|
||||
b0 = vec_xl(0, const_cast<float*>(p));
|
||||
b1 = vec_xl(0, const_cast<float*>(p + 4));
|
||||
}
|
||||
|
||||
// [2] BFloat16 Specialization (Little Endian ppc64le)
|
||||
// On ppc64le (LE): BF16 bits should land in the HIGH 16 bits of each float32.
|
||||
// Byte layout of float32 on LE: [byte0(LSB), byte1, byte2, byte3(MSB)]
|
||||
// We need BF16 in bytes2-3 (high half) with bytes0-1 zeroed.
|
||||
// vec_mergeh on LE interleaves elements 0..3: result_i = {a[i], b[i]}
|
||||
// So vec_mergeh(zeros_u16, raw_u16) gives for each uint16 pair:
|
||||
// uint16[2i] = zeros[i] -> low 16 bits of uint32 -> zeroed mantissa LSBs
|
||||
// uint16[2i+1] = raw[i] -> high 16 bits of uint32 -> BF16 bits
|
||||
// Cast to float32 gives exactly (bf16_bits << 16) per element.
|
||||
template <>
|
||||
FORCE_INLINE void load_row8_B_as_f32<c10::BFloat16>(const c10::BFloat16* p,
|
||||
__vector float& b0,
|
||||
__vector float& b1) {
|
||||
__vector unsigned short raw = vec_xl(
|
||||
0, reinterpret_cast<unsigned short*>(const_cast<c10::BFloat16*>(p)));
|
||||
__vector unsigned short zeros = vec_splat_u16(0);
|
||||
|
||||
// LE: zeros in low 16 bits, raw in high 16 bits → bf16 << 16 == float32
|
||||
b0 = (__vector float)vec_mergeh(zeros, raw);
|
||||
b1 = (__vector float)vec_mergel(zeros, raw);
|
||||
}
|
||||
|
||||
// Note: c10::Half (FP16) is not supported on PowerPC architecture
|
||||
|
||||
template <int32_t M, typename kv_cache_t>
|
||||
FORCE_INLINE void gemm_micro_ppc64le_Mx8_Ku4(
|
||||
const float* __restrict A, // [M x K]
|
||||
const kv_cache_t* __restrict B, // [K x 8]
|
||||
float* __restrict C, // [M x 8]
|
||||
int64_t lda, int64_t ldb, int64_t ldc, int32_t K, bool accumulate) {
|
||||
static_assert(1 <= M && M <= 8, "M must be in [1,8]");
|
||||
|
||||
#define ROWS_APPLY(OP) OP(0) OP(1) OP(2) OP(3) OP(4) OP(5) OP(6) OP(7)
|
||||
#define IF_M(i) if constexpr (M > (i))
|
||||
|
||||
// 1. Define A pointers
|
||||
#define DECL_A(i) const float* a##i = A + (i) * lda;
|
||||
ROWS_APPLY(DECL_A)
|
||||
#undef DECL_A
|
||||
|
||||
// 2. Define Accumulators (2 vectors covers 8 columns)
|
||||
#define DECL_ACC(i) __vector float acc##i##_0, acc##i##_1;
|
||||
ROWS_APPLY(DECL_ACC)
|
||||
#undef DECL_ACC
|
||||
|
||||
// 3. Initialize Accumulators (Load C or Zero)
|
||||
#define INIT_ACC(i) \
|
||||
IF_M(i) { \
|
||||
if (accumulate) { \
|
||||
acc##i##_0 = vec_xl(0, const_cast<float*>(C + (i) * ldc + 0)); \
|
||||
acc##i##_1 = vec_xl(0, const_cast<float*>(C + (i) * ldc + 4)); \
|
||||
} else { \
|
||||
acc##i##_0 = vec_splats(0.0f); \
|
||||
acc##i##_1 = vec_splats(0.0f); \
|
||||
} \
|
||||
}
|
||||
ROWS_APPLY(INIT_ACC)
|
||||
#undef INIT_ACC
|
||||
|
||||
int32_t k = 0;
|
||||
|
||||
for (; k + 3 < K; k += 4) {
|
||||
// Load 4 values of A for each Row M: A[k...k+3]
|
||||
#define LOAD_A4(i) \
|
||||
__vector float a##i##v; \
|
||||
IF_M(i) a##i##v = vec_xl(0, const_cast<float*>(a##i + k));
|
||||
ROWS_APPLY(LOAD_A4)
|
||||
#undef LOAD_A4
|
||||
|
||||
// FMA for specific lane L of A
|
||||
// ppc64le: vec_madd(b, vec_splat(a, lane), acc)
|
||||
#define FMAS_LANE(i, aiv, L) \
|
||||
IF_M(i) { \
|
||||
__vector float a_broad = vec_splat(aiv, L); \
|
||||
acc##i##_0 = vec_madd(b0, a_broad, acc##i##_0); \
|
||||
acc##i##_1 = vec_madd(b1, a_broad, acc##i##_1); \
|
||||
}
|
||||
|
||||
// Unroll K=0..3
|
||||
{
|
||||
__vector float b0, b1;
|
||||
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 0) * ldb, b0, b1);
|
||||
#define STEP_K0(i) FMAS_LANE(i, a##i##v, 0)
|
||||
ROWS_APPLY(STEP_K0)
|
||||
#undef STEP_K0
|
||||
}
|
||||
{
|
||||
__vector float b0, b1;
|
||||
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 1) * ldb, b0, b1);
|
||||
#define STEP_K1(i) FMAS_LANE(i, a##i##v, 1)
|
||||
ROWS_APPLY(STEP_K1)
|
||||
#undef STEP_K1
|
||||
}
|
||||
{
|
||||
__vector float b0, b1;
|
||||
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 2) * ldb, b0, b1);
|
||||
#define STEP_K2(i) FMAS_LANE(i, a##i##v, 2)
|
||||
ROWS_APPLY(STEP_K2)
|
||||
#undef STEP_K2
|
||||
}
|
||||
{
|
||||
__vector float b0, b1;
|
||||
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 3) * ldb, b0, b1);
|
||||
#define STEP_K3(i) FMAS_LANE(i, a##i##v, 3)
|
||||
ROWS_APPLY(STEP_K3)
|
||||
#undef STEP_K3
|
||||
}
|
||||
#undef FMAS_LANE
|
||||
}
|
||||
|
||||
for (; k < K; ++k) {
|
||||
__vector float b0, b1;
|
||||
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)k * ldb, b0, b1);
|
||||
#define TAIL_ROW(i) \
|
||||
IF_M(i) { \
|
||||
__vector float ai = vec_splats(*(a##i + k)); \
|
||||
acc##i##_0 = vec_madd(b0, ai, acc##i##_0); \
|
||||
acc##i##_1 = vec_madd(b1, ai, acc##i##_1); \
|
||||
}
|
||||
ROWS_APPLY(TAIL_ROW)
|
||||
#undef TAIL_ROW
|
||||
}
|
||||
|
||||
#define STORE_ROW(i) \
|
||||
IF_M(i) { \
|
||||
vec_xst(acc##i##_0, 0, C + (i) * ldc + 0); \
|
||||
vec_xst(acc##i##_1, 0, C + (i) * ldc + 4); \
|
||||
}
|
||||
ROWS_APPLY(STORE_ROW)
|
||||
#undef STORE_ROW
|
||||
|
||||
#undef ROWS_APPLY
|
||||
#undef IF_M
|
||||
}
|
||||
|
||||
template <int32_t N, typename kv_cache_t>
|
||||
FORCE_INLINE void gemm_macro_ppc64le_Mx8_Ku4(const float* __restrict A,
|
||||
const kv_cache_t* __restrict B,
|
||||
float* __restrict C, int32_t M,
|
||||
int32_t K, int64_t lda,
|
||||
int64_t ldb, int64_t ldc,
|
||||
bool accumulate) {
|
||||
static_assert(N % 8 == 0, "N must be a multiple of 8");
|
||||
for (int32_t m = 0; m < M;) {
|
||||
int32_t mb = (M - m >= 8) ? 8 : (M - m >= 4) ? 4 : (M - m >= 2) ? 2 : 1;
|
||||
const float* Ab = A + m * lda;
|
||||
float* Cb = C + m * ldc;
|
||||
|
||||
for (int32_t n = 0; n < N; n += 8) {
|
||||
const kv_cache_t* Bn = B + n;
|
||||
float* Cn = Cb + n;
|
||||
switch (mb) {
|
||||
case 8:
|
||||
gemm_micro_ppc64le_Mx8_Ku4<8, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
|
||||
K, accumulate);
|
||||
break;
|
||||
case 4:
|
||||
gemm_micro_ppc64le_Mx8_Ku4<4, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
|
||||
K, accumulate);
|
||||
break;
|
||||
case 2:
|
||||
gemm_micro_ppc64le_Mx8_Ku4<2, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
|
||||
K, accumulate);
|
||||
break;
|
||||
default:
|
||||
gemm_micro_ppc64le_Mx8_Ku4<1, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc,
|
||||
K, accumulate);
|
||||
break;
|
||||
}
|
||||
}
|
||||
m += mb;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename kv_cache_t>
|
||||
class TileGemmPPC64 {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size,
|
||||
float* __restrict__ a_tile,
|
||||
kv_cache_t* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
gemm_macro_ppc64le_Mx8_Ku4<BLOCK_SIZE_ALIGNMENT, kv_cache_t>(
|
||||
a_tile, b_tile, c_tile, m_size, k_size, lda, ldb, ldc, accum_c);
|
||||
} else {
|
||||
gemm_macro_ppc64le_Mx8_Ku4<HEAD_SIZE_ALIGNMENT, kv_cache_t>(
|
||||
a_tile, b_tile, c_tile, m_size, dynamic_k_size, lda, ldb, ldc,
|
||||
accum_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename scalar_t, int64_t head_dim>
|
||||
class AttentionImpl<ISA::VSX, scalar_t, head_dim> {
|
||||
public:
|
||||
using query_t = scalar_t;
|
||||
using q_buffer_t = float;
|
||||
using kv_cache_t = scalar_t;
|
||||
using logits_buffer_t = float;
|
||||
using partial_output_buffer_t = float;
|
||||
using prob_buffer_t = float;
|
||||
|
||||
constexpr static int64_t BlockSizeAlignment = BLOCK_SIZE_ALIGNMENT;
|
||||
constexpr static int64_t HeadDimAlignment = HEAD_SIZE_ALIGNMENT;
|
||||
constexpr static int64_t MaxQHeadNumPerIteration = MAX_Q_HEAD_NUM_PER_ITER;
|
||||
constexpr static int64_t HeadDim = head_dim;
|
||||
constexpr static ISA ISAType = ISA::VSX;
|
||||
constexpr static bool scale_on_logits =
|
||||
false; // Scale is applied to Q during copy
|
||||
|
||||
public:
|
||||
AttentionImpl() {}
|
||||
|
||||
template <template <typename tile_gemm_t> typename attention>
|
||||
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
|
||||
attention<TileGemmPPC64<kv_cache_t>> attention_iteration;
|
||||
attention_iteration(CPU_ATTENTION_PARAMS);
|
||||
}
|
||||
|
||||
// Strides for Memory Layout
|
||||
constexpr static int64_t k_cache_token_group_stride(
|
||||
const int32_t block_size) {
|
||||
return BlockSizeAlignment; // [head_dim, block_size] layout
|
||||
}
|
||||
|
||||
constexpr static int64_t v_cache_token_group_stride(
|
||||
const int32_t block_size) {
|
||||
return head_dim * BlockSizeAlignment;
|
||||
}
|
||||
|
||||
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
|
||||
return HeadDimAlignment;
|
||||
}
|
||||
|
||||
static void copy_q_heads_tile(scalar_t* __restrict__ src,
|
||||
float* __restrict__ q_buffer,
|
||||
const int32_t q_num,
|
||||
const int32_t q_heads_per_kv,
|
||||
const int64_t q_num_stride,
|
||||
const int64_t q_head_stride, float scale) {
|
||||
__vector float scale_vec = vec_splats(scale);
|
||||
constexpr bool is_bf16 = std::is_same<scalar_t, c10::BFloat16>::value;
|
||||
|
||||
for (int32_t i = 0; i < q_num; ++i) {
|
||||
for (int32_t h = 0; h < q_heads_per_kv; ++h) {
|
||||
scalar_t* curr_src = src + i * q_num_stride + h * q_head_stride;
|
||||
float* curr_dst =
|
||||
q_buffer + i * q_heads_per_kv * head_dim + h * head_dim;
|
||||
|
||||
int32_t d = 0;
|
||||
for (; d <= head_dim - 8; d += 8) {
|
||||
__vector float v0, v1;
|
||||
load_row8_B_as_f32<scalar_t>(curr_src + d, v0, v1);
|
||||
|
||||
v0 = vec_mul(v0, scale_vec);
|
||||
v1 = vec_mul(v1, scale_vec);
|
||||
|
||||
vec_xst(v0, 0, curr_dst + d);
|
||||
vec_xst(v1, 0, curr_dst + d + 4);
|
||||
}
|
||||
|
||||
for (; d < head_dim; ++d) {
|
||||
float val = static_cast<float>(curr_src[d]);
|
||||
curr_dst[d] = val * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void reshape_and_cache(
|
||||
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
|
||||
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
|
||||
const int64_t head_num, const int64_t key_head_num_stride,
|
||||
const int64_t value_head_num_stride, const int64_t num_blocks,
|
||||
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
|
||||
const int64_t block_size, const int64_t block_size_stride,
|
||||
const float k_inv = 0.0f, const float v_inv = 0.0f) {
|
||||
// k_inv and v_inv are unused on VSX: FP8 KV cache is not supported on
|
||||
// PowerPC. The parameters are present to match the common interface.
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
|
||||
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
|
||||
const int64_t pos = slot_mapping[token_idx];
|
||||
if (pos < 0) continue;
|
||||
|
||||
const int64_t block_idx = pos / block_size;
|
||||
const int64_t block_offset = pos % block_size;
|
||||
|
||||
{
|
||||
const scalar_t* key_src = key + token_idx * key_token_num_stride +
|
||||
head_idx * key_head_num_stride;
|
||||
scalar_t* key_dst = key_cache + block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride + block_offset;
|
||||
|
||||
for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
|
||||
key_dst[j] = key_src[i];
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const scalar_t* val_src = value + token_idx * value_token_num_stride +
|
||||
head_idx * value_head_num_stride;
|
||||
scalar_t* val_dst = value_cache + block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride +
|
||||
block_offset * head_dim;
|
||||
|
||||
std::memcpy(val_dst, val_src, sizeof(scalar_t) * head_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cpu_attention
|
||||
|
||||
#undef BLOCK_SIZE_ALIGNMENT
|
||||
#undef HEAD_SIZE_ALIGNMENT
|
||||
#undef MAX_Q_HEAD_NUM_PER_ITER
|
||||
|
||||
#endif // CPU_ATTN_VSX_HPP
|
||||
@@ -9,6 +9,10 @@
|
||||
|
||||
namespace vec_op {
|
||||
|
||||
// FP8 tag types for tag dispatch (see cpu_attn_vec.hpp)
|
||||
struct fp8_e4m3_tag {};
|
||||
struct fp8_e5m2_tag {};
|
||||
|
||||
// FIXME: FP16 is not fully supported in Torch-CPU
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
|
||||
@@ -20,6 +20,7 @@ ISA_TYPES = {
|
||||
"VEC16": 2,
|
||||
"NEON": 3,
|
||||
"VXE": 4,
|
||||
"VSX": 5,
|
||||
}
|
||||
|
||||
# KV cache index: 0 = auto (same as scalar_t), 1 = fp8_e4m3, 2 = fp8_e5m2
|
||||
@@ -37,7 +38,7 @@ KV_CACHE_CPP_TYPES = {
|
||||
}
|
||||
|
||||
# ISAs supported for head_dims divisible by 32
|
||||
ISA_FOR_32 = ["AMX", "NEON", "VEC", "VEC16", "VXE"]
|
||||
ISA_FOR_32 = ["AMX", "NEON", "VEC", "VEC16", "VXE", "VSX"]
|
||||
|
||||
# ISAs supported for head_dims divisible by 16 only
|
||||
ISA_FOR_16 = ["VEC16"]
|
||||
@@ -148,6 +149,10 @@ def generate_header_file() -> str:
|
||||
#include "cpu_attn_vxe.hpp"
|
||||
#endif
|
||||
|
||||
#ifdef __powerpc__
|
||||
#include "cpu_attn_vsx.hpp"
|
||||
#endif
|
||||
|
||||
"""
|
||||
|
||||
header += generate_helper_function()
|
||||
@@ -207,6 +212,11 @@ def generate_header_file() -> str:
|
||||
["VXE", "VEC", "VEC16"],
|
||||
fp8=False,
|
||||
)
|
||||
header += _macro_block(
|
||||
"#elif defined(__powerpc__)",
|
||||
["VSX", "VEC", "VEC16"],
|
||||
fp8=False,
|
||||
)
|
||||
header += _macro_block(
|
||||
"#elif defined(__AVX512F__)",
|
||||
["VEC", "VEC16"],
|
||||
@@ -223,7 +233,8 @@ def generate_header_file() -> str:
|
||||
fp8=False,
|
||||
)
|
||||
header += (
|
||||
"#endif /* CPU_CAPABILITY_AMXBF16 / __aarch64__ / __s390x__ */\n\n"
|
||||
"#endif /* CPU_CAPABILITY_AMXBF16 / __aarch64__ / "
|
||||
"__s390x__ / __powerpc__ */\n\n"
|
||||
"#endif // CPU_ATTN_DISPATCH_GENERATED_H\n"
|
||||
)
|
||||
|
||||
|
||||
+1
-1
@@ -54,7 +54,7 @@ struct Counter {
|
||||
};
|
||||
|
||||
inline int64_t get_available_l2_size() {
|
||||
#if defined(__s390x__)
|
||||
#if defined(__s390x__) || defined(__powerpc__)
|
||||
static int64_t size = []() {
|
||||
uint32_t l2_cache_size = 0;
|
||||
auto caps = at::cpu::get_cpu_capabilities();
|
||||
|
||||
@@ -29,7 +29,12 @@ from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86, CpuArchEnum.ARM, CpuArchEnum.S390X)
|
||||
_CPU_ARCH_PREFER_MIXED_BATCH = (
|
||||
CpuArchEnum.X86,
|
||||
CpuArchEnum.ARM,
|
||||
CpuArchEnum.S390X,
|
||||
CpuArchEnum.POWERPC,
|
||||
)
|
||||
|
||||
|
||||
class CPUAttentionBackend(AttentionBackend):
|
||||
@@ -510,8 +515,10 @@ def _get_attn_isa(
|
||||
)
|
||||
return "vec16"
|
||||
supports_amx = torch.cpu._is_amx_tile_supported()
|
||||
supports_arm = current_platform.get_cpu_architecture() == CpuArchEnum.ARM
|
||||
supports_vxe = current_platform.get_cpu_architecture() == CpuArchEnum.S390X
|
||||
arch = current_platform.get_cpu_architecture()
|
||||
supports_arm = arch == CpuArchEnum.ARM
|
||||
supports_vxe = arch == CpuArchEnum.S390X
|
||||
supports_vsx = arch == CpuArchEnum.POWERPC
|
||||
supports_avx512 = torch.cpu._is_avx512_supported()
|
||||
if fp8_kv and not supports_amx and not supports_avx512:
|
||||
raise NotImplementedError(
|
||||
@@ -525,6 +532,8 @@ def _get_attn_isa(
|
||||
return "neon"
|
||||
elif supports_vxe:
|
||||
return "vxe"
|
||||
elif supports_vsx:
|
||||
return "vsx"
|
||||
else:
|
||||
return "vec"
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user