diff --git a/csrc/cpu/cpu_attn_impl.hpp b/csrc/cpu/cpu_attn_impl.hpp index 2d0859a13db..70081b36ee5 100644 --- a/csrc/cpu/cpu_attn_impl.hpp +++ b/csrc/cpu/cpu_attn_impl.hpp @@ -408,9 +408,19 @@ class AttentionScheduler { const int64_t cache_size = cpu_utils::get_available_l2_size(); const int32_t max_num_q_per_iter = input.max_num_q_per_iter; const int32_t kv_len_alignment = input.kv_block_alignment; + bool has_decode_request = false; + bool decode_only_batch = true; + for (int32_t req_id = 0; req_id < input.num_reqs; ++req_id) { + const int32_t q_token_num = + input.query_start_loc[req_id + 1] - input.query_start_loc[req_id]; + has_decode_request = has_decode_request || (q_token_num == 1); + decode_only_batch = decode_only_batch && (q_token_num == 1); + } int32_t q_head_per_kv = input.num_heads_q / input.num_heads_kv; - const bool use_gqa = (max_num_q_per_iter % q_head_per_kv == 0); - if (!use_gqa) { + const bool supports_gqa = q_head_per_kv <= max_num_q_per_iter; + const bool use_gqa_fast_path = supports_gqa && decode_only_batch; + const bool use_gqa_scratchpad = supports_gqa && has_decode_request; + if (!use_gqa_scratchpad) { q_head_per_kv = 1; // fallback to MHA } const int32_t min_split_kv_len = @@ -680,7 +690,7 @@ class AttentionScheduler { metadata_ptr->attention_scratchpad_size_per_thread * metadata_ptr->thread_num + metadata_ptr->reduction_scratchpad_size_per_kv_head * - (use_gqa ? input.num_heads_kv : input.num_heads_q); + (use_gqa_fast_path ? input.num_heads_kv : input.num_heads_q); cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc( scratchpad_size); @@ -1409,13 +1419,24 @@ class AttentionMainLoop { const int32_t q_head_num = input->num_heads; const int32_t kv_head_num = input->num_kv_heads; const int32_t q_heads_per_kv = q_head_num / kv_head_num; - const bool use_gqa = - (max_q_head_num_per_iter % q_heads_per_kv == 0) ? true : false; - const int32_t actual_kv_head_num = use_gqa ? kv_head_num : q_head_num; - const int32_t actual_q_heads_per_kv = use_gqa ? q_heads_per_kv : 1; + AttentionWorkItemGroup* const workitem_groups = + metadata.workitem_groups_ptr; + const int32_t* cu_workitem_num_per_thread = + metadata.cu_workitem_num_per_thread; + ReductionWorkItemGroup* const reduction_items = + metadata.reduction_items_ptr; + const bool supports_gqa = q_heads_per_kv <= max_q_head_num_per_iter; + bool decode_only_batch = true; + for (int32_t i = 0; i < metadata.workitem_group_num; ++i) { + decode_only_batch = + decode_only_batch && (workitem_groups[i].q_token_num == 1); + } + const bool use_gqa_fast_path = supports_gqa && decode_only_batch; + const int32_t actual_kv_head_num = + use_gqa_fast_path ? kv_head_num : q_head_num; + const int32_t actual_q_heads_per_kv = + use_gqa_fast_path ? q_heads_per_kv : 1; TORCH_CHECK_LE(actual_q_heads_per_kv, max_q_head_num_per_iter); - const int32_t max_q_token_num_per_iter = - max_q_head_num_per_iter / actual_q_heads_per_kv; const int64_t q_token_num_stride = input->query_num_tokens_stride; const int64_t q_head_num_stride = input->query_num_heads_stride; const int64_t kv_cache_head_num_stride = input->cache_num_kv_heads_stride; @@ -1461,15 +1482,6 @@ class AttentionMainLoop { sizeof(q_buffer_t), sizeof(logits_buffer_t), sizeof(partial_output_buffer_t), max_q_head_num_per_iter, max_q_head_num_per_iter); - const int32_t default_q_tile_token_num = - default_tile_size / actual_q_heads_per_kv; - - AttentionWorkItemGroup* const workitem_groups = - metadata.workitem_groups_ptr; - const int32_t* cu_workitem_num_per_thread = - metadata.cu_workitem_num_per_thread; - ReductionWorkItemGroup* const reduction_items = - metadata.reduction_items_ptr; const int32_t effective_thread_num = metadata.effective_thread_num; const int32_t reduction_item_num = metadata.reduction_item_num; @@ -1513,8 +1525,6 @@ class AttentionMainLoop { cu_workitem_num_per_thread[thread_offset + 1] - cu_workitem_num_per_thread[thread_offset]; - const int32_t q_head_start_idx = kv_head_idx * actual_q_heads_per_kv; - for (int32_t workitem_group_idx = 0; workitem_group_idx < curr_workitem_groups_num; ++workitem_group_idx) { @@ -1529,6 +1539,21 @@ class AttentionMainLoop { const int32_t q_token_id_start = current_workitem_group->q_token_id_start; const int32_t q_token_num = current_workitem_group->q_token_num; + const bool curr_use_gqa = + use_gqa_fast_path || (supports_gqa && q_token_num == 1); + if (!use_gqa_fast_path && curr_use_gqa && + kv_head_idx % q_heads_per_kv != 0) { + continue; + } + const int32_t curr_q_heads_per_kv = + curr_use_gqa ? q_heads_per_kv : 1; + const int32_t curr_max_q_token_num_per_iter = + max_q_head_num_per_iter / curr_q_heads_per_kv; + const int32_t curr_default_q_tile_token_num = + default_tile_size / curr_q_heads_per_kv; + const int32_t q_head_start_idx = + use_gqa_fast_path ? (kv_head_idx * q_heads_per_kv) + : kv_head_idx; // taskgroup general information const int32_t q_end = input->query_start_loc[current_group_idx + 1]; @@ -1542,7 +1567,7 @@ class AttentionMainLoop { current_workitem_group->local_split_id == 0); for (int32_t q_token_offset = 0; q_token_offset < q_token_num; - q_token_offset += default_q_tile_token_num) { + q_token_offset += curr_default_q_tile_token_num) { bool first_iter_flag[AttentionScheduler::MaxQTileIterNum]; for (int32_t i = 0; i < AttentionScheduler::MaxQTileIterNum; ++i) { @@ -1552,9 +1577,9 @@ class AttentionMainLoop { const int32_t q_token_start_idx = q_start + q_token_offset + q_token_id_start; const int32_t actual_q_token_num = std::min( - default_q_tile_token_num, q_token_num - q_token_offset); + curr_default_q_tile_token_num, q_token_num - q_token_offset); const int32_t q_head_tile_size = - actual_q_token_num * actual_q_heads_per_kv; + actual_q_token_num * curr_q_heads_per_kv; const int32_t rounded_q_head_tile_size = ((q_head_tile_size + max_q_head_num_per_iter - 1) / max_q_head_num_per_iter) * @@ -1591,10 +1616,9 @@ class AttentionMainLoop { AttentionScheduler::align_kv_tile_pos( kv_tile_start_pos, kv_tile_end_pos, blocksize_alignment); - int32_t curr_kv_head_idx = - use_gqa ? kv_head_idx - : (kv_head_idx / - q_heads_per_kv); // for GQA disabled case + const int32_t curr_kv_head_idx = + use_gqa_fast_path ? kv_head_idx + : (kv_head_idx / q_heads_per_kv); // std::printf("thread_id: %d, req_id: %d, q_token_start: %d, // q_token_end: %d, q_head_start: %d, q_head_end: %d, kv_head_idx: @@ -1629,12 +1653,12 @@ class AttentionMainLoop { (s_aux != nullptr ? s_aux + q_head_start_idx : nullptr); // copy the Q tile to q_buffer, the logical layout of q_buffer is - // [actual_q_token_num, actual_q_heads_per_kv, head_dim] + // [actual_q_token_num, curr_q_heads_per_kv, head_dim] { attn_impl.copy_q_heads_tile( q_tile_ptr, q_buffer, actual_q_token_num, - actual_q_heads_per_kv, q_token_num_stride, - q_head_num_stride, scale); + curr_q_heads_per_kv, q_token_num_stride, q_head_num_stride, + scale); } if (use_sink) { @@ -1648,29 +1672,29 @@ class AttentionMainLoop { float* __restrict__ curr_max_buffer = max_buffer; for (int32_t token_idx = 0; token_idx < actual_q_token_num; ++token_idx) { - for (int32_t head_idx = 0; head_idx < actual_q_heads_per_kv; + for (int32_t head_idx = 0; head_idx < curr_q_heads_per_kv; ++head_idx) { curr_sum_buffer[head_idx] = 1.0f; curr_max_buffer[head_idx] = s_aux_fp32[head_idx]; } - curr_sum_buffer += actual_q_heads_per_kv; - curr_max_buffer += actual_q_heads_per_kv; + curr_sum_buffer += curr_q_heads_per_kv; + curr_max_buffer += curr_q_heads_per_kv; } } else { float* __restrict__ curr_sum_buffer = sum_buffer; float* __restrict__ curr_max_buffer = max_buffer; for (int32_t token_idx = 0; token_idx < actual_q_token_num; ++token_idx) { - for (int32_t head_idx = 0; head_idx < actual_q_heads_per_kv; + for (int32_t head_idx = 0; head_idx < curr_q_heads_per_kv; ++head_idx) { curr_sum_buffer[head_idx] = 0.0f; curr_max_buffer[head_idx] = std::numeric_limits::lowest(); } - curr_sum_buffer += actual_q_heads_per_kv; - curr_max_buffer += actual_q_heads_per_kv; + curr_sum_buffer += curr_q_heads_per_kv; + curr_max_buffer += curr_q_heads_per_kv; } } @@ -1683,16 +1707,17 @@ class AttentionMainLoop { kv_tile_pos_left + kv_tile_size, rounded_kv_tile_end_pos); for (int32_t q_head_tile_token_offset = 0; q_head_tile_token_offset < actual_q_token_num; - q_head_tile_token_offset += max_q_token_num_per_iter) { + q_head_tile_token_offset += + curr_max_q_token_num_per_iter) { const int32_t q_tile_pos_left = q_tile_start_pos + q_head_tile_token_offset; const int32_t q_tile_token_num = - std::min(max_q_token_num_per_iter, + std::min(curr_max_q_token_num_per_iter, actual_q_token_num - q_head_tile_token_offset); const int32_t q_tile_head_offset = - q_head_tile_token_offset * actual_q_heads_per_kv; + q_head_tile_token_offset * curr_q_heads_per_kv; const int32_t q_tile_head_num = - q_tile_token_num * actual_q_heads_per_kv; + q_tile_token_num * curr_q_heads_per_kv; const int32_t q_tile_pos_right = q_tile_pos_left + q_tile_token_num; const auto [actual_kv_tile_pos_left, @@ -1702,7 +1727,7 @@ class AttentionMainLoop { q_tile_pos_right, sliding_window_left, sliding_window_right); const int32_t q_iter_idx = - q_head_tile_token_offset / max_q_token_num_per_iter; + q_head_tile_token_offset / curr_max_q_token_num_per_iter; if (actual_kv_tile_pos_right <= actual_kv_tile_pos_left) { continue; @@ -1768,7 +1793,7 @@ class AttentionMainLoop { aligned_actual_kv_tile_pos_left, aligned_actual_kv_tile_pos_right, actual_kv_token_num, kv_cache_block_num_stride, q_tile_head_num, - q_tile_token_num, q_tile_pos_left, actual_q_heads_per_kv, + q_tile_token_num, q_tile_pos_left, curr_q_heads_per_kv, block_size, sliding_window_left, sliding_window_right, scale, softcap_scale, curr_alibi_slopes, first_iter_flag[q_iter_idx], use_sink, debug_info); @@ -1782,11 +1807,11 @@ class AttentionMainLoop { final_output(partial_q_buffer, reinterpret_cast(input->output) + output_buffer_offset, - sum_buffer, actual_q_heads_per_kv, + sum_buffer, curr_q_heads_per_kv, actual_q_token_num, q_head_num, output_v_scale); } else { const int32_t stride = - actual_q_heads_per_kv * split_kv_q_token_num_threshold; + curr_q_heads_per_kv * split_kv_q_token_num_threshold; buffer_manager.update(kv_head_idx, total_reduction_split_num, head_dim, stride, sizeof(float)); volatile bool* split_flag_buffer = @@ -1822,18 +1847,26 @@ class AttentionMainLoop { const int32_t curr_split_id = curr_workitem_groups->split_start_id; const int32_t curr_split_num = curr_workitem_groups->split_num; const int32_t current_group_idx = curr_workitem_groups->req_id; + const bool curr_use_gqa = + use_gqa_fast_path || (supports_gqa && curr_output_token_num == 1); + if (!use_gqa_fast_path && curr_use_gqa && + kv_head_idx % q_heads_per_kv != 0) { + continue; + } + const int32_t curr_q_heads_per_kv = curr_use_gqa ? q_heads_per_kv : 1; const int32_t curr_output_head_num = - curr_output_token_num * actual_q_heads_per_kv; + curr_output_token_num * curr_q_heads_per_kv; const int32_t q_start = input->query_start_loc[current_group_idx]; const int32_t q_token_start_idx = q_start + curr_output_token_idx; - const int32_t q_head_start_idx = kv_head_idx * actual_q_heads_per_kv; + const int32_t q_head_start_idx = + use_gqa_fast_path ? (kv_head_idx * q_heads_per_kv) : kv_head_idx; size_t output_buffer_offset = q_token_start_idx * q_head_num * head_dim + q_head_start_idx * head_dim; const int32_t stride = - actual_q_heads_per_kv * split_kv_q_token_num_threshold; + curr_q_heads_per_kv * split_kv_q_token_num_threshold; buffer_manager.update(kv_head_idx, total_reduction_split_num, head_dim, stride, sizeof(float)); volatile bool* split_flag_buffer = @@ -1852,7 +1885,7 @@ class AttentionMainLoop { final_output( split_output_buffer, reinterpret_cast(input->output) + output_buffer_offset, - split_sum_buffer, actual_q_heads_per_kv, curr_output_token_num, + split_sum_buffer, curr_q_heads_per_kv, curr_output_token_num, q_head_num, output_v_scale); } }