[CPU][Perf] Enable fused kernels for GDN's gated delta rules (#43534)

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
Co-authored-by: Li, Jiang <jiang1.li@intel.com>
This commit is contained in:
Fadi Arafeh
2026-06-02 09:00:48 +01:00
committed by GitHub
parent dcdfe66bfa
commit 0b25cf4419
11 changed files with 824 additions and 597 deletions
+3 -1
View File
@@ -16,6 +16,7 @@ steps:
- tests/kernels/test_onednn.py
- tests/kernels/test_awq_int4_to_int8.py
- tests/kernels/quantization/test_cpu_fp8_scaled_mm.py
- tests/kernels/mamba/cpu/test_cpu_gdn_ops.py
commands:
- |
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 30m "
@@ -24,7 +25,8 @@ steps:
pytest -x -v -s tests/kernels/moe/test_cpu_quant_fused_moe.py
pytest -x -v -s tests/kernels/test_onednn.py
pytest -x -v -s tests/kernels/test_awq_int4_to_int8.py
pytest -x -v -s tests/kernels/quantization/test_cpu_fp8_scaled_mm.py"
pytest -x -v -s tests/kernels/quantization/test_cpu_fp8_scaled_mm.py
pytest -x -v -s tests/kernels/mamba/cpu/test_cpu_gdn_ops.py"
- label: CPU-Compatibility Tests
depends_on: []
@@ -37,7 +37,8 @@ function cpu_tests() {
pytest -x -v -s tests/kernels/test_onednn.py
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
pytest -x -v -s tests/kernels/core/test_cpu_activation.py
pytest -x -v -s tests/kernels/moe/test_moe.py -k test_cpu_fused_moe_basic"
pytest -x -v -s tests/kernels/moe/test_moe.py -k test_cpu_fused_moe_basic
pytest -x -v -s tests/kernels/mamba/cpu/test_cpu_gdn_ops.py"
# skip tests requiring model downloads if HF_TOKEN is not set
# due to rate-limits
+18 -1
View File
@@ -369,6 +369,18 @@ else()
add_compile_definitions(-DVLLM_NUMA_DISABLED)
endif()
# check if the pytorch wheel ships libopenblas.so.
set(VLLM_OPENBLAS_LIB "")
if (NOT ENABLE_X86_ISA)
file(GLOB _VLLM_TORCH_OPENBLAS_LIBS
"${TORCH_INSTALL_PREFIX}/lib/libopenblas*.so*")
# Note: we don't link openblas directly to _C extension, as it's available through libtorch.so
if (_VLLM_TORCH_OPENBLAS_LIBS)
list(GET _VLLM_TORCH_OPENBLAS_LIBS 0 VLLM_OPENBLAS_LIB)
message(STATUS "CPU OpenBLAS library: ${VLLM_OPENBLAS_LIB}")
endif()
endif()
#
# Generate CPU attention dispatch header
#
@@ -387,6 +399,7 @@ endif()
#
set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp"
"csrc/cpu/sgl-kernels/fla.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/spec_decode_utils.cpp"
"csrc/cpu/layernorm.cpp"
@@ -418,7 +431,6 @@ endif()
if (ENABLE_X86_ISA)
set(VLLM_EXT_SRC_SGL
"csrc/cpu/sgl-kernels/fla.cpp"
"csrc/cpu/sgl-kernels/conv.cpp"
"csrc/cpu/sgl-kernels/gemm.cpp"
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
@@ -430,6 +442,7 @@ if (ENABLE_X86_ISA)
"csrc/cpu/sgl-kernels/moe_fp8.cpp")
set(VLLM_EXT_SRC_AVX512
"csrc/cpu/sgl-kernels/fla.cpp"
"csrc/cpu/shm.cpp"
"csrc/cpu/cpu_wna16.cpp"
"csrc/cpu/cpu_fused_moe.cpp"
@@ -446,6 +459,7 @@ if (ENABLE_X86_ISA)
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
set(VLLM_EXT_SRC_AVX2
"csrc/cpu/sgl-kernels/fla.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/spec_decode_utils.cpp"
"csrc/cpu/cpu_attn.cpp"
@@ -519,6 +533,9 @@ else()
USE_SABI 3
WITH_SOABI
)
if (VLLM_OPENBLAS_LIB)
target_compile_definitions(_C PRIVATE VLLM_HAS_OPENBLAS)
endif()
endif()
message(STATUS "Enabling C extension.")
+82
View File
@@ -0,0 +1,82 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#include <ATen/native/CPUBlas.h>
// Unlike brgemm, PyTorch does not publicly expose at::native::cpublas::gemm
// If OpenBLS is available in the PyTorch wheel, we rely on it for fast
// bf16:bf16->fp32 GEMMs Otherwise, we fall back to PyTorch reference BLAS path.
#if defined(VLLM_HAS_OPENBLAS)
extern "C" void sbgemm_(char* transa, char* transb, int* m, int* n, int* k,
float* alpha, const at::BFloat16* a, int* lda,
const at::BFloat16* b, int* ldb, float* beta, float* c,
int* ldc);
extern "C" void sgemm_(char* transa, char* transb, int* m, int* n, int* k,
float* alpha, const float* a, int* lda, const float* b,
int* ldb, float* beta, float* c, int* ldc);
inline char blas_transpose(at::native::TransposeType trans) {
switch (trans) {
case at::native::TransposeType::NoTranspose:
return 'n';
case at::native::TransposeType::Transpose:
return 't';
case at::native::TransposeType::ConjTranspose:
return 'c';
}
return 'n';
}
inline void blas_gemm(at::native::TransposeType transa,
at::native::TransposeType transb, int64_t m, int64_t n,
int64_t k, float alpha, const at::BFloat16* a,
int64_t lda, const at::BFloat16* b, int64_t ldb,
float beta, float* c, int64_t ldc) {
char transa_ = blas_transpose(transa);
char transb_ = blas_transpose(transb);
int m_ = static_cast<int>(m);
int n_ = static_cast<int>(n);
int k_ = static_cast<int>(k);
int lda_ = static_cast<int>(lda);
int ldb_ = static_cast<int>(ldb);
int ldc_ = static_cast<int>(ldc);
sbgemm_(&transa_, &transb_, &m_, &n_, &k_, &alpha, a, &lda_, b, &ldb_, &beta,
c, &ldc_);
}
inline void blas_gemm(at::native::TransposeType transa,
at::native::TransposeType transb, int64_t m, int64_t n,
int64_t k, float alpha, const float* a, int64_t lda,
const float* b, int64_t ldb, float beta, float* c,
int64_t ldc) {
char transa_ = blas_transpose(transa);
char transb_ = blas_transpose(transb);
int m_ = static_cast<int>(m);
int n_ = static_cast<int>(n);
int k_ = static_cast<int>(k);
int lda_ = static_cast<int>(lda);
int ldb_ = static_cast<int>(ldb);
int ldc_ = static_cast<int>(ldc);
sgemm_(&transa_, &transb_, &m_, &n_, &k_, &alpha, a, &lda_, b, &ldb_, &beta,
c, &ldc_);
}
inline void blas_gemm(at::native::TransposeType, at::native::TransposeType,
int64_t, int64_t, int64_t, float, const at::Half*,
int64_t, const at::Half*, int64_t, float, float*,
int64_t) {
TORCH_CHECK(false, "CPU OpenBLAS hgemm is not available.");
}
#else
template <typename scalar_t>
inline void blas_gemm(at::native::TransposeType transa,
at::native::TransposeType transb, int64_t m, int64_t n,
int64_t k, float alpha, const scalar_t* a, int64_t lda,
const scalar_t* b, int64_t ldb, float beta, float* c,
int64_t ldc) {
auto gemm = at::native::cpublas::gemm_no_downcast_stub.DEFAULT;
gemm(c10::CppTypeToScalarType<scalar_t>::value, transa, transb, m, n, k,
at::Scalar(alpha), a, lda, b, ldb, at::Scalar(beta), c, ldc);
}
#endif
+278 -141
View File
@@ -301,25 +301,42 @@ void chunk_gated_delta_rule_kernel_impl(
// attn = k_beta @ key.transpose(-1, -2)
// attn: [B, HV, num_chunk, chunk_size, chunk_size]
// transpose and pack for key
pack_vnni<scalar_t>(
/* dst */ k_transpose,
/* src */ curr_k_pad,
/* N */ chunk_size,
/* K */ qk_head_size,
/* ld_src */ qk_head_size,
/* ld_dst */ chunk_size);
// k_beta @ key.transpose(-1, -2)
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ chunk_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ chunk_size,
/* ldc */ chunk_size,
/* add_C */ false,
/* A */ curr_k_beta,
/* B */ k_transpose,
/* C */ curr_attn);
if constexpr (brgemm_supported()) {
pack_vnni<scalar_t>(
/* dst */ k_transpose,
/* src */ curr_k_pad,
/* N */ chunk_size,
/* K */ qk_head_size,
/* ld_src */ qk_head_size,
/* ld_dst */ chunk_size);
// k_beta @ key.transpose(-1, -2)
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ chunk_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ chunk_size,
/* ldc */ chunk_size,
/* add_C */ false,
/* A */ curr_k_beta,
/* B */ k_transpose,
/* C */ curr_attn);
} else {
blas_gemm(
at::native::TransposeType::Transpose,
at::native::TransposeType::NoTranspose,
chunk_size,
chunk_size,
qk_head_size,
1.0f,
curr_k_pad,
qk_head_size,
curr_k_beta,
qk_head_size,
0.0f,
curr_attn,
chunk_size);
}
// attn = attn * decay_mask
for (int64_t m = 0; m < chunk_size; m++) {
at::vec::map2<float>(
@@ -413,25 +430,42 @@ void chunk_gated_delta_rule_kernel_impl(
// k_beta_g = k_beta * g: [B, HV, num_chunk, chunk_size, EK]
// k_cumdecay: [B, HV, num_chunk, chunk_size, EK]
// pack for value
pack_vnni2<scalar_t>(
/* dst */ v_pack,
/* src */ curr_v_beta,
/* N */ chunk_size,
/* K */ v_head_size,
/* ld_src */ v_head_size,
/* ld_dst */ v_head_size);
// value = attn @ v_beta
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ curr_attn_reduced,
/* B */ v_pack,
/* C */ curr_value);
if constexpr (brgemm_supported()) {
pack_vnni2<scalar_t>(
/* dst */ v_pack,
/* src */ curr_v_beta,
/* N */ chunk_size,
/* K */ v_head_size,
/* ld_src */ v_head_size,
/* ld_dst */ v_head_size);
// value = attn @ v_beta
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ curr_attn_reduced,
/* B */ v_pack,
/* C */ curr_value);
} else {
blas_gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
v_head_size,
chunk_size,
chunk_size,
1.0f,
curr_v_beta,
v_head_size,
curr_attn_reduced,
chunk_size,
0.0f,
curr_value,
v_head_size);
}
// k_beta_g = k_beta * g.exp().unsqueeze(-1)
for (int64_t j = 0; j < chunk_size; j++) {
int64_t i = 0;
@@ -445,25 +479,42 @@ void chunk_gated_delta_rule_kernel_impl(
}
}
// pack for k_beta_g
pack_vnni2<scalar_t>(
/* dst */ k_beta_g_pack,
/* src */ k_beta_g,
/* N */ chunk_size,
/* K */ qk_head_size,
/* ld_src */ qk_head_size,
/* ld_dst */ qk_head_size);
// k_cumdecay = attn @ k_beta_g
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ qk_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ qk_head_size,
/* ldc */ qk_head_size,
/* add_C */ false,
/* A */ curr_attn_reduced,
/* B */ k_beta_g_pack,
/* C */ k_cumdecay);
if constexpr (brgemm_supported()) {
pack_vnni2<scalar_t>(
/* dst */ k_beta_g_pack,
/* src */ k_beta_g,
/* N */ chunk_size,
/* K */ qk_head_size,
/* ld_src */ qk_head_size,
/* ld_dst */ qk_head_size);
// k_cumdecay = attn @ k_beta_g
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ qk_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ qk_head_size,
/* ldc */ qk_head_size,
/* add_C */ false,
/* A */ curr_attn_reduced,
/* B */ k_beta_g_pack,
/* C */ k_cumdecay);
} else {
blas_gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
qk_head_size,
chunk_size,
chunk_size,
1.0f,
k_beta_g,
qk_head_size,
curr_attn_reduced,
chunk_size,
0.0f,
k_cumdecay,
qk_head_size);
}
for (int i = 0; i < chunk_size; i++) {
at::vec::map<scalar_t>(
[](fVec x) { return x; },
@@ -551,25 +602,42 @@ void chunk_gated_delta_rule_kernel_impl(
// attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
// k_transpose_i = k_i.transpose(-1, -2)
pack_vnni<scalar_t>(
/* dst */ k_transpose_i,
/* src */ k_i,
/* N */ chunk_size,
/* K */ qk_head_size,
/* ld_src */ qk_head_size,
/* ld_dst */ chunk_size);
// attn_i = q_i @ k_transpose_i
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ chunk_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ chunk_size,
/* ldc */ chunk_size,
/* add_C */ false,
/* A */ q_i,
/* B */ k_transpose_i,
/* C */ attn_i);
if constexpr (brgemm_supported()) {
pack_vnni<scalar_t>(
/* dst */ k_transpose_i,
/* src */ k_i,
/* N */ chunk_size,
/* K */ qk_head_size,
/* ld_src */ qk_head_size,
/* ld_dst */ chunk_size);
// attn_i = q_i @ k_transpose_i
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ chunk_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ chunk_size,
/* ldc */ chunk_size,
/* add_C */ false,
/* A */ q_i,
/* B */ k_transpose_i,
/* C */ attn_i);
} else {
blas_gemm(
at::native::TransposeType::Transpose,
at::native::TransposeType::NoTranspose,
chunk_size,
chunk_size,
qk_head_size,
1.0f,
k_i,
qk_head_size,
q_i,
qk_head_size,
0.0f,
attn_i,
chunk_size);
}
// attn_i = attn_i * decay_mask_i
for (int64_t m = 0; m < chunk_size; m++) {
auto attn_i_m = attn_i + m * chunk_size;
@@ -609,28 +677,45 @@ void chunk_gated_delta_rule_kernel_impl(
}
// pack for curr_last_recurrent_state
pack_vnni2<scalar_t>(
/* dst */ curr_last_recurrent_state_pack_reduced,
/* src */ curr_last_recurrent_state_reduced,
/* N */ qk_head_size,
/* K */ v_head_size,
/* ld_src */ v_head_size,
/* ld_dst */ v_head_size);
if constexpr (brgemm_supported()) {
pack_vnni2<scalar_t>(
/* dst */ curr_last_recurrent_state_pack_reduced,
/* src */ curr_last_recurrent_state_reduced,
/* N */ qk_head_size,
/* K */ v_head_size,
/* ld_src */ v_head_size,
/* ld_dst */ v_head_size);
// v_prime = k_cumdecay_i @ curr_last_recurrent_state: [chunk_size, EV]
// k_cumdecay_i: [chunk_size, EK]
// curr_last_recurrent_state: [EK, EV]
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ k_cumdecay_i_reduced,
/* B */ curr_last_recurrent_state_pack_reduced,
/* C */ v_prime);
// v_prime = k_cumdecay_i @ curr_last_recurrent_state: [chunk_size, EV]
// k_cumdecay_i: [chunk_size, EK]
// curr_last_recurrent_state: [EK, EV]
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ k_cumdecay_i_reduced,
/* B */ curr_last_recurrent_state_pack_reduced,
/* C */ v_prime);
} else {
blas_gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
v_head_size,
chunk_size,
qk_head_size,
1.0f,
curr_last_recurrent_state_reduced,
v_head_size,
k_cumdecay_i_reduced,
qk_head_size,
0.0f,
v_prime,
v_head_size);
}
// v_new = v_prime = v_i - v_prime
// v_i: [chunk_size, EV]
@@ -663,41 +748,75 @@ void chunk_gated_delta_rule_kernel_impl(
}
// attn_inter = qg @ curr_last_recurrent_state: [chunk_size, EV]
// curr_last_recurrent_state: [EK, EV]
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ qg,
/* B */ curr_last_recurrent_state_pack_reduced,
/* C */ attn_inter);
if constexpr (brgemm_supported()) {
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ qk_head_size,
/* lda */ qk_head_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ qg,
/* B */ curr_last_recurrent_state_pack_reduced,
/* C */ attn_inter);
} else {
blas_gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
v_head_size,
chunk_size,
qk_head_size,
1.0f,
curr_last_recurrent_state_reduced,
v_head_size,
qg,
qk_head_size,
0.0f,
attn_inter,
v_head_size);
}
// core_attn_out[:, :, i] = attn_inter + attn_i @ v_new
// pack for v_prime
pack_vnni2<scalar_t>(
/* dst */ v_prime_pack_reduced,
/* src */ v_prime_reduced,
/* N */ chunk_size,
/* K */ v_head_size,
/* ld_src */ v_head_size,
/* ld_dst */ v_head_size);
// attn_inter = attn_inter + attn_i @ v_new: [chunk_size, EV]
// attn_i: [chunk_size, chunk_size]
// v_new: [chunk_size, EV]
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ true,
/* A */ attn_i_reduced,
/* B */ v_prime_pack_reduced,
/* C */ attn_inter);
if constexpr (brgemm_supported()) {
pack_vnni2<scalar_t>(
/* dst */ v_prime_pack_reduced,
/* src */ v_prime_reduced,
/* N */ chunk_size,
/* K */ v_head_size,
/* ld_src */ v_head_size,
/* ld_dst */ v_head_size);
// attn_inter = attn_inter + attn_i @ v_new: [chunk_size, EV]
// attn_i: [chunk_size, chunk_size]
// v_new: [chunk_size, EV]
at::native::cpublas::brgemm(
/* M */ chunk_size,
/* N */ v_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ true,
/* A */ attn_i_reduced,
/* B */ v_prime_pack_reduced,
/* C */ attn_inter);
} else {
blas_gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
v_head_size,
chunk_size,
chunk_size,
1.0f,
v_prime_reduced,
v_head_size,
attn_i_reduced,
chunk_size,
1.0f,
attn_inter,
v_head_size);
}
// core_attn_out[:, :, i] = attn_inter
for (int64_t m = 0; m < chunk_size; m++) {
@@ -762,17 +881,34 @@ void chunk_gated_delta_rule_kernel_impl(
/* ld_dst */ chunk_size);
// kgv = kg.transpose(-1, -2) @ v_new
// v_new: [chunk_size, EV]
at::native::cpublas::brgemm(
/* M */ qk_head_size,
/* N */ v_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ kg_transpose,
/* B */ v_prime_pack_reduced,
/* C */ kgv);
if constexpr (brgemm_supported()) {
at::native::cpublas::brgemm(
/* M */ qk_head_size,
/* N */ v_head_size,
/* K */ chunk_size,
/* lda */ chunk_size,
/* ldb */ v_head_size,
/* ldc */ v_head_size,
/* add_C */ false,
/* A */ kg_transpose,
/* B */ v_prime_pack_reduced,
/* C */ kgv);
} else {
blas_gemm(
at::native::TransposeType::NoTranspose,
at::native::TransposeType::NoTranspose,
v_head_size,
qk_head_size,
chunk_size,
1.0f,
v_prime_reduced,
v_head_size,
kg_transpose,
chunk_size,
0.0f,
kgv,
v_head_size);
}
// last_recurrent_state = 1) + 2)
for (int64_t m = 0; m < qk_head_size; m++) {
at::vec::map2<float>(
@@ -921,7 +1057,8 @@ void fused_sigmoid_gating_delta_rule_update_kernel_impl(
float k_scale = use_qk_l2norm_in_kernel ? qk_scale_buf[k_scale_offset] : 1.0f;
int64_t v_offset = si * v_strideS + bi * v_strideB + ni * v_strideH;
int64_t o_offset = ((bi * seq_len + si) * v_num_heads + ni) * v_head_dim;
float beta_val = 1 / (1 + std::exp(-b_ptr[ni]));
// See: https://github.com/sgl-project/sglang/pull/26634
float beta_val = 1 / (1 + std::exp(-b_ptr[bi * v_num_heads + ni]));
fVec beta_vec = fVec(beta_val);
int64_t dvi = 0;
for (; dvi <= v_head_dim - VecSize; dvi += VecSize) {
+18 -7
View File
@@ -4,9 +4,12 @@
// clang-format off
#pragma once
#include <ATen/native/CPUBlas.h>
#include "common.h"
#include "blas_gemm.h"
#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__)
#define CPU_CAPABILITY_AVX512
#endif
// amx-bf16
#define TILE_M 16
@@ -21,31 +24,39 @@ constexpr int block_size_n() {
return 2 * TILE_N;
}
constexpr bool brgemm_supported() {
#if defined(CPU_CAPABILITY_AVX512)
return true;
#else
return false;
#endif
}
// define threshold using brgemm (intel AMX)
template <typename T>
inline bool can_use_brgemm(int M);
template <>
inline bool can_use_brgemm<at::BFloat16>(int M) {
return M > 4;
return brgemm_supported() && M > 4;
}
template <>
inline bool can_use_brgemm<at::Half>(int M) {
return true;
return brgemm_supported();
}
// this requires PyTorch 2.7 or above
template <>
inline bool can_use_brgemm<int8_t>(int M) {
return M > 4;
return brgemm_supported() && M > 4;
}
template <>
inline bool can_use_brgemm<uint8_t>(int M) {
return M > 4;
return brgemm_supported() && M > 4;
}
template <>
inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) {
return M > 4;
return brgemm_supported() && M > 4;
}
// work around compiler internal error
+2
View File
@@ -11,7 +11,9 @@
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#if defined(CPU_CAPABILITY_AVX512)
#include <immintrin.h>
#endif
namespace {
using namespace at::vec;
+19 -19
View File
@@ -447,6 +447,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"bool is_vnni) -> Tensor");
ops.impl("fp8_scaled_mm_cpu", torch::kCPU, &fp8_scaled_mm_cpu);
// Adapted from sglang: casual_conv1d kernels
ops.def("causal_conv1d_weight_pack(Tensor weight) -> Tensor");
ops.impl("causal_conv1d_weight_pack", torch::kCPU,
&causal_conv1d_weight_pack);
ops.def(
"causal_conv1d_fwd_cpu(Tensor x, Tensor weight, Tensor? bias, Tensor? "
"conv_states, Tensor? query_start_loc,"
"Tensor? cache_indices, Tensor? has_initial_state, bool silu_activation, "
"int pad_slot_id, bool is_vnni) -> "
"Tensor");
ops.impl("causal_conv1d_fwd_cpu", torch::kCPU, &causal_conv1d_fwd_cpu);
ops.def(
"causal_conv1d_update_cpu(Tensor x, Tensor(a!) conv_states, Tensor "
"weight, Tensor? bias, bool silu_activation,"
"Tensor? cache_seqlens, Tensor? conv_state_indices, int pad_slot_id, "
"bool is_vnni) -> Tensor");
ops.impl("causal_conv1d_update_cpu", torch::kCPU, &causal_conv1d_update_cpu);
#endif
// Adapted from sglang: GDN kernels
ops.def(
"chunk_gated_delta_rule_cpu(Tensor query, Tensor key, Tensor value, "
@@ -470,25 +489,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"-> (Tensor, Tensor)");
ops.impl("fused_gdn_gating_cpu", torch::kCPU, &fused_gdn_gating_cpu);
// Adapted from sglang: casual_conv1d kernels
ops.def("causal_conv1d_weight_pack(Tensor weight) -> Tensor");
ops.impl("causal_conv1d_weight_pack", torch::kCPU,
&causal_conv1d_weight_pack);
ops.def(
"causal_conv1d_fwd_cpu(Tensor x, Tensor weight, Tensor? bias, Tensor? "
"conv_states, Tensor? query_start_loc,"
"Tensor? cache_indices, Tensor? has_initial_state, bool silu_activation, "
"int pad_slot_id, bool is_vnni) -> "
"Tensor");
ops.impl("causal_conv1d_fwd_cpu", torch::kCPU, &causal_conv1d_fwd_cpu);
ops.def(
"causal_conv1d_update_cpu(Tensor x, Tensor(a!) conv_states, Tensor "
"weight, Tensor? bias, bool silu_activation,"
"Tensor? cache_seqlens, Tensor? conv_state_indices, int pad_slot_id, "
"bool is_vnni) -> Tensor");
ops.impl("causal_conv1d_update_cpu", torch::kCPU, &causal_conv1d_update_cpu);
#endif
// CPU attention kernels
ops.def(
"get_scheduler_metadata(int num_req, int num_heads_q, int num_heads_kv, "
+314
View File
@@ -0,0 +1,314 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import pytest
import torch
import torch.nn.functional as F
import vllm._custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
if not current_platform.is_cpu():
pytest.skip("skipping CPU-only tests", allow_module_level=True)
set_random_seed(12345)
NUM_HEADS = [
(2, 4),
(4, 4),
]
HEAD_DIMS = [
(32, 32),
(64, 32),
]
CHUNK_SIZE = 64
PREFILL_SEQ_LENS = [
[1],
[1, 2, 3],
[CHUNK_SIZE - 1],
[CHUNK_SIZE],
[CHUNK_SIZE + 1],
[CHUNK_SIZE - 1, CHUNK_SIZE, CHUNK_SIZE + 1],
[2 * CHUNK_SIZE - 1, 2 * CHUNK_SIZE, 2 * CHUNK_SIZE + 1],
[4 * CHUNK_SIZE + 17],
]
DECODE_BATCH_SIZES = [1, 3, 5]
@functools.lru_cache(maxsize=128, typed=False)
def tensor_cache(
elem_num: int,
dtype: torch.dtype,
) -> torch.Tensor:
tensor = torch.rand(elem_num, dtype=dtype)
return tensor
def ref_l2norm(
x: torch.Tensor,
dim: int = -1,
eps: float = 1e-5,
) -> torch.Tensor:
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return x * inv_norm
def ref_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
softplus_x = F.softplus(a.float() + dt_bias.float(), beta=1.0, threshold=20.0)
g = -torch.exp(A_log.float()) * softplus_x
beta = torch.sigmoid(b.float()).to(dtype=b.dtype)
return g, beta
def ref_gated_delta_rule(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
A_log: torch.Tensor,
dt_bias: torch.Tensor,
initial_state: torch.Tensor,
cu_seqlens: torch.Tensor,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
g, beta = ref_gdn_gating(A_log, a, b, dt_bias)
out = torch.empty_like(value)
final_state = torch.empty_like(initial_state)
for seq_idx in range(cu_seqlens.numel() - 1):
begin = int(cu_seqlens[seq_idx].item())
end = int(cu_seqlens[seq_idx + 1].item())
q_seq = query[:, begin:end]
k_seq = key[:, begin:end]
v_seq = value[:, begin:end]
g_seq = g[begin:end].unsqueeze(0)
beta_seq = beta[begin:end].unsqueeze(0)
initial_dtype = q_seq.dtype
if use_qk_l2norm_in_kernel:
q_seq = ref_l2norm(q_seq, dim=-1)
k_seq = ref_l2norm(k_seq, dim=-1)
if q_seq.shape[2] != v_seq.shape[2]:
repeat_factor = v_seq.shape[2] // q_seq.shape[2]
q_seq = q_seq.repeat_interleave(repeat_factor, dim=2)
k_seq = k_seq.repeat_interleave(repeat_factor, dim=2)
q_seq, k_seq, v_seq, beta_seq, g_seq = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (q_seq, k_seq, v_seq, beta_seq, g_seq)
]
batch_size, num_heads, seq_len, head_dim = q_seq.shape
v_head_dim = v_seq.shape[-1]
q_seq = q_seq * (1 / (head_dim**0.5))
out_seq = torch.empty(
batch_size,
num_heads,
seq_len,
v_head_dim,
dtype=v_seq.dtype,
)
state = initial_state[seq_idx : seq_idx + 1].to(v_seq)
for token_idx in range(seq_len):
q_t = q_seq[:, :, token_idx]
k_t = k_seq[:, :, token_idx]
v_t = v_seq[:, :, token_idx]
g_t = g_seq[:, :, token_idx].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta_seq[:, :, token_idx].unsqueeze(-1)
state = state * g_t
kv_mem = (state * k_t.unsqueeze(-2)).sum(dim=-1)
delta = (v_t - kv_mem) * beta_t
state = state + delta.unsqueeze(-1) * k_t.unsqueeze(-2)
out_seq[:, :, token_idx] = (state * q_t.unsqueeze(-2)).sum(dim=-1)
out[:, begin:end] = out_seq.transpose(1, 2).contiguous().to(initial_dtype)
final_state[seq_idx] = state.squeeze(0)
return out, final_state
def gdn_inputs(
num_tokens: int,
num_heads: tuple[int, int],
head_dims: tuple[int, int],
) -> tuple[torch.Tensor, ...]:
num_qk_heads, num_v_heads = num_heads
head_dim, v_head_dim = head_dims
q_shape = (1, num_tokens, num_qk_heads, head_dim)
q_numel = num_tokens * num_qk_heads * head_dim
q = tensor_cache(q_numel, torch.bfloat16).view(q_shape)
k = tensor_cache(q_numel, torch.bfloat16).view(q_shape)
v_shape = (1, num_tokens, num_v_heads, v_head_dim)
v = tensor_cache(num_tokens * num_v_heads * v_head_dim, torch.bfloat16).view(
v_shape
)
gate_shape = (num_tokens, num_v_heads)
gate_numel = num_tokens * num_v_heads
a = tensor_cache(gate_numel, torch.bfloat16).view(gate_shape)
b = tensor_cache(gate_numel, torch.bfloat16).view(gate_shape)
A_log = tensor_cache(num_v_heads, torch.float32)
dt_bias = tensor_cache(num_v_heads, torch.bfloat16)
return q, k, v, a, b, A_log, dt_bias
@pytest.mark.parametrize("num_tokens", [1, 9])
@pytest.mark.parametrize("num_v_heads", [4, 8])
@torch.inference_mode()
def test_fused_gdn_gating_cpu(
num_tokens: int,
num_v_heads: int,
) -> None:
gate_shape = (num_tokens, num_v_heads)
gate_numel = num_tokens * num_v_heads
a = tensor_cache(gate_numel, torch.bfloat16).view(gate_shape)
b = tensor_cache(gate_numel, torch.bfloat16).view(gate_shape)
A_log = tensor_cache(num_v_heads, torch.float32)
dt_bias = tensor_cache(num_v_heads, torch.bfloat16)
g_ref, beta_ref = ref_gdn_gating(A_log, a, b, dt_bias)
g, beta = ops.fused_gdn_gating_cpu(A_log, a, b, dt_bias)
torch.testing.assert_close(g, g_ref.unsqueeze(0), atol=1e-4, rtol=1e-4)
torch.testing.assert_close(
beta.float(), beta_ref.unsqueeze(0).float(), atol=5e-3, rtol=5e-3
)
# decode path
@pytest.mark.parametrize("batch_size", DECODE_BATCH_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_dims", HEAD_DIMS)
@torch.inference_mode()
def test_fused_sigmoid_gating_delta_rule_update_cpu(
batch_size: int,
num_heads: tuple[int, int],
head_dims: tuple[int, int],
) -> None:
q, k, v, a, b, A_log, dt_bias = gdn_inputs(
num_tokens=batch_size,
num_heads=num_heads,
head_dims=head_dims,
)
_, num_v_heads = num_heads
head_dim, v_head_dim = head_dims
state_indices = torch.arange(batch_size, dtype=torch.int32)
cu_seqlens = torch.arange(batch_size + 1, dtype=torch.int32)
state_shape = (batch_size, num_v_heads, head_dim, v_head_dim)
state = tensor_cache(
batch_size * num_v_heads * head_dim * v_head_dim, torch.float32
).view(state_shape)
state_ref = state[state_indices].transpose(-1, -2).contiguous()
out_ref, final_state_ref = ref_gated_delta_rule(
query=q,
key=k,
value=v,
a=a,
b=b,
A_log=A_log,
dt_bias=dt_bias,
initial_state=state_ref,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
out_ref = out_ref.transpose(0, 1).contiguous()
state_out = state.clone()
out = ops.fused_sigmoid_gating_delta_rule_update_cpu(
A_log=A_log,
dt_bias=dt_bias,
q=q,
k=k,
v=v,
a=a,
b=b,
initial_state_source=state_out,
initial_state_indices=state_indices,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(
state_out[state_indices].transpose(-1, -2),
final_state_ref,
atol=1e-2,
rtol=1e-2,
)
# prefill path
@pytest.mark.parametrize("seq_lens", PREFILL_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_dims", HEAD_DIMS)
@torch.inference_mode()
def test_chunk_gated_delta_rule_cpu(
seq_lens: list[int],
num_heads: tuple[int, int],
head_dims: tuple[int, int],
) -> None:
total_tokens = sum(seq_lens)
q, k, v, a, b, A_log, dt_bias = gdn_inputs(
num_tokens=total_tokens,
num_heads=num_heads,
head_dims=head_dims,
)
_, num_v_heads = num_heads
head_dim, v_head_dim = head_dims
cu_seqlens = torch.tensor(
[0, *torch.tensor(seq_lens).cumsum(0).tolist()], dtype=torch.int32
)
initial_state_shape = (len(seq_lens), num_v_heads, head_dim, v_head_dim)
initial_state = tensor_cache(
len(seq_lens) * num_v_heads * head_dim * v_head_dim, torch.float32
).view(initial_state_shape)
initial_state_ref = initial_state.transpose(-1, -2).contiguous()
out_ref, final_state_ref = ref_gated_delta_rule(
query=q,
key=k,
value=v,
a=a,
b=b,
A_log=A_log,
dt_bias=dt_bias,
initial_state=initial_state_ref,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
g, beta = ref_gdn_gating(A_log, a, b, dt_bias)
out, final_state = ops.chunk_gated_delta_rule_cpu(
query=q,
key=k,
value=v,
g=g.unsqueeze(0),
beta=beta.unsqueeze(0),
initial_state=initial_state,
output_final_state=True,
cu_seqlens=cu_seqlens,
head_first=False,
use_qk_l2norm_in_kernel=True,
)
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(
final_state.transpose(-1, -2),
final_state_ref,
atol=1e-2,
rtol=1e-2,
)
@@ -12,11 +12,6 @@ from vllm.model_executor.layers.mamba.ops.cpu.causal_conv1d import (
causal_conv1d_torch,
causal_conv1d_update_torch,
)
from vllm.model_executor.layers.mamba.ops.cpu.recurrent_gated_delta_rule import (
chunk_gated_delta_rule,
gdn_gating,
recurrent_gated_delta_rule,
)
from vllm.utils.torch_utils import (
LayerNameType,
_resolve_layer_name,
@@ -55,88 +50,91 @@ def cpu_gdn_attention_core(
attn_metadata_i.spec_sequence_masks is None
and attn_metadata_i.num_accepted_tokens is None
), "speculative decode not supported in CPU GDN attention."
if torch.cpu._is_amx_tile_supported():
return cpu_gdn_attention_core_amx(
mixed_qkv,
b,
a,
core_attn_out,
attn_metadata_i,
layer,
)
assert mixed_qkv.dtype == torch.bfloat16, "CPU GDN attention requires BF16."
state_indices_tensor = attn_metadata_i.non_spec_state_indices_tensor
query_start_loc = attn_metadata_i.non_spec_query_start_loc
assert state_indices_tensor is not None
assert query_start_loc is not None
# [num_allocated_slots, conv_dim, kernel - 1]
is_amx = torch.cpu._is_amx_tile_supported()
conv_state = layer.kv_cache[0]
if not is_conv_state_dim_first():
conv_state = conv_state.transpose(-1, -2)
if is_amx:
# AMX causal conv requires [num_allocated_slots, kernel - 1, conv_dim].
if is_conv_state_dim_first():
raise RuntimeError("AMX GDN attention requires `SD` conv_state layout.")
conv_state = conv_state.transpose(1, 2)
else:
if not is_conv_state_dim_first():
conv_state = conv_state.transpose(-1, -2)
conv_weights = layer.conv1d.weight.view(
layer.conv1d.weight.size(0), layer.conv1d.weight.size(2)
)
# [num_allocated_slots, num_v_heads / tp_size, v_dim, k_dim]
ssm_state = layer.kv_cache[1]
mixed_qkv = mixed_qkv.contiguous()
a = a.contiguous()
b = b.contiguous()
num_allocated_slots, head_num, v_dim, k_dim = ssm_state.size()
ssm_state = ssm_state.view(
num_allocated_slots,
head_num,
k_dim,
v_dim,
)
num_decodes = attn_metadata_i.num_decodes
num_decode_tokens = attn_metadata_i.num_decode_tokens
num_prefills = attn_metadata_i.num_prefills
num_prefill_tokens = attn_metadata_i.num_prefill_tokens
conv_weights = layer.conv1d.weight.view(
layer.conv1d.weight.size(0), layer.conv1d.weight.size(2)
)
# all decode requests (batched)
if num_decodes > 0:
decode_mixed_qkv = mixed_qkv[:num_decode_tokens]
decode_b = b[:num_decode_tokens]
decode_a = a[:num_decode_tokens]
decode_state_indices = state_indices_tensor[:num_decodes]
decode_conv_state = conv_state[decode_state_indices].contiguous()
if is_amx:
decode_mixed_qkv = ops.causal_conv1d_update_cpu(
x=decode_mixed_qkv,
conv_states=conv_state,
weight=layer.conv1d.weight,
bias=layer.conv1d.bias,
silu_activation=layer.activation == "silu",
conv_state_indices=decode_state_indices,
is_vnni=True,
)
else:
decode_conv_state = conv_state[decode_state_indices].contiguous()
decode_mixed_qkv = causal_conv1d_update_torch(
# [B, dim] -> [B, dim, 1]
x=decode_mixed_qkv.unsqueeze(-1),
conv_state=decode_conv_state,
weight=conv_weights,
bias=layer.conv1d.bias,
activation=layer.activation,
).squeeze(-1)
conv_state[decode_state_indices] = decode_conv_state
decode_mixed_qkv = causal_conv1d_update_torch(
# [B, dim] -> [B, dim, 1]
x=decode_mixed_qkv.unsqueeze(-1),
conv_state=decode_conv_state,
weight=conv_weights,
bias=layer.conv1d.bias,
activation=layer.activation,
).squeeze(-1)
conv_state[decode_state_indices] = decode_conv_state
query, key, value = layer.rearrange_mixed_qkv(decode_mixed_qkv)
# [1, L, H, D] -> [B, 1, H, D] for batched decode
query = query.transpose(0, 1).contiguous()
key = key.transpose(0, 1).contiguous()
value = value.transpose(0, 1).contiguous()
g, beta_output = gdn_gating(
attn_out = ops.fused_sigmoid_gating_delta_rule_update_cpu(
A_log=layer.A_log,
dt_bias=layer.dt_bias,
q=query,
k=key,
v=value,
a=decode_a,
b=decode_b,
dt_bias=layer.dt_bias,
)
if g.ndim == 2:
g = g.unsqueeze(1)
beta_output = beta_output.unsqueeze(1)
initial_state = ssm_state[decode_state_indices].contiguous()
attn_out, last_recurrent_state = recurrent_gated_delta_rule(
query=query,
key=key,
value=value,
g=g,
beta=beta_output,
initial_state=initial_state,
scale=None,
initial_state_source=ssm_state,
initial_state_indices=decode_state_indices,
cu_seqlens=query_start_loc[: num_decodes + 1],
use_qk_l2norm_in_kernel=True,
)
ssm_state[decode_state_indices] = last_recurrent_state.to(
ssm_state.dtype
).contiguous()
core_attn_out[:num_decode_tokens] = attn_out.squeeze(1)
# all prefill requests: (varlen) currently naively loops over sequences
@@ -160,154 +158,29 @@ def cpu_gdn_attention_core(
num_decodes : num_decodes + num_prefills
]
prefill_mixed_qkv = causal_conv1d_torch(
x=prefill_mixed_qkv.transpose(0, 1),
weight=conv_weights,
bias=layer.conv1d.bias,
conv_states=conv_state,
query_start_loc=prefill_query_start_loc,
cache_indices=prefill_state_indices,
has_initial_state=prefill_has_initial_state,
activation=layer.activation,
).transpose(0, 1)
query, key, value = layer.rearrange_mixed_qkv(prefill_mixed_qkv)
g, beta = gdn_gating(layer.A_log, prefill_a, prefill_b, layer.dt_bias)
if g.ndim == 2:
g = g.unsqueeze(0)
beta = beta.unsqueeze(0)
initial_state = ssm_state[prefill_state_indices].contiguous()
initial_state[~prefill_has_initial_state, ...] = 0
attn_out, last_recurrent_state = chunk_gated_delta_rule(
q=query,
k=key,
v=value,
g=g,
beta=beta,
scale=None,
initial_state=initial_state,
cu_seqlens=prefill_query_start_loc,
use_qk_l2norm_in_kernel=True,
)
ssm_state[prefill_state_indices] = last_recurrent_state.to(ssm_state.dtype)
core_attn_out[prefill_token_start:prefill_token_end] = attn_out.squeeze(0)
def cpu_gdn_attention_core_fake(
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
layer_name: LayerNameType,
) -> None:
"""Fake implementation for torch.compile."""
return
def cpu_gdn_attention_core_amx(
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
attn_metadata_i: GDNAttentionMetadata,
layer: torch.nn.Module,
):
state_indices_tensor = attn_metadata_i.non_spec_state_indices_tensor
query_start_loc = attn_metadata_i.non_spec_query_start_loc
assert state_indices_tensor is not None
assert query_start_loc is not None
# [num_allocated_slots, kernel - 1, conv_dim]
conv_state = layer.kv_cache[0]
if is_conv_state_dim_first():
raise RuntimeError("AMX GDN attention requires `SD` conv_state layout.")
# reshape to [num_allocated_slots, conv_dim, kernel - 1]
conv_state_t = conv_state.transpose(1, 2)
# [num_allocated_slots, num_v_heads / tp_size, v_dim, k_dim]
ssm_state = layer.kv_cache[1]
# rehape to [num_allocated_slots, num_v_heads / tp_size, k_dim, v_dim]
num_allocated_slots, head_num, v_dim, k_dim = ssm_state.size()
ssm_state = ssm_state.view(
num_allocated_slots,
head_num,
k_dim,
v_dim,
)
mixed_qkv = mixed_qkv.contiguous()
a = a.contiguous()
b = b.contiguous()
num_decodes = attn_metadata_i.num_decodes
num_decode_tokens = attn_metadata_i.num_decode_tokens
num_prefills = attn_metadata_i.num_prefills
num_prefill_tokens = attn_metadata_i.num_prefill_tokens
if num_decodes > 0:
decode_mixed_qkv = mixed_qkv[:num_decode_tokens]
decode_b = b[:num_decode_tokens]
decode_a = a[:num_decode_tokens]
decode_state_indices = state_indices_tensor[:num_decodes]
decode_mixed_qkv = ops.causal_conv1d_update_cpu(
x=decode_mixed_qkv,
conv_states=conv_state_t,
weight=layer.conv1d.weight,
bias=layer.conv1d.bias,
silu_activation=layer.activation == "silu",
conv_state_indices=decode_state_indices,
is_vnni=True,
)
query, key, value = layer.rearrange_mixed_qkv(decode_mixed_qkv)
attn_out = ops.fused_sigmoid_gating_delta_rule_update_cpu(
A_log=layer.A_log,
dt_bias=layer.dt_bias,
q=query,
k=key,
v=value,
a=decode_a,
b=decode_b,
initial_state_source=ssm_state,
initial_state_indices=decode_state_indices,
cu_seqlens=query_start_loc[: num_decodes + 1],
use_qk_l2norm_in_kernel=True,
)
core_attn_out[:num_decode_tokens] = attn_out.squeeze(1)
if num_prefills > 0:
has_initial_state = attn_metadata_i.has_initial_state
assert has_initial_state is not None
prefill_token_start = num_decode_tokens
prefill_token_end = prefill_token_start + num_prefill_tokens
prefill_mixed_qkv = mixed_qkv[prefill_token_start:prefill_token_end]
prefill_b = b[prefill_token_start:prefill_token_end]
prefill_a = a[prefill_token_start:prefill_token_end]
prefill_state_indices = state_indices_tensor[
num_decodes : num_decodes + num_prefills
]
prefill_query_start_loc = (
query_start_loc[num_decodes : num_decodes + num_prefills + 1]
- num_decode_tokens
)
prefill_has_initial_state = has_initial_state[
num_decodes : num_decodes + num_prefills
]
prefill_mixed_qkv = ops.causal_conv1d_fwd_cpu(
x=prefill_mixed_qkv.transpose(0, 1),
weight=layer.conv1d.weight,
bias=layer.conv1d.bias,
conv_states=conv_state_t,
query_start_loc=prefill_query_start_loc,
cache_indices=prefill_state_indices,
has_initial_state=prefill_has_initial_state,
silu_activation=layer.activation == "silu",
is_vnni=True,
).transpose(0, 1)
if is_amx:
prefill_mixed_qkv = ops.causal_conv1d_fwd_cpu(
x=prefill_mixed_qkv.transpose(0, 1),
weight=layer.conv1d.weight,
bias=layer.conv1d.bias,
conv_states=conv_state,
query_start_loc=prefill_query_start_loc,
cache_indices=prefill_state_indices,
has_initial_state=prefill_has_initial_state,
silu_activation=layer.activation == "silu",
is_vnni=True,
).transpose(0, 1)
else:
prefill_mixed_qkv = causal_conv1d_torch(
x=prefill_mixed_qkv.transpose(0, 1),
weight=conv_weights,
bias=layer.conv1d.bias,
conv_states=conv_state,
query_start_loc=prefill_query_start_loc,
cache_indices=prefill_state_indices,
has_initial_state=prefill_has_initial_state,
activation=layer.activation,
).transpose(0, 1)
query, key, value = layer.rearrange_mixed_qkv(prefill_mixed_qkv)
g, beta = ops.fused_gdn_gating_cpu(
@@ -334,6 +207,17 @@ def cpu_gdn_attention_core_amx(
core_attn_out[prefill_token_start:prefill_token_end] = attn_out.squeeze(0)
def cpu_gdn_attention_core_fake(
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
layer_name: LayerNameType,
) -> None:
"""Fake implementation for torch.compile."""
return
def register_cpu_gdn_attention_ops() -> None:
global _CPU_GDN_ATTENTION_OPS_REGISTERED
if _CPU_GDN_ATTENTION_OPS_REGISTERED:
@@ -1,223 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn.functional as F
def l2norm(
x: torch.Tensor,
dim: int = -1,
eps: float = 1e-6,
) -> torch.Tensor:
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return x * inv_norm
def recurrent_gated_delta_rule(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
initial_state: torch.Tensor,
scale: float | None = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
if query.shape[2] != value.shape[2]:
repeat_factor = value.shape[2] // query.shape[2]
query = query.repeat_interleave(repeat_factor, dim=2)
key = key.repeat_interleave(repeat_factor, dim=2)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
batch_size, num_heads, sequence_length, _ = key.shape
v_head_dim = value.shape[-1]
if scale is None:
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
core_attn_out = torch.empty(
batch_size,
num_heads,
sequence_length,
v_head_dim,
dtype=value.dtype,
)
last_recurrent_state = initial_state.to(value)
for token_idx in range(sequence_length):
q_t = query[:, :, token_idx]
k_t = key[:, :, token_idx]
v_t = value[:, :, token_idx]
g_t = g[:, :, token_idx].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, token_idx].unsqueeze(-1)
last_recurrent_state = last_recurrent_state * g_t
kv_mem = (last_recurrent_state * k_t.unsqueeze(-2)).sum(dim=-1)
delta = (v_t - kv_mem) * beta_t
last_recurrent_state = last_recurrent_state + delta.unsqueeze(
-1
) * k_t.unsqueeze(-2)
core_attn_out[:, :, token_idx] = (last_recurrent_state * q_t.unsqueeze(-2)).sum(
dim=-1
)
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
def gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> tuple[torch.Tensor, torch.Tensor]:
softplus_x = F.softplus(a.float() + dt_bias.float(), beta=beta, threshold=threshold)
g = -torch.exp(A_log.float()) * softplus_x
beta_output = torch.sigmoid(b.float()).to(dtype=b.dtype)
return g, beta_output
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
*,
initial_state: torch.Tensor,
scale: float | None = None,
cu_seqlens: torch.Tensor,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(v)
state_dtype = initial_state.dtype
chunk_size = 128
sequence_bounds = [
(
seq_idx,
int(cu_seqlens[seq_idx].item()),
int(cu_seqlens[seq_idx + 1].item()),
)
for seq_idx in range(len(cu_seqlens) - 1)
]
chunk_eye = torch.eye(chunk_size, dtype=torch.float32)
num_sequences = len(sequence_bounds)
num_value_heads = v.shape[2]
value_head_dim = v.shape[3]
key_head_dim = k.shape[3]
final_state = torch.empty(
(num_sequences, num_value_heads, value_head_dim, key_head_dim),
dtype=state_dtype,
)
for seq_idx, begin, end in sequence_bounds:
q_seq = q[:, begin:end]
k_seq = k[:, begin:end]
v_seq = v[:, begin:end]
g_seq = g[:, begin:end]
beta_seq = beta[:, begin:end]
initial_dtype = q_seq.dtype
if use_qk_l2norm_in_kernel:
q_seq = l2norm(q_seq, dim=-1, eps=1e-6)
k_seq = l2norm(k_seq, dim=-1, eps=1e-6)
num_qk_heads = q_seq.shape[2]
num_value_heads = v_seq.shape[2]
if num_qk_heads != num_value_heads:
repeat_factor = num_value_heads // num_qk_heads
q_seq = q_seq.repeat_interleave(repeat_factor, dim=2)
k_seq = k_seq.repeat_interleave(repeat_factor, dim=2)
q_seq, k_seq, v_seq, beta_seq, g_seq = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (q_seq, k_seq, v_seq, beta_seq, g_seq)
]
seq_batch_size, num_heads, seq_len, qk_head_dim = q_seq.shape
value_head_dim = v_seq.shape[-1]
if scale is None:
scale = 1 / (qk_head_dim**0.5)
q_seq = q_seq * scale
seq_state = initial_state[seq_idx : seq_idx + 1].to(v_seq)
seq_output = torch.empty(
seq_batch_size,
num_heads,
seq_len,
value_head_dim,
dtype=v_seq.dtype,
)
for chunk_start in range(0, seq_len, chunk_size):
chunk_end = min(chunk_start + chunk_size, seq_len)
q_chunk = q_seq[:, :, chunk_start:chunk_end]
k_chunk = k_seq[:, :, chunk_start:chunk_end]
v_chunk = v_seq[:, :, chunk_start:chunk_end]
beta_chunk = beta_seq[:, :, chunk_start:chunk_end]
g_chunk = g_seq[:, :, chunk_start:chunk_end]
chunk_len = chunk_end - chunk_start
cum_g = g_chunk.cumsum(dim=-1)
exp_cum_g = cum_g.exp()
decay = (cum_g.unsqueeze(-1) - cum_g.unsqueeze(-2)).exp()
interaction = (k_chunk * beta_chunk.unsqueeze(-1)) @ k_chunk.transpose(
-1, -2
)
interaction = torch.tril(interaction * decay, diagonal=-1)
system = interaction + chunk_eye[:chunk_len, :chunk_len]
solved_values = torch.linalg.solve_triangular(
system,
v_chunk * beta_chunk.unsqueeze(-1),
upper=False,
)
solved_keys = torch.linalg.solve_triangular(
system,
(k_chunk * beta_chunk.unsqueeze(-1)) * exp_cum_g.unsqueeze(-1),
upper=False,
)
incoming_memory = torch.einsum("bhvk,bhck->bhcv", seq_state, solved_keys)
transformed_values = solved_values - incoming_memory
# Each chunk contributes both from the incoming recurrent state and
# from its own in-chunk interactions.
inter_chunk = torch.einsum(
"bhvk,bhck->bhcv",
seq_state,
q_chunk * exp_cum_g.unsqueeze(-1),
)
intra_chunk = torch.tril((q_chunk @ k_chunk.transpose(-1, -2)) * decay)
seq_output[:, :, chunk_start:chunk_end] = (
inter_chunk + intra_chunk @ transformed_values
)
# Carry the recurrent state forward to the next chunk boundary.
end_decay = (cum_g[:, :, -1:] - cum_g).exp().unsqueeze(-1)
decayed_keys = k_chunk * end_decay
seq_state = seq_state * exp_cum_g[:, :, -1, None, None] + torch.einsum(
"bhcv,bhck->bhvk", transformed_values, decayed_keys
)
output[0, begin:end].copy_(
seq_output.transpose(1, 2).contiguous().to(initial_dtype).squeeze(0)
)
final_state[seq_idx].copy_(seq_state.squeeze(0).to(state_dtype).contiguous())
return output, final_state