mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[CPU][RISC-V] Add VLEN=256 support to RVV attention kernels (#42943)
Signed-off-by: velonica0 <like@mail.nankai.edu.cn> Signed-off-by: velonica0 <47554626+velonica0@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com>
This commit is contained in:
+68
-101
@@ -4,17 +4,18 @@
|
||||
#ifndef CPU_ATTN_RVV_HPP
|
||||
#define CPU_ATTN_RVV_HPP
|
||||
|
||||
// This kernel is currently hardcoded to VLEN=128 (m1/m2 intrinsics, vl=8).
|
||||
// The fixed-width typedefs below use `riscv_rvv_vector_bits(128)`, which
|
||||
// only matches `vfloat16m1_t`/`vuint16m1_t` register layout when VLEN==128;
|
||||
// at VLEN>=256 those typedefs fail to compile. Scalar RISC-V builds
|
||||
// (-march=rv64gc) additionally don't have <riscv_vector.h>. For both
|
||||
// cases we omit the file entirely and let the dispatcher fall back to the
|
||||
// scalar VEC / VEC16 implementations. TODO: migrate to RVVI() macros +
|
||||
// semantic names in cpu_types_riscv_defs.hpp to support VLEN>=256 natively.
|
||||
#if defined(__riscv_v_min_vlen) && __riscv_v_min_vlen == 128
|
||||
// RVV attention kernel using VLEN-agnostic RVVI() macros from
|
||||
// cpu_types_riscv_defs.hpp. The Mx8 tile GEMM uses 8 FP32 elements
|
||||
// per vector (LMUL_256 bits of FP32 data), which maps to:
|
||||
// VLEN=128: m2 (256 bits = 8 x FP32)
|
||||
// VLEN=256: m1 (256 bits = 8 x FP32)
|
||||
// Only VLEN=128 and VLEN=256 are supported; other VLENs (512, 1024)
|
||||
// and scalar RISC-V builds fall back to VEC/VEC16.
|
||||
#if defined(__riscv_v_min_vlen) && \
|
||||
(__riscv_v_min_vlen == 128 || __riscv_v_min_vlen == 256)
|
||||
|
||||
#include "cpu_attn_impl.hpp"
|
||||
#include "cpu_types_riscv_defs.hpp"
|
||||
#include <riscv_vector.h>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -22,73 +23,50 @@ namespace cpu_attention {
|
||||
|
||||
namespace {
|
||||
|
||||
// File-local concrete-LMUL typedefs. The shared _defs.hpp exposes
|
||||
// VLEN-independent semantic names (fixed_fp32x8_t, fixed_fp16x8_t, ...),
|
||||
// but this kernel is currently hardcoded to VLEN=128 (m1/m2 intrinsics),
|
||||
// so keep the legacy concrete aliases scoped to this file.
|
||||
typedef vfloat16m1_t fixed_vfloat16m1_t
|
||||
__attribute__((riscv_rvv_vector_bits(128)));
|
||||
typedef vfloat32m2_t fixed_vfloat32m2_t
|
||||
__attribute__((riscv_rvv_vector_bits(256)));
|
||||
typedef vuint16m1_t fixed_vuint16m1_t
|
||||
__attribute__((riscv_rvv_vector_bits(128)));
|
||||
typedef vuint32m2_t fixed_vuint32m2_t
|
||||
__attribute__((riscv_rvv_vector_bits(256)));
|
||||
#ifdef __riscv_zvfbfmin
|
||||
typedef vbfloat16m1_t fixed_vbfloat16m1_t
|
||||
__attribute__((riscv_rvv_vector_bits(128)));
|
||||
#endif
|
||||
|
||||
#define BLOCK_SIZE_ALIGNMENT 32
|
||||
#define HEAD_SIZE_ALIGNMENT 32
|
||||
#define MAX_Q_HEAD_NUM_PER_ITER 16
|
||||
|
||||
// ============================================================================
|
||||
// B-matrix row loading: load 8 elements as FP32 (using m2 LMUL at VLEN=128)
|
||||
// B-matrix row loading: load 8 elements as FP32
|
||||
// ============================================================================
|
||||
|
||||
template <typename kv_cache_t>
|
||||
FORCE_INLINE fixed_vfloat32m2_t load_row8_B_as_f32(const kv_cache_t* p);
|
||||
FORCE_INLINE fixed_fp32x8_t load_row8_B_as_f32(const kv_cache_t* p);
|
||||
|
||||
template <>
|
||||
FORCE_INLINE fixed_vfloat32m2_t load_row8_B_as_f32<float>(const float* p) {
|
||||
return __riscv_vle32_v_f32m2(p, 8);
|
||||
FORCE_INLINE fixed_fp32x8_t load_row8_B_as_f32<float>(const float* p) {
|
||||
return RVVI(__riscv_vle32_v_f32, LMUL_256)(p, 8);
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCE_INLINE fixed_vfloat32m2_t
|
||||
load_row8_B_as_f32<c10::Half>(const c10::Half* p) {
|
||||
FORCE_INLINE fixed_fp32x8_t load_row8_B_as_f32<c10::Half>(const c10::Half* p) {
|
||||
#ifdef __riscv_zvfh
|
||||
fixed_vfloat16m1_t h =
|
||||
__riscv_vle16_v_f16m1(reinterpret_cast<const _Float16*>(p), 8);
|
||||
return __riscv_vfwcvt_f_f_v_f32m2(h, 8);
|
||||
fixed_fp16x8_t h = RVVI(__riscv_vle16_v_f16, LMUL_128)(
|
||||
reinterpret_cast<const _Float16*>(p), 8);
|
||||
return RVVI(__riscv_vfwcvt_f_f_v_f32, LMUL_256)(h, 8);
|
||||
#else
|
||||
// Fallback for hardware without Zvfh: scalar half->float conversion.
|
||||
// c10::Half provides operator float() so this is correct on any RVV CPU
|
||||
// that has only the base V extension. Slower than the Zvfh path, but
|
||||
// keeps the kernel buildable on Zvfhmin-only / no-fp16 hardware.
|
||||
alignas(16) float tmp[8];
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
tmp[i] = static_cast<float>(p[i]);
|
||||
}
|
||||
return __riscv_vle32_v_f32m2(tmp, 8);
|
||||
return RVVI(__riscv_vle32_v_f32, LMUL_256)(tmp, 8);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCE_INLINE fixed_vfloat32m2_t
|
||||
FORCE_INLINE fixed_fp32x8_t
|
||||
load_row8_B_as_f32<c10::BFloat16>(const c10::BFloat16* p) {
|
||||
#ifdef __riscv_zvfbfmin
|
||||
fixed_vbfloat16m1_t bf =
|
||||
__riscv_vle16_v_bf16m1(reinterpret_cast<const __bf16*>(p), 8);
|
||||
return __riscv_vfwcvtbf16_f_f_v_f32m2(bf, 8);
|
||||
fixed_bf16x8_t bf = RVVI(__riscv_vle16_v_bf16, LMUL_128)(
|
||||
reinterpret_cast<const __bf16*>(p), 8);
|
||||
return RVVI(__riscv_vfwcvtbf16_f_f_v_f32, LMUL_256)(bf, 8);
|
||||
#else
|
||||
// Fallback: load as uint16, zero-extend to uint32, shift left by 16
|
||||
fixed_vuint16m1_t raw =
|
||||
__riscv_vle16_v_u16m1(reinterpret_cast<const uint16_t*>(p), 8);
|
||||
fixed_vuint32m2_t wide = __riscv_vzext_vf2_u32m2(raw, 8);
|
||||
fixed_vuint32m2_t shifted = __riscv_vsll_vx_u32m2(wide, 16, 8);
|
||||
return __riscv_vreinterpret_v_u32m2_f32m2(shifted);
|
||||
fixed_u16x8_t raw = RVVI(__riscv_vle16_v_u16, LMUL_128)(
|
||||
reinterpret_cast<const uint16_t*>(p), 8);
|
||||
fixed_u32x8_t wide = RVVI(__riscv_vzext_vf2_u32, LMUL_256)(raw, 8);
|
||||
fixed_u32x8_t shifted = RVVI(__riscv_vsll_vx_u32, LMUL_256)(wide, 16, 8);
|
||||
return RVVI4(__riscv_vreinterpret_v_u32, LMUL_256, _f32, LMUL_256)(shifted);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -96,14 +74,12 @@ load_row8_B_as_f32<c10::BFloat16>(const c10::BFloat16* p) {
|
||||
// Micro kernel: Mx8 tile, K unrolled by 4, RVV scalar-broadcast FMA
|
||||
// ============================================================================
|
||||
//
|
||||
// NEON uses vfmaq_laneq_f32 (lane-indexed FMA from a preloaded A vector).
|
||||
// RVV has no lane-indexed FMA; instead we load A elements as scalars and
|
||||
// use __riscv_vfmacc_vf (scalar * vector + accumulator), which is equally
|
||||
// efficient and avoids the need for vrgather/vslidedown.
|
||||
// use vfmacc_vf (scalar * vector + accumulator).
|
||||
//
|
||||
// At VLEN=128, m2 holds 8 x FP32, matching the 8-column tile width.
|
||||
// Register budget: M accumulators (m2 each) + 1 B temp = 2M+2 regs.
|
||||
// M=8 => 18 regs out of 32 available — no spills.
|
||||
// The 8-column tile uses LMUL_256 bits of FP32 data:
|
||||
// VLEN=128: m2 (2 regs per accumulator), M=8 => 18 of 32 regs
|
||||
// VLEN=256: m1 (1 reg per accumulator), M=8 => 9 of 32 regs
|
||||
|
||||
template <int32_t M, typename kv_cache_t>
|
||||
FORCE_INLINE void gemm_micro_rvv_fma_Mx8_Ku4(
|
||||
@@ -115,94 +91,90 @@ FORCE_INLINE void gemm_micro_rvv_fma_Mx8_Ku4(
|
||||
|
||||
constexpr size_t vl = 8;
|
||||
|
||||
// helpers for per-M codegen
|
||||
#define ROWS_APPLY(OP) OP(0) OP(1) OP(2) OP(3) OP(4) OP(5) OP(6) OP(7)
|
||||
#define IF_M(i) if constexpr (M > (i))
|
||||
|
||||
// A row base pointers
|
||||
#define DECL_A(i) const float* a##i = A + (i) * lda;
|
||||
ROWS_APPLY(DECL_A)
|
||||
#undef DECL_A
|
||||
|
||||
// declare one m2 accumulator per row
|
||||
#define DECL_ACC(i) fixed_vfloat32m2_t acc##i;
|
||||
#define DECL_ACC(i) fixed_fp32x8_t acc##i;
|
||||
ROWS_APPLY(DECL_ACC)
|
||||
#undef DECL_ACC
|
||||
|
||||
// initialize accumulators
|
||||
#define INIT_ACC(i) \
|
||||
IF_M(i) { \
|
||||
if (accumulate) { \
|
||||
acc##i = __riscv_vle32_v_f32m2(C + (i) * ldc, vl); \
|
||||
} else { \
|
||||
acc##i = __riscv_vfmv_v_f_f32m2(0.f, vl); \
|
||||
} \
|
||||
#define INIT_ACC(i) \
|
||||
IF_M(i) { \
|
||||
if (accumulate) { \
|
||||
acc##i = RVVI(__riscv_vle32_v_f32, LMUL_256)(C + (i) * ldc, vl); \
|
||||
} else { \
|
||||
acc##i = RVVI(__riscv_vfmv_v_f_f32, LMUL_256)(0.f, vl); \
|
||||
} \
|
||||
}
|
||||
ROWS_APPLY(INIT_ACC)
|
||||
#undef INIT_ACC
|
||||
|
||||
int32_t k = 0;
|
||||
|
||||
// K unrolled by 4
|
||||
for (; k + 3 < K; k += 4) {
|
||||
// k + 0
|
||||
{
|
||||
fixed_vfloat32m2_t b =
|
||||
fixed_fp32x8_t b =
|
||||
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 0) * ldb);
|
||||
#define STEP_K0(i) \
|
||||
IF_M(i) { \
|
||||
acc##i = __riscv_vfmacc_vf_f32m2(acc##i, *(a##i + k + 0), b, vl); \
|
||||
#define STEP_K0(i) \
|
||||
IF_M(i) { \
|
||||
acc##i = RVVI(__riscv_vfmacc_vf_f32, LMUL_256)(acc##i, *(a##i + k + 0), \
|
||||
b, vl); \
|
||||
}
|
||||
ROWS_APPLY(STEP_K0)
|
||||
#undef STEP_K0
|
||||
}
|
||||
// k + 1
|
||||
{
|
||||
fixed_vfloat32m2_t b =
|
||||
fixed_fp32x8_t b =
|
||||
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 1) * ldb);
|
||||
#define STEP_K1(i) \
|
||||
IF_M(i) { \
|
||||
acc##i = __riscv_vfmacc_vf_f32m2(acc##i, *(a##i + k + 1), b, vl); \
|
||||
#define STEP_K1(i) \
|
||||
IF_M(i) { \
|
||||
acc##i = RVVI(__riscv_vfmacc_vf_f32, LMUL_256)(acc##i, *(a##i + k + 1), \
|
||||
b, vl); \
|
||||
}
|
||||
ROWS_APPLY(STEP_K1)
|
||||
#undef STEP_K1
|
||||
}
|
||||
// k + 2
|
||||
{
|
||||
fixed_vfloat32m2_t b =
|
||||
fixed_fp32x8_t b =
|
||||
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 2) * ldb);
|
||||
#define STEP_K2(i) \
|
||||
IF_M(i) { \
|
||||
acc##i = __riscv_vfmacc_vf_f32m2(acc##i, *(a##i + k + 2), b, vl); \
|
||||
#define STEP_K2(i) \
|
||||
IF_M(i) { \
|
||||
acc##i = RVVI(__riscv_vfmacc_vf_f32, LMUL_256)(acc##i, *(a##i + k + 2), \
|
||||
b, vl); \
|
||||
}
|
||||
ROWS_APPLY(STEP_K2)
|
||||
#undef STEP_K2
|
||||
}
|
||||
// k + 3
|
||||
{
|
||||
fixed_vfloat32m2_t b =
|
||||
fixed_fp32x8_t b =
|
||||
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 3) * ldb);
|
||||
#define STEP_K3(i) \
|
||||
IF_M(i) { \
|
||||
acc##i = __riscv_vfmacc_vf_f32m2(acc##i, *(a##i + k + 3), b, vl); \
|
||||
#define STEP_K3(i) \
|
||||
IF_M(i) { \
|
||||
acc##i = RVVI(__riscv_vfmacc_vf_f32, LMUL_256)(acc##i, *(a##i + k + 3), \
|
||||
b, vl); \
|
||||
}
|
||||
ROWS_APPLY(STEP_K3)
|
||||
#undef STEP_K3
|
||||
}
|
||||
}
|
||||
|
||||
// K tail
|
||||
for (; k < K; ++k) {
|
||||
fixed_vfloat32m2_t b = load_row8_B_as_f32<kv_cache_t>(B + (int64_t)k * ldb);
|
||||
#define TAIL_ROW(i) \
|
||||
IF_M(i) { acc##i = __riscv_vfmacc_vf_f32m2(acc##i, *(a##i + k), b, vl); }
|
||||
fixed_fp32x8_t b = load_row8_B_as_f32<kv_cache_t>(B + (int64_t)k * ldb);
|
||||
#define TAIL_ROW(i) \
|
||||
IF_M(i) { \
|
||||
acc##i = \
|
||||
RVVI(__riscv_vfmacc_vf_f32, LMUL_256)(acc##i, *(a##i + k), b, vl); \
|
||||
}
|
||||
ROWS_APPLY(TAIL_ROW)
|
||||
#undef TAIL_ROW
|
||||
}
|
||||
|
||||
// store accumulators to C
|
||||
#define STORE_ROW(i) \
|
||||
IF_M(i) { __riscv_vse32_v_f32m2(C + (i) * ldc, acc##i, vl); }
|
||||
IF_M(i) { RVVI(__riscv_vse32_v_f32, LMUL_256)(C + (i) * ldc, acc##i, vl); }
|
||||
ROWS_APPLY(STORE_ROW)
|
||||
#undef STORE_ROW
|
||||
|
||||
@@ -381,7 +353,6 @@ class AttentionImpl<ISA::RVV, scalar_t, head_dim, kv_cache_scalar_t> {
|
||||
const int64_t block_idx = pos / block_size;
|
||||
const int64_t block_offset = pos % block_size;
|
||||
{
|
||||
// Write Key (transpose to column-major: [head_dim, block_size])
|
||||
const scalar_t* key_start_ptr = key +
|
||||
token_idx * key_token_num_stride +
|
||||
head_idx * key_head_num_stride;
|
||||
@@ -389,8 +360,6 @@ class AttentionImpl<ISA::RVV, scalar_t, head_dim, kv_cache_scalar_t> {
|
||||
key_cache + block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride + block_offset;
|
||||
|
||||
// Strided vector store for efficient transpose.
|
||||
// Load contiguous key elements, store with stride = block_size.
|
||||
{
|
||||
const ptrdiff_t byte_stride = block_size * sizeof(scalar_t);
|
||||
int64_t i = 0;
|
||||
@@ -405,7 +374,6 @@ class AttentionImpl<ISA::RVV, scalar_t, head_dim, kv_cache_scalar_t> {
|
||||
i * block_size),
|
||||
byte_stride, v, vl);
|
||||
} else {
|
||||
// Half and BFloat16 are both 16-bit types
|
||||
vl = __riscv_vsetvl_e16m1(head_dim - i);
|
||||
vuint16m1_t v = __riscv_vle16_v_u16m1(
|
||||
reinterpret_cast<const uint16_t*>(key_start_ptr + i), vl);
|
||||
@@ -419,7 +387,6 @@ class AttentionImpl<ISA::RVV, scalar_t, head_dim, kv_cache_scalar_t> {
|
||||
}
|
||||
}
|
||||
{
|
||||
// Write Value (row-major: [block_size, head_dim])
|
||||
const scalar_t* value_start_ptr = value +
|
||||
token_idx * value_token_num_stride +
|
||||
head_idx * value_head_num_stride;
|
||||
@@ -440,6 +407,6 @@ class AttentionImpl<ISA::RVV, scalar_t, head_dim, kv_cache_scalar_t> {
|
||||
#undef HEAD_SIZE_ALIGNMENT
|
||||
#undef MAX_Q_HEAD_NUM_PER_ITER
|
||||
|
||||
#endif // __riscv_v_min_vlen == 128
|
||||
#endif // __riscv_v_min_vlen == 128 || 256
|
||||
|
||||
#endif // CPU_ATTN_RVV_HPP
|
||||
|
||||
@@ -71,6 +71,10 @@ typedef RVVTYPE(vuint16, LMUL_256, _t) fixed_u16x16_t
|
||||
typedef RVVTYPE(vuint16, LMUL_512, _t) fixed_u16x32_t
|
||||
__attribute__((riscv_rvv_vector_bits(512)));
|
||||
|
||||
// uint32
|
||||
typedef RVVTYPE(vuint32, LMUL_256, _t) fixed_u32x8_t
|
||||
__attribute__((riscv_rvv_vector_bits(256)));
|
||||
|
||||
// bfloat16
|
||||
#ifdef __riscv_zvfbfmin
|
||||
typedef RVVTYPE(vbfloat16, LMUL_128, _t) fixed_bf16x8_t
|
||||
|
||||
@@ -150,12 +150,10 @@ def generate_header_file() -> str:
|
||||
#include "cpu_attn_vxe.hpp"
|
||||
#endif
|
||||
|
||||
// cpu_attn_rvv.hpp is hardcoded to VLEN==128 (m1/m2 intrinsics, vl=8) and
|
||||
// itself includes <riscv_vector.h>, which is unavailable on scalar
|
||||
// (-march=rv64gc) builds. Gate the include the same way as the dispatch
|
||||
// macro below, so non-128 / scalar RISC-V builds skip it entirely.
|
||||
// cpu_attn_rvv.hpp supports VLEN=128 and VLEN=256 via RVVI() macros.
|
||||
// Other VLENs and scalar RISC-V builds skip it entirely.
|
||||
#if defined(__riscv) && defined(__riscv_v_min_vlen) && \
|
||||
__riscv_v_min_vlen == 128
|
||||
(__riscv_v_min_vlen == 128 || __riscv_v_min_vlen == 256)
|
||||
#include "cpu_attn_rvv.hpp"
|
||||
#endif
|
||||
|
||||
@@ -222,15 +220,12 @@ def generate_header_file() -> str:
|
||||
["VXE", "VEC", "VEC16"],
|
||||
fp8=False,
|
||||
)
|
||||
# RISC-V with RVV. cpu_attn_rvv.hpp is hardcoded to VLEN==128
|
||||
# (riscv_rvv_vector_bits(128) typedefs + vl=8 m1/m2 intrinsics), so
|
||||
# we split the dispatch into two top-level branches: VLEN==128 builds
|
||||
# get the full RVV+VEC+VEC16 case set, other VLEN builds get a
|
||||
# VEC/VEC16-only fallback. Preprocessor directives cannot appear
|
||||
# inside a #define body, so this duplication is necessary.
|
||||
# RISC-V with RVV. cpu_attn_rvv.hpp supports VLEN=128 and VLEN=256
|
||||
# via RVVI() macros. Builds with a supported VLEN get
|
||||
# RVV+VEC+VEC16; other RISC-V builds fall back to VEC/VEC16 only.
|
||||
header += _macro_block(
|
||||
"#elif defined(__riscv) && defined(__riscv_v_min_vlen) "
|
||||
"&& __riscv_v_min_vlen == 128",
|
||||
"&& (__riscv_v_min_vlen == 128 || __riscv_v_min_vlen == 256)",
|
||||
["RVV", "VEC", "VEC16"],
|
||||
fp8=False,
|
||||
)
|
||||
|
||||
@@ -514,26 +514,23 @@ def _make_sliding_window_bias(
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _riscv_supports_rvv_vlen128() -> bool:
|
||||
"""Whether the C++ RVV attention path (hardcoded to VLEN==128) is usable.
|
||||
def _riscv_supports_rvv() -> bool:
|
||||
"""Whether the C++ RVV attention path is usable.
|
||||
|
||||
The kernel in csrc/cpu/cpu_attn_rvv.hpp uses riscv_rvv_vector_bits(128)
|
||||
typedefs and m1/m2 intrinsics with vl=8; CMake's auto-detect picks the
|
||||
largest zvl<N>b advertised by /proc/cpuinfo, so the binary contains the
|
||||
RVV path only when the build host advertised exactly zvl128b. Mirror
|
||||
that here so the Python dispatch doesn't request ISA::RVV on builds
|
||||
where it wasn't compiled in (would TORCH_CHECK at first attention call).
|
||||
The kernel in csrc/cpu/cpu_attn_rvv.hpp uses VLEN-agnostic RVVI()
|
||||
macros and supports VLEN=128 and VLEN=256. CMake auto-detects the
|
||||
largest zvl<N>b from /proc/cpuinfo and passes it via -mrvv-vector-bits.
|
||||
The RVV path is compiled whenever __riscv_v_min_vlen is defined, so
|
||||
we check that at least one supported zvl<N>b is advertised.
|
||||
"""
|
||||
try:
|
||||
with open("/proc/cpuinfo") as f:
|
||||
cpuinfo = f.read()
|
||||
except OSError:
|
||||
return False
|
||||
if "zvl128b" not in cpuinfo:
|
||||
return False
|
||||
# CMake auto-detect picks the largest advertised VLEN; if the host
|
||||
# advertises zvl256b or higher, the build skipped the RVV-128 path.
|
||||
return all(f"zvl{n}b" not in cpuinfo for n in (256, 512, 1024))
|
||||
return any(f"zvl{n}b" in cpuinfo for n in (128, 256)) and all(
|
||||
f"zvl{n}b" not in cpuinfo for n in (512, 1024)
|
||||
)
|
||||
|
||||
|
||||
def _get_attn_isa(
|
||||
@@ -566,7 +563,7 @@ def _get_attn_isa(
|
||||
if supports_arm:
|
||||
# support ARM NEON FMLA and BFMMLA (bf16) for block size 32
|
||||
return "neon"
|
||||
elif supports_riscv and _riscv_supports_rvv_vlen128():
|
||||
elif supports_riscv and _riscv_supports_rvv():
|
||||
return "rvv"
|
||||
elif supports_vxe:
|
||||
return "vxe"
|
||||
|
||||
Reference in New Issue
Block a user