[CPU] Enable non-divisible GQA for decode workitems in mixed batches (#43032)

Signed-off-by: zhejiangxiaomai <zhenhui.zhao@intel.com>
This commit is contained in:
zhao, zhenhui
2026-05-26 14:15:47 +08:00
committed by GitHub
parent d56612c621
commit 771e1e48b1
+81 -48
View File
@@ -408,9 +408,19 @@ class AttentionScheduler {
const int64_t cache_size = cpu_utils::get_available_l2_size();
const int32_t max_num_q_per_iter = input.max_num_q_per_iter;
const int32_t kv_len_alignment = input.kv_block_alignment;
bool has_decode_request = false;
bool decode_only_batch = true;
for (int32_t req_id = 0; req_id < input.num_reqs; ++req_id) {
const int32_t q_token_num =
input.query_start_loc[req_id + 1] - input.query_start_loc[req_id];
has_decode_request = has_decode_request || (q_token_num == 1);
decode_only_batch = decode_only_batch && (q_token_num == 1);
}
int32_t q_head_per_kv = input.num_heads_q / input.num_heads_kv;
const bool use_gqa = (max_num_q_per_iter % q_head_per_kv == 0);
if (!use_gqa) {
const bool supports_gqa = q_head_per_kv <= max_num_q_per_iter;
const bool use_gqa_fast_path = supports_gqa && decode_only_batch;
const bool use_gqa_scratchpad = supports_gqa && has_decode_request;
if (!use_gqa_scratchpad) {
q_head_per_kv = 1; // fallback to MHA
}
const int32_t min_split_kv_len =
@@ -680,7 +690,7 @@ class AttentionScheduler {
metadata_ptr->attention_scratchpad_size_per_thread *
metadata_ptr->thread_num +
metadata_ptr->reduction_scratchpad_size_per_kv_head *
(use_gqa ? input.num_heads_kv : input.num_heads_q);
(use_gqa_fast_path ? input.num_heads_kv : input.num_heads_q);
cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(
scratchpad_size);
@@ -1409,13 +1419,24 @@ class AttentionMainLoop {
const int32_t q_head_num = input->num_heads;
const int32_t kv_head_num = input->num_kv_heads;
const int32_t q_heads_per_kv = q_head_num / kv_head_num;
const bool use_gqa =
(max_q_head_num_per_iter % q_heads_per_kv == 0) ? true : false;
const int32_t actual_kv_head_num = use_gqa ? kv_head_num : q_head_num;
const int32_t actual_q_heads_per_kv = use_gqa ? q_heads_per_kv : 1;
AttentionWorkItemGroup* const workitem_groups =
metadata.workitem_groups_ptr;
const int32_t* cu_workitem_num_per_thread =
metadata.cu_workitem_num_per_thread;
ReductionWorkItemGroup* const reduction_items =
metadata.reduction_items_ptr;
const bool supports_gqa = q_heads_per_kv <= max_q_head_num_per_iter;
bool decode_only_batch = true;
for (int32_t i = 0; i < metadata.workitem_group_num; ++i) {
decode_only_batch =
decode_only_batch && (workitem_groups[i].q_token_num == 1);
}
const bool use_gqa_fast_path = supports_gqa && decode_only_batch;
const int32_t actual_kv_head_num =
use_gqa_fast_path ? kv_head_num : q_head_num;
const int32_t actual_q_heads_per_kv =
use_gqa_fast_path ? q_heads_per_kv : 1;
TORCH_CHECK_LE(actual_q_heads_per_kv, max_q_head_num_per_iter);
const int32_t max_q_token_num_per_iter =
max_q_head_num_per_iter / actual_q_heads_per_kv;
const int64_t q_token_num_stride = input->query_num_tokens_stride;
const int64_t q_head_num_stride = input->query_num_heads_stride;
const int64_t kv_cache_head_num_stride = input->cache_num_kv_heads_stride;
@@ -1461,15 +1482,6 @@ class AttentionMainLoop {
sizeof(q_buffer_t), sizeof(logits_buffer_t),
sizeof(partial_output_buffer_t), max_q_head_num_per_iter,
max_q_head_num_per_iter);
const int32_t default_q_tile_token_num =
default_tile_size / actual_q_heads_per_kv;
AttentionWorkItemGroup* const workitem_groups =
metadata.workitem_groups_ptr;
const int32_t* cu_workitem_num_per_thread =
metadata.cu_workitem_num_per_thread;
ReductionWorkItemGroup* const reduction_items =
metadata.reduction_items_ptr;
const int32_t effective_thread_num = metadata.effective_thread_num;
const int32_t reduction_item_num = metadata.reduction_item_num;
@@ -1513,8 +1525,6 @@ class AttentionMainLoop {
cu_workitem_num_per_thread[thread_offset + 1] -
cu_workitem_num_per_thread[thread_offset];
const int32_t q_head_start_idx = kv_head_idx * actual_q_heads_per_kv;
for (int32_t workitem_group_idx = 0;
workitem_group_idx < curr_workitem_groups_num;
++workitem_group_idx) {
@@ -1529,6 +1539,21 @@ class AttentionMainLoop {
const int32_t q_token_id_start =
current_workitem_group->q_token_id_start;
const int32_t q_token_num = current_workitem_group->q_token_num;
const bool curr_use_gqa =
use_gqa_fast_path || (supports_gqa && q_token_num == 1);
if (!use_gqa_fast_path && curr_use_gqa &&
kv_head_idx % q_heads_per_kv != 0) {
continue;
}
const int32_t curr_q_heads_per_kv =
curr_use_gqa ? q_heads_per_kv : 1;
const int32_t curr_max_q_token_num_per_iter =
max_q_head_num_per_iter / curr_q_heads_per_kv;
const int32_t curr_default_q_tile_token_num =
default_tile_size / curr_q_heads_per_kv;
const int32_t q_head_start_idx =
use_gqa_fast_path ? (kv_head_idx * q_heads_per_kv)
: kv_head_idx;
// taskgroup general information
const int32_t q_end = input->query_start_loc[current_group_idx + 1];
@@ -1542,7 +1567,7 @@ class AttentionMainLoop {
current_workitem_group->local_split_id == 0);
for (int32_t q_token_offset = 0; q_token_offset < q_token_num;
q_token_offset += default_q_tile_token_num) {
q_token_offset += curr_default_q_tile_token_num) {
bool first_iter_flag[AttentionScheduler::MaxQTileIterNum];
for (int32_t i = 0; i < AttentionScheduler::MaxQTileIterNum;
++i) {
@@ -1552,9 +1577,9 @@ class AttentionMainLoop {
const int32_t q_token_start_idx =
q_start + q_token_offset + q_token_id_start;
const int32_t actual_q_token_num = std::min(
default_q_tile_token_num, q_token_num - q_token_offset);
curr_default_q_tile_token_num, q_token_num - q_token_offset);
const int32_t q_head_tile_size =
actual_q_token_num * actual_q_heads_per_kv;
actual_q_token_num * curr_q_heads_per_kv;
const int32_t rounded_q_head_tile_size =
((q_head_tile_size + max_q_head_num_per_iter - 1) /
max_q_head_num_per_iter) *
@@ -1591,10 +1616,9 @@ class AttentionMainLoop {
AttentionScheduler::align_kv_tile_pos(
kv_tile_start_pos, kv_tile_end_pos, blocksize_alignment);
int32_t curr_kv_head_idx =
use_gqa ? kv_head_idx
: (kv_head_idx /
q_heads_per_kv); // for GQA disabled case
const int32_t curr_kv_head_idx =
use_gqa_fast_path ? kv_head_idx
: (kv_head_idx / q_heads_per_kv);
// std::printf("thread_id: %d, req_id: %d, q_token_start: %d,
// q_token_end: %d, q_head_start: %d, q_head_end: %d, kv_head_idx:
@@ -1629,12 +1653,12 @@ class AttentionMainLoop {
(s_aux != nullptr ? s_aux + q_head_start_idx : nullptr);
// copy the Q tile to q_buffer, the logical layout of q_buffer is
// [actual_q_token_num, actual_q_heads_per_kv, head_dim]
// [actual_q_token_num, curr_q_heads_per_kv, head_dim]
{
attn_impl.copy_q_heads_tile(
q_tile_ptr, q_buffer, actual_q_token_num,
actual_q_heads_per_kv, q_token_num_stride,
q_head_num_stride, scale);
curr_q_heads_per_kv, q_token_num_stride, q_head_num_stride,
scale);
}
if (use_sink) {
@@ -1648,29 +1672,29 @@ class AttentionMainLoop {
float* __restrict__ curr_max_buffer = max_buffer;
for (int32_t token_idx = 0; token_idx < actual_q_token_num;
++token_idx) {
for (int32_t head_idx = 0; head_idx < actual_q_heads_per_kv;
for (int32_t head_idx = 0; head_idx < curr_q_heads_per_kv;
++head_idx) {
curr_sum_buffer[head_idx] = 1.0f;
curr_max_buffer[head_idx] = s_aux_fp32[head_idx];
}
curr_sum_buffer += actual_q_heads_per_kv;
curr_max_buffer += actual_q_heads_per_kv;
curr_sum_buffer += curr_q_heads_per_kv;
curr_max_buffer += curr_q_heads_per_kv;
}
} else {
float* __restrict__ curr_sum_buffer = sum_buffer;
float* __restrict__ curr_max_buffer = max_buffer;
for (int32_t token_idx = 0; token_idx < actual_q_token_num;
++token_idx) {
for (int32_t head_idx = 0; head_idx < actual_q_heads_per_kv;
for (int32_t head_idx = 0; head_idx < curr_q_heads_per_kv;
++head_idx) {
curr_sum_buffer[head_idx] = 0.0f;
curr_max_buffer[head_idx] =
std::numeric_limits<float>::lowest();
}
curr_sum_buffer += actual_q_heads_per_kv;
curr_max_buffer += actual_q_heads_per_kv;
curr_sum_buffer += curr_q_heads_per_kv;
curr_max_buffer += curr_q_heads_per_kv;
}
}
@@ -1683,16 +1707,17 @@ class AttentionMainLoop {
kv_tile_pos_left + kv_tile_size, rounded_kv_tile_end_pos);
for (int32_t q_head_tile_token_offset = 0;
q_head_tile_token_offset < actual_q_token_num;
q_head_tile_token_offset += max_q_token_num_per_iter) {
q_head_tile_token_offset +=
curr_max_q_token_num_per_iter) {
const int32_t q_tile_pos_left =
q_tile_start_pos + q_head_tile_token_offset;
const int32_t q_tile_token_num =
std::min(max_q_token_num_per_iter,
std::min(curr_max_q_token_num_per_iter,
actual_q_token_num - q_head_tile_token_offset);
const int32_t q_tile_head_offset =
q_head_tile_token_offset * actual_q_heads_per_kv;
q_head_tile_token_offset * curr_q_heads_per_kv;
const int32_t q_tile_head_num =
q_tile_token_num * actual_q_heads_per_kv;
q_tile_token_num * curr_q_heads_per_kv;
const int32_t q_tile_pos_right =
q_tile_pos_left + q_tile_token_num;
const auto [actual_kv_tile_pos_left,
@@ -1702,7 +1727,7 @@ class AttentionMainLoop {
q_tile_pos_right, sliding_window_left,
sliding_window_right);
const int32_t q_iter_idx =
q_head_tile_token_offset / max_q_token_num_per_iter;
q_head_tile_token_offset / curr_max_q_token_num_per_iter;
if (actual_kv_tile_pos_right <= actual_kv_tile_pos_left) {
continue;
@@ -1768,7 +1793,7 @@ class AttentionMainLoop {
aligned_actual_kv_tile_pos_left,
aligned_actual_kv_tile_pos_right, actual_kv_token_num,
kv_cache_block_num_stride, q_tile_head_num,
q_tile_token_num, q_tile_pos_left, actual_q_heads_per_kv,
q_tile_token_num, q_tile_pos_left, curr_q_heads_per_kv,
block_size, sliding_window_left, sliding_window_right,
scale, softcap_scale, curr_alibi_slopes,
first_iter_flag[q_iter_idx], use_sink, debug_info);
@@ -1782,11 +1807,11 @@ class AttentionMainLoop {
final_output(partial_q_buffer,
reinterpret_cast<query_t*>(input->output) +
output_buffer_offset,
sum_buffer, actual_q_heads_per_kv,
sum_buffer, curr_q_heads_per_kv,
actual_q_token_num, q_head_num, output_v_scale);
} else {
const int32_t stride =
actual_q_heads_per_kv * split_kv_q_token_num_threshold;
curr_q_heads_per_kv * split_kv_q_token_num_threshold;
buffer_manager.update(kv_head_idx, total_reduction_split_num,
head_dim, stride, sizeof(float));
volatile bool* split_flag_buffer =
@@ -1822,18 +1847,26 @@ class AttentionMainLoop {
const int32_t curr_split_id = curr_workitem_groups->split_start_id;
const int32_t curr_split_num = curr_workitem_groups->split_num;
const int32_t current_group_idx = curr_workitem_groups->req_id;
const bool curr_use_gqa =
use_gqa_fast_path || (supports_gqa && curr_output_token_num == 1);
if (!use_gqa_fast_path && curr_use_gqa &&
kv_head_idx % q_heads_per_kv != 0) {
continue;
}
const int32_t curr_q_heads_per_kv = curr_use_gqa ? q_heads_per_kv : 1;
const int32_t curr_output_head_num =
curr_output_token_num * actual_q_heads_per_kv;
curr_output_token_num * curr_q_heads_per_kv;
const int32_t q_start = input->query_start_loc[current_group_idx];
const int32_t q_token_start_idx = q_start + curr_output_token_idx;
const int32_t q_head_start_idx = kv_head_idx * actual_q_heads_per_kv;
const int32_t q_head_start_idx =
use_gqa_fast_path ? (kv_head_idx * q_heads_per_kv) : kv_head_idx;
size_t output_buffer_offset =
q_token_start_idx * q_head_num * head_dim +
q_head_start_idx * head_dim;
const int32_t stride =
actual_q_heads_per_kv * split_kv_q_token_num_threshold;
curr_q_heads_per_kv * split_kv_q_token_num_threshold;
buffer_manager.update(kv_head_idx, total_reduction_split_num,
head_dim, stride, sizeof(float));
volatile bool* split_flag_buffer =
@@ -1852,7 +1885,7 @@ class AttentionMainLoop {
final_output(
split_output_buffer,
reinterpret_cast<query_t*>(input->output) + output_buffer_offset,
split_sum_buffer, actual_q_heads_per_kv, curr_output_token_num,
split_sum_buffer, curr_q_heads_per_kv, curr_output_token_num,
q_head_num, output_v_scale);
}
}