TensorRT-LLMs/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h
Bo Li a66eeab537
[TRTLLM-9805][feat] Skip Softmax Attention. (#9821)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
Co-authored-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
2025-12-21 02:52:42 -05:00

766 lines
30 KiB
C++

/*
* SPDX-FileCopyrightText: Copyright (c) 2011-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "fmha/alibi_params.h"
#include "fmha/hopper/fragment.h"
#include "fmha/hopper/utils_warpgroup.h"
#include "fmha/softmax.h"
#include "fmha/warpspec/circular_buffer.h"
#include "fmha/warpspec/dma.h"
#include "fmha/warpspec/epilogue.h"
namespace fmha
{
namespace ws
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// Template instruction traits to specialize structs
template <int, int, int, bool, bool> class Instruction_traits,
// Kernel Traits
typename Kernel_traits>
struct Compute
{
// The shared struct.
using Shared = typename Kernel_traits::Shared;
// The q, or kv tile reader.
using Circular_buffer_q_reader = typename Kernel_traits::Circular_buffer_q_reader;
using Circular_buffer_kv_reader = typename Kernel_traits::Circular_buffer_kv_reader;
// The instruction traits for BMM1.
using Traits_p = typename Kernel_traits::Traits_p;
// The instruction traits for BMM2.
using Traits_o = typename Kernel_traits::Traits_o;
// The CTA description for BMM1.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The CTA description for BMM2.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The Q shared memory tile.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The K shared memory tile.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The V shared memory tile.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The GMMA compute tile for BMM1.
using Compute_tile_p = typename Kernel_traits::Compute_tile_p;
// The GMMA compute tile for BMM2.
using Compute_tile_o = typename Kernel_traits::Compute_tile_o;
// The MMA tile for the BMM1.
using Mma_tile_p = typename Kernel_traits::Mma_tile_p;
// The MMA tile for the BMM2.
using Mma_tile_o = typename Kernel_traits::Mma_tile_o;
// The fragment of BMM1 output.
using Fragment_p = typename Compute_tile_o::Fragment;
// The global memory tile for storing BMM2 output.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
// Softmax
using Softmax = Softmax<Instruction_traits, Kernel_traits>;
// BMM2 epilogue
using Tile_o_epilogue = Tile_o_epilogue<Instruction_traits, Kernel_traits>;
// The step size of Q loop.
enum
{
STEP_Q = Kernel_traits::STEP_Q
};
// The step size of KV loop.
enum
{
STEP_KV = Kernel_traits::STEP_KV
};
// The number of compute groups (currently fixed at 2).
enum
{
NUM_COMPUTE_GROUPS = Kernel_traits::NUM_COMPUTE_GROUPS
};
// Whether we skip those masked tiles when causal mask is enabled ?
enum
{
SKIP_CAUSAL_MASK_TILES = Kernel_traits::CAUSAL_MASK && !Kernel_traits::USE_CUSTOM_MASK
};
// Whether we attend to the specific sliding window or chunk ?
enum
{
SLIDING_OR_CHUNKED_ATTENTION = Kernel_traits::SLIDING_OR_CHUNKED_ATTENTION
};
// Are we applying alibi bias (drop FMA optimizations for accuracy reasons).
enum
{
APPLY_ALIBI = Kernel_traits::APPLY_ALIBI
};
// Do we use custom mask input ?
enum
{
USE_CUSTOM_MASK = Kernel_traits::USE_CUSTOM_MASK
};
// Do we always need to apply the mask ?
enum
{
ALWAYS_APPLY_MASK = APPLY_ALIBI || USE_CUSTOM_MASK
};
// Enable mutex for overlapping mma and softmax instructions.
enum
{
ENABLE_MUTEX = Kernel_traits::ENABLE_MUTEX
};
// The head_dimension groups.
enum
{
D_GROUPS = Kernel_traits::D_GROUPS
};
// The MMA_K groups (corresponding to head_dimension groups).
enum
{
BMM1_MMAS_K_GROUPS = Kernel_traits::D_GROUPS
};
// The number of MMAS_K for each head_dimension group.
enum
{
BMM1_MMAS_K_PER_GROUP = Mma_tile_p::MMAS_K / BMM1_MMAS_K_GROUPS
};
// The MMA_K groups (corresponding to kv_step groups).
enum
{
BMM2_MMAS_K_GROUPS = Kernel_traits::BMM2_K_GROUPS
};
// The number of MMAS_K for each head_dimension group.
enum
{
BMM2_MMAS_K_PER_GROUP = Mma_tile_o::MMAS_K / BMM2_MMAS_K_GROUPS
};
// The tile size of V after head_dimension split.
enum
{
TILE_SIZE_V_PER_D_GROUP = STEP_KV * Kernel_traits::D_PER_GROUP
};
enum
{
TILE_SIZE_V = STEP_KV * Kernel_traits::DV
};
enum
{
TILE_BYTES_V_PER_D_GROUP = STEP_KV * Kernel_traits::D_BYTES_PER_GROUP
};
enum
{
TILE_BYTES_V_PER_K_GROUP = BMM2_MMAS_K_PER_GROUP * Kernel_traits::D_BYTES_PER_GROUP
};
// Named barrier for inter-warpgroup sync
enum
{
SYNC_BARRIER = Kernel_traits::MMA_SYNC_BARRIER_ID
};
// Whether Q and KV is in separate buffer, which means we need to consider different Q and KV lengths.
enum
{
SEPARATE_Q_KV_BUFFER = Kernel_traits::SEPARATE_Q_KV_BUFFER
};
enum
{
SAGE_BLOCK_SIZE_Q = Kernel_traits::SAGE_BLOCK_SIZE_Q
};
// sanitize 0 to -1, avoid DIV BY ZERO below
enum
{
SAGE_BLOCK_SIZE_K = Kernel_traits::SAGE_BLOCK_SIZE_K > 0 ? Kernel_traits::SAGE_BLOCK_SIZE_K : -1
};
enum
{
SAGE_BLOCK_SIZE_V = Kernel_traits::SAGE_BLOCK_SIZE_V > 0 ? Kernel_traits::SAGE_BLOCK_SIZE_V : -1
};
// BLOCK_SIZE_Q should be multiply of STEP_Q (usually 64) so that q scale can be fused into scale_bmm1
static_assert(SAGE_BLOCK_SIZE_Q < 0 || SAGE_BLOCK_SIZE_Q % STEP_Q == 0);
static_assert(SAGE_BLOCK_SIZE_K < 0 || SAGE_BLOCK_SIZE_K % 8 == 0); // 8 = columns of a gmma CORE
static_assert(SAGE_BLOCK_SIZE_V < 0 || SAGE_BLOCK_SIZE_V % 32 == 0); // 32 = K dimension of a qgmma
// SAGE_BLOCKS_PER_STEP_X is used to declare scale buffer like `float scales_k[SAGE_BLOCKS_PER_STEP_K];`
// if SAGE_BLOCKS_PER_STEP_X == 0, you will get `zero-sized variable is not allowed in device code`
// error from nvcc, so the minimal value have to be 1. But don't worry, unused local variables will
// be optimized out by compiler.
enum
{
SAGE_BLOCKS_PER_STEP_K = std::max(STEP_KV / SAGE_BLOCK_SIZE_K, 1)
};
enum
{
SAGE_BLOCKS_PER_STEP_V = std::max(STEP_KV / SAGE_BLOCK_SIZE_V, 1)
};
#define K_TILE_WAIT() \
int ready_k = cbr_k.peek(); \
if (!ready_k) \
{ \
cbr_k.wait(); \
}
#define KV_TILE_COMPLETE() \
cbr_k.complete(tidx == 0, cbr_k.ptr()); \
cbr_v.complete(tidx == 0, cbr_v.ptr()); \
cbr_k.advance(); \
cbr_v.advance();
#define COMPUTE_SINGLE_TILE(IS_FIRST_COL, APPLY_MASK) \
compute_single_tile<IS_FIRST_COL, APPLY_MASK>(params, ctile_p, softmax, ctile_o, p_max, p_sum, tidx, \
actual_kv_seqlen, alibi_head_scale, \
USE_CUSTOM_MASK ? (head_info.mask_sum_s + q_step_idx * STEP_Q + local_q_tile_offset) \
: (q_step_idx * STEP_Q + head_info.q_tile_offset), \
kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, \
&shared->skip_softmax_votes[kv_step_idx & 1][warpgroup_id], kv_step_idx == kv_idx_end - 1);
////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ int div_up(int a, int b)
{
return (a + b - 1) / b;
}
////////////////////////////////////////////////////////////////////////////////////////////////
// Compute the kv_left_mask_end and kv_right_mask_start, where mask is applied when kv_idx < kv_left_mask_end or
// kv_idx >= kv_right_mask_start.
template <typename Params>
inline __device__ std::pair<int, int> compute_kv_mask_start_end(
Params const& params, int const tile_offset_start, int const tile_offset_end, int const kv_idx_end)
{
// The kv_left_mask_end is 0 by default.
int kv_left_mask_end = 0;
// The kv_right_mask_start is kv_idx_end - 1 by default, which means only the last kv tile is masked.
int kv_right_mask_start = kv_idx_end - 1;
// Always apply mask is specified.
if constexpr (ALWAYS_APPLY_MASK)
{
return std::make_pair(0, 0);
}
// Is the chunked_attention used ?
bool is_chunked_attention = params.log2_chunked_attention_size > 0;
// The left mask is needed when we attend to a specific sliding window or chunk.
if constexpr (SLIDING_OR_CHUNKED_ATTENTION)
{
// The kv_left_mask_end is the start of the chunk.
kv_left_mask_end = div_up(is_chunked_attention
? ((tile_offset_end >> params.log2_chunked_attention_size) << params.log2_chunked_attention_size)
: (tile_offset_end + 1 - params.sliding_window_size),
STEP_KV);
}
// The right mask is needed when causal mask (including sliding_window_attention or chunked attention) is used.
if constexpr (SKIP_CAUSAL_MASK_TILES)
{
kv_right_mask_start = tile_offset_start / STEP_KV;
}
// Return the kv_left_mask_end and kv_right_mask_start.
return std::make_pair(kv_left_mask_end, kv_right_mask_start);
}
////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Params>
inline __device__ void run(int warpgroup_id, int tidx, Shared* shared, Params const& params)
{
auto head_tracker = shared->head_info_tracker[warpgroup_id].createReader();
auto cbr = shared->tma_q_tracker[warpgroup_id].createReader();
auto cbr_k = shared->tma_k_tracker.createReader();
auto cbr_v = shared->tma_v_tracker.createReader();
// Ctile_p initialize (relies on q_stage, kv_stage).
char* smem_q = reinterpret_cast<char*>(&shared->smem_q[warpgroup_id][0]);
char* smem_k = reinterpret_cast<char*>(&shared->smem_k[0]);
Compute_tile_p ctile_p(smem_q, smem_k);
// Softmax
Softmax softmax(params, tidx);
// Ctile_o initialize (relies on kv_stage).
uint32_t smem_v = __cvta_generic_to_shared(&shared->smem_v[0]);
Compute_tile_o ctile_o(0, smem_v);
// Mutex between two compute groups.
OrderedMutexAccessor mutex_accessor(shared->compute_mutex, warpgroup_id, SYNC_BARRIER);
// Notify warpgroup 0 to execute HGMMA first (overlap HGMMA and Softmax Math Instructions).
if (ENABLE_MUTEX && warpgroup_id == 1 && Kernel_traits::ELEMENT_BYTES == 2)
{
mutex_accessor.arrive();
}
// While loop for different heads.
while (true)
{
typename Shared::Head_info head_info = head_tracker.pop(true);
if (head_info.kv_steps == -1)
{
break;
}
int const kv_steps = head_info.kv_steps;
int const q_steps = head_info.q_steps;
int const local_q_tile_offset = head_info.local_q_tile_offset;
// The global q tile offset (based on past kv cache).
// Not used by custom mask input.
int const q_tile_offset = SEPARATE_Q_KV_BUFFER ? head_info.q_tile_offset : head_info.local_q_tile_offset;
int const actual_q_seqlen = head_info.actual_seqlen;
// Contiguous QKV FMHA assumes q, and kv have the same sequence length.
int const actual_kv_seqlen = SEPARATE_Q_KV_BUFFER ? head_info.actual_kv_seqlen : actual_q_seqlen;
// Update threshold of Skip-Softmax
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX)
{
softmax.skip_softmax_threshold = params.skip_softmax_threshold_scale_factor / actual_kv_seqlen;
}
// Calculate the alibi head_scaling_factor.
float alibi_head_scale
= APPLY_ALIBI ? get_alibi_head_scaling_factor<AlibiParams>(head_info.bidh, params.alibi_params) : 0.f;
// pre-compute the row of the scale for reuse
int sage_scale_row;
if constexpr (Kernel_traits::SAGE_ATTENTION)
{
sage_scale_row = head_info.bidb * params.h + head_info.bidh;
}
// BMM2 epilogue
Tile_o_epilogue tile_o_epilogue(params, head_info);
int q_step_idx = warpgroup_id;
// Compute work.
for (; q_step_idx < q_steps; q_step_idx += NUM_COMPUTE_GROUPS)
{
// Check whether it is a valid run of q steps.
int const q_offset = q_step_idx * STEP_Q + local_q_tile_offset;
bool const valid_run = q_offset < actual_q_seqlen;
// fuse the scale of q into scale_bmm1
if constexpr (SAGE_BLOCK_SIZE_Q > 0)
{
// I tried another implementation here: store original `scale_bmm1` to a local variable
// to avoid frequent `__ldg`. But experiment shows that the current one is faster.
// A bit counterintuitive.
auto const scale_bmm1 = params.scale_bmm1_d ? __ldg(params.scale_bmm1_d) : params.scale_bmm1;
int const idx = sage_scale_row * params.sage.q.max_nblock + q_offset / SAGE_BLOCK_SIZE_Q;
*(float*) (&softmax.scale_bmm1_)
= reinterpret_cast<float const&>(scale_bmm1) * __ldg(&params.sage.q.scales[idx]);
}
// KV tile is shared by two q tiles,
// so we need to consider the last compute group's q tile.
int const tile_offset_start = q_step_idx * STEP_Q + q_tile_offset;
int const tile_offset_end = tile_offset_start + STEP_Q - 1;
int const warpgroup_tile_offset_start = tile_offset_start - warpgroup_id * STEP_Q;
int const warpgroup_tile_offset_end
= tile_offset_start + (NUM_COMPUTE_GROUPS - warpgroup_id) * STEP_Q - 1;
// Compute the kv_idx start (inclusive) and end (exclusive).
auto const [kv_idx_start, kv_idx_end] = DMA<Kernel_traits>::Device::compute_kv_tile_idx(
params, warpgroup_tile_offset_start, warpgroup_tile_offset_end, kv_steps);
// Compute the kv_left_mask_end and kv_right_mask_start, where mask is applied when kv_idx <
// kv_left_mask_end or kv_idx >= kv_right_mask_start.
auto const [kv_left_mask_end, kv_right_mask_start]
= compute_kv_mask_start_end(params, tile_offset_start, tile_offset_end, kv_idx_end);
// The gmem O tile.
Gmem_tile_o gmem_o(params, head_info, *shared, tidx, q_step_idx * STEP_Q + local_q_tile_offset);
// Q ready to use in smem.
int ready = cbr.peek();
if (!ready)
{
cbr.wait();
}
static_assert(Mma_tile_p::CORES_M == 2);
float p_max[Mma_tile_p::CORES_M];
float p_sum[Mma_tile_p::CORES_M];
int kv_step_idx = kv_idx_start;
// First K tiles ready to use in smem.
K_TILE_WAIT();
// Need to apply mask if only kv tile exists.
if (kv_idx_start < kv_left_mask_end || kv_idx_start >= kv_right_mask_start)
{
COMPUTE_SINGLE_TILE(true, true);
}
else
{
COMPUTE_SINGLE_TILE(true, false);
}
KV_TILE_COMPLETE();
for (kv_step_idx += 1; kv_step_idx < kv_right_mask_start; ++kv_step_idx)
{
// Current step's K tiles ready to use in smem.
K_TILE_WAIT();
// Move kv tile to next buffer.
if (D_GROUPS > 1)
{
ctile_p.increment_gmma_desc_group();
}
else
{
ctile_p.increment_gmma_desc_b_group();
}
ctile_o.increment_gmma_desc_group();
// Apply the start mask only when sliding window attention is enabled.
if (kv_step_idx < kv_left_mask_end)
{
COMPUTE_SINGLE_TILE(false, true);
}
else
{
COMPUTE_SINGLE_TILE(false, false);
}
KV_TILE_COMPLETE();
}
// Always apply the mask in the end.
for (; kv_step_idx < kv_idx_end; ++kv_step_idx)
{
// Current step's K tiles ready to use in smem.
K_TILE_WAIT();
// Move kv tile to next buffer.
if (D_GROUPS > 1)
{
ctile_p.increment_gmma_desc_group();
}
else
{
ctile_p.increment_gmma_desc_b_group();
}
ctile_o.increment_gmma_desc_group();
COMPUTE_SINGLE_TILE(false, true);
KV_TILE_COMPLETE();
}
if (valid_run)
{
// Final step's update.
tile_o_epilogue.scale(ctile_o, p_max, p_sum);
// Store o_tile to gmem.
gmem_o.store(ctile_o.acc_);
}
// Move q, kv to next buffer.
ctile_p.increment_gmma_desc_a_group();
ctile_p.increment_gmma_desc_b_group();
ctile_o.increment_gmma_desc_group();
if constexpr (Kernel_traits::RETURN_SOFTMAX_STATS)
{
using Mma_tile = typename Traits_p::template Mma_tile<Cta_tile_o>;
fmha::Softmax_saver_tma<Cta_tile_o, Mma_tile> saver(params, head_info);
saver.store(p_sum, p_max, sqrtf(params.d), q_step_idx * STEP_Q, valid_run);
}
}
}
#ifdef SKIP_SOFTMAX_STAT
if (tidx == 0)
{
atomicAdd(params.skip_softmax_total_blocks, softmax.total_blocks);
atomicAdd(params.skip_softmax_skipped_blocks, softmax.skipped_blocks);
}
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////
template <bool IS_FIRST_COL, bool APPLY_MASK, typename Params>
inline __device__ void compute_single_tile(Params params, Compute_tile_p& ctile_p, Softmax& softmax,
Compute_tile_o& ctile_o, float (&p_max)[Mma_tile_p::CORES_M], float (&p_sum)[Mma_tile_p::CORES_M],
int const tidx, int const actual_kv_seqlen, float const alibi_head_scale, int const row_offset,
int const col_offset, int const sage_scale_row, Circular_buffer_q_reader& cbr, Circular_buffer_kv_reader& cbr_v,
OrderedMutexAccessor& mutex, uint32_t* skip_softmax_vote, bool complete = false)
{
// Skip-softmax vote initialization
if (tidx == 0)
{
// Note that we need a named_barrier_wait in compute_single_tile to make sure init is before voting.
*skip_softmax_vote = 1;
}
// load the scales of K/V from global memory
#define LOAD_SCALES_KV(dst, which, blocks_per_step, block_size) \
if constexpr (block_size > 0) \
{ \
const int _start = col_offset / block_size; \
const float* _src = params.sage.which.scales + sage_scale_row * params.sage.which.max_nblock + _start; \
const int _end = params.sage.which.max_nblock - _start; \
_Pragma("unroll") for (int _i = 0; _i < blocks_per_step; _i++) \
{ \
dst[_i] = _i < _end ? _src[_i] : 1.0f; \
} \
}
#define LOAD_SCALES_K(scales) LOAD_SCALES_KV(scales, k, SAGE_BLOCKS_PER_STEP_K, SAGE_BLOCK_SIZE_K)
#define LOAD_SCALES_V(scales) LOAD_SCALES_KV(scales, v, SAGE_BLOCKS_PER_STEP_V, SAGE_BLOCK_SIZE_V)
// Load the needed packed masks.
softmax.load_packed_mask(row_offset, col_offset);
// experiments show that here is the best place to load scales of K
float scales_k[SAGE_BLOCKS_PER_STEP_K];
LOAD_SCALES_K(scales_k)
// Wait until another warpgroup has already executed HGMMA.
if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 2)
{
mutex.wait();
}
// Ctile_p is only used once by each n step.
ctile_p.clear();
// If skip_softmax is enabled, make sure there is no racing between the initialization and writing of
// skip_softmax_vote.
named_barrier_wait(Kernel_traits::SKIP_SOFTMAX_BARRIER_ID + threadIdx.x / 128, 128);
// BMM1 (Q x K').
warpgroup_arrive();
// Only single K groups when sizeof(D) <= 128B.
#pragma unroll
for (int kbi = 0; kbi < BMM1_MMAS_K_GROUPS - 1; kbi++)
{
#pragma unroll
for (int ki = 0; ki < BMM1_MMAS_K_PER_GROUP; ki++)
{
ctile_p.compute(ki, false, ki == BMM1_MMAS_K_PER_GROUP - 1);
}
ctile_p.increment_gmma_desc_group();
}
#pragma unroll
for (int ki = 0; ki < BMM1_MMAS_K_PER_GROUP - 1; ki++)
{
ctile_p.compute(ki);
}
ctile_p.compute(BMM1_MMAS_K_PER_GROUP - 1, true, true);
warpgroup_commit();
warpgroup_wait<0>();
// Arrive when the last tile consumes the q tile.
if (complete)
{
cbr.complete(tidx == 0, cbr.ptr());
cbr.advance();
}
if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 2)
{
// Notify another warpgroup to execute HGMMA.
mutex.arrive();
}
if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1)
{
// Wait until another warpgroup has already executed QGMMA.
mutex.named_bar_wait();
}
// Fragment p for BMM2 input
Fragment_p frag_p[Mma_tile_o::MMAS_K];
// Unpack the elements from bmm1 output to floats.
softmax.unpack(ctile_p);
// apply the scales of K before softmax
if constexpr (SAGE_BLOCK_SIZE_K > 0)
{
#pragma unroll
for (int ni = 0; ni < Mma_tile_p::CORES_N; ni++)
{
float const scale_k = scales_k[SAGE_BLOCKS_PER_STEP_K * ni / Mma_tile_p::CORES_N];
#pragma unroll
for (int mi = 0; mi < Mma_tile_p::CORES_M; mi++)
{
softmax.elt_[mi][2 * ni] *= scale_k;
softmax.elt_[mi][2 * ni + 1] *= scale_k;
}
}
}
// Apply the alibi and mask.
softmax.apply_alibi_and_mask<APPLY_MASK>(
ctile_p, params.alibi_params, alibi_head_scale, actual_kv_seqlen, row_offset, col_offset);
// Softmax Exp, max/sum, and update scales. If returns false we skip the rest.
if (!softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum, skip_softmax_vote))
{
if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1)
{
// Notify another warpgroup to execute QGMMA.
mutex.named_bar_arrive();
}
// Need to wait V, otherwise compute-sanitizer synccheck will fail.
int ready2 = cbr_v.peek();
if (!ready2)
{
cbr_v.wait();
}
return;
}
// experiments show that here is the best place to load scales of V
float scales_v[SAGE_BLOCKS_PER_STEP_V];
LOAD_SCALES_V(scales_v)
// Update flash attention scales and pack it for BMM2
softmax.pack<IS_FIRST_COL>(ctile_o, frag_p);
if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1)
{
// Notify another warpgroup to execute QGMMA.
mutex.named_bar_arrive();
}
// Wait until v buffer is ready.
int ready = cbr_v.peek();
if (!ready)
{
cbr_v.wait();
}
warpgroup_arrive();
float last_scale_v;
// Apply the scale of V to partial result.
// Note 2 points:
// 1. Because the matrix V is quantized along the inner dimension, it is necessary to interrupt
// the MMA workflow after processing each BLOCKS_SIZE_V rows of V and scale the intermediate
// results once. For example, STEP_KV=256, qgmma.K=32, then 256/32=8 MMAs are needs,
// so mma_ki = [0,1,2, ..., 7]. If the BLOCK_SIZE_V=64, then after each 2 qgmmas we should scale
// ctile_o.
// 2. The ctile_o is all zero at the beginning. if we directly apply the scale of V after each 2
// qgmmas, let's see what happens:
// ctile_o = [0]
// ctile_o = (ctile_o + P0 x V0) * s0 = P0 x V0 * s0
// ctile_o = (ctile_o + P1 x V1) * s1 = P0 x V0 * s0 * s1 + P1 x V1 * s1
// ctile_o = (ctile_o + P2 x V2) * s2 = P0 x V0 * s0 * s1 * s2 + P1 x V1 * s1 * s2 + P2 x V2 * s2
// ...
// As you see, the actual scale of a V block is the cumulative product of the scales of all
// later blocks. To solve this, we have to preprocess the scale s[i] of block[i] to s[i]/s[i+1],
// and the final block uses the actual scale.
// But to fetch the next scale in next STEP leads to bad performance. So we apply s[i-1]/s[i] to
// current partial result BEFORE each V block.
#define APPLY_SCALE_V(mma_ki) \
if constexpr (SAGE_BLOCK_SIZE_V > 0) \
{ \
if (mma_ki % (Mma_tile_o::MMAS_K / SAGE_BLOCKS_PER_STEP_V) == 0) \
{ \
float _scale_v = scales_v[SAGE_BLOCKS_PER_STEP_V * mma_ki / Mma_tile_o::MMAS_K]; \
if (mma_ki != 0) \
{ \
warpgroup_commit(); \
warpgroup_wait<0>(); \
} \
last_scale_v = _scale_v; \
} \
}
// BMM2 (S * V).
#pragma unroll
for (int kbi = 0; kbi < BMM2_MMAS_K_GROUPS - 1; kbi++)
{
#pragma unroll
for (int ki = 0; ki < BMM2_MMAS_K_PER_GROUP; ++ki)
{
int const mma_ki = kbi * BMM2_MMAS_K_PER_GROUP + ki;
APPLY_SCALE_V(mma_ki)
ctile_o.fill_frag_a(frag_p[mma_ki]);
ctile_o.compute(ki, false, ki == BMM2_MMAS_K_PER_GROUP - 1);
}
ctile_o.increment_gmma_desc_group();
}
#pragma unroll
for (int ki = 0; ki < BMM2_MMAS_K_PER_GROUP - 1; ++ki)
{
int const mma_ki = (BMM2_MMAS_K_GROUPS - 1) * BMM2_MMAS_K_PER_GROUP + ki;
APPLY_SCALE_V(mma_ki)
ctile_o.fill_frag_a(frag_p[mma_ki]);
ctile_o.compute(ki);
}
APPLY_SCALE_V((Mma_tile_o::MMAS_K - 1))
ctile_o.fill_frag_a(frag_p[Mma_tile_o::MMAS_K - 1]);
ctile_o.compute(Mma_tile_o::MMAS_K - 1, true, true);
warpgroup_commit();
warpgroup_wait<0>();
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace ws
} // namespace fmha