mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[CPU][Perf] Enable fused kernels for GDN's gated delta rules (#43534)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com>
This commit is contained in:
@@ -16,6 +16,7 @@ steps:
|
||||
- tests/kernels/test_onednn.py
|
||||
- tests/kernels/test_awq_int4_to_int8.py
|
||||
- tests/kernels/quantization/test_cpu_fp8_scaled_mm.py
|
||||
- tests/kernels/mamba/cpu/test_cpu_gdn_ops.py
|
||||
commands:
|
||||
- |
|
||||
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 30m "
|
||||
@@ -24,7 +25,8 @@ steps:
|
||||
pytest -x -v -s tests/kernels/moe/test_cpu_quant_fused_moe.py
|
||||
pytest -x -v -s tests/kernels/test_onednn.py
|
||||
pytest -x -v -s tests/kernels/test_awq_int4_to_int8.py
|
||||
pytest -x -v -s tests/kernels/quantization/test_cpu_fp8_scaled_mm.py"
|
||||
pytest -x -v -s tests/kernels/quantization/test_cpu_fp8_scaled_mm.py
|
||||
pytest -x -v -s tests/kernels/mamba/cpu/test_cpu_gdn_ops.py"
|
||||
|
||||
- label: CPU-Compatibility Tests
|
||||
depends_on: []
|
||||
|
||||
@@ -37,7 +37,8 @@ function cpu_tests() {
|
||||
pytest -x -v -s tests/kernels/test_onednn.py
|
||||
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
|
||||
pytest -x -v -s tests/kernels/core/test_cpu_activation.py
|
||||
pytest -x -v -s tests/kernels/moe/test_moe.py -k test_cpu_fused_moe_basic"
|
||||
pytest -x -v -s tests/kernels/moe/test_moe.py -k test_cpu_fused_moe_basic
|
||||
pytest -x -v -s tests/kernels/mamba/cpu/test_cpu_gdn_ops.py"
|
||||
|
||||
# skip tests requiring model downloads if HF_TOKEN is not set
|
||||
# due to rate-limits
|
||||
|
||||
@@ -369,6 +369,18 @@ else()
|
||||
add_compile_definitions(-DVLLM_NUMA_DISABLED)
|
||||
endif()
|
||||
|
||||
# check if the pytorch wheel ships libopenblas.so.
|
||||
set(VLLM_OPENBLAS_LIB "")
|
||||
if (NOT ENABLE_X86_ISA)
|
||||
file(GLOB _VLLM_TORCH_OPENBLAS_LIBS
|
||||
"${TORCH_INSTALL_PREFIX}/lib/libopenblas*.so*")
|
||||
# Note: we don't link openblas directly to _C extension, as it's available through libtorch.so
|
||||
if (_VLLM_TORCH_OPENBLAS_LIBS)
|
||||
list(GET _VLLM_TORCH_OPENBLAS_LIBS 0 VLLM_OPENBLAS_LIB)
|
||||
message(STATUS "CPU OpenBLAS library: ${VLLM_OPENBLAS_LIB}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Generate CPU attention dispatch header
|
||||
#
|
||||
@@ -387,6 +399,7 @@ endif()
|
||||
#
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/activation.cpp"
|
||||
"csrc/cpu/sgl-kernels/fla.cpp"
|
||||
"csrc/cpu/utils.cpp"
|
||||
"csrc/cpu/spec_decode_utils.cpp"
|
||||
"csrc/cpu/layernorm.cpp"
|
||||
@@ -418,7 +431,6 @@ endif()
|
||||
|
||||
if (ENABLE_X86_ISA)
|
||||
set(VLLM_EXT_SRC_SGL
|
||||
"csrc/cpu/sgl-kernels/fla.cpp"
|
||||
"csrc/cpu/sgl-kernels/conv.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
|
||||
@@ -430,6 +442,7 @@ if (ENABLE_X86_ISA)
|
||||
"csrc/cpu/sgl-kernels/moe_fp8.cpp")
|
||||
|
||||
set(VLLM_EXT_SRC_AVX512
|
||||
"csrc/cpu/sgl-kernels/fla.cpp"
|
||||
"csrc/cpu/shm.cpp"
|
||||
"csrc/cpu/cpu_wna16.cpp"
|
||||
"csrc/cpu/cpu_fused_moe.cpp"
|
||||
@@ -446,6 +459,7 @@ if (ENABLE_X86_ISA)
|
||||
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
|
||||
|
||||
set(VLLM_EXT_SRC_AVX2
|
||||
"csrc/cpu/sgl-kernels/fla.cpp"
|
||||
"csrc/cpu/utils.cpp"
|
||||
"csrc/cpu/spec_decode_utils.cpp"
|
||||
"csrc/cpu/cpu_attn.cpp"
|
||||
@@ -519,6 +533,9 @@ else()
|
||||
USE_SABI 3
|
||||
WITH_SOABI
|
||||
)
|
||||
if (VLLM_OPENBLAS_LIB)
|
||||
target_compile_definitions(_C PRIVATE VLLM_HAS_OPENBLAS)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message(STATUS "Enabling C extension.")
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
|
||||
// Unlike brgemm, PyTorch does not publicly expose at::native::cpublas::gemm
|
||||
// If OpenBLS is available in the PyTorch wheel, we rely on it for fast
|
||||
// bf16:bf16->fp32 GEMMs Otherwise, we fall back to PyTorch reference BLAS path.
|
||||
#if defined(VLLM_HAS_OPENBLAS)
|
||||
extern "C" void sbgemm_(char* transa, char* transb, int* m, int* n, int* k,
|
||||
float* alpha, const at::BFloat16* a, int* lda,
|
||||
const at::BFloat16* b, int* ldb, float* beta, float* c,
|
||||
int* ldc);
|
||||
|
||||
extern "C" void sgemm_(char* transa, char* transb, int* m, int* n, int* k,
|
||||
float* alpha, const float* a, int* lda, const float* b,
|
||||
int* ldb, float* beta, float* c, int* ldc);
|
||||
|
||||
inline char blas_transpose(at::native::TransposeType trans) {
|
||||
switch (trans) {
|
||||
case at::native::TransposeType::NoTranspose:
|
||||
return 'n';
|
||||
case at::native::TransposeType::Transpose:
|
||||
return 't';
|
||||
case at::native::TransposeType::ConjTranspose:
|
||||
return 'c';
|
||||
}
|
||||
return 'n';
|
||||
}
|
||||
|
||||
inline void blas_gemm(at::native::TransposeType transa,
|
||||
at::native::TransposeType transb, int64_t m, int64_t n,
|
||||
int64_t k, float alpha, const at::BFloat16* a,
|
||||
int64_t lda, const at::BFloat16* b, int64_t ldb,
|
||||
float beta, float* c, int64_t ldc) {
|
||||
char transa_ = blas_transpose(transa);
|
||||
char transb_ = blas_transpose(transb);
|
||||
int m_ = static_cast<int>(m);
|
||||
int n_ = static_cast<int>(n);
|
||||
int k_ = static_cast<int>(k);
|
||||
int lda_ = static_cast<int>(lda);
|
||||
int ldb_ = static_cast<int>(ldb);
|
||||
int ldc_ = static_cast<int>(ldc);
|
||||
sbgemm_(&transa_, &transb_, &m_, &n_, &k_, &alpha, a, &lda_, b, &ldb_, &beta,
|
||||
c, &ldc_);
|
||||
}
|
||||
|
||||
inline void blas_gemm(at::native::TransposeType transa,
|
||||
at::native::TransposeType transb, int64_t m, int64_t n,
|
||||
int64_t k, float alpha, const float* a, int64_t lda,
|
||||
const float* b, int64_t ldb, float beta, float* c,
|
||||
int64_t ldc) {
|
||||
char transa_ = blas_transpose(transa);
|
||||
char transb_ = blas_transpose(transb);
|
||||
int m_ = static_cast<int>(m);
|
||||
int n_ = static_cast<int>(n);
|
||||
int k_ = static_cast<int>(k);
|
||||
int lda_ = static_cast<int>(lda);
|
||||
int ldb_ = static_cast<int>(ldb);
|
||||
int ldc_ = static_cast<int>(ldc);
|
||||
sgemm_(&transa_, &transb_, &m_, &n_, &k_, &alpha, a, &lda_, b, &ldb_, &beta,
|
||||
c, &ldc_);
|
||||
}
|
||||
|
||||
inline void blas_gemm(at::native::TransposeType, at::native::TransposeType,
|
||||
int64_t, int64_t, int64_t, float, const at::Half*,
|
||||
int64_t, const at::Half*, int64_t, float, float*,
|
||||
int64_t) {
|
||||
TORCH_CHECK(false, "CPU OpenBLAS hgemm is not available.");
|
||||
}
|
||||
#else
|
||||
template <typename scalar_t>
|
||||
inline void blas_gemm(at::native::TransposeType transa,
|
||||
at::native::TransposeType transb, int64_t m, int64_t n,
|
||||
int64_t k, float alpha, const scalar_t* a, int64_t lda,
|
||||
const scalar_t* b, int64_t ldb, float beta, float* c,
|
||||
int64_t ldc) {
|
||||
auto gemm = at::native::cpublas::gemm_no_downcast_stub.DEFAULT;
|
||||
gemm(c10::CppTypeToScalarType<scalar_t>::value, transa, transb, m, n, k,
|
||||
at::Scalar(alpha), a, lda, b, ldb, at::Scalar(beta), c, ldc);
|
||||
}
|
||||
#endif
|
||||
+278
-141
@@ -301,25 +301,42 @@ void chunk_gated_delta_rule_kernel_impl(
|
||||
// attn = k_beta @ key.transpose(-1, -2)
|
||||
// attn: [B, HV, num_chunk, chunk_size, chunk_size]
|
||||
// transpose and pack for key
|
||||
pack_vnni<scalar_t>(
|
||||
/* dst */ k_transpose,
|
||||
/* src */ curr_k_pad,
|
||||
/* N */ chunk_size,
|
||||
/* K */ qk_head_size,
|
||||
/* ld_src */ qk_head_size,
|
||||
/* ld_dst */ chunk_size);
|
||||
// k_beta @ key.transpose(-1, -2)
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ chunk_size,
|
||||
/* N */ chunk_size,
|
||||
/* K */ qk_head_size,
|
||||
/* lda */ qk_head_size,
|
||||
/* ldb */ chunk_size,
|
||||
/* ldc */ chunk_size,
|
||||
/* add_C */ false,
|
||||
/* A */ curr_k_beta,
|
||||
/* B */ k_transpose,
|
||||
/* C */ curr_attn);
|
||||
if constexpr (brgemm_supported()) {
|
||||
pack_vnni<scalar_t>(
|
||||
/* dst */ k_transpose,
|
||||
/* src */ curr_k_pad,
|
||||
/* N */ chunk_size,
|
||||
/* K */ qk_head_size,
|
||||
/* ld_src */ qk_head_size,
|
||||
/* ld_dst */ chunk_size);
|
||||
// k_beta @ key.transpose(-1, -2)
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ chunk_size,
|
||||
/* N */ chunk_size,
|
||||
/* K */ qk_head_size,
|
||||
/* lda */ qk_head_size,
|
||||
/* ldb */ chunk_size,
|
||||
/* ldc */ chunk_size,
|
||||
/* add_C */ false,
|
||||
/* A */ curr_k_beta,
|
||||
/* B */ k_transpose,
|
||||
/* C */ curr_attn);
|
||||
} else {
|
||||
blas_gemm(
|
||||
at::native::TransposeType::Transpose,
|
||||
at::native::TransposeType::NoTranspose,
|
||||
chunk_size,
|
||||
chunk_size,
|
||||
qk_head_size,
|
||||
1.0f,
|
||||
curr_k_pad,
|
||||
qk_head_size,
|
||||
curr_k_beta,
|
||||
qk_head_size,
|
||||
0.0f,
|
||||
curr_attn,
|
||||
chunk_size);
|
||||
}
|
||||
// attn = attn * decay_mask
|
||||
for (int64_t m = 0; m < chunk_size; m++) {
|
||||
at::vec::map2<float>(
|
||||
@@ -413,25 +430,42 @@ void chunk_gated_delta_rule_kernel_impl(
|
||||
// k_beta_g = k_beta * g: [B, HV, num_chunk, chunk_size, EK]
|
||||
// k_cumdecay: [B, HV, num_chunk, chunk_size, EK]
|
||||
// pack for value
|
||||
pack_vnni2<scalar_t>(
|
||||
/* dst */ v_pack,
|
||||
/* src */ curr_v_beta,
|
||||
/* N */ chunk_size,
|
||||
/* K */ v_head_size,
|
||||
/* ld_src */ v_head_size,
|
||||
/* ld_dst */ v_head_size);
|
||||
// value = attn @ v_beta
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ chunk_size,
|
||||
/* N */ v_head_size,
|
||||
/* K */ chunk_size,
|
||||
/* lda */ chunk_size,
|
||||
/* ldb */ v_head_size,
|
||||
/* ldc */ v_head_size,
|
||||
/* add_C */ false,
|
||||
/* A */ curr_attn_reduced,
|
||||
/* B */ v_pack,
|
||||
/* C */ curr_value);
|
||||
if constexpr (brgemm_supported()) {
|
||||
pack_vnni2<scalar_t>(
|
||||
/* dst */ v_pack,
|
||||
/* src */ curr_v_beta,
|
||||
/* N */ chunk_size,
|
||||
/* K */ v_head_size,
|
||||
/* ld_src */ v_head_size,
|
||||
/* ld_dst */ v_head_size);
|
||||
// value = attn @ v_beta
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ chunk_size,
|
||||
/* N */ v_head_size,
|
||||
/* K */ chunk_size,
|
||||
/* lda */ chunk_size,
|
||||
/* ldb */ v_head_size,
|
||||
/* ldc */ v_head_size,
|
||||
/* add_C */ false,
|
||||
/* A */ curr_attn_reduced,
|
||||
/* B */ v_pack,
|
||||
/* C */ curr_value);
|
||||
} else {
|
||||
blas_gemm(
|
||||
at::native::TransposeType::NoTranspose,
|
||||
at::native::TransposeType::NoTranspose,
|
||||
v_head_size,
|
||||
chunk_size,
|
||||
chunk_size,
|
||||
1.0f,
|
||||
curr_v_beta,
|
||||
v_head_size,
|
||||
curr_attn_reduced,
|
||||
chunk_size,
|
||||
0.0f,
|
||||
curr_value,
|
||||
v_head_size);
|
||||
}
|
||||
// k_beta_g = k_beta * g.exp().unsqueeze(-1)
|
||||
for (int64_t j = 0; j < chunk_size; j++) {
|
||||
int64_t i = 0;
|
||||
@@ -445,25 +479,42 @@ void chunk_gated_delta_rule_kernel_impl(
|
||||
}
|
||||
}
|
||||
// pack for k_beta_g
|
||||
pack_vnni2<scalar_t>(
|
||||
/* dst */ k_beta_g_pack,
|
||||
/* src */ k_beta_g,
|
||||
/* N */ chunk_size,
|
||||
/* K */ qk_head_size,
|
||||
/* ld_src */ qk_head_size,
|
||||
/* ld_dst */ qk_head_size);
|
||||
// k_cumdecay = attn @ k_beta_g
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ chunk_size,
|
||||
/* N */ qk_head_size,
|
||||
/* K */ chunk_size,
|
||||
/* lda */ chunk_size,
|
||||
/* ldb */ qk_head_size,
|
||||
/* ldc */ qk_head_size,
|
||||
/* add_C */ false,
|
||||
/* A */ curr_attn_reduced,
|
||||
/* B */ k_beta_g_pack,
|
||||
/* C */ k_cumdecay);
|
||||
if constexpr (brgemm_supported()) {
|
||||
pack_vnni2<scalar_t>(
|
||||
/* dst */ k_beta_g_pack,
|
||||
/* src */ k_beta_g,
|
||||
/* N */ chunk_size,
|
||||
/* K */ qk_head_size,
|
||||
/* ld_src */ qk_head_size,
|
||||
/* ld_dst */ qk_head_size);
|
||||
// k_cumdecay = attn @ k_beta_g
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ chunk_size,
|
||||
/* N */ qk_head_size,
|
||||
/* K */ chunk_size,
|
||||
/* lda */ chunk_size,
|
||||
/* ldb */ qk_head_size,
|
||||
/* ldc */ qk_head_size,
|
||||
/* add_C */ false,
|
||||
/* A */ curr_attn_reduced,
|
||||
/* B */ k_beta_g_pack,
|
||||
/* C */ k_cumdecay);
|
||||
} else {
|
||||
blas_gemm(
|
||||
at::native::TransposeType::NoTranspose,
|
||||
at::native::TransposeType::NoTranspose,
|
||||
qk_head_size,
|
||||
chunk_size,
|
||||
chunk_size,
|
||||
1.0f,
|
||||
k_beta_g,
|
||||
qk_head_size,
|
||||
curr_attn_reduced,
|
||||
chunk_size,
|
||||
0.0f,
|
||||
k_cumdecay,
|
||||
qk_head_size);
|
||||
}
|
||||
for (int i = 0; i < chunk_size; i++) {
|
||||
at::vec::map<scalar_t>(
|
||||
[](fVec x) { return x; },
|
||||
@@ -551,25 +602,42 @@ void chunk_gated_delta_rule_kernel_impl(
|
||||
|
||||
// attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
// k_transpose_i = k_i.transpose(-1, -2)
|
||||
pack_vnni<scalar_t>(
|
||||
/* dst */ k_transpose_i,
|
||||
/* src */ k_i,
|
||||
/* N */ chunk_size,
|
||||
/* K */ qk_head_size,
|
||||
/* ld_src */ qk_head_size,
|
||||
/* ld_dst */ chunk_size);
|
||||
// attn_i = q_i @ k_transpose_i
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ chunk_size,
|
||||
/* N */ chunk_size,
|
||||
/* K */ qk_head_size,
|
||||
/* lda */ qk_head_size,
|
||||
/* ldb */ chunk_size,
|
||||
/* ldc */ chunk_size,
|
||||
/* add_C */ false,
|
||||
/* A */ q_i,
|
||||
/* B */ k_transpose_i,
|
||||
/* C */ attn_i);
|
||||
if constexpr (brgemm_supported()) {
|
||||
pack_vnni<scalar_t>(
|
||||
/* dst */ k_transpose_i,
|
||||
/* src */ k_i,
|
||||
/* N */ chunk_size,
|
||||
/* K */ qk_head_size,
|
||||
/* ld_src */ qk_head_size,
|
||||
/* ld_dst */ chunk_size);
|
||||
// attn_i = q_i @ k_transpose_i
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ chunk_size,
|
||||
/* N */ chunk_size,
|
||||
/* K */ qk_head_size,
|
||||
/* lda */ qk_head_size,
|
||||
/* ldb */ chunk_size,
|
||||
/* ldc */ chunk_size,
|
||||
/* add_C */ false,
|
||||
/* A */ q_i,
|
||||
/* B */ k_transpose_i,
|
||||
/* C */ attn_i);
|
||||
} else {
|
||||
blas_gemm(
|
||||
at::native::TransposeType::Transpose,
|
||||
at::native::TransposeType::NoTranspose,
|
||||
chunk_size,
|
||||
chunk_size,
|
||||
qk_head_size,
|
||||
1.0f,
|
||||
k_i,
|
||||
qk_head_size,
|
||||
q_i,
|
||||
qk_head_size,
|
||||
0.0f,
|
||||
attn_i,
|
||||
chunk_size);
|
||||
}
|
||||
// attn_i = attn_i * decay_mask_i
|
||||
for (int64_t m = 0; m < chunk_size; m++) {
|
||||
auto attn_i_m = attn_i + m * chunk_size;
|
||||
@@ -609,28 +677,45 @@ void chunk_gated_delta_rule_kernel_impl(
|
||||
}
|
||||
|
||||
// pack for curr_last_recurrent_state
|
||||
pack_vnni2<scalar_t>(
|
||||
/* dst */ curr_last_recurrent_state_pack_reduced,
|
||||
/* src */ curr_last_recurrent_state_reduced,
|
||||
/* N */ qk_head_size,
|
||||
/* K */ v_head_size,
|
||||
/* ld_src */ v_head_size,
|
||||
/* ld_dst */ v_head_size);
|
||||
if constexpr (brgemm_supported()) {
|
||||
pack_vnni2<scalar_t>(
|
||||
/* dst */ curr_last_recurrent_state_pack_reduced,
|
||||
/* src */ curr_last_recurrent_state_reduced,
|
||||
/* N */ qk_head_size,
|
||||
/* K */ v_head_size,
|
||||
/* ld_src */ v_head_size,
|
||||
/* ld_dst */ v_head_size);
|
||||
|
||||
// v_prime = k_cumdecay_i @ curr_last_recurrent_state: [chunk_size, EV]
|
||||
// k_cumdecay_i: [chunk_size, EK]
|
||||
// curr_last_recurrent_state: [EK, EV]
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ chunk_size,
|
||||
/* N */ v_head_size,
|
||||
/* K */ qk_head_size,
|
||||
/* lda */ qk_head_size,
|
||||
/* ldb */ v_head_size,
|
||||
/* ldc */ v_head_size,
|
||||
/* add_C */ false,
|
||||
/* A */ k_cumdecay_i_reduced,
|
||||
/* B */ curr_last_recurrent_state_pack_reduced,
|
||||
/* C */ v_prime);
|
||||
// v_prime = k_cumdecay_i @ curr_last_recurrent_state: [chunk_size, EV]
|
||||
// k_cumdecay_i: [chunk_size, EK]
|
||||
// curr_last_recurrent_state: [EK, EV]
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ chunk_size,
|
||||
/* N */ v_head_size,
|
||||
/* K */ qk_head_size,
|
||||
/* lda */ qk_head_size,
|
||||
/* ldb */ v_head_size,
|
||||
/* ldc */ v_head_size,
|
||||
/* add_C */ false,
|
||||
/* A */ k_cumdecay_i_reduced,
|
||||
/* B */ curr_last_recurrent_state_pack_reduced,
|
||||
/* C */ v_prime);
|
||||
} else {
|
||||
blas_gemm(
|
||||
at::native::TransposeType::NoTranspose,
|
||||
at::native::TransposeType::NoTranspose,
|
||||
v_head_size,
|
||||
chunk_size,
|
||||
qk_head_size,
|
||||
1.0f,
|
||||
curr_last_recurrent_state_reduced,
|
||||
v_head_size,
|
||||
k_cumdecay_i_reduced,
|
||||
qk_head_size,
|
||||
0.0f,
|
||||
v_prime,
|
||||
v_head_size);
|
||||
}
|
||||
|
||||
// v_new = v_prime = v_i - v_prime
|
||||
// v_i: [chunk_size, EV]
|
||||
@@ -663,41 +748,75 @@ void chunk_gated_delta_rule_kernel_impl(
|
||||
}
|
||||
// attn_inter = qg @ curr_last_recurrent_state: [chunk_size, EV]
|
||||
// curr_last_recurrent_state: [EK, EV]
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ chunk_size,
|
||||
/* N */ v_head_size,
|
||||
/* K */ qk_head_size,
|
||||
/* lda */ qk_head_size,
|
||||
/* ldb */ v_head_size,
|
||||
/* ldc */ v_head_size,
|
||||
/* add_C */ false,
|
||||
/* A */ qg,
|
||||
/* B */ curr_last_recurrent_state_pack_reduced,
|
||||
/* C */ attn_inter);
|
||||
if constexpr (brgemm_supported()) {
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ chunk_size,
|
||||
/* N */ v_head_size,
|
||||
/* K */ qk_head_size,
|
||||
/* lda */ qk_head_size,
|
||||
/* ldb */ v_head_size,
|
||||
/* ldc */ v_head_size,
|
||||
/* add_C */ false,
|
||||
/* A */ qg,
|
||||
/* B */ curr_last_recurrent_state_pack_reduced,
|
||||
/* C */ attn_inter);
|
||||
} else {
|
||||
blas_gemm(
|
||||
at::native::TransposeType::NoTranspose,
|
||||
at::native::TransposeType::NoTranspose,
|
||||
v_head_size,
|
||||
chunk_size,
|
||||
qk_head_size,
|
||||
1.0f,
|
||||
curr_last_recurrent_state_reduced,
|
||||
v_head_size,
|
||||
qg,
|
||||
qk_head_size,
|
||||
0.0f,
|
||||
attn_inter,
|
||||
v_head_size);
|
||||
}
|
||||
|
||||
// core_attn_out[:, :, i] = attn_inter + attn_i @ v_new
|
||||
// pack for v_prime
|
||||
pack_vnni2<scalar_t>(
|
||||
/* dst */ v_prime_pack_reduced,
|
||||
/* src */ v_prime_reduced,
|
||||
/* N */ chunk_size,
|
||||
/* K */ v_head_size,
|
||||
/* ld_src */ v_head_size,
|
||||
/* ld_dst */ v_head_size);
|
||||
// attn_inter = attn_inter + attn_i @ v_new: [chunk_size, EV]
|
||||
// attn_i: [chunk_size, chunk_size]
|
||||
// v_new: [chunk_size, EV]
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ chunk_size,
|
||||
/* N */ v_head_size,
|
||||
/* K */ chunk_size,
|
||||
/* lda */ chunk_size,
|
||||
/* ldb */ v_head_size,
|
||||
/* ldc */ v_head_size,
|
||||
/* add_C */ true,
|
||||
/* A */ attn_i_reduced,
|
||||
/* B */ v_prime_pack_reduced,
|
||||
/* C */ attn_inter);
|
||||
if constexpr (brgemm_supported()) {
|
||||
pack_vnni2<scalar_t>(
|
||||
/* dst */ v_prime_pack_reduced,
|
||||
/* src */ v_prime_reduced,
|
||||
/* N */ chunk_size,
|
||||
/* K */ v_head_size,
|
||||
/* ld_src */ v_head_size,
|
||||
/* ld_dst */ v_head_size);
|
||||
// attn_inter = attn_inter + attn_i @ v_new: [chunk_size, EV]
|
||||
// attn_i: [chunk_size, chunk_size]
|
||||
// v_new: [chunk_size, EV]
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ chunk_size,
|
||||
/* N */ v_head_size,
|
||||
/* K */ chunk_size,
|
||||
/* lda */ chunk_size,
|
||||
/* ldb */ v_head_size,
|
||||
/* ldc */ v_head_size,
|
||||
/* add_C */ true,
|
||||
/* A */ attn_i_reduced,
|
||||
/* B */ v_prime_pack_reduced,
|
||||
/* C */ attn_inter);
|
||||
} else {
|
||||
blas_gemm(
|
||||
at::native::TransposeType::NoTranspose,
|
||||
at::native::TransposeType::NoTranspose,
|
||||
v_head_size,
|
||||
chunk_size,
|
||||
chunk_size,
|
||||
1.0f,
|
||||
v_prime_reduced,
|
||||
v_head_size,
|
||||
attn_i_reduced,
|
||||
chunk_size,
|
||||
1.0f,
|
||||
attn_inter,
|
||||
v_head_size);
|
||||
}
|
||||
|
||||
// core_attn_out[:, :, i] = attn_inter
|
||||
for (int64_t m = 0; m < chunk_size; m++) {
|
||||
@@ -762,17 +881,34 @@ void chunk_gated_delta_rule_kernel_impl(
|
||||
/* ld_dst */ chunk_size);
|
||||
// kgv = kg.transpose(-1, -2) @ v_new
|
||||
// v_new: [chunk_size, EV]
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ qk_head_size,
|
||||
/* N */ v_head_size,
|
||||
/* K */ chunk_size,
|
||||
/* lda */ chunk_size,
|
||||
/* ldb */ v_head_size,
|
||||
/* ldc */ v_head_size,
|
||||
/* add_C */ false,
|
||||
/* A */ kg_transpose,
|
||||
/* B */ v_prime_pack_reduced,
|
||||
/* C */ kgv);
|
||||
if constexpr (brgemm_supported()) {
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ qk_head_size,
|
||||
/* N */ v_head_size,
|
||||
/* K */ chunk_size,
|
||||
/* lda */ chunk_size,
|
||||
/* ldb */ v_head_size,
|
||||
/* ldc */ v_head_size,
|
||||
/* add_C */ false,
|
||||
/* A */ kg_transpose,
|
||||
/* B */ v_prime_pack_reduced,
|
||||
/* C */ kgv);
|
||||
} else {
|
||||
blas_gemm(
|
||||
at::native::TransposeType::NoTranspose,
|
||||
at::native::TransposeType::NoTranspose,
|
||||
v_head_size,
|
||||
qk_head_size,
|
||||
chunk_size,
|
||||
1.0f,
|
||||
v_prime_reduced,
|
||||
v_head_size,
|
||||
kg_transpose,
|
||||
chunk_size,
|
||||
0.0f,
|
||||
kgv,
|
||||
v_head_size);
|
||||
}
|
||||
// last_recurrent_state = 1) + 2)
|
||||
for (int64_t m = 0; m < qk_head_size; m++) {
|
||||
at::vec::map2<float>(
|
||||
@@ -921,7 +1057,8 @@ void fused_sigmoid_gating_delta_rule_update_kernel_impl(
|
||||
float k_scale = use_qk_l2norm_in_kernel ? qk_scale_buf[k_scale_offset] : 1.0f;
|
||||
int64_t v_offset = si * v_strideS + bi * v_strideB + ni * v_strideH;
|
||||
int64_t o_offset = ((bi * seq_len + si) * v_num_heads + ni) * v_head_dim;
|
||||
float beta_val = 1 / (1 + std::exp(-b_ptr[ni]));
|
||||
// See: https://github.com/sgl-project/sglang/pull/26634
|
||||
float beta_val = 1 / (1 + std::exp(-b_ptr[bi * v_num_heads + ni]));
|
||||
fVec beta_vec = fVec(beta_val);
|
||||
int64_t dvi = 0;
|
||||
for (; dvi <= v_head_dim - VecSize; dvi += VecSize) {
|
||||
|
||||
@@ -4,9 +4,12 @@
|
||||
// clang-format off
|
||||
|
||||
#pragma once
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
|
||||
#include "common.h"
|
||||
#include "blas_gemm.h"
|
||||
|
||||
#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__)
|
||||
#define CPU_CAPABILITY_AVX512
|
||||
#endif
|
||||
|
||||
// amx-bf16
|
||||
#define TILE_M 16
|
||||
@@ -21,31 +24,39 @@ constexpr int block_size_n() {
|
||||
return 2 * TILE_N;
|
||||
}
|
||||
|
||||
constexpr bool brgemm_supported() {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
// define threshold using brgemm (intel AMX)
|
||||
template <typename T>
|
||||
inline bool can_use_brgemm(int M);
|
||||
template <>
|
||||
inline bool can_use_brgemm<at::BFloat16>(int M) {
|
||||
return M > 4;
|
||||
return brgemm_supported() && M > 4;
|
||||
}
|
||||
template <>
|
||||
inline bool can_use_brgemm<at::Half>(int M) {
|
||||
return true;
|
||||
return brgemm_supported();
|
||||
}
|
||||
// this requires PyTorch 2.7 or above
|
||||
template <>
|
||||
inline bool can_use_brgemm<int8_t>(int M) {
|
||||
return M > 4;
|
||||
return brgemm_supported() && M > 4;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool can_use_brgemm<uint8_t>(int M) {
|
||||
return M > 4;
|
||||
return brgemm_supported() && M > 4;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) {
|
||||
return M > 4;
|
||||
return brgemm_supported() && M > 4;
|
||||
}
|
||||
|
||||
// work around compiler internal error
|
||||
|
||||
@@ -11,7 +11,9 @@
|
||||
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
namespace {
|
||||
|
||||
using namespace at::vec;
|
||||
|
||||
+19
-19
@@ -447,6 +447,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"bool is_vnni) -> Tensor");
|
||||
ops.impl("fp8_scaled_mm_cpu", torch::kCPU, &fp8_scaled_mm_cpu);
|
||||
|
||||
// Adapted from sglang: casual_conv1d kernels
|
||||
ops.def("causal_conv1d_weight_pack(Tensor weight) -> Tensor");
|
||||
ops.impl("causal_conv1d_weight_pack", torch::kCPU,
|
||||
&causal_conv1d_weight_pack);
|
||||
ops.def(
|
||||
"causal_conv1d_fwd_cpu(Tensor x, Tensor weight, Tensor? bias, Tensor? "
|
||||
"conv_states, Tensor? query_start_loc,"
|
||||
"Tensor? cache_indices, Tensor? has_initial_state, bool silu_activation, "
|
||||
"int pad_slot_id, bool is_vnni) -> "
|
||||
"Tensor");
|
||||
ops.impl("causal_conv1d_fwd_cpu", torch::kCPU, &causal_conv1d_fwd_cpu);
|
||||
ops.def(
|
||||
"causal_conv1d_update_cpu(Tensor x, Tensor(a!) conv_states, Tensor "
|
||||
"weight, Tensor? bias, bool silu_activation,"
|
||||
"Tensor? cache_seqlens, Tensor? conv_state_indices, int pad_slot_id, "
|
||||
"bool is_vnni) -> Tensor");
|
||||
ops.impl("causal_conv1d_update_cpu", torch::kCPU, &causal_conv1d_update_cpu);
|
||||
#endif
|
||||
|
||||
// Adapted from sglang: GDN kernels
|
||||
ops.def(
|
||||
"chunk_gated_delta_rule_cpu(Tensor query, Tensor key, Tensor value, "
|
||||
@@ -470,25 +489,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"-> (Tensor, Tensor)");
|
||||
ops.impl("fused_gdn_gating_cpu", torch::kCPU, &fused_gdn_gating_cpu);
|
||||
|
||||
// Adapted from sglang: casual_conv1d kernels
|
||||
ops.def("causal_conv1d_weight_pack(Tensor weight) -> Tensor");
|
||||
ops.impl("causal_conv1d_weight_pack", torch::kCPU,
|
||||
&causal_conv1d_weight_pack);
|
||||
ops.def(
|
||||
"causal_conv1d_fwd_cpu(Tensor x, Tensor weight, Tensor? bias, Tensor? "
|
||||
"conv_states, Tensor? query_start_loc,"
|
||||
"Tensor? cache_indices, Tensor? has_initial_state, bool silu_activation, "
|
||||
"int pad_slot_id, bool is_vnni) -> "
|
||||
"Tensor");
|
||||
ops.impl("causal_conv1d_fwd_cpu", torch::kCPU, &causal_conv1d_fwd_cpu);
|
||||
ops.def(
|
||||
"causal_conv1d_update_cpu(Tensor x, Tensor(a!) conv_states, Tensor "
|
||||
"weight, Tensor? bias, bool silu_activation,"
|
||||
"Tensor? cache_seqlens, Tensor? conv_state_indices, int pad_slot_id, "
|
||||
"bool is_vnni) -> Tensor");
|
||||
ops.impl("causal_conv1d_update_cpu", torch::kCPU, &causal_conv1d_update_cpu);
|
||||
#endif
|
||||
|
||||
// CPU attention kernels
|
||||
ops.def(
|
||||
"get_scheduler_metadata(int num_req, int num_heads_q, int num_heads_kv, "
|
||||
|
||||
@@ -0,0 +1,314 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
if not current_platform.is_cpu():
|
||||
pytest.skip("skipping CPU-only tests", allow_module_level=True)
|
||||
|
||||
set_random_seed(12345)
|
||||
|
||||
NUM_HEADS = [
|
||||
(2, 4),
|
||||
(4, 4),
|
||||
]
|
||||
HEAD_DIMS = [
|
||||
(32, 32),
|
||||
(64, 32),
|
||||
]
|
||||
CHUNK_SIZE = 64
|
||||
PREFILL_SEQ_LENS = [
|
||||
[1],
|
||||
[1, 2, 3],
|
||||
[CHUNK_SIZE - 1],
|
||||
[CHUNK_SIZE],
|
||||
[CHUNK_SIZE + 1],
|
||||
[CHUNK_SIZE - 1, CHUNK_SIZE, CHUNK_SIZE + 1],
|
||||
[2 * CHUNK_SIZE - 1, 2 * CHUNK_SIZE, 2 * CHUNK_SIZE + 1],
|
||||
[4 * CHUNK_SIZE + 17],
|
||||
]
|
||||
DECODE_BATCH_SIZES = [1, 3, 5]
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128, typed=False)
|
||||
def tensor_cache(
|
||||
elem_num: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
tensor = torch.rand(elem_num, dtype=dtype)
|
||||
return tensor
|
||||
|
||||
|
||||
def ref_l2norm(
|
||||
x: torch.Tensor,
|
||||
dim: int = -1,
|
||||
eps: float = 1e-5,
|
||||
) -> torch.Tensor:
|
||||
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
||||
return x * inv_norm
|
||||
|
||||
|
||||
def ref_gdn_gating(
|
||||
A_log: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
softplus_x = F.softplus(a.float() + dt_bias.float(), beta=1.0, threshold=20.0)
|
||||
g = -torch.exp(A_log.float()) * softplus_x
|
||||
beta = torch.sigmoid(b.float()).to(dtype=b.dtype)
|
||||
return g, beta
|
||||
|
||||
|
||||
def ref_gated_delta_rule(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
A_log: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
initial_state: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g, beta = ref_gdn_gating(A_log, a, b, dt_bias)
|
||||
out = torch.empty_like(value)
|
||||
final_state = torch.empty_like(initial_state)
|
||||
|
||||
for seq_idx in range(cu_seqlens.numel() - 1):
|
||||
begin = int(cu_seqlens[seq_idx].item())
|
||||
end = int(cu_seqlens[seq_idx + 1].item())
|
||||
q_seq = query[:, begin:end]
|
||||
k_seq = key[:, begin:end]
|
||||
v_seq = value[:, begin:end]
|
||||
g_seq = g[begin:end].unsqueeze(0)
|
||||
beta_seq = beta[begin:end].unsqueeze(0)
|
||||
initial_dtype = q_seq.dtype
|
||||
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q_seq = ref_l2norm(q_seq, dim=-1)
|
||||
k_seq = ref_l2norm(k_seq, dim=-1)
|
||||
|
||||
if q_seq.shape[2] != v_seq.shape[2]:
|
||||
repeat_factor = v_seq.shape[2] // q_seq.shape[2]
|
||||
q_seq = q_seq.repeat_interleave(repeat_factor, dim=2)
|
||||
k_seq = k_seq.repeat_interleave(repeat_factor, dim=2)
|
||||
|
||||
q_seq, k_seq, v_seq, beta_seq, g_seq = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32)
|
||||
for x in (q_seq, k_seq, v_seq, beta_seq, g_seq)
|
||||
]
|
||||
|
||||
batch_size, num_heads, seq_len, head_dim = q_seq.shape
|
||||
v_head_dim = v_seq.shape[-1]
|
||||
q_seq = q_seq * (1 / (head_dim**0.5))
|
||||
out_seq = torch.empty(
|
||||
batch_size,
|
||||
num_heads,
|
||||
seq_len,
|
||||
v_head_dim,
|
||||
dtype=v_seq.dtype,
|
||||
)
|
||||
state = initial_state[seq_idx : seq_idx + 1].to(v_seq)
|
||||
|
||||
for token_idx in range(seq_len):
|
||||
q_t = q_seq[:, :, token_idx]
|
||||
k_t = k_seq[:, :, token_idx]
|
||||
v_t = v_seq[:, :, token_idx]
|
||||
g_t = g_seq[:, :, token_idx].exp().unsqueeze(-1).unsqueeze(-1)
|
||||
beta_t = beta_seq[:, :, token_idx].unsqueeze(-1)
|
||||
|
||||
state = state * g_t
|
||||
kv_mem = (state * k_t.unsqueeze(-2)).sum(dim=-1)
|
||||
delta = (v_t - kv_mem) * beta_t
|
||||
state = state + delta.unsqueeze(-1) * k_t.unsqueeze(-2)
|
||||
out_seq[:, :, token_idx] = (state * q_t.unsqueeze(-2)).sum(dim=-1)
|
||||
|
||||
out[:, begin:end] = out_seq.transpose(1, 2).contiguous().to(initial_dtype)
|
||||
final_state[seq_idx] = state.squeeze(0)
|
||||
|
||||
return out, final_state
|
||||
|
||||
|
||||
def gdn_inputs(
|
||||
num_tokens: int,
|
||||
num_heads: tuple[int, int],
|
||||
head_dims: tuple[int, int],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_qk_heads, num_v_heads = num_heads
|
||||
head_dim, v_head_dim = head_dims
|
||||
q_shape = (1, num_tokens, num_qk_heads, head_dim)
|
||||
q_numel = num_tokens * num_qk_heads * head_dim
|
||||
q = tensor_cache(q_numel, torch.bfloat16).view(q_shape)
|
||||
k = tensor_cache(q_numel, torch.bfloat16).view(q_shape)
|
||||
|
||||
v_shape = (1, num_tokens, num_v_heads, v_head_dim)
|
||||
v = tensor_cache(num_tokens * num_v_heads * v_head_dim, torch.bfloat16).view(
|
||||
v_shape
|
||||
)
|
||||
|
||||
gate_shape = (num_tokens, num_v_heads)
|
||||
gate_numel = num_tokens * num_v_heads
|
||||
a = tensor_cache(gate_numel, torch.bfloat16).view(gate_shape)
|
||||
b = tensor_cache(gate_numel, torch.bfloat16).view(gate_shape)
|
||||
A_log = tensor_cache(num_v_heads, torch.float32)
|
||||
dt_bias = tensor_cache(num_v_heads, torch.bfloat16)
|
||||
return q, k, v, a, b, A_log, dt_bias
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 9])
|
||||
@pytest.mark.parametrize("num_v_heads", [4, 8])
|
||||
@torch.inference_mode()
|
||||
def test_fused_gdn_gating_cpu(
|
||||
num_tokens: int,
|
||||
num_v_heads: int,
|
||||
) -> None:
|
||||
gate_shape = (num_tokens, num_v_heads)
|
||||
gate_numel = num_tokens * num_v_heads
|
||||
a = tensor_cache(gate_numel, torch.bfloat16).view(gate_shape)
|
||||
b = tensor_cache(gate_numel, torch.bfloat16).view(gate_shape)
|
||||
A_log = tensor_cache(num_v_heads, torch.float32)
|
||||
dt_bias = tensor_cache(num_v_heads, torch.bfloat16)
|
||||
|
||||
g_ref, beta_ref = ref_gdn_gating(A_log, a, b, dt_bias)
|
||||
g, beta = ops.fused_gdn_gating_cpu(A_log, a, b, dt_bias)
|
||||
|
||||
torch.testing.assert_close(g, g_ref.unsqueeze(0), atol=1e-4, rtol=1e-4)
|
||||
torch.testing.assert_close(
|
||||
beta.float(), beta_ref.unsqueeze(0).float(), atol=5e-3, rtol=5e-3
|
||||
)
|
||||
|
||||
|
||||
# decode path
|
||||
@pytest.mark.parametrize("batch_size", DECODE_BATCH_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_dims", HEAD_DIMS)
|
||||
@torch.inference_mode()
|
||||
def test_fused_sigmoid_gating_delta_rule_update_cpu(
|
||||
batch_size: int,
|
||||
num_heads: tuple[int, int],
|
||||
head_dims: tuple[int, int],
|
||||
) -> None:
|
||||
q, k, v, a, b, A_log, dt_bias = gdn_inputs(
|
||||
num_tokens=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dims=head_dims,
|
||||
)
|
||||
_, num_v_heads = num_heads
|
||||
head_dim, v_head_dim = head_dims
|
||||
state_indices = torch.arange(batch_size, dtype=torch.int32)
|
||||
cu_seqlens = torch.arange(batch_size + 1, dtype=torch.int32)
|
||||
state_shape = (batch_size, num_v_heads, head_dim, v_head_dim)
|
||||
state = tensor_cache(
|
||||
batch_size * num_v_heads * head_dim * v_head_dim, torch.float32
|
||||
).view(state_shape)
|
||||
state_ref = state[state_indices].transpose(-1, -2).contiguous()
|
||||
|
||||
out_ref, final_state_ref = ref_gated_delta_rule(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
a=a,
|
||||
b=b,
|
||||
A_log=A_log,
|
||||
dt_bias=dt_bias,
|
||||
initial_state=state_ref,
|
||||
cu_seqlens=cu_seqlens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
out_ref = out_ref.transpose(0, 1).contiguous()
|
||||
|
||||
state_out = state.clone()
|
||||
out = ops.fused_sigmoid_gating_delta_rule_update_cpu(
|
||||
A_log=A_log,
|
||||
dt_bias=dt_bias,
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
a=a,
|
||||
b=b,
|
||||
initial_state_source=state_out,
|
||||
initial_state_indices=state_indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(
|
||||
state_out[state_indices].transpose(-1, -2),
|
||||
final_state_ref,
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
# prefill path
|
||||
@pytest.mark.parametrize("seq_lens", PREFILL_SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_dims", HEAD_DIMS)
|
||||
@torch.inference_mode()
|
||||
def test_chunk_gated_delta_rule_cpu(
|
||||
seq_lens: list[int],
|
||||
num_heads: tuple[int, int],
|
||||
head_dims: tuple[int, int],
|
||||
) -> None:
|
||||
total_tokens = sum(seq_lens)
|
||||
q, k, v, a, b, A_log, dt_bias = gdn_inputs(
|
||||
num_tokens=total_tokens,
|
||||
num_heads=num_heads,
|
||||
head_dims=head_dims,
|
||||
)
|
||||
_, num_v_heads = num_heads
|
||||
head_dim, v_head_dim = head_dims
|
||||
cu_seqlens = torch.tensor(
|
||||
[0, *torch.tensor(seq_lens).cumsum(0).tolist()], dtype=torch.int32
|
||||
)
|
||||
initial_state_shape = (len(seq_lens), num_v_heads, head_dim, v_head_dim)
|
||||
initial_state = tensor_cache(
|
||||
len(seq_lens) * num_v_heads * head_dim * v_head_dim, torch.float32
|
||||
).view(initial_state_shape)
|
||||
initial_state_ref = initial_state.transpose(-1, -2).contiguous()
|
||||
|
||||
out_ref, final_state_ref = ref_gated_delta_rule(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
a=a,
|
||||
b=b,
|
||||
A_log=A_log,
|
||||
dt_bias=dt_bias,
|
||||
initial_state=initial_state_ref,
|
||||
cu_seqlens=cu_seqlens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
g, beta = ref_gdn_gating(A_log, a, b, dt_bias)
|
||||
out, final_state = ops.chunk_gated_delta_rule_cpu(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
g=g.unsqueeze(0),
|
||||
beta=beta.unsqueeze(0),
|
||||
initial_state=initial_state,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens,
|
||||
head_first=False,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(
|
||||
final_state.transpose(-1, -2),
|
||||
final_state_ref,
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
@@ -12,11 +12,6 @@ from vllm.model_executor.layers.mamba.ops.cpu.causal_conv1d import (
|
||||
causal_conv1d_torch,
|
||||
causal_conv1d_update_torch,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.cpu.recurrent_gated_delta_rule import (
|
||||
chunk_gated_delta_rule,
|
||||
gdn_gating,
|
||||
recurrent_gated_delta_rule,
|
||||
)
|
||||
from vllm.utils.torch_utils import (
|
||||
LayerNameType,
|
||||
_resolve_layer_name,
|
||||
@@ -55,88 +50,91 @@ def cpu_gdn_attention_core(
|
||||
attn_metadata_i.spec_sequence_masks is None
|
||||
and attn_metadata_i.num_accepted_tokens is None
|
||||
), "speculative decode not supported in CPU GDN attention."
|
||||
|
||||
if torch.cpu._is_amx_tile_supported():
|
||||
return cpu_gdn_attention_core_amx(
|
||||
mixed_qkv,
|
||||
b,
|
||||
a,
|
||||
core_attn_out,
|
||||
attn_metadata_i,
|
||||
layer,
|
||||
)
|
||||
assert mixed_qkv.dtype == torch.bfloat16, "CPU GDN attention requires BF16."
|
||||
|
||||
state_indices_tensor = attn_metadata_i.non_spec_state_indices_tensor
|
||||
query_start_loc = attn_metadata_i.non_spec_query_start_loc
|
||||
assert state_indices_tensor is not None
|
||||
assert query_start_loc is not None
|
||||
|
||||
# [num_allocated_slots, conv_dim, kernel - 1]
|
||||
is_amx = torch.cpu._is_amx_tile_supported()
|
||||
|
||||
conv_state = layer.kv_cache[0]
|
||||
if not is_conv_state_dim_first():
|
||||
conv_state = conv_state.transpose(-1, -2)
|
||||
if is_amx:
|
||||
# AMX causal conv requires [num_allocated_slots, kernel - 1, conv_dim].
|
||||
if is_conv_state_dim_first():
|
||||
raise RuntimeError("AMX GDN attention requires `SD` conv_state layout.")
|
||||
conv_state = conv_state.transpose(1, 2)
|
||||
else:
|
||||
if not is_conv_state_dim_first():
|
||||
conv_state = conv_state.transpose(-1, -2)
|
||||
conv_weights = layer.conv1d.weight.view(
|
||||
layer.conv1d.weight.size(0), layer.conv1d.weight.size(2)
|
||||
)
|
||||
|
||||
# [num_allocated_slots, num_v_heads / tp_size, v_dim, k_dim]
|
||||
ssm_state = layer.kv_cache[1]
|
||||
mixed_qkv = mixed_qkv.contiguous()
|
||||
a = a.contiguous()
|
||||
b = b.contiguous()
|
||||
|
||||
num_allocated_slots, head_num, v_dim, k_dim = ssm_state.size()
|
||||
ssm_state = ssm_state.view(
|
||||
num_allocated_slots,
|
||||
head_num,
|
||||
k_dim,
|
||||
v_dim,
|
||||
)
|
||||
|
||||
num_decodes = attn_metadata_i.num_decodes
|
||||
num_decode_tokens = attn_metadata_i.num_decode_tokens
|
||||
num_prefills = attn_metadata_i.num_prefills
|
||||
num_prefill_tokens = attn_metadata_i.num_prefill_tokens
|
||||
|
||||
conv_weights = layer.conv1d.weight.view(
|
||||
layer.conv1d.weight.size(0), layer.conv1d.weight.size(2)
|
||||
)
|
||||
|
||||
# all decode requests (batched)
|
||||
if num_decodes > 0:
|
||||
decode_mixed_qkv = mixed_qkv[:num_decode_tokens]
|
||||
decode_b = b[:num_decode_tokens]
|
||||
decode_a = a[:num_decode_tokens]
|
||||
decode_state_indices = state_indices_tensor[:num_decodes]
|
||||
decode_conv_state = conv_state[decode_state_indices].contiguous()
|
||||
if is_amx:
|
||||
decode_mixed_qkv = ops.causal_conv1d_update_cpu(
|
||||
x=decode_mixed_qkv,
|
||||
conv_states=conv_state,
|
||||
weight=layer.conv1d.weight,
|
||||
bias=layer.conv1d.bias,
|
||||
silu_activation=layer.activation == "silu",
|
||||
conv_state_indices=decode_state_indices,
|
||||
is_vnni=True,
|
||||
)
|
||||
else:
|
||||
decode_conv_state = conv_state[decode_state_indices].contiguous()
|
||||
|
||||
decode_mixed_qkv = causal_conv1d_update_torch(
|
||||
# [B, dim] -> [B, dim, 1]
|
||||
x=decode_mixed_qkv.unsqueeze(-1),
|
||||
conv_state=decode_conv_state,
|
||||
weight=conv_weights,
|
||||
bias=layer.conv1d.bias,
|
||||
activation=layer.activation,
|
||||
).squeeze(-1)
|
||||
conv_state[decode_state_indices] = decode_conv_state
|
||||
decode_mixed_qkv = causal_conv1d_update_torch(
|
||||
# [B, dim] -> [B, dim, 1]
|
||||
x=decode_mixed_qkv.unsqueeze(-1),
|
||||
conv_state=decode_conv_state,
|
||||
weight=conv_weights,
|
||||
bias=layer.conv1d.bias,
|
||||
activation=layer.activation,
|
||||
).squeeze(-1)
|
||||
conv_state[decode_state_indices] = decode_conv_state
|
||||
|
||||
query, key, value = layer.rearrange_mixed_qkv(decode_mixed_qkv)
|
||||
|
||||
# [1, L, H, D] -> [B, 1, H, D] for batched decode
|
||||
query = query.transpose(0, 1).contiguous()
|
||||
key = key.transpose(0, 1).contiguous()
|
||||
value = value.transpose(0, 1).contiguous()
|
||||
|
||||
g, beta_output = gdn_gating(
|
||||
attn_out = ops.fused_sigmoid_gating_delta_rule_update_cpu(
|
||||
A_log=layer.A_log,
|
||||
dt_bias=layer.dt_bias,
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
a=decode_a,
|
||||
b=decode_b,
|
||||
dt_bias=layer.dt_bias,
|
||||
)
|
||||
if g.ndim == 2:
|
||||
g = g.unsqueeze(1)
|
||||
beta_output = beta_output.unsqueeze(1)
|
||||
|
||||
initial_state = ssm_state[decode_state_indices].contiguous()
|
||||
attn_out, last_recurrent_state = recurrent_gated_delta_rule(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
g=g,
|
||||
beta=beta_output,
|
||||
initial_state=initial_state,
|
||||
scale=None,
|
||||
initial_state_source=ssm_state,
|
||||
initial_state_indices=decode_state_indices,
|
||||
cu_seqlens=query_start_loc[: num_decodes + 1],
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
ssm_state[decode_state_indices] = last_recurrent_state.to(
|
||||
ssm_state.dtype
|
||||
).contiguous()
|
||||
core_attn_out[:num_decode_tokens] = attn_out.squeeze(1)
|
||||
|
||||
# all prefill requests: (varlen) currently naively loops over sequences
|
||||
@@ -160,154 +158,29 @@ def cpu_gdn_attention_core(
|
||||
num_decodes : num_decodes + num_prefills
|
||||
]
|
||||
|
||||
prefill_mixed_qkv = causal_conv1d_torch(
|
||||
x=prefill_mixed_qkv.transpose(0, 1),
|
||||
weight=conv_weights,
|
||||
bias=layer.conv1d.bias,
|
||||
conv_states=conv_state,
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
cache_indices=prefill_state_indices,
|
||||
has_initial_state=prefill_has_initial_state,
|
||||
activation=layer.activation,
|
||||
).transpose(0, 1)
|
||||
|
||||
query, key, value = layer.rearrange_mixed_qkv(prefill_mixed_qkv)
|
||||
g, beta = gdn_gating(layer.A_log, prefill_a, prefill_b, layer.dt_bias)
|
||||
if g.ndim == 2:
|
||||
g = g.unsqueeze(0)
|
||||
beta = beta.unsqueeze(0)
|
||||
|
||||
initial_state = ssm_state[prefill_state_indices].contiguous()
|
||||
initial_state[~prefill_has_initial_state, ...] = 0
|
||||
attn_out, last_recurrent_state = chunk_gated_delta_rule(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=None,
|
||||
initial_state=initial_state,
|
||||
cu_seqlens=prefill_query_start_loc,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
ssm_state[prefill_state_indices] = last_recurrent_state.to(ssm_state.dtype)
|
||||
core_attn_out[prefill_token_start:prefill_token_end] = attn_out.squeeze(0)
|
||||
|
||||
|
||||
def cpu_gdn_attention_core_fake(
|
||||
mixed_qkv: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
layer_name: LayerNameType,
|
||||
) -> None:
|
||||
"""Fake implementation for torch.compile."""
|
||||
return
|
||||
|
||||
|
||||
def cpu_gdn_attention_core_amx(
|
||||
mixed_qkv: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
attn_metadata_i: GDNAttentionMetadata,
|
||||
layer: torch.nn.Module,
|
||||
):
|
||||
state_indices_tensor = attn_metadata_i.non_spec_state_indices_tensor
|
||||
query_start_loc = attn_metadata_i.non_spec_query_start_loc
|
||||
assert state_indices_tensor is not None
|
||||
assert query_start_loc is not None
|
||||
|
||||
# [num_allocated_slots, kernel - 1, conv_dim]
|
||||
conv_state = layer.kv_cache[0]
|
||||
if is_conv_state_dim_first():
|
||||
raise RuntimeError("AMX GDN attention requires `SD` conv_state layout.")
|
||||
# reshape to [num_allocated_slots, conv_dim, kernel - 1]
|
||||
conv_state_t = conv_state.transpose(1, 2)
|
||||
|
||||
# [num_allocated_slots, num_v_heads / tp_size, v_dim, k_dim]
|
||||
ssm_state = layer.kv_cache[1]
|
||||
# rehape to [num_allocated_slots, num_v_heads / tp_size, k_dim, v_dim]
|
||||
num_allocated_slots, head_num, v_dim, k_dim = ssm_state.size()
|
||||
ssm_state = ssm_state.view(
|
||||
num_allocated_slots,
|
||||
head_num,
|
||||
k_dim,
|
||||
v_dim,
|
||||
)
|
||||
|
||||
mixed_qkv = mixed_qkv.contiguous()
|
||||
a = a.contiguous()
|
||||
b = b.contiguous()
|
||||
|
||||
num_decodes = attn_metadata_i.num_decodes
|
||||
num_decode_tokens = attn_metadata_i.num_decode_tokens
|
||||
num_prefills = attn_metadata_i.num_prefills
|
||||
num_prefill_tokens = attn_metadata_i.num_prefill_tokens
|
||||
|
||||
if num_decodes > 0:
|
||||
decode_mixed_qkv = mixed_qkv[:num_decode_tokens]
|
||||
decode_b = b[:num_decode_tokens]
|
||||
decode_a = a[:num_decode_tokens]
|
||||
decode_state_indices = state_indices_tensor[:num_decodes]
|
||||
|
||||
decode_mixed_qkv = ops.causal_conv1d_update_cpu(
|
||||
x=decode_mixed_qkv,
|
||||
conv_states=conv_state_t,
|
||||
weight=layer.conv1d.weight,
|
||||
bias=layer.conv1d.bias,
|
||||
silu_activation=layer.activation == "silu",
|
||||
conv_state_indices=decode_state_indices,
|
||||
is_vnni=True,
|
||||
)
|
||||
|
||||
query, key, value = layer.rearrange_mixed_qkv(decode_mixed_qkv)
|
||||
attn_out = ops.fused_sigmoid_gating_delta_rule_update_cpu(
|
||||
A_log=layer.A_log,
|
||||
dt_bias=layer.dt_bias,
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
a=decode_a,
|
||||
b=decode_b,
|
||||
initial_state_source=ssm_state,
|
||||
initial_state_indices=decode_state_indices,
|
||||
cu_seqlens=query_start_loc[: num_decodes + 1],
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
core_attn_out[:num_decode_tokens] = attn_out.squeeze(1)
|
||||
|
||||
if num_prefills > 0:
|
||||
has_initial_state = attn_metadata_i.has_initial_state
|
||||
assert has_initial_state is not None
|
||||
|
||||
prefill_token_start = num_decode_tokens
|
||||
prefill_token_end = prefill_token_start + num_prefill_tokens
|
||||
prefill_mixed_qkv = mixed_qkv[prefill_token_start:prefill_token_end]
|
||||
prefill_b = b[prefill_token_start:prefill_token_end]
|
||||
prefill_a = a[prefill_token_start:prefill_token_end]
|
||||
prefill_state_indices = state_indices_tensor[
|
||||
num_decodes : num_decodes + num_prefills
|
||||
]
|
||||
prefill_query_start_loc = (
|
||||
query_start_loc[num_decodes : num_decodes + num_prefills + 1]
|
||||
- num_decode_tokens
|
||||
)
|
||||
prefill_has_initial_state = has_initial_state[
|
||||
num_decodes : num_decodes + num_prefills
|
||||
]
|
||||
|
||||
prefill_mixed_qkv = ops.causal_conv1d_fwd_cpu(
|
||||
x=prefill_mixed_qkv.transpose(0, 1),
|
||||
weight=layer.conv1d.weight,
|
||||
bias=layer.conv1d.bias,
|
||||
conv_states=conv_state_t,
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
cache_indices=prefill_state_indices,
|
||||
has_initial_state=prefill_has_initial_state,
|
||||
silu_activation=layer.activation == "silu",
|
||||
is_vnni=True,
|
||||
).transpose(0, 1)
|
||||
if is_amx:
|
||||
prefill_mixed_qkv = ops.causal_conv1d_fwd_cpu(
|
||||
x=prefill_mixed_qkv.transpose(0, 1),
|
||||
weight=layer.conv1d.weight,
|
||||
bias=layer.conv1d.bias,
|
||||
conv_states=conv_state,
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
cache_indices=prefill_state_indices,
|
||||
has_initial_state=prefill_has_initial_state,
|
||||
silu_activation=layer.activation == "silu",
|
||||
is_vnni=True,
|
||||
).transpose(0, 1)
|
||||
else:
|
||||
prefill_mixed_qkv = causal_conv1d_torch(
|
||||
x=prefill_mixed_qkv.transpose(0, 1),
|
||||
weight=conv_weights,
|
||||
bias=layer.conv1d.bias,
|
||||
conv_states=conv_state,
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
cache_indices=prefill_state_indices,
|
||||
has_initial_state=prefill_has_initial_state,
|
||||
activation=layer.activation,
|
||||
).transpose(0, 1)
|
||||
|
||||
query, key, value = layer.rearrange_mixed_qkv(prefill_mixed_qkv)
|
||||
g, beta = ops.fused_gdn_gating_cpu(
|
||||
@@ -334,6 +207,17 @@ def cpu_gdn_attention_core_amx(
|
||||
core_attn_out[prefill_token_start:prefill_token_end] = attn_out.squeeze(0)
|
||||
|
||||
|
||||
def cpu_gdn_attention_core_fake(
|
||||
mixed_qkv: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
layer_name: LayerNameType,
|
||||
) -> None:
|
||||
"""Fake implementation for torch.compile."""
|
||||
return
|
||||
|
||||
|
||||
def register_cpu_gdn_attention_ops() -> None:
|
||||
global _CPU_GDN_ATTENTION_OPS_REGISTERED
|
||||
if _CPU_GDN_ATTENTION_OPS_REGISTERED:
|
||||
|
||||
@@ -1,223 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def l2norm(
|
||||
x: torch.Tensor,
|
||||
dim: int = -1,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
||||
return x * inv_norm
|
||||
|
||||
|
||||
def recurrent_gated_delta_rule(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
initial_dtype = query.dtype
|
||||
if use_qk_l2norm_in_kernel:
|
||||
query = l2norm(query, dim=-1, eps=1e-6)
|
||||
key = l2norm(key, dim=-1, eps=1e-6)
|
||||
|
||||
if query.shape[2] != value.shape[2]:
|
||||
repeat_factor = value.shape[2] // query.shape[2]
|
||||
query = query.repeat_interleave(repeat_factor, dim=2)
|
||||
key = key.repeat_interleave(repeat_factor, dim=2)
|
||||
|
||||
query, key, value, beta, g = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32)
|
||||
for x in (query, key, value, beta, g)
|
||||
]
|
||||
|
||||
batch_size, num_heads, sequence_length, _ = key.shape
|
||||
v_head_dim = value.shape[-1]
|
||||
if scale is None:
|
||||
scale = 1 / (query.shape[-1] ** 0.5)
|
||||
query = query * scale
|
||||
|
||||
core_attn_out = torch.empty(
|
||||
batch_size,
|
||||
num_heads,
|
||||
sequence_length,
|
||||
v_head_dim,
|
||||
dtype=value.dtype,
|
||||
)
|
||||
last_recurrent_state = initial_state.to(value)
|
||||
|
||||
for token_idx in range(sequence_length):
|
||||
q_t = query[:, :, token_idx]
|
||||
k_t = key[:, :, token_idx]
|
||||
v_t = value[:, :, token_idx]
|
||||
g_t = g[:, :, token_idx].exp().unsqueeze(-1).unsqueeze(-1)
|
||||
beta_t = beta[:, :, token_idx].unsqueeze(-1)
|
||||
|
||||
last_recurrent_state = last_recurrent_state * g_t
|
||||
kv_mem = (last_recurrent_state * k_t.unsqueeze(-2)).sum(dim=-1)
|
||||
delta = (v_t - kv_mem) * beta_t
|
||||
last_recurrent_state = last_recurrent_state + delta.unsqueeze(
|
||||
-1
|
||||
) * k_t.unsqueeze(-2)
|
||||
core_attn_out[:, :, token_idx] = (last_recurrent_state * q_t.unsqueeze(-2)).sum(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
||||
return core_attn_out, last_recurrent_state
|
||||
|
||||
|
||||
def gdn_gating(
|
||||
A_log: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
beta: float = 1.0,
|
||||
threshold: float = 20.0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
softplus_x = F.softplus(a.float() + dt_bias.float(), beta=beta, threshold=threshold)
|
||||
g = -torch.exp(A_log.float()) * softplus_x
|
||||
beta_output = torch.sigmoid(b.float()).to(dtype=b.dtype)
|
||||
return g, beta_output
|
||||
|
||||
|
||||
def chunk_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
*,
|
||||
initial_state: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
output = torch.empty_like(v)
|
||||
state_dtype = initial_state.dtype
|
||||
chunk_size = 128
|
||||
sequence_bounds = [
|
||||
(
|
||||
seq_idx,
|
||||
int(cu_seqlens[seq_idx].item()),
|
||||
int(cu_seqlens[seq_idx + 1].item()),
|
||||
)
|
||||
for seq_idx in range(len(cu_seqlens) - 1)
|
||||
]
|
||||
chunk_eye = torch.eye(chunk_size, dtype=torch.float32)
|
||||
num_sequences = len(sequence_bounds)
|
||||
num_value_heads = v.shape[2]
|
||||
value_head_dim = v.shape[3]
|
||||
key_head_dim = k.shape[3]
|
||||
final_state = torch.empty(
|
||||
(num_sequences, num_value_heads, value_head_dim, key_head_dim),
|
||||
dtype=state_dtype,
|
||||
)
|
||||
|
||||
for seq_idx, begin, end in sequence_bounds:
|
||||
q_seq = q[:, begin:end]
|
||||
k_seq = k[:, begin:end]
|
||||
v_seq = v[:, begin:end]
|
||||
g_seq = g[:, begin:end]
|
||||
beta_seq = beta[:, begin:end]
|
||||
|
||||
initial_dtype = q_seq.dtype
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q_seq = l2norm(q_seq, dim=-1, eps=1e-6)
|
||||
k_seq = l2norm(k_seq, dim=-1, eps=1e-6)
|
||||
|
||||
num_qk_heads = q_seq.shape[2]
|
||||
num_value_heads = v_seq.shape[2]
|
||||
if num_qk_heads != num_value_heads:
|
||||
repeat_factor = num_value_heads // num_qk_heads
|
||||
q_seq = q_seq.repeat_interleave(repeat_factor, dim=2)
|
||||
k_seq = k_seq.repeat_interleave(repeat_factor, dim=2)
|
||||
|
||||
q_seq, k_seq, v_seq, beta_seq, g_seq = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32)
|
||||
for x in (q_seq, k_seq, v_seq, beta_seq, g_seq)
|
||||
]
|
||||
seq_batch_size, num_heads, seq_len, qk_head_dim = q_seq.shape
|
||||
value_head_dim = v_seq.shape[-1]
|
||||
|
||||
if scale is None:
|
||||
scale = 1 / (qk_head_dim**0.5)
|
||||
|
||||
q_seq = q_seq * scale
|
||||
|
||||
seq_state = initial_state[seq_idx : seq_idx + 1].to(v_seq)
|
||||
seq_output = torch.empty(
|
||||
seq_batch_size,
|
||||
num_heads,
|
||||
seq_len,
|
||||
value_head_dim,
|
||||
dtype=v_seq.dtype,
|
||||
)
|
||||
|
||||
for chunk_start in range(0, seq_len, chunk_size):
|
||||
chunk_end = min(chunk_start + chunk_size, seq_len)
|
||||
q_chunk = q_seq[:, :, chunk_start:chunk_end]
|
||||
k_chunk = k_seq[:, :, chunk_start:chunk_end]
|
||||
v_chunk = v_seq[:, :, chunk_start:chunk_end]
|
||||
beta_chunk = beta_seq[:, :, chunk_start:chunk_end]
|
||||
g_chunk = g_seq[:, :, chunk_start:chunk_end]
|
||||
chunk_len = chunk_end - chunk_start
|
||||
|
||||
cum_g = g_chunk.cumsum(dim=-1)
|
||||
exp_cum_g = cum_g.exp()
|
||||
decay = (cum_g.unsqueeze(-1) - cum_g.unsqueeze(-2)).exp()
|
||||
|
||||
interaction = (k_chunk * beta_chunk.unsqueeze(-1)) @ k_chunk.transpose(
|
||||
-1, -2
|
||||
)
|
||||
interaction = torch.tril(interaction * decay, diagonal=-1)
|
||||
system = interaction + chunk_eye[:chunk_len, :chunk_len]
|
||||
|
||||
solved_values = torch.linalg.solve_triangular(
|
||||
system,
|
||||
v_chunk * beta_chunk.unsqueeze(-1),
|
||||
upper=False,
|
||||
)
|
||||
solved_keys = torch.linalg.solve_triangular(
|
||||
system,
|
||||
(k_chunk * beta_chunk.unsqueeze(-1)) * exp_cum_g.unsqueeze(-1),
|
||||
upper=False,
|
||||
)
|
||||
|
||||
incoming_memory = torch.einsum("bhvk,bhck->bhcv", seq_state, solved_keys)
|
||||
transformed_values = solved_values - incoming_memory
|
||||
|
||||
# Each chunk contributes both from the incoming recurrent state and
|
||||
# from its own in-chunk interactions.
|
||||
inter_chunk = torch.einsum(
|
||||
"bhvk,bhck->bhcv",
|
||||
seq_state,
|
||||
q_chunk * exp_cum_g.unsqueeze(-1),
|
||||
)
|
||||
intra_chunk = torch.tril((q_chunk @ k_chunk.transpose(-1, -2)) * decay)
|
||||
seq_output[:, :, chunk_start:chunk_end] = (
|
||||
inter_chunk + intra_chunk @ transformed_values
|
||||
)
|
||||
|
||||
# Carry the recurrent state forward to the next chunk boundary.
|
||||
end_decay = (cum_g[:, :, -1:] - cum_g).exp().unsqueeze(-1)
|
||||
decayed_keys = k_chunk * end_decay
|
||||
seq_state = seq_state * exp_cum_g[:, :, -1, None, None] + torch.einsum(
|
||||
"bhcv,bhck->bhvk", transformed_values, decayed_keys
|
||||
)
|
||||
|
||||
output[0, begin:end].copy_(
|
||||
seq_output.transpose(1, 2).contiguous().to(initial_dtype).squeeze(0)
|
||||
)
|
||||
final_state[seq_idx].copy_(seq_state.squeeze(0).to(state_dtype).contiguous())
|
||||
|
||||
return output, final_state
|
||||
Reference in New Issue
Block a user