mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Feat] CPU fp8 attn for AMX/AVX-512 (#39445)
Signed-off-by: Li, Tianmu <tianmu.li@intel.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com>
This commit is contained in:
+78
-25
@@ -1,5 +1,16 @@
|
||||
#include "cpu_attn_dispatch_generated.h"
|
||||
|
||||
// Maps kv_cache_dtype string to Fp8KVCacheDataType enum.
|
||||
// "auto" -> kAuto(0); "fp8"/"fp8_e4m3" -> kFp8E4M3; "fp8_e5m2" -> kFp8E5M2.
|
||||
static inline cpu_attention::Fp8KVCacheDataType parse_fp8_kv_dtype(
|
||||
const std::string& kv_cache_dtype) {
|
||||
if (kv_cache_dtype == "fp8_e5m2")
|
||||
return cpu_attention::Fp8KVCacheDataType::kFp8E5M2;
|
||||
if (kv_cache_dtype == "fp8_e4m3" || kv_cache_dtype == "fp8")
|
||||
return cpu_attention::Fp8KVCacheDataType::kFp8E4M3;
|
||||
return cpu_attention::Fp8KVCacheDataType::kAuto;
|
||||
}
|
||||
|
||||
torch::Tensor get_scheduler_metadata(
|
||||
const int64_t num_req, const int64_t num_heads_q,
|
||||
const int64_t num_heads_kv, const int64_t head_dim,
|
||||
@@ -49,7 +60,7 @@ torch::Tensor get_scheduler_metadata(
|
||||
input.enable_kv_split = enable_kv_split;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
|
||||
CPU_ATTN_DISPATCH(head_dim, isa, [&]() {
|
||||
CPU_ATTN_DISPATCH(head_dim, isa, 0, [&]() {
|
||||
input.elem_size = sizeof(scalar_t);
|
||||
input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t);
|
||||
input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t);
|
||||
@@ -72,7 +83,9 @@ void cpu_attn_reshape_and_cache(
|
||||
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
const torch::Tensor& slot_mapping, const std::string& isa) {
|
||||
const torch::Tensor& slot_mapping, const std::string& isa,
|
||||
const double k_scale = 1.0, const double v_scale = 1.0,
|
||||
const std::string& kv_cache_dtype = "auto") {
|
||||
TORCH_CHECK_EQ(key.dim(), 3);
|
||||
TORCH_CHECK_EQ(value.dim(), 3);
|
||||
TORCH_CHECK_EQ(key_cache.dim(), 4);
|
||||
@@ -80,18 +93,30 @@ void cpu_attn_reshape_and_cache(
|
||||
TORCH_CHECK_EQ(key.stride(2), 1);
|
||||
TORCH_CHECK_EQ(value.stride(2), 1);
|
||||
|
||||
const int64_t kv_cache_idx =
|
||||
static_cast<int64_t>(parse_fp8_kv_dtype(kv_cache_dtype));
|
||||
const bool is_fp8 = (kv_cache_idx != 0);
|
||||
|
||||
if (is_fp8) {
|
||||
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte,
|
||||
"key_cache must be uint8 for FP8 path");
|
||||
TORCH_CHECK(value_cache.scalar_type() == at::ScalarType::Byte,
|
||||
"value_cache must be uint8 for FP8 path");
|
||||
TORCH_CHECK(k_scale > 0, "k_scale must be positive for FP8 path");
|
||||
TORCH_CHECK(v_scale > 0, "v_scale must be positive for FP8 path");
|
||||
}
|
||||
|
||||
const float k_inv = is_fp8 ? 1.0f / static_cast<float>(k_scale) : 0.0f;
|
||||
const float v_inv = is_fp8 ? 1.0f / static_cast<float>(v_scale) : 0.0f;
|
||||
|
||||
const int64_t token_num = key.size(0);
|
||||
const int64_t key_token_num_stride = key.stride(0);
|
||||
const int64_t value_token_num_stride = value.stride(0);
|
||||
const int64_t head_num = value.size(1);
|
||||
const int64_t key_head_num_stride = key.stride(1);
|
||||
const int64_t value_head_num_stride = value.stride(1);
|
||||
const int64_t head_num = key.size(1);
|
||||
const int64_t head_dim = key.size(2);
|
||||
const int64_t num_blocks = key_cache.size(0);
|
||||
const int64_t num_blocks_stride = key_cache.stride(0);
|
||||
const int64_t cache_head_num_stride = key_cache.stride(1);
|
||||
const int64_t block_size = key_cache.size(2);
|
||||
const int64_t block_size_stride = key_cache.stride(2);
|
||||
const int64_t head_dim = key.size(-1);
|
||||
|
||||
cpu_attention::ISA isa_tag = [&]() {
|
||||
if (isa == "amx") {
|
||||
@@ -109,16 +134,24 @@ void cpu_attn_reshape_and_cache(
|
||||
}
|
||||
}();
|
||||
|
||||
if (is_fp8) {
|
||||
TORCH_CHECK(isa_tag == cpu_attention::ISA::AMX ||
|
||||
isa_tag == cpu_attention::ISA::VEC,
|
||||
"FP8 KV cache is only supported on x86 (AMX/VEC) ISA");
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() {
|
||||
CPU_ATTN_DISPATCH(head_dim, isa_tag, [&]() {
|
||||
CPU_ATTN_DISPATCH(head_dim, isa_tag, kv_cache_idx, [&]() {
|
||||
using kv_t = typename attn_impl::kv_cache_t;
|
||||
attn_impl::reshape_and_cache(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), token_num, key_token_num_stride,
|
||||
value_token_num_stride, head_num, key_head_num_stride,
|
||||
value_head_num_stride, num_blocks, num_blocks_stride,
|
||||
cache_head_num_stride, block_size, block_size_stride);
|
||||
reinterpret_cast<kv_t*>(key_cache.data_ptr()),
|
||||
reinterpret_cast<kv_t*>(value_cache.data_ptr()),
|
||||
slot_mapping.data_ptr<int64_t>(), token_num, key.stride(0),
|
||||
value.stride(0), head_num, key.stride(1), value.stride(1),
|
||||
num_blocks, num_blocks_stride, cache_head_num_stride, block_size,
|
||||
block_size_stride, k_inv, v_inv);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -137,13 +170,26 @@ void cpu_attention_with_kv_cache(
|
||||
const int64_t sliding_window_left, const int64_t sliding_window_right,
|
||||
const torch::Tensor& block_table, // [num_tokens, max_block_num]
|
||||
const double softcap, const torch::Tensor& scheduler_metadata,
|
||||
const std::optional<torch::Tensor>& s_aux // [num_heads]
|
||||
) {
|
||||
const std::optional<torch::Tensor>& s_aux, // [num_heads]
|
||||
const double k_scale = 1.0, const double v_scale = 1.0,
|
||||
const std::string& kv_cache_dtype = "auto") {
|
||||
TORCH_CHECK_EQ(query.dim(), 3);
|
||||
TORCH_CHECK_EQ(query.stride(2), 1);
|
||||
TORCH_CHECK_EQ(key_cache.dim(), 4);
|
||||
TORCH_CHECK_EQ(value_cache.dim(), 4);
|
||||
|
||||
const int64_t kv_cache_idx =
|
||||
static_cast<int64_t>(parse_fp8_kv_dtype(kv_cache_dtype));
|
||||
const bool is_fp8 = (kv_cache_idx != 0);
|
||||
if (is_fp8) {
|
||||
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte,
|
||||
"key_cache must be uint8 for FP8 path");
|
||||
TORCH_CHECK(value_cache.scalar_type() == at::ScalarType::Byte,
|
||||
"value_cache must be uint8 for FP8 path");
|
||||
TORCH_CHECK(k_scale > 0, "k_scale must be positive for FP8 path");
|
||||
TORCH_CHECK(v_scale > 0, "v_scale must be positive for FP8 path");
|
||||
}
|
||||
|
||||
cpu_attention::AttentionInput input;
|
||||
input.metadata = reinterpret_cast<cpu_attention::AttentionMetadata*>(
|
||||
scheduler_metadata.data_ptr());
|
||||
@@ -165,25 +211,32 @@ void cpu_attention_with_kv_cache(
|
||||
input.block_table = block_table.data_ptr<int32_t>();
|
||||
input.alibi_slopes =
|
||||
alibi_slopes.has_value() ? alibi_slopes->data_ptr<float>() : nullptr;
|
||||
// For now sink must be bf16
|
||||
input.s_aux = s_aux.has_value() ? s_aux->data_ptr<c10::BFloat16>() : nullptr;
|
||||
input.scale = scale;
|
||||
input.causal = causal;
|
||||
input.sliding_window_left = sliding_window_left;
|
||||
input.sliding_window_right = sliding_window_right;
|
||||
if (input.causal) {
|
||||
// to make boundary calculation easier
|
||||
input.sliding_window_right = 0;
|
||||
}
|
||||
float softcap_fp32 = softcap;
|
||||
input.softcap = softcap_fp32;
|
||||
input.softcap = static_cast<float>(softcap);
|
||||
|
||||
if (is_fp8) {
|
||||
input.k_scale_fp8 = static_cast<float>(k_scale);
|
||||
input.v_scale_fp8 = static_cast<float>(v_scale);
|
||||
TORCH_CHECK(input.metadata->isa == cpu_attention::ISA::AMX ||
|
||||
input.metadata->isa == cpu_attention::ISA::VEC,
|
||||
"FP8 KV cache is only supported on x86 (AMX/VEC) ISA");
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(), "cpu_attention_with_kv_cache", [&]() {
|
||||
CPU_ATTN_DISPATCH(query.size(2), input.metadata->isa, [&]() {
|
||||
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0);
|
||||
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
|
||||
mainloop(&input);
|
||||
});
|
||||
CPU_ATTN_DISPATCH(
|
||||
query.size(2), input.metadata->isa, kv_cache_idx, [&]() {
|
||||
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment,
|
||||
0);
|
||||
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
|
||||
mainloop(&input);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
+171
-46
@@ -1,6 +1,7 @@
|
||||
#ifndef CPU_ATTN_AMX_HPP
|
||||
#define CPU_ATTN_AMX_HPP
|
||||
|
||||
#include "cpu_attn_fp8.hpp"
|
||||
#include "cpu_attn_impl.hpp"
|
||||
|
||||
namespace cpu_attention {
|
||||
@@ -21,9 +22,10 @@ typedef struct __tile_config {
|
||||
// 2-2-4 pattern, for 16 < m <= 32
|
||||
// TILE 0, 1: load A matrix, row num should be 16, m - 16
|
||||
// TILE 2, 3: load B matrix, row num should be 16
|
||||
// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m
|
||||
// - 16
|
||||
template <typename kv_cache_t>
|
||||
// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16,
|
||||
// m - 16, m - 16
|
||||
// q_buffer_t: A (Q/P) tile type; kv_cache_t: B (K/V cache) tile type.
|
||||
template <typename q_buffer_t, typename kv_cache_t>
|
||||
class TileGemm224 {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
@@ -42,13 +44,56 @@ class TileGemm224 {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class TileGemm224<c10::BFloat16> {
|
||||
// Dequantize one FP8 tile (AMX_TILE_ROW_NUM rows x 32 cols) to BF16.
|
||||
template <typename kv_cache_t>
|
||||
FORCE_INLINE void deq_tile_amx(const uint8_t* src, c10::BFloat16* dst) {
|
||||
for (int r = 0; r < AMX_TILE_ROW_NUM; ++r) {
|
||||
if constexpr (std::is_same_v<kv_cache_t, c10::Float8_e4m3fn>) {
|
||||
vec_op::BF16Vec32(src + r * 32, vec_op::fp8_bf16_e4m3_tag{})
|
||||
.save(dst + r * 32);
|
||||
} else {
|
||||
vec_op::BF16Vec32(src + r * 32, vec_op::fp8_bf16_e5m2_tag{})
|
||||
.save(dst + r * 32);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For FP8: dequant src into scratch and return scratch.
|
||||
// For BF16: return src directly (scratch is unused; the compiler elides it).
|
||||
template <typename kv_cache_t>
|
||||
FORCE_INLINE const c10::BFloat16* prepare_b_tile(const kv_cache_t* src,
|
||||
c10::BFloat16* scratch) {
|
||||
if constexpr (std::is_same_v<kv_cache_t, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<kv_cache_t, c10::Float8_e5m2>) {
|
||||
deq_tile_amx<kv_cache_t>(reinterpret_cast<const uint8_t*>(src), scratch);
|
||||
return scratch;
|
||||
} else {
|
||||
return reinterpret_cast<const c10::BFloat16*>(src);
|
||||
}
|
||||
}
|
||||
|
||||
// Handles both BF16 and FP8 KV cache (2-2-4 pattern).
|
||||
template <typename kv_cache_t>
|
||||
class TileGemm224<c10::BFloat16, kv_cache_t> {
|
||||
static_assert(std::is_same_v<kv_cache_t, c10::BFloat16> ||
|
||||
std::is_same_v<kv_cache_t, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<kv_cache_t, c10::Float8_e5m2>,
|
||||
"kv_cache_t must be BFloat16, Float8_e4m3fn, or Float8_e5m2");
|
||||
|
||||
static constexpr bool fp8_kv =
|
||||
std::is_same_v<kv_cache_t, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<kv_cache_t, c10::Float8_e5m2>;
|
||||
|
||||
static constexpr int64_t tile_elems = AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
// BF16 path: scratch_elems=1 so the scratch array is eliminated by the
|
||||
// compiler.
|
||||
static constexpr int64_t scratch_elems = fp8_kv ? tile_elems : 1;
|
||||
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size,
|
||||
c10::BFloat16* __restrict__ a_tile,
|
||||
c10::BFloat16* __restrict__ b_tile,
|
||||
kv_cache_t* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
@@ -56,6 +101,7 @@ class TileGemm224<c10::BFloat16> {
|
||||
const bool accum_c) {
|
||||
const int32_t k_times =
|
||||
dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
|
||||
|
||||
c10::BFloat16* __restrict__ a_tile_0 = a_tile;
|
||||
c10::BFloat16* __restrict__ a_tile_1 = a_tile + lda * AMX_TILE_ROW_NUM;
|
||||
const int64_t a_tile_stride = [&]() {
|
||||
@@ -70,8 +116,8 @@ class TileGemm224<c10::BFloat16> {
|
||||
}
|
||||
}();
|
||||
|
||||
c10::BFloat16* __restrict__ b_tile_2 = b_tile;
|
||||
c10::BFloat16* __restrict__ b_tile_3 = [&]() {
|
||||
kv_cache_t* __restrict__ b_tile_2 = b_tile;
|
||||
kv_cache_t* __restrict__ b_tile_3 = [&]() {
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// k_cache is prepacked
|
||||
return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4);
|
||||
@@ -106,11 +152,16 @@ class TileGemm224<c10::BFloat16> {
|
||||
_tile_zero(7);
|
||||
}
|
||||
|
||||
alignas(64) c10::BFloat16 scratch_2[scratch_elems];
|
||||
alignas(64) c10::BFloat16 scratch_3[scratch_elems];
|
||||
for (int32_t k = 0; k < k_times; ++k) {
|
||||
const c10::BFloat16* load_2 = prepare_b_tile(b_tile_2, scratch_2);
|
||||
const c10::BFloat16* load_3 = prepare_b_tile(b_tile_3, scratch_3);
|
||||
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_tile_stride);
|
||||
_tile_stream_loadd(2, const_cast<c10::BFloat16*>(load_2), b_tile_stride);
|
||||
_tile_dpbf16ps(4, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_tile_stride);
|
||||
_tile_stream_loadd(3, const_cast<c10::BFloat16*>(load_3), b_tile_stride);
|
||||
_tile_dpbf16ps(5, 0, 3);
|
||||
_tile_loadd(1, a_tile_1, a_tile_stride);
|
||||
_tile_dpbf16ps(6, 1, 2);
|
||||
@@ -154,13 +205,13 @@ class TileGemm224<c10::BFloat16> {
|
||||
};
|
||||
|
||||
// 1-2-2 pattern, for 0 < m <= 16
|
||||
// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be
|
||||
// m, m
|
||||
// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row
|
||||
// num should be 16
|
||||
// TILE 6, 7, (6, 7): store results C matrix, row num should be
|
||||
// m
|
||||
template <typename kv_cache_t>
|
||||
// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should
|
||||
// be m, m
|
||||
// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row num
|
||||
// should be 16
|
||||
// TILE 6, 7: store results C matrix, row num should be m
|
||||
// q_buffer_t: A (Q/P) tile type; kv_cache_t: B (K/V cache) tile type.
|
||||
template <typename q_buffer_t, typename kv_cache_t>
|
||||
class TileGemm122 {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
@@ -179,13 +230,26 @@ class TileGemm122 {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class TileGemm122<c10::BFloat16> {
|
||||
// Handles both BF16 and FP8 KV cache (1-2-2 pattern).
|
||||
template <typename kv_cache_t>
|
||||
class TileGemm122<c10::BFloat16, kv_cache_t> {
|
||||
static_assert(std::is_same_v<kv_cache_t, c10::BFloat16> ||
|
||||
std::is_same_v<kv_cache_t, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<kv_cache_t, c10::Float8_e5m2>,
|
||||
"kv_cache_t must be BFloat16, Float8_e4m3fn, or Float8_e5m2");
|
||||
|
||||
static constexpr bool fp8_kv =
|
||||
std::is_same_v<kv_cache_t, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<kv_cache_t, c10::Float8_e5m2>;
|
||||
|
||||
static constexpr int64_t tile_elems = AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
static constexpr int64_t scratch_elems = fp8_kv ? tile_elems : 1;
|
||||
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size,
|
||||
c10::BFloat16* __restrict__ a_tile,
|
||||
c10::BFloat16* __restrict__ b_tile,
|
||||
kv_cache_t* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
@@ -215,21 +279,19 @@ class TileGemm122<c10::BFloat16> {
|
||||
}
|
||||
}();
|
||||
|
||||
c10::BFloat16* __restrict__ b_tile_2 = b_tile;
|
||||
c10::BFloat16* __restrict__ b_tile_3 = [&]() {
|
||||
kv_cache_t* __restrict__ b_tile_2 = b_tile;
|
||||
kv_cache_t* __restrict__ b_tile_3 = [&]() {
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// k_cache is prepacked
|
||||
return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4);
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// v_cache is prepacked
|
||||
return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unreachable");
|
||||
}
|
||||
}();
|
||||
c10::BFloat16* __restrict__ b_tile_4 =
|
||||
kv_cache_t* __restrict__ b_tile_4 =
|
||||
b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
c10::BFloat16* __restrict__ b_tile_5 =
|
||||
kv_cache_t* __restrict__ b_tile_5 =
|
||||
b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
int64_t b_stride = AMX_TILE_ROW_BYTES;
|
||||
|
||||
@@ -250,16 +312,25 @@ class TileGemm122<c10::BFloat16> {
|
||||
_tile_zero(7);
|
||||
}
|
||||
|
||||
alignas(64) c10::BFloat16 scratch_2[scratch_elems];
|
||||
alignas(64) c10::BFloat16 scratch_3[scratch_elems];
|
||||
alignas(64) c10::BFloat16 scratch_4[scratch_elems];
|
||||
alignas(64) c10::BFloat16 scratch_5[scratch_elems];
|
||||
for (int32_t k = 0; k < k_group_times; ++k) {
|
||||
const c10::BFloat16* load_2 = prepare_b_tile(b_tile_2, scratch_2);
|
||||
const c10::BFloat16* load_3 = prepare_b_tile(b_tile_3, scratch_3);
|
||||
const c10::BFloat16* load_4 = prepare_b_tile(b_tile_4, scratch_4);
|
||||
const c10::BFloat16* load_5 = prepare_b_tile(b_tile_5, scratch_5);
|
||||
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_stride);
|
||||
_tile_stream_loadd(2, const_cast<c10::BFloat16*>(load_2), b_stride);
|
||||
_tile_dpbf16ps(6, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_stride);
|
||||
_tile_stream_loadd(3, const_cast<c10::BFloat16*>(load_3), b_stride);
|
||||
_tile_dpbf16ps(7, 0, 3);
|
||||
_tile_loadd(1, a_tile_1, a_tile_stride);
|
||||
_tile_stream_loadd(4, b_tile_4, b_stride);
|
||||
_tile_stream_loadd(4, const_cast<c10::BFloat16*>(load_4), b_stride);
|
||||
_tile_dpbf16ps(6, 1, 4);
|
||||
_tile_stream_loadd(5, b_tile_5, b_stride);
|
||||
_tile_stream_loadd(5, const_cast<c10::BFloat16*>(load_5), b_stride);
|
||||
_tile_dpbf16ps(7, 1, 5);
|
||||
|
||||
// update ptrs
|
||||
@@ -279,10 +350,13 @@ class TileGemm122<c10::BFloat16> {
|
||||
}
|
||||
|
||||
if (has_tail) {
|
||||
const c10::BFloat16* load_2 = prepare_b_tile(b_tile_2, scratch_2);
|
||||
const c10::BFloat16* load_3 = prepare_b_tile(b_tile_3, scratch_3);
|
||||
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_stride);
|
||||
_tile_stream_loadd(2, const_cast<c10::BFloat16*>(load_2), b_stride);
|
||||
_tile_dpbf16ps(6, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_stride);
|
||||
_tile_stream_loadd(3, const_cast<c10::BFloat16*>(load_3), b_stride);
|
||||
_tile_dpbf16ps(7, 0, 3);
|
||||
}
|
||||
|
||||
@@ -302,21 +376,25 @@ class TileGemm122<c10::BFloat16> {
|
||||
_tile_loadconfig(&config);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename scalar_t, int64_t head_dim>
|
||||
class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
|
||||
template <typename scalar_t, int64_t head_dim, typename kv_cache_scalar_t>
|
||||
class AttentionImpl<ISA::AMX, scalar_t, head_dim, kv_cache_scalar_t> {
|
||||
static constexpr bool fp8_kv =
|
||||
std::is_same_v<kv_cache_scalar_t, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<kv_cache_scalar_t, c10::Float8_e5m2>;
|
||||
|
||||
public:
|
||||
using query_t = scalar_t;
|
||||
using q_buffer_t = scalar_t;
|
||||
using kv_cache_t = scalar_t;
|
||||
using kv_cache_t = kv_cache_scalar_t;
|
||||
using logits_buffer_t = float;
|
||||
using partial_output_buffer_t = float;
|
||||
using prob_buffer_t = scalar_t;
|
||||
|
||||
constexpr static int64_t BlockSizeAlignment =
|
||||
AMX_TILE_ROW_BYTES /
|
||||
sizeof(kv_cache_t); // KV token num unit of QK and PV phases
|
||||
32; // AMX_TILE_ROW_NUM = 16 tokens/tile; 32 = 2 tiles
|
||||
constexpr static int64_t HeadDimAlignment =
|
||||
2 * (AMX_TILE_ROW_BYTES / 4); // headdim num unit of PV phase
|
||||
constexpr static int64_t MaxQHeadNumPerIteration = 32;
|
||||
@@ -324,6 +402,9 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
|
||||
constexpr static ISA ISAType = ISA::AMX;
|
||||
constexpr static bool scale_on_logits = true;
|
||||
|
||||
float k_scale = 1.0f;
|
||||
float v_scale = 1.0f;
|
||||
|
||||
public:
|
||||
AttentionImpl() : current_q_head_num_(0) {
|
||||
// Use all columns in AMX tiles
|
||||
@@ -332,21 +413,50 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
|
||||
|
||||
~AttentionImpl() { _tile_release(); }
|
||||
|
||||
void init_from_input(const AttentionInput* input) {
|
||||
if constexpr (fp8_kv) {
|
||||
k_scale = input->k_scale_fp8;
|
||||
v_scale = input->v_scale_fp8;
|
||||
}
|
||||
}
|
||||
|
||||
float get_output_v_scale() const noexcept {
|
||||
if constexpr (fp8_kv) {
|
||||
// AMX dequant places FP8 payload into a BF16 field (exponent bias 127).
|
||||
// Correction = 2^(127 - FP8_bias): E4M3 bias=7 → 2^120, E5M2 bias=15 →
|
||||
// 2^112.
|
||||
constexpr float bias =
|
||||
std::is_same_v<kv_cache_t, c10::Float8_e5m2> ? 0x1p112f : 0x1p120f;
|
||||
return v_scale * bias;
|
||||
}
|
||||
return 1.0f;
|
||||
}
|
||||
|
||||
template <template <typename tile_gemm_t> typename attention>
|
||||
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
|
||||
if constexpr (fp8_kv) {
|
||||
// Same bias correction as get_output_v_scale: AMX FP8→BF16 dequant
|
||||
// shifts the exponent bias from FP8 to BF16 (127), so we multiply by
|
||||
// 2^(127-FP8_bias) to recover the true value. E4M3: 2^120, E5M2: 2^112.
|
||||
const float bias =
|
||||
std::is_same_v<kv_cache_t, c10::Float8_e5m2> ? 0x1p112f : 0x1p120f;
|
||||
scale *= k_scale * bias;
|
||||
}
|
||||
if (q_head_num > AMX_TILE_ROW_NUM) {
|
||||
if (q_head_num != current_q_head_num_) {
|
||||
current_q_head_num_ = q_head_num;
|
||||
TileGemm224<kv_cache_t>::init_tile_config(q_head_num, amx_tile_config_);
|
||||
TileGemm224<q_buffer_t, kv_cache_t>::init_tile_config(q_head_num,
|
||||
amx_tile_config_);
|
||||
}
|
||||
attention<TileGemm224<kv_cache_t>> attention_iteration;
|
||||
attention<TileGemm224<q_buffer_t, kv_cache_t>> attention_iteration;
|
||||
attention_iteration(CPU_ATTENTION_PARAMS);
|
||||
} else {
|
||||
if (q_head_num != current_q_head_num_) {
|
||||
current_q_head_num_ = q_head_num;
|
||||
TileGemm122<kv_cache_t>::init_tile_config(q_head_num, amx_tile_config_);
|
||||
TileGemm122<q_buffer_t, kv_cache_t>::init_tile_config(q_head_num,
|
||||
amx_tile_config_);
|
||||
}
|
||||
attention<TileGemm122<kv_cache_t>> attention_iteration;
|
||||
attention<TileGemm122<q_buffer_t, kv_cache_t>> attention_iteration;
|
||||
attention_iteration(CPU_ATTENTION_PARAMS);
|
||||
}
|
||||
}
|
||||
@@ -411,13 +521,26 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
|
||||
// reshape KV to AMX friendly layout
|
||||
static void reshape_and_cache(
|
||||
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||
kv_cache_t* __restrict__ key_cache, kv_cache_t* __restrict__ value_cache,
|
||||
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
|
||||
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
|
||||
const int64_t head_num, const int64_t key_head_num_stride,
|
||||
const int64_t value_head_num_stride, const int64_t num_blocks,
|
||||
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
|
||||
const int64_t block_size, const int64_t block_size_stride) {
|
||||
const int64_t block_size, const int64_t block_size_stride,
|
||||
const float k_inv = 0.0f, const float v_inv = 0.0f) {
|
||||
if constexpr (fp8_kv) {
|
||||
constexpr auto qfn = select_fp8_quant_fn<kv_cache_t>();
|
||||
reshape_and_cache_fp8_amx_impl<scalar_t, qfn>(
|
||||
key, value, reinterpret_cast<uint8_t*>(key_cache),
|
||||
reinterpret_cast<uint8_t*>(value_cache), slot_mapping, token_num,
|
||||
head_num, head_dim, block_size, key_token_num_stride,
|
||||
key_head_num_stride, value_token_num_stride, value_head_num_stride,
|
||||
num_blocks_stride, cache_head_num_stride, num_blocks_stride,
|
||||
cache_head_num_stride, k_inv, v_inv);
|
||||
return;
|
||||
}
|
||||
|
||||
// For AMX 2D tiles, size of each line is 64 bytes
|
||||
constexpr int64_t amx_tile_row_size = AMX_TILE_ROW_BYTES;
|
||||
// For AMX B matrix, N always is 16
|
||||
@@ -426,6 +549,9 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
|
||||
// For now suppose block_size is divisible by amx_tile_column_num
|
||||
TORCH_CHECK_EQ(block_size % amx_b_tile_k_size, 0);
|
||||
|
||||
scalar_t* __restrict__ kc = reinterpret_cast<scalar_t*>(key_cache);
|
||||
scalar_t* __restrict__ vc = reinterpret_cast<scalar_t*>(value_cache);
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
|
||||
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
|
||||
@@ -453,8 +579,7 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
|
||||
constexpr int64_t quadword_num_per_group =
|
||||
token_num_per_group * quadword_num;
|
||||
int32_t* key_cache_start_ptr =
|
||||
reinterpret_cast<int32_t*>(key_cache +
|
||||
block_idx * num_blocks_stride +
|
||||
reinterpret_cast<int32_t*>(kc + block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride) +
|
||||
group_idx * quadword_num_per_group + group_offset;
|
||||
|
||||
@@ -483,7 +608,7 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
|
||||
token_idx * value_token_num_stride +
|
||||
head_idx * value_head_num_stride;
|
||||
scalar_t* value_cache_start_ptr =
|
||||
value_cache + block_idx * num_blocks_stride +
|
||||
vc + block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride +
|
||||
sub_group_idx * token_num_per_sub_group * amx_b_tile_n_size +
|
||||
sub_group_offset;
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
#include "cpu/utils.hpp"
|
||||
|
||||
typedef uint32_t __attribute__((__may_alias__)) u32_alias_t;
|
||||
typedef uint16_t __attribute__((__may_alias__)) u16_alias_t;
|
||||
typedef float __attribute__((__may_alias__)) f32_alias_t;
|
||||
|
||||
// Reference scalar dequant — used to verify vectorized AMX dequant.
|
||||
inline float fp8e4m3_to_float_scalar(uint8_t b, float scale) noexcept {
|
||||
// NaN encoding in E4M3
|
||||
if ((b & 0x7F) == 0x7F) return std::numeric_limits<float>::quiet_NaN();
|
||||
uint32_t b_u32 = static_cast<uint32_t>(b);
|
||||
uint32_t sign = (b_u32 & 0x80) << 24;
|
||||
uint32_t payload = (b_u32 & 0x7F) << 20;
|
||||
uint32_t bits = sign | payload;
|
||||
float b_f32_unscaled = *reinterpret_cast<const f32_alias_t*>(&bits);
|
||||
float b_f32_scaled = b_f32_unscaled * scale * 0x1p120f;
|
||||
return b_f32_scaled;
|
||||
}
|
||||
|
||||
inline uint8_t float_to_fp8e4m3_scalar(float v, float inv_scale) noexcept {
|
||||
v *= inv_scale;
|
||||
constexpr float fp8_max = 448.0f;
|
||||
v = std::max(-fp8_max, std::min(fp8_max, v));
|
||||
if (v == 0.0f) return 0;
|
||||
|
||||
// Inverse mapping of fp8e4m3_to_float_scalar: shift the effective exponent
|
||||
// bias from fp32 (127) back to fp8 e4m3 (7), then pack sign|payload.
|
||||
float v_f32_unscaled = v * 0x1p-120f;
|
||||
uint32_t bits = *reinterpret_cast<const u32_alias_t*>(&v_f32_unscaled);
|
||||
uint8_t sign = static_cast<uint8_t>((bits >> 24) & 0x80);
|
||||
uint8_t payload = static_cast<uint8_t>((bits >> 20) & 0x7F);
|
||||
if (payload == 0) return sign;
|
||||
payload = std::min<uint8_t>(payload, 0x7E); // keep 0x7F as NaN encoding
|
||||
return static_cast<uint8_t>(sign | payload);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// AMX reshape impl — parameterised on the quantisation function.
|
||||
// Writes key/value into uint8 FP8 KV cache using the AMX tile-friendly layout.
|
||||
// K: halfword-packed (2 FP8 per uint16, token_num_per_group=16).
|
||||
// V: sub-group packing (token_num_per_sub_group=2, head_elems_per_group=16).
|
||||
// block_size must be divisible by 32.
|
||||
// ---------------------------------------------------------------------------
|
||||
template <typename scalar_t, uint8_t (*quant_fn)(float, float)>
|
||||
inline void reshape_and_cache_fp8_amx_impl(
|
||||
const scalar_t* key_ptr, const scalar_t* value_ptr, uint8_t* key_cache_ptr,
|
||||
uint8_t* value_cache_ptr, const int64_t* slot_ptr, int64_t token_num,
|
||||
int64_t head_num, int64_t head_dim, int64_t block_size, int64_t k_stride0,
|
||||
int64_t k_stride1, int64_t v_stride0, int64_t v_stride1, int64_t kc_stride0,
|
||||
int64_t kc_stride1, int64_t vc_stride0, int64_t vc_stride1, float k_inv,
|
||||
float v_inv) {
|
||||
constexpr int64_t token_num_per_group = 16; // AMX_TILE_ROW_NUM
|
||||
const int64_t halfword_num = head_dim / 2; // 2 FP8 per uint16
|
||||
const int64_t halfword_num_per_group = token_num_per_group * halfword_num;
|
||||
constexpr int64_t head_elems_per_group = 16;
|
||||
constexpr int64_t token_num_per_sub_group = 2; // = 4 / sizeof(BF16)
|
||||
const int64_t group_num = head_dim / head_elems_per_group;
|
||||
const int64_t group_size = block_size * head_elems_per_group;
|
||||
|
||||
#pragma omp parallel for collapse(2) schedule(static)
|
||||
for (int64_t tok = 0; tok < token_num; ++tok) {
|
||||
for (int64_t h = 0; h < head_num; ++h) {
|
||||
const int64_t slot = slot_ptr[tok];
|
||||
if (slot < 0) continue;
|
||||
const int64_t block_idx = slot / block_size;
|
||||
const int64_t block_offset = slot % block_size;
|
||||
|
||||
// Key: halfword-packed, 2 FP8 per uint16
|
||||
{
|
||||
const scalar_t* ksrc = key_ptr + tok * k_stride0 + h * k_stride1;
|
||||
const int64_t group_idx = block_offset / token_num_per_group;
|
||||
const int64_t group_offset = block_offset % token_num_per_group;
|
||||
uint16_t* kdst =
|
||||
reinterpret_cast<uint16_t*>(key_cache_ptr + block_idx * kc_stride0 +
|
||||
h * kc_stride1) +
|
||||
group_idx * halfword_num_per_group + group_offset;
|
||||
for (int64_t j = 0; j < halfword_num; ++j) {
|
||||
uint8_t fp8_0 = quant_fn(static_cast<float>(ksrc[j * 2]), k_inv);
|
||||
uint8_t fp8_1 = quant_fn(static_cast<float>(ksrc[j * 2 + 1]), k_inv);
|
||||
uint8_t bytes[2] = {fp8_0, fp8_1};
|
||||
uint16_t hw = *reinterpret_cast<const u16_alias_t*>(bytes);
|
||||
kdst[j * token_num_per_group] = hw;
|
||||
}
|
||||
}
|
||||
|
||||
// Value: sub-group packing (token_num_per_sub_group = 2)
|
||||
{
|
||||
const scalar_t* vsrc = value_ptr + tok * v_stride0 + h * v_stride1;
|
||||
const int64_t sub_group_idx = block_offset / token_num_per_sub_group;
|
||||
const int64_t sub_group_offset = block_offset % token_num_per_sub_group;
|
||||
uint8_t* vdst =
|
||||
value_cache_ptr + block_idx * vc_stride0 + h * vc_stride1 +
|
||||
sub_group_idx * token_num_per_sub_group * head_elems_per_group +
|
||||
sub_group_offset;
|
||||
for (int64_t i = 0; i < group_num; ++i) {
|
||||
for (int64_t j = 0; j < head_elems_per_group; ++j)
|
||||
vdst[j * token_num_per_sub_group] =
|
||||
quant_fn(static_cast<float>(vsrc[j]), v_inv);
|
||||
vsrc += head_elems_per_group;
|
||||
vdst += group_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// FP8 E5M2 scalar helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Reference scalar dequant — used to verify vectorized AMX dequant.
|
||||
// FP8 E5M2: s[7] e[6:2] m[1:0], exponent bias = 15 (same as FP16).
|
||||
// Byte b → FP16 bits = b << 8 (no bias correction needed).
|
||||
inline float fp8e5m2_to_float_scalar(uint8_t b, float scale) noexcept {
|
||||
const uint8_t exp_bits = (b >> 2) & 0x1F;
|
||||
const uint8_t mant_bits = b & 0x03;
|
||||
// NaN: exp=11111, mant!=00
|
||||
if (exp_bits == 0x1F && mant_bits != 0)
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
const uint32_t sign = static_cast<uint32_t>(b & 0x80) << 24;
|
||||
if (exp_bits == 0x1F)
|
||||
return sign ? -std::numeric_limits<float>::infinity()
|
||||
: std::numeric_limits<float>::infinity();
|
||||
if (exp_bits == 0) { // subnormal: (-1)^s * 2^-14 * mant/4
|
||||
if (mant_bits == 0) return 0.0f;
|
||||
float v = mant_bits * 0x1p-16f;
|
||||
return (sign ? -v : v) * scale;
|
||||
}
|
||||
// Normal: FP32 exp = exp5 - 15 + 127, mantissa top 2 bits
|
||||
uint32_t fp32_bits = sign |
|
||||
((static_cast<uint32_t>(exp_bits) - 15 + 127) << 23) |
|
||||
(static_cast<uint32_t>(mant_bits) << 21);
|
||||
float val = *reinterpret_cast<const f32_alias_t*>(&fp32_bits);
|
||||
return val * scale;
|
||||
}
|
||||
|
||||
inline uint8_t float_to_fp8e5m2_scalar(float v, float inv_scale) noexcept {
|
||||
v *= inv_scale;
|
||||
constexpr float fp8_e5m2_max = 57344.0f;
|
||||
v = std::max(-fp8_e5m2_max, std::min(fp8_e5m2_max, v));
|
||||
if (v == 0.0f) return 0;
|
||||
uint32_t bits = *reinterpret_cast<const u32_alias_t*>(&v);
|
||||
const uint8_t sign = static_cast<uint8_t>((bits >> 24) & 0x80);
|
||||
const int32_t exp_fp32 = static_cast<int32_t>((bits >> 23) & 0xFF) - 127;
|
||||
const uint8_t mant2 = static_cast<uint8_t>((bits >> 21) & 0x03);
|
||||
if (exp_fp32 < -14) { // subnormal in E5M2
|
||||
const int shift = -14 - exp_fp32;
|
||||
if (shift + 21 >= 32)
|
||||
return sign; // underflow: too small for E5M2 subnormal
|
||||
const uint32_t m = (0x800000u | (bits & 0x7FFFFFu)) >> (shift + 21);
|
||||
return sign | static_cast<uint8_t>(std::min<uint32_t>(m, 3u));
|
||||
}
|
||||
const uint8_t exp5 = static_cast<uint8_t>(exp_fp32 + 15);
|
||||
return sign | (exp5 << 2) | mant2;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Select the FP8 quant function at compile time based on kv_cache_t.
|
||||
// ---------------------------------------------------------------------------
|
||||
template <typename kv_cache_t>
|
||||
constexpr auto select_fp8_quant_fn() {
|
||||
if constexpr (std::is_same_v<kv_cache_t, c10::Float8_e5m2>)
|
||||
return float_to_fp8e5m2_scalar;
|
||||
else
|
||||
return float_to_fp8e4m3_scalar;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// VEC reshape impl — parameterised on the quantisation function.
|
||||
// Writes key (column-major) and value (row-major) into uint8 FP8 KV cache.
|
||||
// The pragma omp must live outside VLLM_DISPATCH_FLOATING_TYPES because
|
||||
// #pragma cannot appear inside variadic macro arguments.
|
||||
// ---------------------------------------------------------------------------
|
||||
template <typename scalar_t, uint8_t (*quant_fn)(float, float)>
|
||||
inline void reshape_and_cache_fp8_vec_impl(
|
||||
const scalar_t* key_ptr, const scalar_t* value_ptr, uint8_t* key_cache_ptr,
|
||||
uint8_t* value_cache_ptr, const int64_t* slot_ptr, int64_t token_num,
|
||||
int64_t head_num, int64_t head_dim, int64_t block_size, int64_t k_stride0,
|
||||
int64_t k_stride1, int64_t v_stride0, int64_t v_stride1, int64_t kc_stride0,
|
||||
int64_t kc_stride1, int64_t vc_stride0, int64_t vc_stride1, float k_inv,
|
||||
float v_inv) {
|
||||
#pragma omp parallel for collapse(2) schedule(static)
|
||||
for (int64_t tok = 0; tok < token_num; ++tok) {
|
||||
for (int64_t h = 0; h < head_num; ++h) {
|
||||
const int64_t slot = slot_ptr[tok];
|
||||
if (slot < 0) continue;
|
||||
const int64_t block_idx = slot / block_size;
|
||||
const int64_t block_offset = slot % block_size;
|
||||
|
||||
// Key layout: column-major within block
|
||||
const scalar_t* ksrc = key_ptr + tok * k_stride0 + h * k_stride1;
|
||||
uint8_t* kdst = key_cache_ptr + block_idx * kc_stride0 + h * kc_stride1 +
|
||||
block_offset;
|
||||
for (int64_t i = 0; i < head_dim; ++i)
|
||||
kdst[i * block_size] = quant_fn(static_cast<float>(ksrc[i]), k_inv);
|
||||
|
||||
// Value layout: row-major within block (contiguous head_dim bytes)
|
||||
const scalar_t* vsrc = value_ptr + tok * v_stride0 + h * v_stride1;
|
||||
uint8_t* vdst = value_cache_ptr + block_idx * vc_stride0 +
|
||||
h * vc_stride1 + block_offset * head_dim;
|
||||
for (int64_t i = 0; i < head_dim; ++i)
|
||||
vdst[i] = quant_fn(static_cast<float>(vsrc[i]), v_inv);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,8 +14,22 @@
|
||||
namespace cpu_attention {
|
||||
enum class ISA { AMX, VEC, VEC16, NEON, VXE };
|
||||
|
||||
template <ISA isa, typename scalar_t, int64_t head_dim>
|
||||
class AttentionImpl {};
|
||||
// Mirrors csrc/attention/dtype_fp8.cuh Fp8KVCacheDataType exactly.
|
||||
enum class Fp8KVCacheDataType {
|
||||
kAuto = 0,
|
||||
kFp8E4M3 = 1,
|
||||
kFp8E5M2 = 2,
|
||||
};
|
||||
|
||||
struct AttentionInput;
|
||||
|
||||
template <ISA isa, typename scalar_t, int64_t head_dim,
|
||||
typename kv_cache_scalar_t = scalar_t>
|
||||
class AttentionImpl {
|
||||
public:
|
||||
void init_from_input(const AttentionInput*) {}
|
||||
float get_output_v_scale() const noexcept { return 1.0f; }
|
||||
};
|
||||
|
||||
struct AttentionWorkItemGroup {
|
||||
int32_t req_id;
|
||||
@@ -780,6 +794,9 @@ struct AttentionInput {
|
||||
int32_t sliding_window_left;
|
||||
int32_t sliding_window_right;
|
||||
float softcap;
|
||||
// FP8 KV cache scales (used by FP8 attention implementations)
|
||||
float k_scale_fp8 = 1.0f;
|
||||
float v_scale_fp8 = 1.0f;
|
||||
};
|
||||
|
||||
#define DEFINE_CPU_ATTENTION_PARAMS \
|
||||
@@ -1374,6 +1391,13 @@ class AttentionMainLoop {
|
||||
}
|
||||
|
||||
attention_impl_t attn_impl;
|
||||
constexpr bool fp8_kv = std::is_same_v<kv_cache_t, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<kv_cache_t, c10::Float8_e5m2>;
|
||||
float output_v_scale = 1.0f;
|
||||
if constexpr (fp8_kv) {
|
||||
attn_impl.init_from_input(input);
|
||||
output_v_scale = attn_impl.get_output_v_scale();
|
||||
}
|
||||
|
||||
// general information
|
||||
const int32_t q_head_num = input->num_heads;
|
||||
@@ -1753,7 +1777,7 @@ class AttentionMainLoop {
|
||||
reinterpret_cast<query_t*>(input->output) +
|
||||
output_buffer_offset,
|
||||
sum_buffer, actual_q_heads_per_kv,
|
||||
actual_q_token_num, q_head_num);
|
||||
actual_q_token_num, q_head_num, output_v_scale);
|
||||
} else {
|
||||
const int32_t stride =
|
||||
actual_q_heads_per_kv * split_kv_q_token_num_threshold;
|
||||
@@ -1823,7 +1847,7 @@ class AttentionMainLoop {
|
||||
split_output_buffer,
|
||||
reinterpret_cast<query_t*>(input->output) + output_buffer_offset,
|
||||
split_sum_buffer, actual_q_heads_per_kv, curr_output_token_num,
|
||||
q_head_num);
|
||||
q_head_num, output_v_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1947,8 +1971,8 @@ class AttentionMainLoop {
|
||||
query_t* __restrict__ curr_output_buffer,
|
||||
float* __restrict__ sum_buffer,
|
||||
const int32_t q_heads_per_kv,
|
||||
const int32_t actual_q_token_num,
|
||||
const int32_t q_head_num) {
|
||||
const int32_t actual_q_token_num, const int32_t q_head_num,
|
||||
const float v_scale = 1.0f) {
|
||||
// final output
|
||||
using output_vec_t = typename VecTypeTrait<query_t>::vec_t;
|
||||
|
||||
@@ -1962,7 +1986,7 @@ class AttentionMainLoop {
|
||||
curr_partial_output_buffer;
|
||||
query_t* __restrict__ curr_output_buffer_iter = curr_output_buffer;
|
||||
for (int32_t head_idx = 0; head_idx < q_heads_per_kv; ++head_idx) {
|
||||
vec_op::FP32Vec16 inv_sum_scale_vec(1.0 / *curr_sum_buffer);
|
||||
vec_op::FP32Vec16 inv_sum_scale_vec(v_scale / *curr_sum_buffer);
|
||||
|
||||
for (int32_t i = 0; i < group_num_per_head; ++i) {
|
||||
vec_op::FP32Vec16 vec(curr_partial_output_buffer_iter);
|
||||
|
||||
@@ -248,8 +248,8 @@ class TileGemmNeonFMLA {
|
||||
} // namespace
|
||||
|
||||
// this is similar to "ISA::VEC" at the moment
|
||||
template <typename scalar_t, int64_t head_dim>
|
||||
class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
|
||||
template <typename scalar_t, int64_t head_dim, typename kv_cache_scalar_t>
|
||||
class AttentionImpl<ISA::NEON, scalar_t, head_dim, kv_cache_scalar_t> {
|
||||
public:
|
||||
using query_t = scalar_t;
|
||||
using q_buffer_t = float;
|
||||
@@ -343,7 +343,8 @@ class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
|
||||
const int64_t head_num, const int64_t key_head_num_stride,
|
||||
const int64_t value_head_num_stride, const int64_t num_blocks,
|
||||
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
|
||||
const int64_t block_size, const int64_t block_size_stride) {
|
||||
const int64_t block_size, const int64_t block_size_stride,
|
||||
const float /*k_inv*/ = 0.0f, const float /*v_inv*/ = 0.0f) {
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
|
||||
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
|
||||
@@ -388,7 +389,7 @@ class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
// For BF16 on Arm, reuse the BFMMLA kernels with 32-token alignment.
|
||||
template <int64_t head_dim>
|
||||
class AttentionImpl<ISA::NEON, c10::BFloat16, head_dim>
|
||||
class AttentionImpl<ISA::NEON, c10::BFloat16, head_dim, c10::BFloat16>
|
||||
: public AttentionImplNEONBFMMLA<BLOCK_SIZE_ALIGNMENT, ISA::NEON,
|
||||
head_dim> {};
|
||||
#endif
|
||||
|
||||
@@ -602,7 +602,8 @@ class AttentionImplNEONBFMMLA {
|
||||
[[maybe_unused]] const int64_t num_blocks,
|
||||
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
|
||||
const int64_t block_size,
|
||||
[[maybe_unused]] const int64_t block_size_stride) {
|
||||
[[maybe_unused]] const int64_t block_size_stride,
|
||||
const float /*k_inv*/ = 0.0f, const float /*v_inv*/ = 0.0f) {
|
||||
const int64_t k_block_stride = (head_dim / TILE_K) * K_INNER_STRIDE;
|
||||
const int64_t v_pair_stride =
|
||||
(block_size / V_TOKENS_PER_ROW_BLOCK) * V_INNER_STRIDE;
|
||||
|
||||
+105
-28
@@ -1,11 +1,37 @@
|
||||
#ifndef CPU_ATTN_VEC_HPP
|
||||
#define CPU_ATTN_VEC_HPP
|
||||
|
||||
#include "cpu_attn_fp8.hpp"
|
||||
#include "cpu_attn_impl.hpp"
|
||||
|
||||
namespace cpu_attention {
|
||||
|
||||
namespace {
|
||||
|
||||
// Load 32 kv_cache_t elements starting at ptr and return them as two FP32Vec16s
|
||||
// covering the lower 16 and upper 16 positions.
|
||||
// For FP8: both halves come from a single BF16Vec32 dequant of 32 bytes.
|
||||
// For BF16/FP16/FP32: two separate vector loads at ptr and ptr+16.
|
||||
template <typename kv_cache_t>
|
||||
FORCE_INLINE std::pair<vec_op::FP32Vec16, vec_op::FP32Vec16> load_b_pair_vec(
|
||||
const kv_cache_t* ptr) {
|
||||
if constexpr (std::is_same_v<kv_cache_t, c10::Float8_e4m3fn>) {
|
||||
// BF16 container, but values are in the FP16 exponent range (bias 15 not
|
||||
// 127).
|
||||
vec_op::BF16Vec32 bf16_b_reg(reinterpret_cast<const uint8_t*>(ptr),
|
||||
vec_op::fp8_e4m3_tag{});
|
||||
return {vec_op::FP32Vec16(bf16_b_reg, 0), vec_op::FP32Vec16(bf16_b_reg, 1)};
|
||||
} else if constexpr (std::is_same_v<kv_cache_t, c10::Float8_e5m2>) {
|
||||
vec_op::BF16Vec32 bf16_b_reg(reinterpret_cast<const uint8_t*>(ptr),
|
||||
vec_op::fp8_e5m2_tag{});
|
||||
return {vec_op::FP32Vec16(bf16_b_reg, 0), vec_op::FP32Vec16(bf16_b_reg, 1)};
|
||||
} else {
|
||||
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
|
||||
return {vec_op::FP32Vec16(load_vec_t(ptr)),
|
||||
vec_op::FP32Vec16(load_vec_t(ptr + 16))};
|
||||
}
|
||||
}
|
||||
|
||||
// 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32]
|
||||
template <typename kv_cache_t>
|
||||
class TileGemm82 {
|
||||
@@ -54,10 +80,7 @@ class TileGemm82 {
|
||||
const int32_t block_size, const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
static_assert(0 < M && M <= 8);
|
||||
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
|
||||
|
||||
kv_cache_t* __restrict__ curr_b_0 = b_tile;
|
||||
kv_cache_t* __restrict__ curr_b_1 = b_tile + 16;
|
||||
float* __restrict__ curr_c_0 = c_tile;
|
||||
float* __restrict__ curr_c_1 = c_tile + 16;
|
||||
|
||||
@@ -76,16 +99,14 @@ class TileGemm82 {
|
||||
}
|
||||
|
||||
float* __restrict__ curr_a = a_tile;
|
||||
kv_cache_t* __restrict__ curr_b = b_tile;
|
||||
|
||||
for (int32_t k = 0; k < dynamic_k_size; ++k) {
|
||||
load_vec_t b_0_reg(curr_b_0);
|
||||
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
|
||||
load_vec_t b_1_reg(curr_b_1);
|
||||
vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg);
|
||||
auto [fp32_b_0_reg, fp32_b_1_reg] = load_b_pair_vec(curr_b);
|
||||
|
||||
float* __restrict__ curr_m_a = curr_a;
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
float v = *curr_m_a;
|
||||
vec_op::FP32Vec16 a_reg(v);
|
||||
vec_op::FP32Vec16 a_reg(*curr_m_a);
|
||||
c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg;
|
||||
c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg;
|
||||
|
||||
@@ -95,8 +116,7 @@ class TileGemm82 {
|
||||
|
||||
// update
|
||||
curr_a += 1;
|
||||
curr_b_0 += ldb;
|
||||
curr_b_1 += ldb;
|
||||
curr_b += ldb;
|
||||
}
|
||||
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
@@ -109,15 +129,20 @@ class TileGemm82 {
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// This is a general but naive implementation based on vector instructions
|
||||
template <typename scalar_t, int64_t head_dim>
|
||||
class AttentionImpl<ISA::VEC, scalar_t, head_dim> {
|
||||
template <typename scalar_t, int64_t head_dim, typename kv_cache_scalar_t>
|
||||
class AttentionImpl<ISA::VEC, scalar_t, head_dim, kv_cache_scalar_t> {
|
||||
static constexpr bool fp8_kv =
|
||||
std::is_same_v<kv_cache_scalar_t, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<kv_cache_scalar_t, c10::Float8_e5m2>;
|
||||
|
||||
public:
|
||||
using query_t = scalar_t;
|
||||
using q_buffer_t = float;
|
||||
using kv_cache_t = scalar_t;
|
||||
using kv_cache_t = kv_cache_scalar_t;
|
||||
using logits_buffer_t = float;
|
||||
using partial_output_buffer_t = float;
|
||||
using prob_buffer_t = float;
|
||||
@@ -129,11 +154,45 @@ class AttentionImpl<ISA::VEC, scalar_t, head_dim> {
|
||||
constexpr static int64_t MaxQHeadNumPerIteration = 8;
|
||||
constexpr static int64_t HeadDim = head_dim;
|
||||
constexpr static ISA ISAType = ISA::VEC;
|
||||
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
|
||||
constexpr static bool scale_on_logits = fp8_kv;
|
||||
|
||||
float k_scale = 1.0f;
|
||||
float v_scale = 1.0f;
|
||||
|
||||
public:
|
||||
void init_from_input(const AttentionInput* input) {
|
||||
if constexpr (fp8_kv) {
|
||||
k_scale = input->k_scale_fp8;
|
||||
v_scale = input->v_scale_fp8;
|
||||
}
|
||||
}
|
||||
|
||||
float get_output_v_scale() const noexcept {
|
||||
if constexpr (fp8_kv) {
|
||||
// VEC dequant unpacks FP8 into a pseudo-FP16 layout (exponent bias 15).
|
||||
// E4M3 (bias=7) needs correction 2^(15-7) = 2^8; E5M2 bias matches FP16
|
||||
// so no correction.
|
||||
if constexpr (std::is_same_v<kv_cache_t, c10::Float8_e5m2>) {
|
||||
return v_scale;
|
||||
} else {
|
||||
return v_scale * 0x1p8f;
|
||||
}
|
||||
}
|
||||
return 1.0f;
|
||||
}
|
||||
|
||||
template <template <typename tile_gemm_t> typename attention>
|
||||
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
|
||||
if constexpr (fp8_kv) {
|
||||
// Same bias correction as get_output_v_scale: VEC FP8→pseudo-FP16 dequant
|
||||
// uses bias 15; E4M3 (bias=7) needs ×2^8, E5M2 (bias=15) needs no
|
||||
// correction.
|
||||
if constexpr (std::is_same_v<kv_cache_t, c10::Float8_e5m2>) {
|
||||
scale *= k_scale;
|
||||
} else {
|
||||
scale *= k_scale * 0x1p8f;
|
||||
}
|
||||
}
|
||||
attention<TileGemm82<kv_cache_t>> attention_iteration;
|
||||
attention_iteration(CPU_ATTENTION_PARAMS);
|
||||
}
|
||||
@@ -161,17 +220,19 @@ class AttentionImpl<ISA::VEC, scalar_t, head_dim> {
|
||||
// row-major
|
||||
}
|
||||
|
||||
// Copy q to q_buffer and cast it to fp32
|
||||
static void copy_q_heads_tile(
|
||||
scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
|
||||
float* __restrict__ q_buffer, const int32_t q_num,
|
||||
const int32_t q_heads_per_kv, const int64_t q_num_stride,
|
||||
const int64_t q_head_stride, float scale) {
|
||||
// Copy q to q_buffer and cast it to fp32.
|
||||
// FP8: QK scale is folded into execute_attention; copy Q unscaled here.
|
||||
void copy_q_heads_tile(scalar_t* __restrict__ src,
|
||||
float* __restrict__ q_buffer, const int32_t q_num,
|
||||
const int32_t q_heads_per_kv,
|
||||
const int64_t q_num_stride,
|
||||
const int64_t q_head_stride, float scale) {
|
||||
static_assert(head_dim % 16 == 0);
|
||||
constexpr int32_t unroll_size = head_dim / 16;
|
||||
using load_vec_t = typename VecTypeTrait<scalar_t>::vec_t;
|
||||
|
||||
vec_op::FP32Vec16 scale_vec(scale);
|
||||
const float effective_scale = fp8_kv ? 1.0f : scale;
|
||||
vec_op::FP32Vec16 scale_vec(effective_scale);
|
||||
for (int32_t q_num_idx = 0; q_num_idx < q_num; ++q_num_idx) {
|
||||
for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv; ++q_head_idx) {
|
||||
scalar_t* __restrict__ curr_q =
|
||||
@@ -196,13 +257,26 @@ class AttentionImpl<ISA::VEC, scalar_t, head_dim> {
|
||||
// reshape K as column-major and V as row-major
|
||||
static void reshape_and_cache(
|
||||
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||
kv_cache_t* __restrict__ key_cache, kv_cache_t* __restrict__ value_cache,
|
||||
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
|
||||
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
|
||||
const int64_t head_num, const int64_t key_head_num_stride,
|
||||
const int64_t value_head_num_stride, const int64_t num_blocks,
|
||||
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
|
||||
const int64_t block_size, const int64_t block_size_stride) {
|
||||
const int64_t block_size, const int64_t block_size_stride,
|
||||
const float k_inv = 0.0f, const float v_inv = 0.0f) {
|
||||
if constexpr (fp8_kv) {
|
||||
constexpr auto qfn = select_fp8_quant_fn<kv_cache_t>();
|
||||
reshape_and_cache_fp8_vec_impl<scalar_t, qfn>(
|
||||
key, value, reinterpret_cast<uint8_t*>(key_cache),
|
||||
reinterpret_cast<uint8_t*>(value_cache), slot_mapping, token_num,
|
||||
head_num, head_dim, block_size, key_token_num_stride,
|
||||
key_head_num_stride, value_token_num_stride, value_head_num_stride,
|
||||
num_blocks_stride, cache_head_num_stride, num_blocks_stride,
|
||||
cache_head_num_stride, k_inv, v_inv);
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
|
||||
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
|
||||
@@ -220,8 +294,9 @@ class AttentionImpl<ISA::VEC, scalar_t, head_dim> {
|
||||
token_idx * key_token_num_stride +
|
||||
head_idx * key_head_num_stride;
|
||||
scalar_t* key_cache_start_ptr =
|
||||
key_cache + block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride + block_offset;
|
||||
reinterpret_cast<scalar_t*>(key_cache) +
|
||||
block_idx * num_blocks_stride + head_idx * cache_head_num_stride +
|
||||
block_offset;
|
||||
|
||||
#pragma GCC unroll 8
|
||||
for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
|
||||
@@ -234,8 +309,9 @@ class AttentionImpl<ISA::VEC, scalar_t, head_dim> {
|
||||
token_idx * value_token_num_stride +
|
||||
head_idx * value_head_num_stride;
|
||||
scalar_t* value_cache_start_ptr =
|
||||
value_cache + block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride + block_offset * head_dim;
|
||||
reinterpret_cast<scalar_t*>(value_cache) +
|
||||
block_idx * num_blocks_stride + head_idx * cache_head_num_stride +
|
||||
block_offset * head_dim;
|
||||
std::memcpy(value_cache_start_ptr, value_start_ptr,
|
||||
sizeof(scalar_t) * head_dim);
|
||||
}
|
||||
@@ -243,6 +319,7 @@ class AttentionImpl<ISA::VEC, scalar_t, head_dim> {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cpu_attention
|
||||
|
||||
#endif
|
||||
|
||||
@@ -116,9 +116,9 @@ class TileGemm161 {
|
||||
} // namespace
|
||||
|
||||
// This is a general but naive implementation based on vector instructions
|
||||
template <typename scalar_t, int64_t head_dim>
|
||||
class AttentionImpl<ISA::VEC16, scalar_t, head_dim>
|
||||
: public AttentionImpl<ISA::VEC, scalar_t, head_dim> {
|
||||
template <typename scalar_t, int64_t head_dim, typename kv_cache_scalar_t>
|
||||
class AttentionImpl<ISA::VEC16, scalar_t, head_dim, kv_cache_scalar_t>
|
||||
: public AttentionImpl<ISA::VEC, scalar_t, head_dim, kv_cache_scalar_t> {
|
||||
public:
|
||||
using query_t = scalar_t;
|
||||
using q_buffer_t = float;
|
||||
|
||||
@@ -244,8 +244,8 @@ class TileGemmS390X {
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename scalar_t, int64_t head_dim>
|
||||
class AttentionImpl<ISA::VXE, scalar_t, head_dim> {
|
||||
template <typename scalar_t, int64_t head_dim, typename kv_cache_scalar_t>
|
||||
class AttentionImpl<ISA::VXE, scalar_t, head_dim, kv_cache_scalar_t> {
|
||||
public:
|
||||
using query_t = scalar_t;
|
||||
using q_buffer_t = float;
|
||||
@@ -342,7 +342,8 @@ class AttentionImpl<ISA::VXE, scalar_t, head_dim> {
|
||||
const int64_t head_num, const int64_t key_head_num_stride,
|
||||
const int64_t value_head_num_stride, const int64_t num_blocks,
|
||||
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
|
||||
const int64_t block_size, const int64_t block_size_stride) {
|
||||
const int64_t block_size, const int64_t block_size_stride,
|
||||
const float /*k_inv*/ = 0.0f, const float /*v_inv*/ = 0.0f) {
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
|
||||
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
|
||||
|
||||
@@ -15,6 +15,9 @@ using namespace at::vec;
|
||||
|
||||
namespace vec_op {
|
||||
|
||||
struct fp8_e4m3_tag {};
|
||||
struct fp8_e5m2_tag {};
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
@@ -322,6 +325,9 @@ struct BF16Vec32 : public VectorizedRegWrapper<BF16Vec32, 4, c10::BFloat16> {
|
||||
reg.val[2] = vec8_data.reg.val[0];
|
||||
reg.val[3] = vec8_data.reg.val[0];
|
||||
};
|
||||
|
||||
explicit BF16Vec32(const uint8_t*, fp8_e4m3_tag) : Base() {}
|
||||
explicit BF16Vec32(const uint8_t*, fp8_e5m2_tag) : Base() {}
|
||||
};
|
||||
|
||||
struct FP32Vec4 : public VectorizedRegWrapper<FP32Vec4, 1, float> {
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
#include <torch/all.h>
|
||||
namespace vec_op {
|
||||
|
||||
struct fp8_e4m3_tag {};
|
||||
struct fp8_e5m2_tag {};
|
||||
|
||||
#define vec_neg(a) (-(a))
|
||||
#define vec_add(a, b) ((a) + (b))
|
||||
#define vec_sub(a, b) ((a) - (b))
|
||||
@@ -241,6 +244,9 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
explicit BF16Vec32(const BF16Vec8& vec8_data)
|
||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}
|
||||
|
||||
explicit BF16Vec32(const uint8_t*, fp8_e4m3_tag) : reg{} {}
|
||||
explicit BF16Vec32(const uint8_t*, fp8_e5m2_tag) : reg{} {}
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
|
||||
@@ -11,6 +11,17 @@ static_assert(false, "AVX2 must be supported for the current implementation.");
|
||||
|
||||
namespace vec_op {
|
||||
|
||||
// Tags for FP8 BF16Vec32 constructors (avoid overload collision with
|
||||
// BF16Vec32(void*)).
|
||||
// VEC path (FP8 → pseudo-FP16 layout, scale correction applied later):
|
||||
struct fp8_e4m3_tag {}; // E4M3 → pseudo-FP16; BF16 value = true_E4M3 * 2^-8
|
||||
struct fp8_e5m2_tag {}; // E5M2 → FP16 bits directly (same exponent bias=15)
|
||||
// AMX path (FP8 → unscaled BF16, no FP32 round-trip):
|
||||
// BF16 value = true_E4M3 * 2^-120 (E4M3) or true_E5M2 * 2^-112 (E5M2).
|
||||
// Exponent rebiasing is folded into k/v scales by the caller.
|
||||
struct fp8_bf16_e4m3_tag {};
|
||||
struct fp8_bf16_e5m2_tag {};
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
@@ -176,6 +187,50 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
(__m128i)vec8_data.reg, 2),
|
||||
(__m128i)vec8_data.reg, 3)) {}
|
||||
|
||||
// Decode 32 FP8-E4M3 bytes to pseudo-FP16 layout (stored in the BF16
|
||||
// register). Result = true_E4M3 * 2^-8; caller applies scale * 2^8.
|
||||
explicit BF16Vec32(const uint8_t* ptr, fp8_e4m3_tag) {
|
||||
__m256i b8 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
||||
__m512i b16 = _mm512_cvtepu8_epi16(b8);
|
||||
__m512i sign =
|
||||
_mm512_slli_epi16(_mm512_and_si512(b16, _mm512_set1_epi16(0x80)), 8);
|
||||
__m512i payload =
|
||||
_mm512_slli_epi16(_mm512_and_si512(b16, _mm512_set1_epi16(0x7F)), 7);
|
||||
reg = _mm512_or_si512(sign, payload);
|
||||
}
|
||||
|
||||
// Decode 32 FP8-E5M2 bytes to FP16 layout.
|
||||
// E5M2 and FP16 share the same 5-bit exponent bias (15), so FP8 byte b maps
|
||||
// directly to FP16 bits by shifting left 8 — no sign/payload reconstruction.
|
||||
explicit BF16Vec32(const uint8_t* ptr, fp8_e5m2_tag) {
|
||||
__m256i b8 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
||||
reg = _mm512_slli_epi16(_mm512_cvtepu8_epi16(b8), 8);
|
||||
}
|
||||
|
||||
// Direct FP8-E4M3 → unscaled BF16 for AMX (no FP32 round-trip).
|
||||
// BF16 value = true_E4M3 * 2^-120; exponent rebiasing folded into k/v scales.
|
||||
explicit BF16Vec32(const uint8_t* ptr, fp8_bf16_e4m3_tag) {
|
||||
__m256i b8 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
||||
__m512i b16 = _mm512_cvtepu8_epi16(b8);
|
||||
__m512i sign =
|
||||
_mm512_slli_epi16(_mm512_and_si512(b16, _mm512_set1_epi16(0x80)), 8);
|
||||
__m512i payload =
|
||||
_mm512_slli_epi16(_mm512_and_si512(b16, _mm512_set1_epi16(0x7F)), 4);
|
||||
reg = _mm512_or_si512(sign, payload);
|
||||
}
|
||||
|
||||
// Direct FP8-E5M2 → unscaled BF16 for AMX (no FP32 round-trip).
|
||||
// BF16 value = true_E5M2 * 2^-112; exponent rebiasing folded into k/v scales.
|
||||
explicit BF16Vec32(const uint8_t* ptr, fp8_bf16_e5m2_tag) {
|
||||
__m256i b8 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
||||
__m512i b16 = _mm512_cvtepu8_epi16(b8);
|
||||
__m512i sign =
|
||||
_mm512_slli_epi16(_mm512_and_si512(b16, _mm512_set1_epi16(0x80)), 8);
|
||||
__m512i payload =
|
||||
_mm512_slli_epi16(_mm512_and_si512(b16, _mm512_set1_epi16(0x7F)), 5);
|
||||
reg = _mm512_or_si512(sign, payload);
|
||||
}
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<__m512i*>(ptr) = reg; }
|
||||
};
|
||||
#else
|
||||
@@ -200,6 +255,77 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)) {}
|
||||
|
||||
// E4M3 decode (AVX2 path) — same bit-layout trick as the AVX512 variant
|
||||
// above. Result = true_E4M3 * 2^-8; caller applies scale * 2^8.
|
||||
explicit BF16Vec32(const uint8_t* ptr, fp8_e4m3_tag) {
|
||||
__m256i b8 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
||||
__m128i b8_low = _mm256_extracti128_si256(b8, 0);
|
||||
__m128i b8_high = _mm256_extracti128_si256(b8, 1);
|
||||
__m256i b16_low = _mm256_cvtepu8_epi16(b8_low);
|
||||
__m256i b16_high = _mm256_cvtepu8_epi16(b8_high);
|
||||
|
||||
__m256i sign_low = _mm256_slli_epi16(
|
||||
_mm256_and_si256(b16_low, _mm256_set1_epi16(0x80)), 8);
|
||||
__m256i payload_low = _mm256_slli_epi16(
|
||||
_mm256_and_si256(b16_low, _mm256_set1_epi16(0x7F)), 7);
|
||||
__m256i sign_high = _mm256_slli_epi16(
|
||||
_mm256_and_si256(b16_high, _mm256_set1_epi16(0x80)), 8);
|
||||
__m256i payload_high = _mm256_slli_epi16(
|
||||
_mm256_and_si256(b16_high, _mm256_set1_epi16(0x7F)), 7);
|
||||
reg_low = _mm256_or_si256(sign_low, payload_low);
|
||||
reg_high = _mm256_or_si256(sign_high, payload_high);
|
||||
}
|
||||
|
||||
// E5M2 decode (AVX2 path) — b << 8 maps to FP16 bits; see AVX512 variant
|
||||
// above.
|
||||
explicit BF16Vec32(const uint8_t* ptr, fp8_e5m2_tag) {
|
||||
__m256i b8 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
||||
__m128i b8_low = _mm256_extracti128_si256(b8, 0);
|
||||
__m128i b8_high = _mm256_extracti128_si256(b8, 1);
|
||||
reg_low = _mm256_slli_epi16(_mm256_cvtepu8_epi16(b8_low), 8);
|
||||
reg_high = _mm256_slli_epi16(_mm256_cvtepu8_epi16(b8_high), 8);
|
||||
}
|
||||
|
||||
// Direct FP8-E4M3 → unscaled BF16 for AMX (AVX2 path, no FP32 round-trip).
|
||||
// BF16 value = true_E4M3 * 2^-120; exponent rebiasing folded into k/v scales.
|
||||
explicit BF16Vec32(const uint8_t* ptr, fp8_bf16_e4m3_tag) {
|
||||
__m256i b8 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
||||
__m128i b8_low = _mm256_extracti128_si256(b8, 0);
|
||||
__m128i b8_high = _mm256_extracti128_si256(b8, 1);
|
||||
__m256i b16_low = _mm256_cvtepu8_epi16(b8_low);
|
||||
__m256i b16_high = _mm256_cvtepu8_epi16(b8_high);
|
||||
reg_low = _mm256_or_si256(
|
||||
_mm256_slli_epi16(_mm256_and_si256(b16_low, _mm256_set1_epi16(0x80)),
|
||||
8),
|
||||
_mm256_slli_epi16(_mm256_and_si256(b16_low, _mm256_set1_epi16(0x7F)),
|
||||
4));
|
||||
reg_high = _mm256_or_si256(
|
||||
_mm256_slli_epi16(_mm256_and_si256(b16_high, _mm256_set1_epi16(0x80)),
|
||||
8),
|
||||
_mm256_slli_epi16(_mm256_and_si256(b16_high, _mm256_set1_epi16(0x7F)),
|
||||
4));
|
||||
}
|
||||
|
||||
// Direct FP8-E5M2 → unscaled BF16 for AMX (AVX2 path, no FP32 round-trip).
|
||||
// BF16 value = true_E5M2 * 2^-112; exponent rebiasing folded into k/v scales.
|
||||
explicit BF16Vec32(const uint8_t* ptr, fp8_bf16_e5m2_tag) {
|
||||
__m256i b8 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
|
||||
__m128i b8_low = _mm256_extracti128_si256(b8, 0);
|
||||
__m128i b8_high = _mm256_extracti128_si256(b8, 1);
|
||||
__m256i b16_low = _mm256_cvtepu8_epi16(b8_low);
|
||||
__m256i b16_high = _mm256_cvtepu8_epi16(b8_high);
|
||||
reg_low = _mm256_or_si256(
|
||||
_mm256_slli_epi16(_mm256_and_si256(b16_low, _mm256_set1_epi16(0x80)),
|
||||
8),
|
||||
_mm256_slli_epi16(_mm256_and_si256(b16_low, _mm256_set1_epi16(0x7F)),
|
||||
5));
|
||||
reg_high = _mm256_or_si256(
|
||||
_mm256_slli_epi16(_mm256_and_si256(b16_high, _mm256_set1_epi16(0x80)),
|
||||
8),
|
||||
_mm256_slli_epi16(_mm256_and_si256(b16_high, _mm256_set1_epi16(0x7F)),
|
||||
5));
|
||||
}
|
||||
|
||||
void save(void* ptr) const {
|
||||
_mm256_storeu_si256((__m256i*)ptr, reg_low);
|
||||
_mm256_storeu_si256((__m256i*)ptr + 1, reg_high);
|
||||
@@ -390,6 +516,11 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
: reg(_mm512_castsi512_ps(
|
||||
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec32& v, int upper) {
|
||||
__m256i v_half_i = _mm512_extracti32x8_epi32(v.reg, upper);
|
||||
reg = _mm512_cvtph_ps(v_half_i);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec16& v) : reg(_mm512_cvtph_ps(v.reg)) {}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
@@ -494,6 +625,14 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
explicit FP32Vec16(const FP32Vec8& data)
|
||||
: reg_low(data.reg), reg_high(data.reg) {}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec32& v, int upper) {
|
||||
const __m256i& half = upper ? v.reg_high : v.reg_low;
|
||||
__m128i lo = _mm256_extractf128_si256(half, 0);
|
||||
__m128i hi = _mm256_extractf128_si256(half, 1);
|
||||
reg_low = _mm256_cvtph_ps(lo);
|
||||
reg_high = _mm256_cvtph_ps(hi);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec16& v) {
|
||||
__m128i low = _mm256_extractf128_si256(v.reg, 0);
|
||||
__m128i high = _mm256_extractf128_si256(v.reg, 1);
|
||||
|
||||
@@ -22,71 +22,95 @@ ISA_TYPES = {
|
||||
"VXE": 4,
|
||||
}
|
||||
|
||||
# KV cache index: 0 = auto (same as scalar_t), 1 = fp8_e4m3, 2 = fp8_e5m2
|
||||
KV_CACHE_IDX = {
|
||||
"auto": 0,
|
||||
"fp8_e4m3": 1,
|
||||
"fp8_e5m2": 2,
|
||||
}
|
||||
|
||||
# C++ type for each kv_cache index
|
||||
KV_CACHE_CPP_TYPES = {
|
||||
"auto": "scalar_t",
|
||||
"fp8_e4m3": "c10::Float8_e4m3fn",
|
||||
"fp8_e5m2": "c10::Float8_e5m2",
|
||||
}
|
||||
|
||||
# ISAs supported for head_dims divisible by 32
|
||||
ISA_FOR_32 = ["AMX", "NEON", "VEC", "VEC16", "VXE"]
|
||||
|
||||
# ISAs supported for head_dims divisible by 16 only
|
||||
ISA_FOR_16 = ["VEC16"]
|
||||
|
||||
# ISAs that support FP8 KV cache (x86 AVX2/AVX-512 required)
|
||||
ISA_FOR_FP8 = ["AMX", "VEC"]
|
||||
|
||||
def encode_params(head_dim: int, isa_type: str) -> int:
|
||||
"""Encode head_dim and ISA type into a single int64_t."""
|
||||
|
||||
def encode_params(head_dim: int, isa_type: str, kv_cache: str = "auto") -> int:
|
||||
"""Encode head_dim, ISA type, and KV cache type into a single int64_t."""
|
||||
isa_val = ISA_TYPES[isa_type]
|
||||
# Encoding: (head_dim << 8) | isa_type
|
||||
# This allows head_dim up to 2^56 - 1 and 256 ISA types
|
||||
return (head_dim << 8) | isa_val
|
||||
kv_val = KV_CACHE_IDX[kv_cache]
|
||||
# Encoding: (head_dim << 16) | (kv_cache_idx << 8) | isa_type
|
||||
# This allows head_dim up to 2^48 - 1, 256 KV cache types, and 256 ISA types
|
||||
return (head_dim << 16) | (kv_val << 8) | isa_val
|
||||
|
||||
|
||||
def generate_cases_for_isa_group(isa_list: list[str]) -> str:
|
||||
def _make_case(
|
||||
head_dim: int, isa: str, kv_cache: str = "auto", isa_override: str | None = None
|
||||
) -> str:
|
||||
"""Generate a single switch case line."""
|
||||
encoded = encode_params(head_dim, isa, kv_cache)
|
||||
actual_isa = isa_override if isa_override else isa
|
||||
cpp_type = KV_CACHE_CPP_TYPES[kv_cache]
|
||||
attn_impl = (
|
||||
f"cpu_attention::AttentionImpl<"
|
||||
f"cpu_attention::ISA::{actual_isa}, \\\n"
|
||||
f" "
|
||||
f"scalar_t, head_dim, {cpp_type}>"
|
||||
)
|
||||
comment = (
|
||||
f"head_dim={head_dim}, isa={isa}"
|
||||
if kv_cache == "auto"
|
||||
else f"head_dim={head_dim}, isa={isa}, kv_cache={kv_cache}"
|
||||
)
|
||||
return (
|
||||
f""" case {encoded}LL: {{ """
|
||||
f"""/* {comment} */ \\"""
|
||||
f"""
|
||||
constexpr size_t head_dim = {head_dim}; \\"""
|
||||
f"""
|
||||
using attn_impl = {attn_impl}; \\"""
|
||||
f"""
|
||||
return __VA_ARGS__(); \\"""
|
||||
f"""
|
||||
}} \\"""
|
||||
)
|
||||
|
||||
|
||||
def generate_cases_for_isa_group(isa_list: list[str], include_fp8: bool = False) -> str:
|
||||
"""Generate switch cases for a specific ISA group."""
|
||||
cases = []
|
||||
|
||||
# Generate cases for head_dims divisible by 32
|
||||
# Non-FP8 cases for head_dims divisible by 32
|
||||
for head_dim in HEAD_DIMS_32:
|
||||
for isa in isa_list:
|
||||
if isa not in ISA_FOR_32:
|
||||
continue
|
||||
encoded = encode_params(head_dim, isa)
|
||||
case_str = (
|
||||
f""" case {encoded}LL: {{ """
|
||||
f"""/* head_dim={head_dim}, isa={isa} */ \\"""
|
||||
f"""
|
||||
constexpr size_t head_dim = {head_dim}; \\"""
|
||||
f"""
|
||||
using attn_impl = cpu_attention::AttentionImpl<"""
|
||||
f"""cpu_attention::ISA::{isa}, \\"""
|
||||
f"""
|
||||
"""
|
||||
f"""scalar_t, head_dim>; \\"""
|
||||
f"""
|
||||
return __VA_ARGS__(); \\"""
|
||||
f"""
|
||||
}} \\"""
|
||||
)
|
||||
cases.append(case_str)
|
||||
cases.append(_make_case(head_dim, isa, "auto"))
|
||||
|
||||
# Generate cases for head_dims divisible by 16 only
|
||||
# Non-FP8 cases for head_dims divisible by 16 only
|
||||
for head_dim in HEAD_DIMS_16:
|
||||
for isa in isa_list:
|
||||
encoded = encode_params(head_dim, isa)
|
||||
case_str = (
|
||||
f""" case {encoded}LL: {{ """
|
||||
f"""/* head_dim={head_dim}, isa={isa} """
|
||||
f"""(using VEC16) */ \\"""
|
||||
f"""
|
||||
constexpr size_t head_dim = {head_dim}; \\"""
|
||||
f"""
|
||||
using attn_impl = cpu_attention::AttentionImpl<"""
|
||||
f"""cpu_attention::ISA::VEC16, \\"""
|
||||
f"""
|
||||
"""
|
||||
f"""scalar_t, head_dim>; \\"""
|
||||
f"""
|
||||
return __VA_ARGS__(); \\"""
|
||||
f"""
|
||||
}} \\"""
|
||||
)
|
||||
cases.append(case_str)
|
||||
cases.append(_make_case(head_dim, isa, "auto", isa_override="VEC16"))
|
||||
|
||||
# FP8 cases: only AMX and VEC, only head_dims divisible by 32
|
||||
if include_fp8:
|
||||
for fp8_type in ("fp8_e4m3", "fp8_e5m2"):
|
||||
for head_dim in HEAD_DIMS_32:
|
||||
for isa in isa_list:
|
||||
if isa not in ISA_FOR_FP8:
|
||||
continue
|
||||
cases.append(_make_case(head_dim, isa, fp8_type))
|
||||
|
||||
return "\n".join(cases)
|
||||
|
||||
@@ -94,8 +118,9 @@ def generate_cases_for_isa_group(isa_list: list[str]) -> str:
|
||||
def generate_helper_function() -> str:
|
||||
"""Generate helper function to encode parameters."""
|
||||
return """
|
||||
inline int64_t encode_cpu_attn_params(int64_t head_dim, cpu_attention::ISA isa) {
|
||||
return (head_dim << 8) | static_cast<int64_t>(isa);
|
||||
inline int64_t encode_cpu_attn_params(int64_t head_dim, cpu_attention::ISA isa,
|
||||
int64_t kv_cache_idx = 0) {
|
||||
return (head_dim << 16) | (kv_cache_idx << 8) | static_cast<int64_t>(isa);
|
||||
}
|
||||
"""
|
||||
|
||||
@@ -129,87 +154,78 @@ def generate_header_file() -> str:
|
||||
|
||||
# Generate dispatch macro with conditional compilation for different ISA sets
|
||||
header += """
|
||||
// Dispatch macro using encoded parameters
|
||||
// Dispatch macro using encoded parameters.
|
||||
// KV_CACHE_IDX: Fp8KVCacheDataType enum value (kAuto=0, kFp8E4M3=1, kFp8E5M2=2).
|
||||
// FP8 cases (kv_cache_idx != 0) are generated on x86 platforms with AVX2 or
|
||||
// AVX-512: BF16Vec32 FP8 constructors have both AVX-512 and AVX2 implementations
|
||||
// in cpu_types_x86.hpp. Non-x86 platforms (#else fallback) have fp8=False.
|
||||
"""
|
||||
|
||||
# x86_64 with AMX
|
||||
header += """#if defined(CPU_CAPABILITY_AMXBF16)
|
||||
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
|
||||
[&] { \\
|
||||
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
|
||||
switch (encoded_params) { \\
|
||||
"""
|
||||
header += generate_cases_for_isa_group(["AMX", "VEC", "VEC16"])
|
||||
header += """
|
||||
default: { \\
|
||||
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
|
||||
std::to_string(HEAD_DIM) + " isa=" + \\
|
||||
std::to_string(static_cast<int>(ISA_TYPE))); \\
|
||||
} \\
|
||||
} \\
|
||||
}()
|
||||
def _macro_block(guard: str, isa_list: list[str], fp8: bool) -> str:
|
||||
"""Return one CPU_ATTN_DISPATCH macro block for a given guard."""
|
||||
enc = (
|
||||
" int64_t encoded_params = encode_cpu_attn_params("
|
||||
"HEAD_DIM, ISA_TYPE, KV_CACHE_IDX); \\"
|
||||
)
|
||||
cases = generate_cases_for_isa_group(isa_list, include_fp8=fp8)
|
||||
tail = (
|
||||
"\n"
|
||||
" default: { \\\n"
|
||||
" TORCH_CHECK(false, "
|
||||
'"Unsupported CPU attention configuration: head_dim=" + \\\n'
|
||||
' std::to_string(HEAD_DIM) + " isa=" + \\\n'
|
||||
" std::to_string(static_cast<int>(ISA_TYPE))"
|
||||
" + \\\n"
|
||||
' " kv_cache_idx=" + '
|
||||
"std::to_string(KV_CACHE_IDX)); \\\n"
|
||||
" } \\\n"
|
||||
" } \\\n"
|
||||
" }()\n\n"
|
||||
)
|
||||
return (
|
||||
f"{guard}\n"
|
||||
"#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, KV_CACHE_IDX, ...) \\\n"
|
||||
" [&] { \\\n"
|
||||
f"{enc}\n"
|
||||
" switch (encoded_params) { \\\n"
|
||||
f"{cases}"
|
||||
f"{tail}"
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
# ARM64 with NEON
|
||||
header += """#elif defined(__aarch64__)
|
||||
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
|
||||
[&] { \\
|
||||
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
|
||||
switch (encoded_params) { \\
|
||||
"""
|
||||
header += generate_cases_for_isa_group(["NEON", "VEC", "VEC16"])
|
||||
header += """
|
||||
default: { \\
|
||||
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
|
||||
std::to_string(HEAD_DIM) + " isa=" + \\
|
||||
std::to_string(static_cast<int>(ISA_TYPE))); \\
|
||||
} \\
|
||||
} \\
|
||||
}()
|
||||
|
||||
"""
|
||||
|
||||
# s390x with VXE
|
||||
header += """#elif defined(__s390x__)
|
||||
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
|
||||
[&] { \\
|
||||
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
|
||||
switch (encoded_params) { \\
|
||||
"""
|
||||
header += generate_cases_for_isa_group(["VXE", "VEC", "VEC16"])
|
||||
header += """
|
||||
default: { \\
|
||||
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
|
||||
std::to_string(HEAD_DIM) + " isa=" + \\
|
||||
std::to_string(static_cast<int>(ISA_TYPE))); \\
|
||||
} \\
|
||||
} \\
|
||||
}()
|
||||
|
||||
"""
|
||||
|
||||
# Fallback: VEC and VEC16 only
|
||||
header += """#else
|
||||
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
|
||||
[&] { \\
|
||||
int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
|
||||
switch (encoded_params) { \\
|
||||
"""
|
||||
header += generate_cases_for_isa_group(["VEC", "VEC16"])
|
||||
header += """
|
||||
default: { \\
|
||||
TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
|
||||
std::to_string(HEAD_DIM) + " isa=" + \\
|
||||
std::to_string(static_cast<int>(ISA_TYPE))); \\
|
||||
} \\
|
||||
} \\
|
||||
}()
|
||||
|
||||
#endif /* CPU_CAPABILITY_AMXBF16 / __aarch64__ / __s390x__ */
|
||||
|
||||
#endif // CPU_ATTN_DISPATCH_GENERATED_H
|
||||
"""
|
||||
header += _macro_block(
|
||||
"#if defined(CPU_CAPABILITY_AMXBF16)",
|
||||
["AMX", "VEC", "VEC16"],
|
||||
fp8=True,
|
||||
)
|
||||
header += _macro_block(
|
||||
"#elif defined(__aarch64__)",
|
||||
["NEON", "VEC", "VEC16"],
|
||||
fp8=False,
|
||||
)
|
||||
header += _macro_block(
|
||||
"#elif defined(__s390x__)",
|
||||
["VXE", "VEC", "VEC16"],
|
||||
fp8=False,
|
||||
)
|
||||
header += _macro_block(
|
||||
"#elif defined(__AVX512F__)",
|
||||
["VEC", "VEC16"],
|
||||
fp8=True,
|
||||
)
|
||||
header += _macro_block(
|
||||
"#elif defined(__AVX2__)",
|
||||
["VEC", "VEC16"],
|
||||
fp8=False,
|
||||
)
|
||||
header += _macro_block(
|
||||
"#else",
|
||||
["VEC", "VEC16"],
|
||||
fp8=False,
|
||||
)
|
||||
header += (
|
||||
"#endif /* CPU_CAPABILITY_AMXBF16 / __aarch64__ / __s390x__ */\n\n"
|
||||
"#endif // CPU_ATTN_DISPATCH_GENERATED_H\n"
|
||||
)
|
||||
|
||||
return header
|
||||
|
||||
|
||||
@@ -101,7 +101,9 @@ void cpu_attn_reshape_and_cache(const torch::Tensor& key,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
const torch::Tensor& slot_mapping,
|
||||
const std::string& isa);
|
||||
const std::string& isa, const double k_scale,
|
||||
const double v_scale,
|
||||
const std::string& kv_cache_dtype);
|
||||
|
||||
void cpu_attention_with_kv_cache(
|
||||
const torch::Tensor& query, const torch::Tensor& key_cache,
|
||||
@@ -112,7 +114,8 @@ void cpu_attention_with_kv_cache(
|
||||
const int64_t sliding_window_left, const int64_t sliding_window_right,
|
||||
const torch::Tensor& block_table, const double softcap,
|
||||
const torch::Tensor& scheduler_metadata,
|
||||
const std::optional<torch::Tensor>& s_aux);
|
||||
const std::optional<torch::Tensor>& s_aux, const double k_scale,
|
||||
const double v_scale, const std::string& kv_cache_dtype);
|
||||
|
||||
// Note: just for avoiding importing errors
|
||||
void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); }
|
||||
@@ -384,15 +387,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
&get_scheduler_metadata);
|
||||
ops.def(
|
||||
"cpu_attn_reshape_and_cache(Tensor key, Tensor value, Tensor(a2!) "
|
||||
"key_cache, Tensor(a3!) value_cache, Tensor slot_mapping, str "
|
||||
"isa) -> ()",
|
||||
"key_cache, Tensor(a3!) value_cache, Tensor slot_mapping, str isa, "
|
||||
"float k_scale=1.0, float v_scale=1.0, str kv_cache_dtype=\"auto\") -> "
|
||||
"()",
|
||||
&cpu_attn_reshape_and_cache);
|
||||
ops.def(
|
||||
"cpu_attention_with_kv_cache(Tensor query, Tensor key_cache, Tensor "
|
||||
"value_cache, Tensor(a3!) output, Tensor query_start_loc, Tensor "
|
||||
"seq_lens, float scale, bool causal, Tensor? alibi_slopes, SymInt "
|
||||
"sliding_window_left, SymInt sliding_window_right, Tensor block_table, "
|
||||
"float softcap, Tensor scheduler_metadata, Tensor? s_aux) -> ()",
|
||||
"float softcap, Tensor scheduler_metadata, Tensor? s_aux, "
|
||||
"float k_scale=1.0, float v_scale=1.0, str kv_cache_dtype=\"auto\") -> "
|
||||
"()",
|
||||
&cpu_attention_with_kv_cache);
|
||||
|
||||
// placeholders
|
||||
|
||||
@@ -167,7 +167,7 @@ Priority is **1 = highest** (tried first).
|
||||
|
||||
| Backend | Version | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | MM Prefix | DCP | Attention Types | Compute Cap. |
|
||||
| ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ |
|
||||
| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512 | ❌ | ❌ | ❌ | All | N/A |
|
||||
| `CPU_ATTN` | | fp16, bf16, fp32 | `auto`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512 | ❌ | ❌ | ❌ | All | N/A |
|
||||
| `FLASHINFER` | Native† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x |
|
||||
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
|
||||
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
|
||||
|
||||
@@ -20,6 +20,12 @@ from vllm._custom_ops import (
|
||||
cpu_attn_reshape_and_cache,
|
||||
)
|
||||
|
||||
# Enable AMX tile data registers so isolated runs (e.g. -k fp8_amx) don't rely
|
||||
# on ref_paged_attn's einsum to trigger oneDNN's _init_amx() first.
|
||||
if torch.cpu._is_amx_tile_supported():
|
||||
torch.cpu._init_amx()
|
||||
|
||||
|
||||
NUM_HEADS = [
|
||||
(4, 4),
|
||||
(8, 2),
|
||||
@@ -178,6 +184,10 @@ def ref_paged_attn(
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
_FP8_ATOL = {"fp8_e4m3": 0.2, "fp8_e5m2": 0.3}
|
||||
_FP8_RTOL = 0.1
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def varlen_with_paged_kv(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
@@ -191,6 +201,9 @@ def varlen_with_paged_kv(
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
kv_cache_dtype: str = "auto",
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
) -> None:
|
||||
set_random_seed(0)
|
||||
num_seqs = len(seq_lens)
|
||||
@@ -212,6 +225,10 @@ def varlen_with_paged_kv(
|
||||
15 * torch.rand((num_query_heads,), dtype=torch.bfloat16) if use_sink else None
|
||||
)
|
||||
|
||||
is_fp8 = kv_cache_dtype != "auto"
|
||||
if is_fp8 and current_platform.get_cpu_architecture() != CpuArchEnum.X86:
|
||||
pytest.skip("FP8 KV cache only supported on x86")
|
||||
|
||||
query = tensor_cache(
|
||||
elem_num=token_num * num_query_heads * head_size,
|
||||
dtype=dtype,
|
||||
@@ -233,11 +250,17 @@ def varlen_with_paged_kv(
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
)
|
||||
if is_fp8:
|
||||
# Clamp KV to [-1, 1] so FP8 quantization error (<=12.5% for E4M3,
|
||||
# <=25% for E5M2) stays within the test tolerances regardless of
|
||||
# which tensor_cache values happen to be in use.
|
||||
key_value = key_value.clamp(-1, 1)
|
||||
key_cache, value_cache = key_value.unbind(0)
|
||||
|
||||
# KV cache for CPU attention
|
||||
cache_dtype = torch.uint8 if is_fp8 else dtype
|
||||
packed_key_cache = torch.empty(
|
||||
num_blocks, num_kv_heads, block_size, head_size, dtype=dtype
|
||||
num_blocks, num_kv_heads, block_size, head_size, dtype=cache_dtype
|
||||
)
|
||||
packed_value_cache = torch.empty_like(packed_key_cache)
|
||||
|
||||
@@ -252,6 +275,11 @@ def varlen_with_paged_kv(
|
||||
|
||||
# use reshape_and_cache to pack key_cache and value_cache
|
||||
slot_mapping = torch.arange(0, num_blocks * block_size, dtype=torch.int64)
|
||||
fp8_kwargs: dict = (
|
||||
dict(k_scale=k_scale, v_scale=v_scale, kv_cache_dtype=kv_cache_dtype)
|
||||
if is_fp8
|
||||
else {}
|
||||
)
|
||||
cpu_attn_reshape_and_cache(
|
||||
key=key_cache.view(-1, num_kv_heads, head_size),
|
||||
value=value_cache.view(-1, num_kv_heads, head_size),
|
||||
@@ -259,6 +287,7 @@ def varlen_with_paged_kv(
|
||||
value_cache=packed_value_cache,
|
||||
slot_mapping=slot_mapping,
|
||||
isa=isa,
|
||||
**fp8_kwargs,
|
||||
)
|
||||
|
||||
metadata = cpu_attn_get_scheduler_metadata(
|
||||
@@ -291,6 +320,7 @@ def varlen_with_paged_kv(
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
scheduler_metadata=metadata,
|
||||
s_aux=s_aux,
|
||||
**fp8_kwargs,
|
||||
)
|
||||
|
||||
metadata = cpu_attn_get_scheduler_metadata(
|
||||
@@ -323,23 +353,59 @@ def varlen_with_paged_kv(
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
scheduler_metadata=metadata,
|
||||
s_aux=s_aux,
|
||||
**fp8_kwargs,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
alibi_slopes=alibi_slopes,
|
||||
s_aux=s_aux,
|
||||
)
|
||||
if is_fp8:
|
||||
# Build a float KV cache via the non-FP8 path and run float attention
|
||||
# to use as the reference.
|
||||
ref_key_cache = torch.empty(
|
||||
num_blocks, num_kv_heads, block_size, head_size, dtype=dtype
|
||||
)
|
||||
ref_value_cache = torch.empty_like(ref_key_cache)
|
||||
cpu_attn_reshape_and_cache(
|
||||
key=key_cache.view(-1, num_kv_heads, head_size),
|
||||
value=value_cache.view(-1, num_kv_heads, head_size),
|
||||
key_cache=ref_key_cache,
|
||||
value_cache=ref_value_cache,
|
||||
slot_mapping=slot_mapping,
|
||||
isa=isa,
|
||||
)
|
||||
ref_output = torch.empty_like(query)
|
||||
cpu_attention_with_kv_cache(
|
||||
query=query,
|
||||
key_cache=ref_key_cache,
|
||||
value_cache=ref_value_cache,
|
||||
output=ref_output,
|
||||
query_start_loc=cu_query_lens,
|
||||
seq_lens=kv_lens_tensor,
|
||||
scale=scale,
|
||||
causal=True,
|
||||
alibi_slopes=alibi_slopes,
|
||||
sliding_window=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
scheduler_metadata=metadata,
|
||||
s_aux=s_aux,
|
||||
)
|
||||
atol = _FP8_ATOL[kv_cache_dtype]
|
||||
rtol = _FP8_RTOL
|
||||
else:
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
alibi_slopes=alibi_slopes,
|
||||
s_aux=s_aux,
|
||||
)
|
||||
atol, rtol = 1.5e-2, 1e-2
|
||||
|
||||
atol, rtol = 1.5e-2, 1e-2
|
||||
(
|
||||
torch.testing.assert_close(out_with_split, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(out_with_split - ref_output))}",
|
||||
@@ -350,6 +416,7 @@ def varlen_with_paged_kv(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8_e4m3", "fp8_e5m2"])
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@@ -373,6 +440,7 @@ def test_varlen_with_paged_kv_normal_vec(
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
@@ -386,9 +454,11 @@ def test_varlen_with_paged_kv_normal_vec(
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8_e4m3", "fp8_e5m2"])
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@@ -413,6 +483,7 @@ def test_varlen_with_paged_kv_normal_amx(
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
@@ -426,6 +497,7 @@ def test_varlen_with_paged_kv_normal_amx(
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
|
||||
@@ -511,6 +583,7 @@ def test_varlen_with_paged_kv_normal_neon(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8_e4m3"])
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [96])
|
||||
@@ -534,6 +607,7 @@ def test_varlen_with_paged_kv_softcap(
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
@@ -547,9 +621,11 @@ def test_varlen_with_paged_kv_softcap(
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8_e4m3"])
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [96])
|
||||
@@ -573,6 +649,7 @@ def test_varlen_with_paged_kv_alibi(
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
@@ -586,9 +663,11 @@ def test_varlen_with_paged_kv_alibi(
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8_e4m3"])
|
||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [96])
|
||||
@@ -612,6 +691,7 @@ def test_varlen_with_paged_kv_sink(
|
||||
use_alibi: bool,
|
||||
use_sink: bool,
|
||||
isa: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
varlen_with_paged_kv(
|
||||
seq_lens=seq_lens,
|
||||
@@ -625,4 +705,5 @@ def test_varlen_with_paged_kv_sink(
|
||||
use_alibi=use_alibi,
|
||||
use_sink=use_sink,
|
||||
isa=isa,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
@@ -3403,6 +3403,9 @@ def cpu_attn_reshape_and_cache(
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
isa: str,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
kv_cache_dtype: str = "auto",
|
||||
) -> None:
|
||||
torch.ops._C.cpu_attn_reshape_and_cache(
|
||||
key,
|
||||
@@ -3411,6 +3414,9 @@ def cpu_attn_reshape_and_cache(
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
isa,
|
||||
k_scale,
|
||||
v_scale,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
|
||||
|
||||
@@ -3429,6 +3435,9 @@ def cpu_attention_with_kv_cache(
|
||||
softcap: float,
|
||||
scheduler_metadata: torch.Tensor,
|
||||
s_aux: torch.Tensor | None,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
kv_cache_dtype: str = "auto",
|
||||
) -> None:
|
||||
torch.ops._C.cpu_attention_with_kv_cache(
|
||||
query,
|
||||
@@ -3446,6 +3455,9 @@ def cpu_attention_with_kv_cache(
|
||||
softcap,
|
||||
scheduler_metadata,
|
||||
s_aux,
|
||||
k_scale,
|
||||
v_scale,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
|
||||
|
||||
|
||||
+4
-15
@@ -16,7 +16,6 @@ from vllm.utils.cpu_resource_utils import (
|
||||
get_memory_node_info,
|
||||
)
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.torch_utils import is_quantized_kv_cache
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
@@ -134,20 +133,6 @@ class CpuPlatform(Platform):
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
# async scheduling is not required on CPU
|
||||
scheduler_config.async_scheduling = False
|
||||
if (
|
||||
scheduler_config.enable_chunked_prefill
|
||||
or cache_config.enable_prefix_caching
|
||||
) and is_quantized_kv_cache(cache_config.cache_dtype):
|
||||
raise RuntimeError(
|
||||
"Chunked-prefill and prefix-cache on the CPU "
|
||||
"backend is not compatible with FP8 KV cache."
|
||||
)
|
||||
|
||||
if is_quantized_kv_cache(cache_config.cache_dtype):
|
||||
logger.warning(
|
||||
"CPU backend doesn't support KV cache quantization fallback to auto."
|
||||
)
|
||||
cache_config.cache_dtype = "auto"
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
# OMP requires the MP executor to function correctly, UniProc is not
|
||||
@@ -458,6 +443,10 @@ class CpuPlatform(Platform):
|
||||
block_offsets.reshape(1, block_size)
|
||||
+ indices.reshape(num_blocks, 1) * block_size
|
||||
).flatten()
|
||||
if key_cache.dtype == torch.uint8:
|
||||
raise NotImplementedError(
|
||||
"FP8 KV cache is not yet supported with KV transfer on CPU"
|
||||
)
|
||||
cpu_attn_reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.cache import CacheDType
|
||||
|
||||
import torch
|
||||
|
||||
@@ -35,6 +38,12 @@ class CPUAttentionBackend(AttentionBackend):
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
@@ -133,7 +142,13 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
|
||||
if self.window_size is None:
|
||||
self.window_size = -1
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.isa = _get_attn_isa(self.dtype, self.block_size, self.head_dim)
|
||||
kv_cache_dtype_str = vllm_config.cache_config.cache_dtype
|
||||
self.isa = _get_attn_isa(
|
||||
self.dtype,
|
||||
self.block_size,
|
||||
self.head_dim,
|
||||
kv_cache_dtype_str,
|
||||
)
|
||||
self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec)
|
||||
|
||||
def build(
|
||||
@@ -247,8 +262,7 @@ class CPUAttentionBackendImpl(AttentionImpl):
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
raise NotImplementedError("FP8 KV cache is unsupported in CPU_ATTN")
|
||||
self.is_fp8_kv_cache = is_quantized_kv_cache(kv_cache_dtype)
|
||||
self.attn_type = attn_type
|
||||
|
||||
self.sinks = sinks
|
||||
@@ -325,6 +339,9 @@ class CPUAttentionBackendImpl(AttentionImpl):
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
attn_metadata.isa,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
|
||||
if attn_metadata.use_sdpa_prefill:
|
||||
@@ -356,6 +373,9 @@ class CPUAttentionBackendImpl(AttentionImpl):
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
s_aux=self.sinks,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -477,13 +497,26 @@ def _make_sliding_window_bias(
|
||||
|
||||
|
||||
def _get_attn_isa(
|
||||
dtype: torch.dtype, block_size: int, head_size: int | None = None
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
head_size: int | None = None,
|
||||
kv_cache_dtype: str | None = None,
|
||||
) -> str:
|
||||
fp8_kv = is_quantized_kv_cache(kv_cache_dtype) if kv_cache_dtype else False
|
||||
if head_size is not None and head_size % 32 != 0 and head_size % 16 == 0:
|
||||
if fp8_kv:
|
||||
raise NotImplementedError(
|
||||
"FP8 KV cache requires head_size divisible by 32 on CPU."
|
||||
)
|
||||
return "vec16"
|
||||
supports_amx = torch.cpu._is_amx_tile_supported()
|
||||
supports_arm = current_platform.get_cpu_architecture() == CpuArchEnum.ARM
|
||||
supports_vxe = current_platform.get_cpu_architecture() == CpuArchEnum.S390X
|
||||
supports_avx512 = torch.cpu._is_avx512_supported()
|
||||
if fp8_kv and not supports_amx and not supports_avx512:
|
||||
raise NotImplementedError(
|
||||
"FP8 KV cache on CPU requires x86 with AVX-512 or AMX."
|
||||
)
|
||||
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
|
||||
return "amx"
|
||||
elif block_size % 32 == 0:
|
||||
|
||||
Reference in New Issue
Block a user