mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[CPU] Enable non-divisible GQA for decode workitems in mixed batches (#43032)
Signed-off-by: zhejiangxiaomai <zhenhui.zhao@intel.com>
This commit is contained in:
+81
-48
@@ -408,9 +408,19 @@ class AttentionScheduler {
|
||||
const int64_t cache_size = cpu_utils::get_available_l2_size();
|
||||
const int32_t max_num_q_per_iter = input.max_num_q_per_iter;
|
||||
const int32_t kv_len_alignment = input.kv_block_alignment;
|
||||
bool has_decode_request = false;
|
||||
bool decode_only_batch = true;
|
||||
for (int32_t req_id = 0; req_id < input.num_reqs; ++req_id) {
|
||||
const int32_t q_token_num =
|
||||
input.query_start_loc[req_id + 1] - input.query_start_loc[req_id];
|
||||
has_decode_request = has_decode_request || (q_token_num == 1);
|
||||
decode_only_batch = decode_only_batch && (q_token_num == 1);
|
||||
}
|
||||
int32_t q_head_per_kv = input.num_heads_q / input.num_heads_kv;
|
||||
const bool use_gqa = (max_num_q_per_iter % q_head_per_kv == 0);
|
||||
if (!use_gqa) {
|
||||
const bool supports_gqa = q_head_per_kv <= max_num_q_per_iter;
|
||||
const bool use_gqa_fast_path = supports_gqa && decode_only_batch;
|
||||
const bool use_gqa_scratchpad = supports_gqa && has_decode_request;
|
||||
if (!use_gqa_scratchpad) {
|
||||
q_head_per_kv = 1; // fallback to MHA
|
||||
}
|
||||
const int32_t min_split_kv_len =
|
||||
@@ -680,7 +690,7 @@ class AttentionScheduler {
|
||||
metadata_ptr->attention_scratchpad_size_per_thread *
|
||||
metadata_ptr->thread_num +
|
||||
metadata_ptr->reduction_scratchpad_size_per_kv_head *
|
||||
(use_gqa ? input.num_heads_kv : input.num_heads_q);
|
||||
(use_gqa_fast_path ? input.num_heads_kv : input.num_heads_q);
|
||||
cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(
|
||||
scratchpad_size);
|
||||
|
||||
@@ -1409,13 +1419,24 @@ class AttentionMainLoop {
|
||||
const int32_t q_head_num = input->num_heads;
|
||||
const int32_t kv_head_num = input->num_kv_heads;
|
||||
const int32_t q_heads_per_kv = q_head_num / kv_head_num;
|
||||
const bool use_gqa =
|
||||
(max_q_head_num_per_iter % q_heads_per_kv == 0) ? true : false;
|
||||
const int32_t actual_kv_head_num = use_gqa ? kv_head_num : q_head_num;
|
||||
const int32_t actual_q_heads_per_kv = use_gqa ? q_heads_per_kv : 1;
|
||||
AttentionWorkItemGroup* const workitem_groups =
|
||||
metadata.workitem_groups_ptr;
|
||||
const int32_t* cu_workitem_num_per_thread =
|
||||
metadata.cu_workitem_num_per_thread;
|
||||
ReductionWorkItemGroup* const reduction_items =
|
||||
metadata.reduction_items_ptr;
|
||||
const bool supports_gqa = q_heads_per_kv <= max_q_head_num_per_iter;
|
||||
bool decode_only_batch = true;
|
||||
for (int32_t i = 0; i < metadata.workitem_group_num; ++i) {
|
||||
decode_only_batch =
|
||||
decode_only_batch && (workitem_groups[i].q_token_num == 1);
|
||||
}
|
||||
const bool use_gqa_fast_path = supports_gqa && decode_only_batch;
|
||||
const int32_t actual_kv_head_num =
|
||||
use_gqa_fast_path ? kv_head_num : q_head_num;
|
||||
const int32_t actual_q_heads_per_kv =
|
||||
use_gqa_fast_path ? q_heads_per_kv : 1;
|
||||
TORCH_CHECK_LE(actual_q_heads_per_kv, max_q_head_num_per_iter);
|
||||
const int32_t max_q_token_num_per_iter =
|
||||
max_q_head_num_per_iter / actual_q_heads_per_kv;
|
||||
const int64_t q_token_num_stride = input->query_num_tokens_stride;
|
||||
const int64_t q_head_num_stride = input->query_num_heads_stride;
|
||||
const int64_t kv_cache_head_num_stride = input->cache_num_kv_heads_stride;
|
||||
@@ -1461,15 +1482,6 @@ class AttentionMainLoop {
|
||||
sizeof(q_buffer_t), sizeof(logits_buffer_t),
|
||||
sizeof(partial_output_buffer_t), max_q_head_num_per_iter,
|
||||
max_q_head_num_per_iter);
|
||||
const int32_t default_q_tile_token_num =
|
||||
default_tile_size / actual_q_heads_per_kv;
|
||||
|
||||
AttentionWorkItemGroup* const workitem_groups =
|
||||
metadata.workitem_groups_ptr;
|
||||
const int32_t* cu_workitem_num_per_thread =
|
||||
metadata.cu_workitem_num_per_thread;
|
||||
ReductionWorkItemGroup* const reduction_items =
|
||||
metadata.reduction_items_ptr;
|
||||
|
||||
const int32_t effective_thread_num = metadata.effective_thread_num;
|
||||
const int32_t reduction_item_num = metadata.reduction_item_num;
|
||||
@@ -1513,8 +1525,6 @@ class AttentionMainLoop {
|
||||
cu_workitem_num_per_thread[thread_offset + 1] -
|
||||
cu_workitem_num_per_thread[thread_offset];
|
||||
|
||||
const int32_t q_head_start_idx = kv_head_idx * actual_q_heads_per_kv;
|
||||
|
||||
for (int32_t workitem_group_idx = 0;
|
||||
workitem_group_idx < curr_workitem_groups_num;
|
||||
++workitem_group_idx) {
|
||||
@@ -1529,6 +1539,21 @@ class AttentionMainLoop {
|
||||
const int32_t q_token_id_start =
|
||||
current_workitem_group->q_token_id_start;
|
||||
const int32_t q_token_num = current_workitem_group->q_token_num;
|
||||
const bool curr_use_gqa =
|
||||
use_gqa_fast_path || (supports_gqa && q_token_num == 1);
|
||||
if (!use_gqa_fast_path && curr_use_gqa &&
|
||||
kv_head_idx % q_heads_per_kv != 0) {
|
||||
continue;
|
||||
}
|
||||
const int32_t curr_q_heads_per_kv =
|
||||
curr_use_gqa ? q_heads_per_kv : 1;
|
||||
const int32_t curr_max_q_token_num_per_iter =
|
||||
max_q_head_num_per_iter / curr_q_heads_per_kv;
|
||||
const int32_t curr_default_q_tile_token_num =
|
||||
default_tile_size / curr_q_heads_per_kv;
|
||||
const int32_t q_head_start_idx =
|
||||
use_gqa_fast_path ? (kv_head_idx * q_heads_per_kv)
|
||||
: kv_head_idx;
|
||||
|
||||
// taskgroup general information
|
||||
const int32_t q_end = input->query_start_loc[current_group_idx + 1];
|
||||
@@ -1542,7 +1567,7 @@ class AttentionMainLoop {
|
||||
current_workitem_group->local_split_id == 0);
|
||||
|
||||
for (int32_t q_token_offset = 0; q_token_offset < q_token_num;
|
||||
q_token_offset += default_q_tile_token_num) {
|
||||
q_token_offset += curr_default_q_tile_token_num) {
|
||||
bool first_iter_flag[AttentionScheduler::MaxQTileIterNum];
|
||||
for (int32_t i = 0; i < AttentionScheduler::MaxQTileIterNum;
|
||||
++i) {
|
||||
@@ -1552,9 +1577,9 @@ class AttentionMainLoop {
|
||||
const int32_t q_token_start_idx =
|
||||
q_start + q_token_offset + q_token_id_start;
|
||||
const int32_t actual_q_token_num = std::min(
|
||||
default_q_tile_token_num, q_token_num - q_token_offset);
|
||||
curr_default_q_tile_token_num, q_token_num - q_token_offset);
|
||||
const int32_t q_head_tile_size =
|
||||
actual_q_token_num * actual_q_heads_per_kv;
|
||||
actual_q_token_num * curr_q_heads_per_kv;
|
||||
const int32_t rounded_q_head_tile_size =
|
||||
((q_head_tile_size + max_q_head_num_per_iter - 1) /
|
||||
max_q_head_num_per_iter) *
|
||||
@@ -1591,10 +1616,9 @@ class AttentionMainLoop {
|
||||
AttentionScheduler::align_kv_tile_pos(
|
||||
kv_tile_start_pos, kv_tile_end_pos, blocksize_alignment);
|
||||
|
||||
int32_t curr_kv_head_idx =
|
||||
use_gqa ? kv_head_idx
|
||||
: (kv_head_idx /
|
||||
q_heads_per_kv); // for GQA disabled case
|
||||
const int32_t curr_kv_head_idx =
|
||||
use_gqa_fast_path ? kv_head_idx
|
||||
: (kv_head_idx / q_heads_per_kv);
|
||||
|
||||
// std::printf("thread_id: %d, req_id: %d, q_token_start: %d,
|
||||
// q_token_end: %d, q_head_start: %d, q_head_end: %d, kv_head_idx:
|
||||
@@ -1629,12 +1653,12 @@ class AttentionMainLoop {
|
||||
(s_aux != nullptr ? s_aux + q_head_start_idx : nullptr);
|
||||
|
||||
// copy the Q tile to q_buffer, the logical layout of q_buffer is
|
||||
// [actual_q_token_num, actual_q_heads_per_kv, head_dim]
|
||||
// [actual_q_token_num, curr_q_heads_per_kv, head_dim]
|
||||
{
|
||||
attn_impl.copy_q_heads_tile(
|
||||
q_tile_ptr, q_buffer, actual_q_token_num,
|
||||
actual_q_heads_per_kv, q_token_num_stride,
|
||||
q_head_num_stride, scale);
|
||||
curr_q_heads_per_kv, q_token_num_stride, q_head_num_stride,
|
||||
scale);
|
||||
}
|
||||
|
||||
if (use_sink) {
|
||||
@@ -1648,29 +1672,29 @@ class AttentionMainLoop {
|
||||
float* __restrict__ curr_max_buffer = max_buffer;
|
||||
for (int32_t token_idx = 0; token_idx < actual_q_token_num;
|
||||
++token_idx) {
|
||||
for (int32_t head_idx = 0; head_idx < actual_q_heads_per_kv;
|
||||
for (int32_t head_idx = 0; head_idx < curr_q_heads_per_kv;
|
||||
++head_idx) {
|
||||
curr_sum_buffer[head_idx] = 1.0f;
|
||||
curr_max_buffer[head_idx] = s_aux_fp32[head_idx];
|
||||
}
|
||||
|
||||
curr_sum_buffer += actual_q_heads_per_kv;
|
||||
curr_max_buffer += actual_q_heads_per_kv;
|
||||
curr_sum_buffer += curr_q_heads_per_kv;
|
||||
curr_max_buffer += curr_q_heads_per_kv;
|
||||
}
|
||||
} else {
|
||||
float* __restrict__ curr_sum_buffer = sum_buffer;
|
||||
float* __restrict__ curr_max_buffer = max_buffer;
|
||||
for (int32_t token_idx = 0; token_idx < actual_q_token_num;
|
||||
++token_idx) {
|
||||
for (int32_t head_idx = 0; head_idx < actual_q_heads_per_kv;
|
||||
for (int32_t head_idx = 0; head_idx < curr_q_heads_per_kv;
|
||||
++head_idx) {
|
||||
curr_sum_buffer[head_idx] = 0.0f;
|
||||
curr_max_buffer[head_idx] =
|
||||
std::numeric_limits<float>::lowest();
|
||||
}
|
||||
|
||||
curr_sum_buffer += actual_q_heads_per_kv;
|
||||
curr_max_buffer += actual_q_heads_per_kv;
|
||||
curr_sum_buffer += curr_q_heads_per_kv;
|
||||
curr_max_buffer += curr_q_heads_per_kv;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1683,16 +1707,17 @@ class AttentionMainLoop {
|
||||
kv_tile_pos_left + kv_tile_size, rounded_kv_tile_end_pos);
|
||||
for (int32_t q_head_tile_token_offset = 0;
|
||||
q_head_tile_token_offset < actual_q_token_num;
|
||||
q_head_tile_token_offset += max_q_token_num_per_iter) {
|
||||
q_head_tile_token_offset +=
|
||||
curr_max_q_token_num_per_iter) {
|
||||
const int32_t q_tile_pos_left =
|
||||
q_tile_start_pos + q_head_tile_token_offset;
|
||||
const int32_t q_tile_token_num =
|
||||
std::min(max_q_token_num_per_iter,
|
||||
std::min(curr_max_q_token_num_per_iter,
|
||||
actual_q_token_num - q_head_tile_token_offset);
|
||||
const int32_t q_tile_head_offset =
|
||||
q_head_tile_token_offset * actual_q_heads_per_kv;
|
||||
q_head_tile_token_offset * curr_q_heads_per_kv;
|
||||
const int32_t q_tile_head_num =
|
||||
q_tile_token_num * actual_q_heads_per_kv;
|
||||
q_tile_token_num * curr_q_heads_per_kv;
|
||||
const int32_t q_tile_pos_right =
|
||||
q_tile_pos_left + q_tile_token_num;
|
||||
const auto [actual_kv_tile_pos_left,
|
||||
@@ -1702,7 +1727,7 @@ class AttentionMainLoop {
|
||||
q_tile_pos_right, sliding_window_left,
|
||||
sliding_window_right);
|
||||
const int32_t q_iter_idx =
|
||||
q_head_tile_token_offset / max_q_token_num_per_iter;
|
||||
q_head_tile_token_offset / curr_max_q_token_num_per_iter;
|
||||
|
||||
if (actual_kv_tile_pos_right <= actual_kv_tile_pos_left) {
|
||||
continue;
|
||||
@@ -1768,7 +1793,7 @@ class AttentionMainLoop {
|
||||
aligned_actual_kv_tile_pos_left,
|
||||
aligned_actual_kv_tile_pos_right, actual_kv_token_num,
|
||||
kv_cache_block_num_stride, q_tile_head_num,
|
||||
q_tile_token_num, q_tile_pos_left, actual_q_heads_per_kv,
|
||||
q_tile_token_num, q_tile_pos_left, curr_q_heads_per_kv,
|
||||
block_size, sliding_window_left, sliding_window_right,
|
||||
scale, softcap_scale, curr_alibi_slopes,
|
||||
first_iter_flag[q_iter_idx], use_sink, debug_info);
|
||||
@@ -1782,11 +1807,11 @@ class AttentionMainLoop {
|
||||
final_output(partial_q_buffer,
|
||||
reinterpret_cast<query_t*>(input->output) +
|
||||
output_buffer_offset,
|
||||
sum_buffer, actual_q_heads_per_kv,
|
||||
sum_buffer, curr_q_heads_per_kv,
|
||||
actual_q_token_num, q_head_num, output_v_scale);
|
||||
} else {
|
||||
const int32_t stride =
|
||||
actual_q_heads_per_kv * split_kv_q_token_num_threshold;
|
||||
curr_q_heads_per_kv * split_kv_q_token_num_threshold;
|
||||
buffer_manager.update(kv_head_idx, total_reduction_split_num,
|
||||
head_dim, stride, sizeof(float));
|
||||
volatile bool* split_flag_buffer =
|
||||
@@ -1822,18 +1847,26 @@ class AttentionMainLoop {
|
||||
const int32_t curr_split_id = curr_workitem_groups->split_start_id;
|
||||
const int32_t curr_split_num = curr_workitem_groups->split_num;
|
||||
const int32_t current_group_idx = curr_workitem_groups->req_id;
|
||||
const bool curr_use_gqa =
|
||||
use_gqa_fast_path || (supports_gqa && curr_output_token_num == 1);
|
||||
if (!use_gqa_fast_path && curr_use_gqa &&
|
||||
kv_head_idx % q_heads_per_kv != 0) {
|
||||
continue;
|
||||
}
|
||||
const int32_t curr_q_heads_per_kv = curr_use_gqa ? q_heads_per_kv : 1;
|
||||
const int32_t curr_output_head_num =
|
||||
curr_output_token_num * actual_q_heads_per_kv;
|
||||
curr_output_token_num * curr_q_heads_per_kv;
|
||||
|
||||
const int32_t q_start = input->query_start_loc[current_group_idx];
|
||||
const int32_t q_token_start_idx = q_start + curr_output_token_idx;
|
||||
const int32_t q_head_start_idx = kv_head_idx * actual_q_heads_per_kv;
|
||||
const int32_t q_head_start_idx =
|
||||
use_gqa_fast_path ? (kv_head_idx * q_heads_per_kv) : kv_head_idx;
|
||||
size_t output_buffer_offset =
|
||||
q_token_start_idx * q_head_num * head_dim +
|
||||
q_head_start_idx * head_dim;
|
||||
|
||||
const int32_t stride =
|
||||
actual_q_heads_per_kv * split_kv_q_token_num_threshold;
|
||||
curr_q_heads_per_kv * split_kv_q_token_num_threshold;
|
||||
buffer_manager.update(kv_head_idx, total_reduction_split_num,
|
||||
head_dim, stride, sizeof(float));
|
||||
volatile bool* split_flag_buffer =
|
||||
@@ -1852,7 +1885,7 @@ class AttentionMainLoop {
|
||||
final_output(
|
||||
split_output_buffer,
|
||||
reinterpret_cast<query_t*>(input->output) + output_buffer_offset,
|
||||
split_sum_buffer, actual_q_heads_per_kv, curr_output_token_num,
|
||||
split_sum_buffer, curr_q_heads_per_kv, curr_output_token_num,
|
||||
q_head_num, output_v_scale);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user