update fmha_v2 (#4895)

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
This commit is contained in:
qsang-nv 2025-06-05 22:14:28 +08:00 committed by GitHub
parent 51652b9b2b
commit 180b91f957
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 219 additions and 75 deletions

View File

@ -157,7 +157,7 @@ def test_trtllm_context_mla_attention_fmha(dtype, s):
epsilon += ' -epsilon 0.03' epsilon += ' -epsilon 0.03'
sm_version = getSMVersion() sm_version = getSMVersion()
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89: if sm_version != 89:
pytest.skip("FP8 MLAs only supported on sm89 currently.") pytest.skip("FP8 MLAs only supported on sm89 currently.")
# Context phase kernels. # Context phase kernels.

View File

@ -189,8 +189,7 @@ namespace kernels
ns_close = r""" ns_close = r"""
// clang-format on // clang-format on
} // namespace kernels } // namespace kernels
} // namespace tensorrt_llm } // namespace tensorrt_llm""" if generate_cu_trtllm else ""
""" if generate_cu_trtllm else ""
copyright = '''\ copyright = '''\
/*************************************************************************************************** /***************************************************************************************************
@ -1344,7 +1343,7 @@ void {sliding_or_chunked_causal_kernel_name}_nl({params_type} params){{
#endif // sliding_or_chunked_causal_mask #endif // sliding_or_chunked_causal_mask
void {launcher_name}_nl({params_type} &params, void {launcher_name}_nl({fused_multihead_attention_params_v2_str} &params,
const Launch_params& launch_params, cudaStream_t stream){{ const Launch_params& launch_params, cudaStream_t stream){{
constexpr int loop_iters = {seq_len} / {noloop_step}; constexpr int loop_iters = {seq_len} / {noloop_step};
static_assert(loop_iters * {noloop_step} == {seq_len}, ""); static_assert(loop_iters * {noloop_step} == {seq_len}, "");
@ -1431,6 +1430,7 @@ using Ktraits = {kernel_traits_header}
{loop_step}, {loop_step},
{kv_loop_step}, {kv_loop_step},
{head_size}, {head_size},
{head_size_v},
{q_tile_buffers}, {q_tile_buffers},
{kv_tile_buffers}, {kv_tile_buffers},
NUM_COMPUTE_GROUPS, NUM_COMPUTE_GROUPS,
@ -1453,6 +1453,7 @@ using Ktraits_causal = {kernel_traits_header}
{loop_step}, {loop_step},
{kv_loop_step}, {kv_loop_step},
{head_size}, {head_size},
{head_size_v},
{q_tile_buffers}, {q_tile_buffers},
{kv_tile_buffers}, {kv_tile_buffers},
NUM_COMPUTE_GROUPS, NUM_COMPUTE_GROUPS,
@ -1472,6 +1473,7 @@ using Ktraits_sliding_or_chunked_causal = {kernel_traits_header}
{loop_step}, {loop_step},
{kv_loop_step}, {kv_loop_step},
{head_size}, {head_size},
{head_size_v},
{q_tile_buffers}, {q_tile_buffers},
{kv_tile_buffers}, {kv_tile_buffers},
NUM_COMPUTE_GROUPS, NUM_COMPUTE_GROUPS,
@ -1491,6 +1493,7 @@ using Ktraits_custom_mask = {kernel_traits_header}
{loop_step}, {loop_step},
{kv_loop_step}, {kv_loop_step},
{head_size}, {head_size},
{head_size_v},
{q_tile_buffers}, {q_tile_buffers},
{kv_tile_buffers}, {kv_tile_buffers},
NUM_COMPUTE_GROUPS, NUM_COMPUTE_GROUPS,
@ -2881,6 +2884,7 @@ def get_kernel_traits_code(specs_names):
{loop_step}, {loop_step},
{kv_loop_step}, {kv_loop_step},
{head_size}, {head_size},
{head_size_v},
{q_tile_buffers}, {q_tile_buffers},
{kv_tile_buffers}, {kv_tile_buffers},
NUM_COMPUTE_GROUPS, NUM_COMPUTE_GROUPS,
@ -3213,7 +3217,7 @@ def get_cubin_header(kernel_traits, specs_names):
return 'nullptr' return 'nullptr'
lname = kname.replace('_kernel', '') lname = kname.replace('_kernel', '')
mask_types = [ mask_types = [
'_sliding_window_causal', '_custom_mask', '_causal' '_sliding_or_chunked_causal', '_custom_mask', '_causal'
] ]
for mask_type in mask_types: for mask_type in mask_types:
lname = lname.replace(mask_type, '') lname = lname.replace(mask_type, '')
@ -3228,6 +3232,12 @@ def get_cubin_header(kernel_traits, specs_names):
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \ {cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \ {attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\ {is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
'''.format(**locals()) if 'sage' in kname and 'sm90' in kname else '''\
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
'''.format(**locals()) '''.format(**locals())
else: else:
code = '''\ code = '''\
@ -3332,7 +3342,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{metadata_v2} {metadata_v2}
}}; }};
{local_ns_close} {local_ns_close}
'''.format(**locals(), copyright=copyright) '''.format(**locals(), copyright=copyright)
else: else:
@ -3540,7 +3549,10 @@ def enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype='fp16'):
# Note this will be used in TRT-LLM. # Note this will be used in TRT-LLM.
def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'): def enumerate_hgmma_flash_warpspec_kernels(specs,
sm=90,
dtype='fp16',
head_size_v=0):
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1')) scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
@ -3563,6 +3575,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
dtype=dtype, dtype=dtype,
seq_len=0, # support any sequence length seq_len=0, # support any sequence length
head_size=[32, 40, 48, 64], head_size=[32, 40, 48, 64],
head_size_v=head_size_v,
warps_m=4, #4x1 warpgroups warps_m=4, #4x1 warpgroups
warps_n=1, warps_n=1,
version=2, version=2,
@ -3595,6 +3608,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
dtype=dtype, dtype=dtype,
seq_len=0, # support any sequence length seq_len=0, # support any sequence length
head_size=[72, 80, 96, 104, 128], head_size=[72, 80, 96, 104, 128],
head_size_v=head_size_v,
warps_m=4, #4x1 warpgroups warps_m=4, #4x1 warpgroups
warps_n=1, warps_n=1,
version=2, version=2,
@ -3627,6 +3641,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
dtype=dtype, dtype=dtype,
seq_len=0, # support any sequence length seq_len=0, # support any sequence length
head_size=[160, 192, 256], head_size=[160, 192, 256],
head_size_v=head_size_v,
warps_m=4, #4x1 warpgroups warps_m=4, #4x1 warpgroups
warps_n=1, warps_n=1,
version=2, version=2,
@ -3652,6 +3667,40 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
scheduling_mode=scheduling_mode, scheduling_mode=scheduling_mode,
input_layout=input_layout)) input_layout=input_layout))
# for deepseek context 192/128, kv_step=128
specs.append(
kernel_spec(
sm=sm,
sm_mma=90,
dtype=dtype,
seq_len=0, # support any sequence length
head_size=192,
head_size_v=128,
warps_m=4, #4x1 warpgroups
warps_n=1,
version=2,
interleaved=False,
ldgsts_q=
False, # for Hopper kernels, ldgsts = False signals TMA usage.
ldgsts_k=False,
ldgsts_v=False,
share_smem_k_v=False,
loop_step=64,
q_tile_buffers=1, # only used by warp specialized kernels
has_noloop=0,
noloop_step=64,
kv_loop_step=128,
kv_tile_buffers=2, # only used by warp specialized kernels
unroll_threshold=1,
has_scale_max=False,
flash_attention=True,
warp_specialization=True,
alibi=alibi,
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
return_softmax_stats=return_softmax,
scheduling_mode=scheduling_mode,
input_layout=input_layout))
# Note this will be used in TRT-LLM. # Note this will be used in TRT-LLM.
def enumerate_qgmma_flash_warpspec_kernels(specs, def enumerate_qgmma_flash_warpspec_kernels(specs,
@ -6215,7 +6264,21 @@ def enumerate_kernels():
and kspec.cross_mha == False and kspec.cross_mha == False
and kspec.flash_attention == True and kspec.flash_attention == True
and kspec.warp_specialization == False and kspec.warp_specialization == False
and kspec.tiled == True) and kspec.tiled == True
and not (kspec.sm == 90 and (kspec.head_size, kspec.head_size_v) == (192, 128)))
# Deepseek MLA (hopper-style context 192/128 packed + paged)
or (kspec.sm == 90
and kspec.dtype == 'bf16'
and kspec.head_size == 192
and kspec.head_size_v == 128
and kspec.sage_block_sizes is None
and kspec.version == 2
and kspec.cross_mha == False
and kspec.flash_attention == True
and kspec.warp_specialization == True
and kspec.input_layout in [InputLayout.PACKED_QKV, InputLayout.Q_PAGED_KV]
and kspec.alibi == False
and kspec.enable_attn_logit_softcapping == False)
# SageAttention (warp_spec, head_size in (80, 128), packed QKV, padding mask) # SageAttention (warp_spec, head_size in (80, 128), packed QKV, padding mask)
or (kspec.sm == 90 or (kspec.sm == 90
and kspec.head_size in [80, 128] and kspec.head_size in [80, 128]

View File

@ -173,7 +173,7 @@ struct Compute
enum enum
{ {
TILE_SIZE_V = STEP_KV * Kernel_traits::D TILE_SIZE_V = STEP_KV * Kernel_traits::DV
}; };
enum enum

View File

@ -76,7 +76,7 @@ struct DMA
// The tile size of V. // The tile size of V.
enum enum
{ {
TILE_SIZE_V = TILE_SIZE_K TILE_SIZE_V = STEP_KV * Kernel_traits::DV
}; };
// The tile size of V after head_dimension split. // The tile size of V after head_dimension split.
@ -280,6 +280,7 @@ struct DMA
cudaTmaDesc const* desc_q = &params.tma_desc_q; cudaTmaDesc const* desc_q = &params.tma_desc_q;
cudaTmaDesc const* desc_kv = &params.tma_desc_kv; cudaTmaDesc const* desc_kv = &params.tma_desc_kv;
cudaTmaDesc const* desc_v = &params.tma_desc_v;
int actual_seqlen; int actual_seqlen;
if (params.is_s_padded) if (params.is_s_padded)
{ {
@ -342,8 +343,8 @@ struct DMA
// Iterate over the kv tiles for this q step. // Iterate over the kv tiles for this q step.
for (int kv_step_idx = kv_idx_start; kv_step_idx < kv_idx_end; kv_step_idx++) for (int kv_step_idx = kv_idx_start; kv_step_idx < kv_idx_end; kv_step_idx++)
{ {
int bar_id = load_kv(bidh, params.h, params.h_kv, kv_step_idx, desc_kv, shared, cbw_k, cbw_v, int bar_id = load_kv(bidh, params.h, params.h_kv, kv_step_idx, desc_kv, desc_v, shared, cbw_k,
cbw_v_scratch, cbr_v_scratch); cbw_v, cbw_v_scratch, cbr_v_scratch);
// Opportunistically hide headinfo in the shadow of UTMALDGs of the QKV tensor // Opportunistically hide headinfo in the shadow of UTMALDGs of the QKV tensor
if (q_step_idx == 0 && kv_step_idx == kv_idx_start) if (q_step_idx == 0 && kv_step_idx == kv_idx_start)
@ -511,7 +512,17 @@ struct DMA
int32_t const* paged_block_offsets int32_t const* paged_block_offsets
= params.paged_kv_cache.mBlockOffsets + bidb * 2 * params.paged_kv_cache.mMaxBlocksPerSeq; = params.paged_kv_cache.mBlockOffsets + bidb * 2 * params.paged_kv_cache.mMaxBlocksPerSeq;
cudaTmaDesc const* desc_kv = &params.tma_desc_kv; cudaTmaDesc const* desc_kv = &params.tma_desc_kv;
// If a separate v_stride_in_bytes is set, we have to use separate tma_desc_v,
// otherwise share with tma_desc_kv.
// This is for the compatibility that TensorRT-LLM needs no modification if padding V to 192.
#ifndef GENERATE_CUBIN
cudaTmaDesc const* desc_v
= (params.v_stride_in_bytes == 0 || params.v_stride_in_bytes == params.kv_stride_in_bytes)
? desc_kv
: &params.tma_desc_v;
#else
cudaTmaDesc const* desc_v = desc_kv;
#endif
if (SCHEDULING_MODE == 0) if (SCHEDULING_MODE == 0)
{ {
// split work across M // split work across M
@ -575,7 +586,8 @@ struct DMA
bar_id = load_paged_kv(bidh_kv, remapped_kv_step_idx * STEP_KV, num_valid_kv_blocks, bar_id = load_paged_kv(bidh_kv, remapped_kv_step_idx * STEP_KV, num_valid_kv_blocks,
params.paged_kv_cache.mTokensPerBlockLog2, params.blocks_per_tma_load, params.paged_kv_cache.mTokensPerBlockLog2, params.blocks_per_tma_load,
params.blocks_per_tma_load_log2, params.paged_kv_cache.mMaxBlocksPerSeq, params.blocks_per_tma_load_log2, params.paged_kv_cache.mMaxBlocksPerSeq,
paged_block_offsets, desc_kv, shared, cbw_k, cbw_v, cbw_v_scratch, cbr_v_scratch); paged_block_offsets, desc_kv, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch,
cbr_v_scratch);
} }
else else
{ {
@ -670,7 +682,7 @@ struct DMA
// Load k,v tiles from gmem to smem by TMA. // Load k,v tiles from gmem to smem by TMA.
template <typename BufferWriter> template <typename BufferWriter>
inline __device__ void load_kv_impl(int bidh, int h, int h_kv, int kv_step_idx, cudaTmaDesc const* desc_kv, inline __device__ void load_kv_impl(int bidh, int h, int h_kv, int kv_step_idx, cudaTmaDesc const* desc_kv,
Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v) cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v)
{ {
int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES); int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES);
@ -685,8 +697,6 @@ struct DMA
// split D into multiple groups in order to satisfy the TMA 128B sizzle mode // split D into multiple groups in order to satisfy the TMA 128B sizzle mode
int32_t const k_coord_dim1 = HEADS_INTERLEAVED ? 1 : bidh; int32_t const k_coord_dim1 = HEADS_INTERLEAVED ? 1 : bidh;
int32_t const k_coord_dim2 = HEADS_INTERLEAVED ? bidh : 1; int32_t const k_coord_dim2 = HEADS_INTERLEAVED ? bidh : 1;
int32_t const v_coord_dim1 = HEADS_INTERLEAVED ? 2 : bidh;
int32_t const v_coord_dim2 = HEADS_INTERLEAVED ? bidh : 2;
#pragma unroll #pragma unroll
for (int di = 0; di < Kernel_traits::D_GROUPS; ++di) for (int di = 0; di < Kernel_traits::D_GROUPS; ++di)
@ -699,12 +709,14 @@ struct DMA
__cvta_generic_to_shared( __cvta_generic_to_shared(
&shared->smem_k[k_barrier_id * TILE_SIZE_K + di * TILE_SIZE_K_PER_D_GROUP]), &shared->smem_k[k_barrier_id * TILE_SIZE_K + di * TILE_SIZE_K_PER_D_GROUP]),
__cvta_generic_to_shared(cbw_k.barrier_ptr(k_barrier_id)), k_coords, elect_one_); __cvta_generic_to_shared(cbw_k.barrier_ptr(k_barrier_id)), k_coords, elect_one_);
}
#pragma unroll
for (int di = 0; di < Kernel_traits::DV_GROUPS; ++di)
{
int32_t const v_coords[4] = {di * Kernel_traits::D_PER_GROUP, int32_t const v_coords[4] = {di * Kernel_traits::D_PER_GROUP,
multi_query_attention_ ? h + h_kv + bidh / (h / h_kv) : v_coord_dim1, multi_query_attention_ ? bidh / (h / h_kv) : bidh, 0, sum_s_q_ + kv_step_idx * STEP_KV};
multi_query_attention_ ? 0 : v_coord_dim2, sum_s_q_ + kv_step_idx * STEP_KV};
fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_kv, fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_v,
__cvta_generic_to_shared( __cvta_generic_to_shared(
&shared->smem_v[v_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP]), &shared->smem_v[v_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP]),
__cvta_generic_to_shared(cbw_v.barrier_ptr(v_barrier_id)), v_coords, elect_one_); __cvta_generic_to_shared(cbw_v.barrier_ptr(v_barrier_id)), v_coords, elect_one_);
@ -748,8 +760,8 @@ struct DMA
template <typename BufferWriter> template <typename BufferWriter>
inline __device__ void load_paged_kv_impl(int bidh, int kv_tile_start_offset, int num_valid_kv_blocks, inline __device__ void load_paged_kv_impl(int bidh, int kv_tile_start_offset, int num_valid_kv_blocks,
int tokens_per_block_log2, int blocks_per_tma_load, int blocks_per_tma_load_log2, int tokens_per_block_log2, int blocks_per_tma_load, int blocks_per_tma_load_log2,
int max_blocks_per_sequence, int32_t const* paged_block_offsets, cudaTmaDesc const* desc_kv, Shared* shared, int max_blocks_per_sequence, int32_t const* paged_block_offsets, cudaTmaDesc const* desc_kv,
BufferWriter& cbw_k, BufferWriter& cbw_v) cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v)
{ {
int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES); int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES);
@ -783,11 +795,14 @@ struct DMA
__cvta_generic_to_shared(&shared->smem_k[k_barrier_id * TILE_SIZE_K __cvta_generic_to_shared(&shared->smem_k[k_barrier_id * TILE_SIZE_K
+ di * TILE_SIZE_K_PER_D_GROUP + bi * tile_size_k_per_block]), + di * TILE_SIZE_K_PER_D_GROUP + bi * tile_size_k_per_block]),
__cvta_generic_to_shared(cbw_k.barrier_ptr(k_barrier_id)), k_coords, elect_one_); __cvta_generic_to_shared(cbw_k.barrier_ptr(k_barrier_id)), k_coords, elect_one_);
}
#pragma unroll
for (int di = 0; di < Kernel_traits::DV_GROUPS; ++di)
{
int32_t const v_coords[4] int32_t const v_coords[4]
= {di * Kernel_traits::D_PER_GROUP, kv_offset_in_block, bidh, v_paged_block_offset}; = {di * Kernel_traits::D_PER_GROUP, kv_offset_in_block, bidh, v_paged_block_offset};
fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_kv, fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_v,
__cvta_generic_to_shared(&shared->smem_v[v_barrier_id * TILE_SIZE_V __cvta_generic_to_shared(&shared->smem_v[v_barrier_id * TILE_SIZE_V
+ di * TILE_SIZE_V_PER_D_GROUP + bi * tile_size_k_per_block]), + di * TILE_SIZE_V_PER_D_GROUP + bi * tile_size_k_per_block]),
__cvta_generic_to_shared(cbw_v.barrier_ptr(v_barrier_id)), v_coords, elect_one_); __cvta_generic_to_shared(cbw_v.barrier_ptr(v_barrier_id)), v_coords, elect_one_);
@ -877,8 +892,8 @@ struct DMA
// Load k,v tiles from gmem to smem by TMA. // Load k,v tiles from gmem to smem by TMA.
template <typename BufferWriter, typename BufferWriterScratch, typename BufferReaderScratch> template <typename BufferWriter, typename BufferWriterScratch, typename BufferReaderScratch>
inline __device__ int load_kv_transpose_v_impl(int bidh, int h, int h_kv, int kv_step_idx, inline __device__ int load_kv_transpose_v_impl(int bidh, int h, int h_kv, int kv_step_idx,
cudaTmaDesc const* desc_kv, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v, cudaTmaDesc const* desc_kv, cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k,
BufferWriterScratch& cbw_v_scratch, BufferReaderScratch& cbr_v_scratch) BufferWriter& cbw_v, BufferWriterScratch& cbw_v_scratch, BufferReaderScratch& cbr_v_scratch)
{ {
int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES); int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES);
@ -890,8 +905,6 @@ struct DMA
// split D into multiple groups in order to satisfy the TMA 128B sizzle mode // split D into multiple groups in order to satisfy the TMA 128B sizzle mode
int32_t const k_coord_dim1 = HEADS_INTERLEAVED ? 1 : bidh; int32_t const k_coord_dim1 = HEADS_INTERLEAVED ? 1 : bidh;
int32_t const k_coord_dim2 = HEADS_INTERLEAVED ? bidh : 1; int32_t const k_coord_dim2 = HEADS_INTERLEAVED ? bidh : 1;
int32_t const v_coord_dim1 = HEADS_INTERLEAVED ? 2 : bidh;
int32_t const v_coord_dim2 = HEADS_INTERLEAVED ? bidh : 2;
#pragma unroll #pragma unroll
for (int di = 0; di < Kernel_traits::D_GROUPS; ++di) for (int di = 0; di < Kernel_traits::D_GROUPS; ++di)
@ -910,13 +923,12 @@ struct DMA
= cbw_v_scratch.tmaReserve(elect_one_, (TILE_SIZE_V) *Kernel_traits::ELEMENT_BYTES); = cbw_v_scratch.tmaReserve(elect_one_, (TILE_SIZE_V) *Kernel_traits::ELEMENT_BYTES);
#pragma unroll #pragma unroll
for (int di = 0; di < Kernel_traits::D_GROUPS; ++di) for (int di = 0; di < Kernel_traits::DV_GROUPS; ++di)
{ {
int32_t const v_coords[4] = {di * Kernel_traits::D_PER_GROUP, int32_t const v_coords[4] = {di * Kernel_traits::D_PER_GROUP,
multi_query_attention_ ? h + h_kv + bidh / (h / h_kv) : v_coord_dim1, multi_query_attention_ ? bidh / (h / h_kv) : bidh, 0, sum_s_q_ + kv_step_idx * STEP_KV};
multi_query_attention_ ? 0 : v_coord_dim2, sum_s_q_ + kv_step_idx * STEP_KV};
fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_kv, fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_v,
__cvta_generic_to_shared( __cvta_generic_to_shared(
&shared->smem_v_scratch[v_scratch_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP]), &shared->smem_v_scratch[v_scratch_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP]),
__cvta_generic_to_shared(cbw_v_scratch.barrier_ptr(v_scratch_barrier_id)), v_coords, elect_one_); __cvta_generic_to_shared(cbw_v_scratch.barrier_ptr(v_scratch_barrier_id)), v_coords, elect_one_);
@ -1030,19 +1042,19 @@ struct DMA
// Load k,v tiles from gmem to smem by TMA. // Load k,v tiles from gmem to smem by TMA.
template <typename BufferWriter, typename BufferWriterScratch, typename BufferReaderScratch> template <typename BufferWriter, typename BufferWriterScratch, typename BufferReaderScratch>
inline __device__ int load_kv(int bidh, int h, int h_kv, int kv_step_idx, cudaTmaDesc const* desc_kv, inline __device__ int load_kv(int bidh, int h, int h_kv, int kv_step_idx, cudaTmaDesc const* desc_kv,
Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v, BufferWriterScratch& cbw_v_scratch, cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v,
BufferReaderScratch& cbr_v_scratch) BufferWriterScratch& cbw_v_scratch, BufferReaderScratch& cbr_v_scratch)
{ {
if constexpr (DMA_GROUP_TRANSPOSE_V) if constexpr (DMA_GROUP_TRANSPOSE_V)
{ {
int v_scratch_barrier_id = load_kv_transpose_v_impl( int v_scratch_barrier_id = load_kv_transpose_v_impl(
bidh, h, h_kv, kv_step_idx, desc_kv, shared, cbw_k, cbw_v, cbw_v_scratch, cbr_v_scratch); bidh, h, h_kv, kv_step_idx, desc_kv, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch, cbr_v_scratch);
return v_scratch_barrier_id; return v_scratch_barrier_id;
} }
else else
{ {
load_kv_impl(bidh, h, h_kv, kv_step_idx, desc_kv, shared, cbw_k, cbw_v); load_kv_impl(bidh, h, h_kv, kv_step_idx, desc_kv, desc_v, shared, cbw_k, cbw_v);
return 0; return 0;
} }
} }
@ -1071,9 +1083,9 @@ struct DMA
template <typename BufferWriter, typename BufferWriterScratch, typename BufferReaderScratch> template <typename BufferWriter, typename BufferWriterScratch, typename BufferReaderScratch>
inline __device__ int load_paged_kv(int bidh, int kv_tile_start_offset, int num_valid_kv_blocks, inline __device__ int load_paged_kv(int bidh, int kv_tile_start_offset, int num_valid_kv_blocks,
int tokens_per_block_log2, int blocks_per_tma_load, int blocks_per_tma_load_log2, int tokens_per_block_log2, int blocks_per_tma_load, int blocks_per_tma_load_log2,
int max_blocks_per_sequence, int32_t const* paged_block_offsets, cudaTmaDesc const* desc_kv, Shared* shared, int max_blocks_per_sequence, int32_t const* paged_block_offsets, cudaTmaDesc const* desc_kv,
BufferWriter& cbw_k, BufferWriter& cbw_v, BufferWriterScratch& cbw_v_scratch, cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v,
BufferReaderScratch& cbr_v_scratch) BufferWriterScratch& cbw_v_scratch, BufferReaderScratch& cbr_v_scratch)
{ {
if constexpr (DMA_GROUP_TRANSPOSE_V) if constexpr (DMA_GROUP_TRANSPOSE_V)
@ -1088,7 +1100,7 @@ struct DMA
{ {
load_paged_kv_impl(bidh, kv_tile_start_offset, num_valid_kv_blocks, tokens_per_block_log2, load_paged_kv_impl(bidh, kv_tile_start_offset, num_valid_kv_blocks, tokens_per_block_log2,
blocks_per_tma_load, blocks_per_tma_load_log2, max_blocks_per_sequence, paged_block_offsets, blocks_per_tma_load, blocks_per_tma_load_log2, max_blocks_per_sequence, paged_block_offsets,
desc_kv, shared, cbw_k, cbw_v); desc_kv, desc_v, shared, cbw_k, cbw_v);
return 0; return 0;
} }
} }
@ -1141,32 +1153,46 @@ struct DMA
// Per batch tensor size. // Per batch tensor size.
uint32_t tensor_size_qkv[4]; uint32_t tensor_size_qkv[4];
// Stride size in bytes. Assumes least significant dim is 1 (?)
uint64_t tensor_size_qk[3], tensor_size_v[3];
uint32_t v_offset;
// Total sequence length. // Total sequence length.
int const total_seqlen = params.is_s_padded ? (params.b * params.s) : launch_params.total_q_seqlen; int const total_seqlen = params.is_s_padded ? (params.b * params.s) : launch_params.total_q_seqlen;
tensor_size_qkv[0] = params.d; // params.d;
tensor_size_qkv[3] = total_seqlen; tensor_size_qkv[3] = total_seqlen;
tensor_size_qk[0] = params.d * Kernel_traits::ELEMENT_BYTES;
tensor_size_qk[2] = params.qkv_stride_in_bytes;
tensor_size_v[1] = 0;
tensor_size_v[2] = params.qkv_stride_in_bytes;
if (params.h_kv < params.h) if (params.h_kv < params.h)
{ {
// Take MQA as non-heads-interleaved. // Take MQA as non-heads-interleaved.
tensor_size_qkv[1] = params.h + params.h_kv;
tensor_size_qkv[2] = 1; tensor_size_qkv[2] = 1;
tensor_size_qkv[1] = (params.h + 2 * params.h_kv); tensor_size_qk[1] = 0;
tensor_size_qkv[0] = params.d; // params.d; tensor_size_v[0] = params.dv * Kernel_traits::ELEMENT_BYTES;
v_offset = (params.h + params.h_kv) * params.d * Kernel_traits::ELEMENT_BYTES;
} }
else if (HEADS_INTERLEAVED) else if (HEADS_INTERLEAVED)
{ {
tensor_size_qkv[1] = 2;
tensor_size_qkv[2] = params.h; tensor_size_qkv[2] = params.h;
tensor_size_qkv[1] = 3; tensor_size_qk[1] = (2 * params.d + params.dv) * Kernel_traits::ELEMENT_BYTES;
tensor_size_qkv[0] = params.d; // params.d; tensor_size_v[0] = tensor_size_qk[1];
v_offset = 2 * params.d * Kernel_traits::ELEMENT_BYTES;
} }
else else
{ {
tensor_size_qkv[2] = 3;
tensor_size_qkv[1] = params.h; tensor_size_qkv[1] = params.h;
tensor_size_qkv[0] = params.d; // params.d; tensor_size_qkv[2] = 2;
tensor_size_qk[1] = params.h * tensor_size_qk[0];
tensor_size_v[0] = params.dv * Kernel_traits::ELEMENT_BYTES;
v_offset = 2 * params.h * params.d * Kernel_traits::ELEMENT_BYTES;
} }
// O : [TOTAL, 1, h, d] // O : [TOTAL, 1, h, d]
uint32_t tensor_size_o[4]; uint32_t tensor_size_o[4];
tensor_size_o[0] = params.d; tensor_size_o[0] = params.dv;
tensor_size_o[1] = params.h; tensor_size_o[1] = params.h;
tensor_size_o[2] = 1; tensor_size_o[2] = 1;
tensor_size_o[3] = total_seqlen; tensor_size_o[3] = total_seqlen;
@ -1178,16 +1204,10 @@ struct DMA
box_size[1] = 1; box_size[1] = 1;
box_size[0] = Kernel_traits::D_PER_GROUP; box_size[0] = Kernel_traits::D_PER_GROUP;
// Stride size in bytes. Assumes least significant dim is 1 (?)
uint64_t tensor_stride_qkv[3];
tensor_stride_qkv[0] = tensor_size_qkv[0] * Kernel_traits::ELEMENT_BYTES; // d
tensor_stride_qkv[1] = tensor_size_qkv[1] * tensor_stride_qkv[0]; // d*h
tensor_stride_qkv[2] = tensor_size_qkv[2] * tensor_stride_qkv[1]; // d*h*3
uint64_t tensor_stride_o[3]; uint64_t tensor_stride_o[3];
tensor_stride_o[0] = tensor_size_o[0] * Kernel_traits::ELEMENT_BYTES; // d tensor_stride_o[0] = tensor_size_o[0] * Kernel_traits::ELEMENT_BYTES; // dv
tensor_stride_o[1] = tensor_size_o[1] * tensor_stride_o[0]; // d*h tensor_stride_o[1] = tensor_size_o[1] * tensor_stride_o[0]; // dv*h
tensor_stride_o[2] = tensor_size_o[2] * tensor_stride_o[1]; // d*h*1 tensor_stride_o[2] = tensor_size_o[2] * tensor_stride_o[1]; // dv*h*1
// Traversal stride. // Traversal stride.
uint32_t traversal_stride_qkv[4] = {1, 1, 1, 1}; uint32_t traversal_stride_qkv[4] = {1, 1, 1, 1};
@ -1225,7 +1245,7 @@ struct DMA
box_size[3] = STEP_Q; box_size[3] = STEP_Q;
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format, qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format,
fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode,
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_stride_qkv, fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_size_qk,
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, &params.tma_desc_q); traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, &params.tma_desc_q);
// O: 16 // O: 16
@ -1242,8 +1262,18 @@ struct DMA
box_size[3] = STEP_KV; box_size[3] = STEP_KV;
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format, qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format,
fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode,
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_stride_qkv, fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_size_qk,
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, &params.tma_desc_kv); traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, &params.tma_desc_kv);
// V: STEP_KV.
tensor_size_qkv[0] = params.dv;
tensor_size_qkv[1] = params.h_kv;
tensor_size_qkv[2] = 1;
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr + v_offset, desc_format,
fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode,
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_size_v,
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, &params.tma_desc_v);
} }
else else
{ {
@ -1353,7 +1383,16 @@ struct DMA
// Paged KV: [UINT32_MAX, H, TokensPerBlock, D] // Paged KV: [UINT32_MAX, H, TokensPerBlock, D]
// Per batch tensor size. // Per batch tensor size.
uint32_t tensor_size_kv[4]; uint32_t tensor_size_kv[4];
tensor_size_kv[3] = params.b * 2 * params.paged_kv_cache.mMaxBlocksPerSeq; // The original code is:
// tensor_size_kv[3] = params.b * 2 * params.paged_kv_cache.mMaxBlocksPerSeq;
// If d != dv and v is not padded, then the code should be:
// tensor_size_kv[3] = params.b * params.paged_kv_cache.mMaxBlocksPerSeq
// * ((params.d + params.dv) / std::gcd(params.d, params.dv));
// TensorRT-LLM uses:
// tensor_size_kv[3] = mLaunchParams.total_device_memory /
// mKernelParams.paged_kv_cache.mBytesPerBlock;
// I think the simplest way is:
tensor_size_kv[3] = INT_MAX;
tensor_size_kv[2] = params.h_kv; tensor_size_kv[2] = params.h_kv;
tensor_size_kv[1] = params.paged_kv_cache.mTokensPerBlock; tensor_size_kv[1] = params.paged_kv_cache.mTokensPerBlock;
tensor_size_kv[0] = params.d; // params.d; tensor_size_kv[0] = params.d; // params.d;
@ -1373,14 +1412,28 @@ struct DMA
// Stride size in bytes. Assumes least significant dim is 1 (?) // Stride size in bytes. Assumes least significant dim is 1 (?)
uint64_t tensor_stride_kv[3]; uint64_t tensor_stride_kv[3];
tensor_stride_kv[0] = tensor_size_kv[0] * Kernel_traits::ELEMENT_BYTES; // d tensor_stride_kv[0] = tensor_size_kv[0] * Kernel_traits::ELEMENT_BYTES; // d
tensor_stride_kv[1] = tensor_size_kv[1] * tensor_stride_kv[0]; // d*h // The original code is:
tensor_stride_kv[2] = tensor_size_kv[2] * tensor_stride_kv[1]; // d*h*3 // tensor_stride_kv[1] = tensor_size_kv[1] * tensor_stride_kv[0]; // d*mTokensPerBlock
// tensor_stride_kv[2] = tensor_size_kv[2] * tensor_stride_kv[1]; // d*mTokensPerBlock*h
// This can be simplified to:
tensor_stride_kv[1] = params.kv_stride_in_bytes;
tensor_stride_kv[2] = params.paged_kv_cache.mBytesPerBlock;
// Paged KV pool tma descriptors. // Paged KV pool tma descriptors.
paged_kv_tma_descriptor.set_tma_desctriptor(reinterpret_cast<char*>(params.paged_kv_cache.mPoolPtr), paged_kv_tma_descriptor.set_tma_desctriptor(reinterpret_cast<char*>(params.paged_kv_cache.mPoolPtr),
desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode,
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_kv, tensor_stride_kv, fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_kv, tensor_stride_kv,
traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, &params.tma_desc_kv); traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, &params.tma_desc_kv);
#ifndef GENERATE_CUBIN
tensor_size_kv[0] = params.dv;
tensor_stride_kv[0] = tensor_size_kv[0] * Kernel_traits::ELEMENT_BYTES; // dv
tensor_stride_kv[1] = params.v_stride_in_bytes; // dv*mTokensPerBlock
paged_kv_tma_descriptor.set_tma_desctriptor(reinterpret_cast<char*>(params.paged_kv_cache.mPoolPtr),
desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode,
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_kv, tensor_stride_kv,
traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, &params.tma_desc_v);
#endif
} }
} }
} }

View File

@ -36,6 +36,8 @@ template <
int STEP_KV_, int STEP_KV_,
// The head dimension. // The head dimension.
int D_, int D_,
// The head dimension of V.
int DV_,
// The number of smem buffers for Q tiles. // The number of smem buffers for Q tiles.
int Q_BUFFERS_, int Q_BUFFERS_,
// The number of smem buffers for K, and V tiles. // The number of smem buffers for K, and V tiles.
@ -83,18 +85,17 @@ struct Kernel_traits
STEP_KV = STEP_KV_ STEP_KV = STEP_KV_
}; };
// The padded head dimension.
enum
{
D = Next_power_of_two<D_>::VALUE
};
// The valid head dimension. // The valid head dimension.
enum enum
{ {
VALID_D = D_ VALID_D = D_
}; };
enum
{
VALID_DV = (DV_ == 0 ? D_ : DV_)
};
// Bootstrap GMMA_K from dummy Instruction_traits where FP16/BF16 K = 16, FP8 K = 32. // Bootstrap GMMA_K from dummy Instruction_traits where FP16/BF16 K = 16, FP8 K = 32.
enum enum
{ {
@ -113,6 +114,16 @@ struct Kernel_traits
ELEMENT_BYTES = sizeof(Element_data_type) ELEMENT_BYTES = sizeof(Element_data_type)
}; };
enum
{
D = Next_power_of_two<VALID_D>::VALUE
};
enum
{
DV = Next_power_of_two<VALID_DV>::VALUE
};
// The number of smem buffers for Q tiles. // The number of smem buffers for Q tiles.
enum enum
{ {
@ -326,6 +337,18 @@ struct Kernel_traits
D_BYTES_PER_GROUP = D_BYTES / D_GROUPS D_BYTES_PER_GROUP = D_BYTES / D_GROUPS
}; };
// The bytes of head dimension of V.
enum
{
DV_BYTES = DV * ELEMENT_BYTES
};
// The number of head_dimension groups of V.
enum
{
DV_GROUPS = fmha::Div_up<DV_BYTES, 128>::VALUE
};
// QGMMA: BMM2 will be split into multiple K groups as we explicitly transpose v (128 * D) in the smem. // QGMMA: BMM2 will be split into multiple K groups as we explicitly transpose v (128 * D) in the smem.
// HGMMA: BMM2 will load from row-major (K * N) smem_v, so we don't need to explicitly split K. // HGMMA: BMM2 will load from row-major (K * N) smem_v, so we don't need to explicitly split K.
static constexpr auto BMM2_LEADING_DIM_BYTES = ELEMENT_BYTES == 1 ? 128 : STEP_KV * ELEMENT_BYTES; static constexpr auto BMM2_LEADING_DIM_BYTES = ELEMENT_BYTES == 1 ? 128 : STEP_KV * ELEMENT_BYTES;
@ -364,7 +387,7 @@ struct Kernel_traits
// The instruction traits for the BMM2. // The instruction traits for the BMM2.
// FP16/BF16 K = 16, FP8 K = 32. // FP16/BF16 K = 16, FP8 K = 32.
using Traits_o = Instruction_traits<STEP_Q, D, GMMA_K, true, false>; using Traits_o = Instruction_traits<STEP_Q, DV, GMMA_K, true, false>;
// The CTA description for BMM1. // The CTA description for BMM1.
using Cta_tile_p = using Cta_tile_p =
@ -375,7 +398,7 @@ struct Kernel_traits
typename Traits_p::template Cta_tile<STEP_Q, STEP_KV, D_PER_GROUP, WARP_GROUP_M, WARP_GROUP_N, WARP_GROUP_K>; typename Traits_p::template Cta_tile<STEP_Q, STEP_KV, D_PER_GROUP, WARP_GROUP_M, WARP_GROUP_N, WARP_GROUP_K>;
// The CTA description for BMM2. // The CTA description for BMM2.
using Cta_tile_o = typename Traits_o::template Cta_padded_tile<STEP_Q, D, STEP_KV, VALID_D, STEP_KV, WARP_GROUP_M, using Cta_tile_o = typename Traits_o::template Cta_padded_tile<STEP_Q, DV, STEP_KV, VALID_DV, STEP_KV, WARP_GROUP_M,
WARP_GROUP_K, WARP_GROUP_N>; WARP_GROUP_K, WARP_GROUP_N>;
// The MMA tile for the 1st GEMM. // The MMA tile for the 1st GEMM.
@ -415,9 +438,9 @@ struct Kernel_traits
// The q, k, v tile buffer. // The q, k, v tile buffer.
using Buffer_q_t = cuda::std::array<Element_data_type, D * STEP_Q * Q_BUFFERS>; using Buffer_q_t = cuda::std::array<Element_data_type, D * STEP_Q * Q_BUFFERS>;
using Buffer_k_t = cuda::std::array<Element_data_type, D * STEP_KV * KV_BUFFERS>; using Buffer_k_t = cuda::std::array<Element_data_type, D * STEP_KV * KV_BUFFERS>;
using Buffer_v_t = cuda::std::array<Element_data_type, D * STEP_KV * KV_BUFFERS>; using Buffer_v_t = cuda::std::array<Element_data_type, DV * STEP_KV * KV_BUFFERS>;
// We need one kv buffer to explicitly transose fp8 smem_tile. // We need one kv buffer to explicitly transose fp8 smem_tile.
using Buffer_v_scratch_t = cuda::std::array<Element_data_type, D * STEP_KV * V_SCRATCH_BUFFERS>; using Buffer_v_scratch_t = cuda::std::array<Element_data_type, DV * STEP_KV * V_SCRATCH_BUFFERS>;
// The smem bytes of q, k, v tiles. // The smem bytes of q, k, v tiles.
enum enum
@ -521,6 +544,8 @@ template < // The step size in query sequence dimension (M of BMM1 and BMM2).
int STEP_KV_, int STEP_KV_,
// The head dimension. // The head dimension.
int D_, int D_,
// The head dimension of V.
int DV_,
// The number of smem buffers for Q tiles. // The number of smem buffers for Q tiles.
int Q_BUFFERS_, int Q_BUFFERS_,
// The number of smem buffers for K, and V tiles. // The number of smem buffers for K, and V tiles.
@ -554,14 +579,14 @@ template < // The step size in query sequence dimension (M of BMM1 and BMM2).
// The sage attention block size for Q, K and V // The sage attention block size for Q, K and V
int SAGE_BLOCK_SIZE_Q_ = 0, int SAGE_BLOCK_SIZE_K_ = 0, int SAGE_BLOCK_SIZE_V_ = 0> int SAGE_BLOCK_SIZE_Q_ = 0, int SAGE_BLOCK_SIZE_K_ = 0, int SAGE_BLOCK_SIZE_V_ = 0>
struct Kernel_traits_Hopper_qgmma_e4m3_fp32 struct Kernel_traits_Hopper_qgmma_e4m3_fp32
: public Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, Q_BUFFERS_, KV_BUFFERS_, : public Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_, NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_,
ENABLE_MUTEX_, SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_, ENABLE_MUTEX_, SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_,
RETURN_SOFTMAX_STATS_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_> RETURN_SOFTMAX_STATS_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>
{ {
// Base class. // Base class.
using Base = Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, Q_BUFFERS_, KV_BUFFERS_, using Base = Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_, ENABLE_MUTEX_, NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_, ENABLE_MUTEX_,
SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_>; SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_>;
@ -601,7 +626,7 @@ struct Kernel_traits_Hopper_qgmma_e4m3_fp32
using Buffer_v_scratch_t = typename Base::Buffer_v_scratch_t; using Buffer_v_scratch_t = typename Base::Buffer_v_scratch_t;
// Extra O buffer if TMA is used for epilogue // Extra O buffer if TMA is used for epilogue
using Element_data_type = typename Base::Element_data_type; using Element_data_type = typename Base::Element_data_type;
using Buffer_o_t = cuda::std::array<Element_data_type, Base::D * Base::STEP_Q * O_BUFFERS>; using Buffer_o_t = cuda::std::array<Element_data_type, Base::DV * Base::STEP_Q * O_BUFFERS>;
// The struct of shared memory buffers. // The struct of shared memory buffers.
struct __align__(128) Shared struct __align__(128) Shared

View File

@ -208,6 +208,8 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba
// Contiguous kv layout: [B, 2, H, S, D]. // Contiguous kv layout: [B, 2, H, S, D].
// Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D]. // Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D].
fmha::cudaTmaDesc tma_desc_kv; fmha::cudaTmaDesc tma_desc_kv;
// Tma descriptor for v if v_stride_in_bytes not in [0, kv_stride_in_bytes]
fmha::cudaTmaDesc tma_desc_v;
// Tma descriptor for o // Tma descriptor for o
fmha::cudaTmaDesc tma_desc_o; fmha::cudaTmaDesc tma_desc_o;

View File

@ -111,6 +111,7 @@ struct Fused_multihead_attention_params_v2
// Contiguous kv layout: [B, 2, H, S, D]. // Contiguous kv layout: [B, 2, H, S, D].
// Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D]. // Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D].
fmha::cudaTmaDesc tma_desc_kv; fmha::cudaTmaDesc tma_desc_kv;
fmha::cudaTmaDesc tma_desc_v;
// Tma descriptor for o // Tma descriptor for o
fmha::cudaTmaDesc tma_desc_o; fmha::cudaTmaDesc tma_desc_o;