From c68c55d43e504745dbfc2d46b552e80acb74d4b9 Mon Sep 17 00:00:00 2001 From: velonica0 <47554626+velonica0@users.noreply.github.com> Date: Thu, 21 May 2026 19:50:49 +0800 Subject: [PATCH] [CPU][RISC-V] Add VLEN=256 support to RVV attention kernels (#42943) Signed-off-by: velonica0 Signed-off-by: velonica0 <47554626+velonica0@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Li, Jiang --- csrc/cpu/cpu_attn_rvv.hpp | 169 ++++++++++--------------- csrc/cpu/cpu_types_riscv_defs.hpp | 4 + csrc/cpu/generate_cpu_attn_dispatch.py | 19 +-- vllm/v1/attention/backends/cpu_attn.py | 25 ++-- 4 files changed, 90 insertions(+), 127 deletions(-) diff --git a/csrc/cpu/cpu_attn_rvv.hpp b/csrc/cpu/cpu_attn_rvv.hpp index ec1230547eb..396cc55c59e 100644 --- a/csrc/cpu/cpu_attn_rvv.hpp +++ b/csrc/cpu/cpu_attn_rvv.hpp @@ -4,17 +4,18 @@ #ifndef CPU_ATTN_RVV_HPP #define CPU_ATTN_RVV_HPP -// This kernel is currently hardcoded to VLEN=128 (m1/m2 intrinsics, vl=8). -// The fixed-width typedefs below use `riscv_rvv_vector_bits(128)`, which -// only matches `vfloat16m1_t`/`vuint16m1_t` register layout when VLEN==128; -// at VLEN>=256 those typedefs fail to compile. Scalar RISC-V builds -// (-march=rv64gc) additionally don't have . For both -// cases we omit the file entirely and let the dispatcher fall back to the -// scalar VEC / VEC16 implementations. TODO: migrate to RVVI() macros + -// semantic names in cpu_types_riscv_defs.hpp to support VLEN>=256 natively. -#if defined(__riscv_v_min_vlen) && __riscv_v_min_vlen == 128 +// RVV attention kernel using VLEN-agnostic RVVI() macros from +// cpu_types_riscv_defs.hpp. The Mx8 tile GEMM uses 8 FP32 elements +// per vector (LMUL_256 bits of FP32 data), which maps to: +// VLEN=128: m2 (256 bits = 8 x FP32) +// VLEN=256: m1 (256 bits = 8 x FP32) +// Only VLEN=128 and VLEN=256 are supported; other VLENs (512, 1024) +// and scalar RISC-V builds fall back to VEC/VEC16. +#if defined(__riscv_v_min_vlen) && \ + (__riscv_v_min_vlen == 128 || __riscv_v_min_vlen == 256) #include "cpu_attn_impl.hpp" + #include "cpu_types_riscv_defs.hpp" #include #include @@ -22,73 +23,50 @@ namespace cpu_attention { namespace { -// File-local concrete-LMUL typedefs. The shared _defs.hpp exposes -// VLEN-independent semantic names (fixed_fp32x8_t, fixed_fp16x8_t, ...), -// but this kernel is currently hardcoded to VLEN=128 (m1/m2 intrinsics), -// so keep the legacy concrete aliases scoped to this file. -typedef vfloat16m1_t fixed_vfloat16m1_t - __attribute__((riscv_rvv_vector_bits(128))); -typedef vfloat32m2_t fixed_vfloat32m2_t - __attribute__((riscv_rvv_vector_bits(256))); -typedef vuint16m1_t fixed_vuint16m1_t - __attribute__((riscv_rvv_vector_bits(128))); -typedef vuint32m2_t fixed_vuint32m2_t - __attribute__((riscv_rvv_vector_bits(256))); - #ifdef __riscv_zvfbfmin -typedef vbfloat16m1_t fixed_vbfloat16m1_t - __attribute__((riscv_rvv_vector_bits(128))); - #endif - #define BLOCK_SIZE_ALIGNMENT 32 #define HEAD_SIZE_ALIGNMENT 32 #define MAX_Q_HEAD_NUM_PER_ITER 16 // ============================================================================ -// B-matrix row loading: load 8 elements as FP32 (using m2 LMUL at VLEN=128) +// B-matrix row loading: load 8 elements as FP32 // ============================================================================ template -FORCE_INLINE fixed_vfloat32m2_t load_row8_B_as_f32(const kv_cache_t* p); +FORCE_INLINE fixed_fp32x8_t load_row8_B_as_f32(const kv_cache_t* p); template <> -FORCE_INLINE fixed_vfloat32m2_t load_row8_B_as_f32(const float* p) { - return __riscv_vle32_v_f32m2(p, 8); +FORCE_INLINE fixed_fp32x8_t load_row8_B_as_f32(const float* p) { + return RVVI(__riscv_vle32_v_f32, LMUL_256)(p, 8); } template <> -FORCE_INLINE fixed_vfloat32m2_t -load_row8_B_as_f32(const c10::Half* p) { +FORCE_INLINE fixed_fp32x8_t load_row8_B_as_f32(const c10::Half* p) { #ifdef __riscv_zvfh - fixed_vfloat16m1_t h = - __riscv_vle16_v_f16m1(reinterpret_cast(p), 8); - return __riscv_vfwcvt_f_f_v_f32m2(h, 8); + fixed_fp16x8_t h = RVVI(__riscv_vle16_v_f16, LMUL_128)( + reinterpret_cast(p), 8); + return RVVI(__riscv_vfwcvt_f_f_v_f32, LMUL_256)(h, 8); #else - // Fallback for hardware without Zvfh: scalar half->float conversion. - // c10::Half provides operator float() so this is correct on any RVV CPU - // that has only the base V extension. Slower than the Zvfh path, but - // keeps the kernel buildable on Zvfhmin-only / no-fp16 hardware. alignas(16) float tmp[8]; for (int i = 0; i < 8; ++i) { tmp[i] = static_cast(p[i]); } - return __riscv_vle32_v_f32m2(tmp, 8); + return RVVI(__riscv_vle32_v_f32, LMUL_256)(tmp, 8); #endif } template <> -FORCE_INLINE fixed_vfloat32m2_t +FORCE_INLINE fixed_fp32x8_t load_row8_B_as_f32(const c10::BFloat16* p) { #ifdef __riscv_zvfbfmin - fixed_vbfloat16m1_t bf = - __riscv_vle16_v_bf16m1(reinterpret_cast(p), 8); - return __riscv_vfwcvtbf16_f_f_v_f32m2(bf, 8); + fixed_bf16x8_t bf = RVVI(__riscv_vle16_v_bf16, LMUL_128)( + reinterpret_cast(p), 8); + return RVVI(__riscv_vfwcvtbf16_f_f_v_f32, LMUL_256)(bf, 8); #else - // Fallback: load as uint16, zero-extend to uint32, shift left by 16 - fixed_vuint16m1_t raw = - __riscv_vle16_v_u16m1(reinterpret_cast(p), 8); - fixed_vuint32m2_t wide = __riscv_vzext_vf2_u32m2(raw, 8); - fixed_vuint32m2_t shifted = __riscv_vsll_vx_u32m2(wide, 16, 8); - return __riscv_vreinterpret_v_u32m2_f32m2(shifted); + fixed_u16x8_t raw = RVVI(__riscv_vle16_v_u16, LMUL_128)( + reinterpret_cast(p), 8); + fixed_u32x8_t wide = RVVI(__riscv_vzext_vf2_u32, LMUL_256)(raw, 8); + fixed_u32x8_t shifted = RVVI(__riscv_vsll_vx_u32, LMUL_256)(wide, 16, 8); + return RVVI4(__riscv_vreinterpret_v_u32, LMUL_256, _f32, LMUL_256)(shifted); #endif } @@ -96,14 +74,12 @@ load_row8_B_as_f32(const c10::BFloat16* p) { // Micro kernel: Mx8 tile, K unrolled by 4, RVV scalar-broadcast FMA // ============================================================================ // -// NEON uses vfmaq_laneq_f32 (lane-indexed FMA from a preloaded A vector). // RVV has no lane-indexed FMA; instead we load A elements as scalars and -// use __riscv_vfmacc_vf (scalar * vector + accumulator), which is equally -// efficient and avoids the need for vrgather/vslidedown. +// use vfmacc_vf (scalar * vector + accumulator). // -// At VLEN=128, m2 holds 8 x FP32, matching the 8-column tile width. -// Register budget: M accumulators (m2 each) + 1 B temp = 2M+2 regs. -// M=8 => 18 regs out of 32 available — no spills. +// The 8-column tile uses LMUL_256 bits of FP32 data: +// VLEN=128: m2 (2 regs per accumulator), M=8 => 18 of 32 regs +// VLEN=256: m1 (1 reg per accumulator), M=8 => 9 of 32 regs template FORCE_INLINE void gemm_micro_rvv_fma_Mx8_Ku4( @@ -115,94 +91,90 @@ FORCE_INLINE void gemm_micro_rvv_fma_Mx8_Ku4( constexpr size_t vl = 8; - // helpers for per-M codegen #define ROWS_APPLY(OP) OP(0) OP(1) OP(2) OP(3) OP(4) OP(5) OP(6) OP(7) #define IF_M(i) if constexpr (M > (i)) - // A row base pointers #define DECL_A(i) const float* a##i = A + (i) * lda; ROWS_APPLY(DECL_A) #undef DECL_A - // declare one m2 accumulator per row - #define DECL_ACC(i) fixed_vfloat32m2_t acc##i; + #define DECL_ACC(i) fixed_fp32x8_t acc##i; ROWS_APPLY(DECL_ACC) #undef DECL_ACC - // initialize accumulators - #define INIT_ACC(i) \ - IF_M(i) { \ - if (accumulate) { \ - acc##i = __riscv_vle32_v_f32m2(C + (i) * ldc, vl); \ - } else { \ - acc##i = __riscv_vfmv_v_f_f32m2(0.f, vl); \ - } \ + #define INIT_ACC(i) \ + IF_M(i) { \ + if (accumulate) { \ + acc##i = RVVI(__riscv_vle32_v_f32, LMUL_256)(C + (i) * ldc, vl); \ + } else { \ + acc##i = RVVI(__riscv_vfmv_v_f_f32, LMUL_256)(0.f, vl); \ + } \ } ROWS_APPLY(INIT_ACC) #undef INIT_ACC int32_t k = 0; - // K unrolled by 4 for (; k + 3 < K; k += 4) { - // k + 0 { - fixed_vfloat32m2_t b = + fixed_fp32x8_t b = load_row8_B_as_f32(B + (int64_t)(k + 0) * ldb); - #define STEP_K0(i) \ - IF_M(i) { \ - acc##i = __riscv_vfmacc_vf_f32m2(acc##i, *(a##i + k + 0), b, vl); \ + #define STEP_K0(i) \ + IF_M(i) { \ + acc##i = RVVI(__riscv_vfmacc_vf_f32, LMUL_256)(acc##i, *(a##i + k + 0), \ + b, vl); \ } ROWS_APPLY(STEP_K0) #undef STEP_K0 } - // k + 1 { - fixed_vfloat32m2_t b = + fixed_fp32x8_t b = load_row8_B_as_f32(B + (int64_t)(k + 1) * ldb); - #define STEP_K1(i) \ - IF_M(i) { \ - acc##i = __riscv_vfmacc_vf_f32m2(acc##i, *(a##i + k + 1), b, vl); \ + #define STEP_K1(i) \ + IF_M(i) { \ + acc##i = RVVI(__riscv_vfmacc_vf_f32, LMUL_256)(acc##i, *(a##i + k + 1), \ + b, vl); \ } ROWS_APPLY(STEP_K1) #undef STEP_K1 } - // k + 2 { - fixed_vfloat32m2_t b = + fixed_fp32x8_t b = load_row8_B_as_f32(B + (int64_t)(k + 2) * ldb); - #define STEP_K2(i) \ - IF_M(i) { \ - acc##i = __riscv_vfmacc_vf_f32m2(acc##i, *(a##i + k + 2), b, vl); \ + #define STEP_K2(i) \ + IF_M(i) { \ + acc##i = RVVI(__riscv_vfmacc_vf_f32, LMUL_256)(acc##i, *(a##i + k + 2), \ + b, vl); \ } ROWS_APPLY(STEP_K2) #undef STEP_K2 } - // k + 3 { - fixed_vfloat32m2_t b = + fixed_fp32x8_t b = load_row8_B_as_f32(B + (int64_t)(k + 3) * ldb); - #define STEP_K3(i) \ - IF_M(i) { \ - acc##i = __riscv_vfmacc_vf_f32m2(acc##i, *(a##i + k + 3), b, vl); \ + #define STEP_K3(i) \ + IF_M(i) { \ + acc##i = RVVI(__riscv_vfmacc_vf_f32, LMUL_256)(acc##i, *(a##i + k + 3), \ + b, vl); \ } ROWS_APPLY(STEP_K3) #undef STEP_K3 } } - // K tail for (; k < K; ++k) { - fixed_vfloat32m2_t b = load_row8_B_as_f32(B + (int64_t)k * ldb); - #define TAIL_ROW(i) \ - IF_M(i) { acc##i = __riscv_vfmacc_vf_f32m2(acc##i, *(a##i + k), b, vl); } + fixed_fp32x8_t b = load_row8_B_as_f32(B + (int64_t)k * ldb); + #define TAIL_ROW(i) \ + IF_M(i) { \ + acc##i = \ + RVVI(__riscv_vfmacc_vf_f32, LMUL_256)(acc##i, *(a##i + k), b, vl); \ + } ROWS_APPLY(TAIL_ROW) #undef TAIL_ROW } - // store accumulators to C #define STORE_ROW(i) \ - IF_M(i) { __riscv_vse32_v_f32m2(C + (i) * ldc, acc##i, vl); } + IF_M(i) { RVVI(__riscv_vse32_v_f32, LMUL_256)(C + (i) * ldc, acc##i, vl); } ROWS_APPLY(STORE_ROW) #undef STORE_ROW @@ -381,7 +353,6 @@ class AttentionImpl { const int64_t block_idx = pos / block_size; const int64_t block_offset = pos % block_size; { - // Write Key (transpose to column-major: [head_dim, block_size]) const scalar_t* key_start_ptr = key + token_idx * key_token_num_stride + head_idx * key_head_num_stride; @@ -389,8 +360,6 @@ class AttentionImpl { key_cache + block_idx * num_blocks_stride + head_idx * cache_head_num_stride + block_offset; - // Strided vector store for efficient transpose. - // Load contiguous key elements, store with stride = block_size. { const ptrdiff_t byte_stride = block_size * sizeof(scalar_t); int64_t i = 0; @@ -405,7 +374,6 @@ class AttentionImpl { i * block_size), byte_stride, v, vl); } else { - // Half and BFloat16 are both 16-bit types vl = __riscv_vsetvl_e16m1(head_dim - i); vuint16m1_t v = __riscv_vle16_v_u16m1( reinterpret_cast(key_start_ptr + i), vl); @@ -419,7 +387,6 @@ class AttentionImpl { } } { - // Write Value (row-major: [block_size, head_dim]) const scalar_t* value_start_ptr = value + token_idx * value_token_num_stride + head_idx * value_head_num_stride; @@ -440,6 +407,6 @@ class AttentionImpl { #undef HEAD_SIZE_ALIGNMENT #undef MAX_Q_HEAD_NUM_PER_ITER -#endif // __riscv_v_min_vlen == 128 +#endif // __riscv_v_min_vlen == 128 || 256 #endif // CPU_ATTN_RVV_HPP diff --git a/csrc/cpu/cpu_types_riscv_defs.hpp b/csrc/cpu/cpu_types_riscv_defs.hpp index c3e4f3af843..8871617f05f 100644 --- a/csrc/cpu/cpu_types_riscv_defs.hpp +++ b/csrc/cpu/cpu_types_riscv_defs.hpp @@ -71,6 +71,10 @@ typedef RVVTYPE(vuint16, LMUL_256, _t) fixed_u16x16_t typedef RVVTYPE(vuint16, LMUL_512, _t) fixed_u16x32_t __attribute__((riscv_rvv_vector_bits(512))); +// uint32 +typedef RVVTYPE(vuint32, LMUL_256, _t) fixed_u32x8_t + __attribute__((riscv_rvv_vector_bits(256))); + // bfloat16 #ifdef __riscv_zvfbfmin typedef RVVTYPE(vbfloat16, LMUL_128, _t) fixed_bf16x8_t diff --git a/csrc/cpu/generate_cpu_attn_dispatch.py b/csrc/cpu/generate_cpu_attn_dispatch.py index 8d9fcf5c755..7c7123a6def 100644 --- a/csrc/cpu/generate_cpu_attn_dispatch.py +++ b/csrc/cpu/generate_cpu_attn_dispatch.py @@ -150,12 +150,10 @@ def generate_header_file() -> str: #include "cpu_attn_vxe.hpp" #endif -// cpu_attn_rvv.hpp is hardcoded to VLEN==128 (m1/m2 intrinsics, vl=8) and -// itself includes , which is unavailable on scalar -// (-march=rv64gc) builds. Gate the include the same way as the dispatch -// macro below, so non-128 / scalar RISC-V builds skip it entirely. +// cpu_attn_rvv.hpp supports VLEN=128 and VLEN=256 via RVVI() macros. +// Other VLENs and scalar RISC-V builds skip it entirely. #if defined(__riscv) && defined(__riscv_v_min_vlen) && \ - __riscv_v_min_vlen == 128 + (__riscv_v_min_vlen == 128 || __riscv_v_min_vlen == 256) #include "cpu_attn_rvv.hpp" #endif @@ -222,15 +220,12 @@ def generate_header_file() -> str: ["VXE", "VEC", "VEC16"], fp8=False, ) - # RISC-V with RVV. cpu_attn_rvv.hpp is hardcoded to VLEN==128 - # (riscv_rvv_vector_bits(128) typedefs + vl=8 m1/m2 intrinsics), so - # we split the dispatch into two top-level branches: VLEN==128 builds - # get the full RVV+VEC+VEC16 case set, other VLEN builds get a - # VEC/VEC16-only fallback. Preprocessor directives cannot appear - # inside a #define body, so this duplication is necessary. + # RISC-V with RVV. cpu_attn_rvv.hpp supports VLEN=128 and VLEN=256 + # via RVVI() macros. Builds with a supported VLEN get + # RVV+VEC+VEC16; other RISC-V builds fall back to VEC/VEC16 only. header += _macro_block( "#elif defined(__riscv) && defined(__riscv_v_min_vlen) " - "&& __riscv_v_min_vlen == 128", + "&& (__riscv_v_min_vlen == 128 || __riscv_v_min_vlen == 256)", ["RVV", "VEC", "VEC16"], fp8=False, ) diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 801a16d319e..005975c4775 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -514,26 +514,23 @@ def _make_sliding_window_bias( @functools.lru_cache(maxsize=1) -def _riscv_supports_rvv_vlen128() -> bool: - """Whether the C++ RVV attention path (hardcoded to VLEN==128) is usable. +def _riscv_supports_rvv() -> bool: + """Whether the C++ RVV attention path is usable. - The kernel in csrc/cpu/cpu_attn_rvv.hpp uses riscv_rvv_vector_bits(128) - typedefs and m1/m2 intrinsics with vl=8; CMake's auto-detect picks the - largest zvlb advertised by /proc/cpuinfo, so the binary contains the - RVV path only when the build host advertised exactly zvl128b. Mirror - that here so the Python dispatch doesn't request ISA::RVV on builds - where it wasn't compiled in (would TORCH_CHECK at first attention call). + The kernel in csrc/cpu/cpu_attn_rvv.hpp uses VLEN-agnostic RVVI() + macros and supports VLEN=128 and VLEN=256. CMake auto-detects the + largest zvlb from /proc/cpuinfo and passes it via -mrvv-vector-bits. + The RVV path is compiled whenever __riscv_v_min_vlen is defined, so + we check that at least one supported zvlb is advertised. """ try: with open("/proc/cpuinfo") as f: cpuinfo = f.read() except OSError: return False - if "zvl128b" not in cpuinfo: - return False - # CMake auto-detect picks the largest advertised VLEN; if the host - # advertises zvl256b or higher, the build skipped the RVV-128 path. - return all(f"zvl{n}b" not in cpuinfo for n in (256, 512, 1024)) + return any(f"zvl{n}b" in cpuinfo for n in (128, 256)) and all( + f"zvl{n}b" not in cpuinfo for n in (512, 1024) + ) def _get_attn_isa( @@ -566,7 +563,7 @@ def _get_attn_isa( if supports_arm: # support ARM NEON FMLA and BFMMLA (bf16) for block size 32 return "neon" - elif supports_riscv and _riscv_supports_rvv_vlen128(): + elif supports_riscv and _riscv_supports_rvv(): return "rvv" elif supports_vxe: return "vxe"