From 180b91f95771993bc8995af426069d4df3d17b20 Mon Sep 17 00:00:00 2001 From: qsang-nv <200703406+qsang-nv@users.noreply.github.com> Date: Thu, 5 Jun 2025 22:14:28 +0800 Subject: [PATCH] update fmha_v2 (#4895) Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> --- cpp/kernels/fmha_v2/fmha_test.py | 2 +- cpp/kernels/fmha_v2/setup.py | 77 ++++++++- .../fmha_v2/src/fmha/warpspec/compute.h | 2 +- cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h | 159 ++++++++++++------ .../fmha_v2/src/fmha/warpspec/kernel_traits.h | 51 ++++-- .../fmha_v2/src/fused_multihead_attention.h | 2 + ...sed_multihead_attention_demo_bert_params.h | 1 + 7 files changed, 219 insertions(+), 75 deletions(-) diff --git a/cpp/kernels/fmha_v2/fmha_test.py b/cpp/kernels/fmha_v2/fmha_test.py index 3617d7e207..3523ee1d10 100644 --- a/cpp/kernels/fmha_v2/fmha_test.py +++ b/cpp/kernels/fmha_v2/fmha_test.py @@ -157,7 +157,7 @@ def test_trtllm_context_mla_attention_fmha(dtype, s): epsilon += ' -epsilon 0.03' sm_version = getSMVersion() - if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89: + if sm_version != 89: pytest.skip("FP8 MLAs only supported on sm89 currently.") # Context phase kernels. diff --git a/cpp/kernels/fmha_v2/setup.py b/cpp/kernels/fmha_v2/setup.py index 48f8b7e590..ac705e719d 100644 --- a/cpp/kernels/fmha_v2/setup.py +++ b/cpp/kernels/fmha_v2/setup.py @@ -189,8 +189,7 @@ namespace kernels ns_close = r""" // clang-format on } // namespace kernels -} // namespace tensorrt_llm -""" if generate_cu_trtllm else "" +} // namespace tensorrt_llm""" if generate_cu_trtllm else "" copyright = '''\ /*************************************************************************************************** @@ -1344,7 +1343,7 @@ void {sliding_or_chunked_causal_kernel_name}_nl({params_type} params){{ #endif // sliding_or_chunked_causal_mask -void {launcher_name}_nl({params_type} ¶ms, +void {launcher_name}_nl({fused_multihead_attention_params_v2_str} ¶ms, const Launch_params& launch_params, cudaStream_t stream){{ constexpr int loop_iters = {seq_len} / {noloop_step}; static_assert(loop_iters * {noloop_step} == {seq_len}, ""); @@ -1431,6 +1430,7 @@ using Ktraits = {kernel_traits_header} {loop_step}, {kv_loop_step}, {head_size}, + {head_size_v}, {q_tile_buffers}, {kv_tile_buffers}, NUM_COMPUTE_GROUPS, @@ -1453,6 +1453,7 @@ using Ktraits_causal = {kernel_traits_header} {loop_step}, {kv_loop_step}, {head_size}, + {head_size_v}, {q_tile_buffers}, {kv_tile_buffers}, NUM_COMPUTE_GROUPS, @@ -1472,6 +1473,7 @@ using Ktraits_sliding_or_chunked_causal = {kernel_traits_header} {loop_step}, {kv_loop_step}, {head_size}, + {head_size_v}, {q_tile_buffers}, {kv_tile_buffers}, NUM_COMPUTE_GROUPS, @@ -1491,6 +1493,7 @@ using Ktraits_custom_mask = {kernel_traits_header} {loop_step}, {kv_loop_step}, {head_size}, + {head_size_v}, {q_tile_buffers}, {kv_tile_buffers}, NUM_COMPUTE_GROUPS, @@ -2881,6 +2884,7 @@ def get_kernel_traits_code(specs_names): {loop_step}, {kv_loop_step}, {head_size}, + {head_size_v}, {q_tile_buffers}, {kv_tile_buffers}, NUM_COMPUTE_GROUPS, @@ -3213,7 +3217,7 @@ def get_cubin_header(kernel_traits, specs_names): return 'nullptr' lname = kname.replace('_kernel', '') mask_types = [ - '_sliding_window_causal', '_custom_mask', '_causal' + '_sliding_or_chunked_causal', '_custom_mask', '_causal' ] for mask_type in mask_types: lname = lname.replace(mask_type, '') @@ -3228,6 +3232,12 @@ def get_cubin_header(kernel_traits, specs_names): {cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \ {attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \ {is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\ +'''.format(**locals()) if 'sage' in kname and 'sm90' in kname else '''\ +{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \ +{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \ +0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \ +{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \ +{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\ '''.format(**locals()) else: code = '''\ @@ -3332,7 +3342,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 {metadata_v2} }}; {local_ns_close} - '''.format(**locals(), copyright=copyright) else: @@ -3540,7 +3549,10 @@ def enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype='fp16'): # Note this will be used in TRT-LLM. -def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'): +def enumerate_hgmma_flash_warpspec_kernels(specs, + sm=90, + dtype='fp16', + head_size_v=0): scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1')) @@ -3563,6 +3575,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'): dtype=dtype, seq_len=0, # support any sequence length head_size=[32, 40, 48, 64], + head_size_v=head_size_v, warps_m=4, #4x1 warpgroups warps_n=1, version=2, @@ -3595,6 +3608,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'): dtype=dtype, seq_len=0, # support any sequence length head_size=[72, 80, 96, 104, 128], + head_size_v=head_size_v, warps_m=4, #4x1 warpgroups warps_n=1, version=2, @@ -3627,6 +3641,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'): dtype=dtype, seq_len=0, # support any sequence length head_size=[160, 192, 256], + head_size_v=head_size_v, warps_m=4, #4x1 warpgroups warps_n=1, version=2, @@ -3652,6 +3667,40 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'): scheduling_mode=scheduling_mode, input_layout=input_layout)) + # for deepseek context 192/128, kv_step=128 + specs.append( + kernel_spec( + sm=sm, + sm_mma=90, + dtype=dtype, + seq_len=0, # support any sequence length + head_size=192, + head_size_v=128, + warps_m=4, #4x1 warpgroups + warps_n=1, + version=2, + interleaved=False, + ldgsts_q= + False, # for Hopper kernels, ldgsts = False signals TMA usage. + ldgsts_k=False, + ldgsts_v=False, + share_smem_k_v=False, + loop_step=64, + q_tile_buffers=1, # only used by warp specialized kernels + has_noloop=0, + noloop_step=64, + kv_loop_step=128, + kv_tile_buffers=2, # only used by warp specialized kernels + unroll_threshold=1, + has_scale_max=False, + flash_attention=True, + warp_specialization=True, + alibi=alibi, + enable_attn_logit_softcapping=enable_attn_logit_softcapping, + return_softmax_stats=return_softmax, + scheduling_mode=scheduling_mode, + input_layout=input_layout)) + # Note this will be used in TRT-LLM. def enumerate_qgmma_flash_warpspec_kernels(specs, @@ -6215,7 +6264,21 @@ def enumerate_kernels(): and kspec.cross_mha == False and kspec.flash_attention == True and kspec.warp_specialization == False - and kspec.tiled == True) + and kspec.tiled == True + and not (kspec.sm == 90 and (kspec.head_size, kspec.head_size_v) == (192, 128))) + # Deepseek MLA (hopper-style context 192/128 packed + paged) + or (kspec.sm == 90 + and kspec.dtype == 'bf16' + and kspec.head_size == 192 + and kspec.head_size_v == 128 + and kspec.sage_block_sizes is None + and kspec.version == 2 + and kspec.cross_mha == False + and kspec.flash_attention == True + and kspec.warp_specialization == True + and kspec.input_layout in [InputLayout.PACKED_QKV, InputLayout.Q_PAGED_KV] + and kspec.alibi == False + and kspec.enable_attn_logit_softcapping == False) # SageAttention (warp_spec, head_size in (80, 128), packed QKV, padding mask) or (kspec.sm == 90 and kspec.head_size in [80, 128] diff --git a/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h b/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h index 1df784d3ed..b95316e184 100644 --- a/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h +++ b/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h @@ -173,7 +173,7 @@ struct Compute enum { - TILE_SIZE_V = STEP_KV * Kernel_traits::D + TILE_SIZE_V = STEP_KV * Kernel_traits::DV }; enum diff --git a/cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h b/cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h index cdea942885..0a353e992d 100644 --- a/cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h +++ b/cpp/kernels/fmha_v2/src/fmha/warpspec/dma.h @@ -76,7 +76,7 @@ struct DMA // The tile size of V. enum { - TILE_SIZE_V = TILE_SIZE_K + TILE_SIZE_V = STEP_KV * Kernel_traits::DV }; // The tile size of V after head_dimension split. @@ -280,6 +280,7 @@ struct DMA cudaTmaDesc const* desc_q = ¶ms.tma_desc_q; cudaTmaDesc const* desc_kv = ¶ms.tma_desc_kv; + cudaTmaDesc const* desc_v = ¶ms.tma_desc_v; int actual_seqlen; if (params.is_s_padded) { @@ -342,8 +343,8 @@ struct DMA // Iterate over the kv tiles for this q step. for (int kv_step_idx = kv_idx_start; kv_step_idx < kv_idx_end; kv_step_idx++) { - int bar_id = load_kv(bidh, params.h, params.h_kv, kv_step_idx, desc_kv, shared, cbw_k, cbw_v, - cbw_v_scratch, cbr_v_scratch); + int bar_id = load_kv(bidh, params.h, params.h_kv, kv_step_idx, desc_kv, desc_v, shared, cbw_k, + cbw_v, cbw_v_scratch, cbr_v_scratch); // Opportunistically hide headinfo in the shadow of UTMALDGs of the QKV tensor if (q_step_idx == 0 && kv_step_idx == kv_idx_start) @@ -511,7 +512,17 @@ struct DMA int32_t const* paged_block_offsets = params.paged_kv_cache.mBlockOffsets + bidb * 2 * params.paged_kv_cache.mMaxBlocksPerSeq; cudaTmaDesc const* desc_kv = ¶ms.tma_desc_kv; - + // If a separate v_stride_in_bytes is set, we have to use separate tma_desc_v, + // otherwise share with tma_desc_kv. + // This is for the compatibility that TensorRT-LLM needs no modification if padding V to 192. +#ifndef GENERATE_CUBIN + cudaTmaDesc const* desc_v + = (params.v_stride_in_bytes == 0 || params.v_stride_in_bytes == params.kv_stride_in_bytes) + ? desc_kv + : ¶ms.tma_desc_v; +#else + cudaTmaDesc const* desc_v = desc_kv; +#endif if (SCHEDULING_MODE == 0) { // split work across M @@ -575,7 +586,8 @@ struct DMA bar_id = load_paged_kv(bidh_kv, remapped_kv_step_idx * STEP_KV, num_valid_kv_blocks, params.paged_kv_cache.mTokensPerBlockLog2, params.blocks_per_tma_load, params.blocks_per_tma_load_log2, params.paged_kv_cache.mMaxBlocksPerSeq, - paged_block_offsets, desc_kv, shared, cbw_k, cbw_v, cbw_v_scratch, cbr_v_scratch); + paged_block_offsets, desc_kv, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch, + cbr_v_scratch); } else { @@ -670,7 +682,7 @@ struct DMA // Load k,v tiles from gmem to smem by TMA. template inline __device__ void load_kv_impl(int bidh, int h, int h_kv, int kv_step_idx, cudaTmaDesc const* desc_kv, - Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v) + cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v) { int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES); @@ -685,8 +697,6 @@ struct DMA // split D into multiple groups in order to satisfy the TMA 128B sizzle mode int32_t const k_coord_dim1 = HEADS_INTERLEAVED ? 1 : bidh; int32_t const k_coord_dim2 = HEADS_INTERLEAVED ? bidh : 1; - int32_t const v_coord_dim1 = HEADS_INTERLEAVED ? 2 : bidh; - int32_t const v_coord_dim2 = HEADS_INTERLEAVED ? bidh : 2; #pragma unroll for (int di = 0; di < Kernel_traits::D_GROUPS; ++di) @@ -699,12 +709,14 @@ struct DMA __cvta_generic_to_shared( &shared->smem_k[k_barrier_id * TILE_SIZE_K + di * TILE_SIZE_K_PER_D_GROUP]), __cvta_generic_to_shared(cbw_k.barrier_ptr(k_barrier_id)), k_coords, elect_one_); - + } +#pragma unroll + for (int di = 0; di < Kernel_traits::DV_GROUPS; ++di) + { int32_t const v_coords[4] = {di * Kernel_traits::D_PER_GROUP, - multi_query_attention_ ? h + h_kv + bidh / (h / h_kv) : v_coord_dim1, - multi_query_attention_ ? 0 : v_coord_dim2, sum_s_q_ + kv_step_idx * STEP_KV}; + multi_query_attention_ ? bidh / (h / h_kv) : bidh, 0, sum_s_q_ + kv_step_idx * STEP_KV}; - fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_kv, + fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_v, __cvta_generic_to_shared( &shared->smem_v[v_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP]), __cvta_generic_to_shared(cbw_v.barrier_ptr(v_barrier_id)), v_coords, elect_one_); @@ -748,8 +760,8 @@ struct DMA template inline __device__ void load_paged_kv_impl(int bidh, int kv_tile_start_offset, int num_valid_kv_blocks, int tokens_per_block_log2, int blocks_per_tma_load, int blocks_per_tma_load_log2, - int max_blocks_per_sequence, int32_t const* paged_block_offsets, cudaTmaDesc const* desc_kv, Shared* shared, - BufferWriter& cbw_k, BufferWriter& cbw_v) + int max_blocks_per_sequence, int32_t const* paged_block_offsets, cudaTmaDesc const* desc_kv, + cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v) { int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES); @@ -783,11 +795,14 @@ struct DMA __cvta_generic_to_shared(&shared->smem_k[k_barrier_id * TILE_SIZE_K + di * TILE_SIZE_K_PER_D_GROUP + bi * tile_size_k_per_block]), __cvta_generic_to_shared(cbw_k.barrier_ptr(k_barrier_id)), k_coords, elect_one_); - + } +#pragma unroll + for (int di = 0; di < Kernel_traits::DV_GROUPS; ++di) + { int32_t const v_coords[4] = {di * Kernel_traits::D_PER_GROUP, kv_offset_in_block, bidh, v_paged_block_offset}; - fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_kv, + fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_v, __cvta_generic_to_shared(&shared->smem_v[v_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP + bi * tile_size_k_per_block]), __cvta_generic_to_shared(cbw_v.barrier_ptr(v_barrier_id)), v_coords, elect_one_); @@ -877,8 +892,8 @@ struct DMA // Load k,v tiles from gmem to smem by TMA. template inline __device__ int load_kv_transpose_v_impl(int bidh, int h, int h_kv, int kv_step_idx, - cudaTmaDesc const* desc_kv, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v, - BufferWriterScratch& cbw_v_scratch, BufferReaderScratch& cbr_v_scratch) + cudaTmaDesc const* desc_kv, cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, + BufferWriter& cbw_v, BufferWriterScratch& cbw_v_scratch, BufferReaderScratch& cbr_v_scratch) { int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES); @@ -890,8 +905,6 @@ struct DMA // split D into multiple groups in order to satisfy the TMA 128B sizzle mode int32_t const k_coord_dim1 = HEADS_INTERLEAVED ? 1 : bidh; int32_t const k_coord_dim2 = HEADS_INTERLEAVED ? bidh : 1; - int32_t const v_coord_dim1 = HEADS_INTERLEAVED ? 2 : bidh; - int32_t const v_coord_dim2 = HEADS_INTERLEAVED ? bidh : 2; #pragma unroll for (int di = 0; di < Kernel_traits::D_GROUPS; ++di) @@ -910,13 +923,12 @@ struct DMA = cbw_v_scratch.tmaReserve(elect_one_, (TILE_SIZE_V) *Kernel_traits::ELEMENT_BYTES); #pragma unroll - for (int di = 0; di < Kernel_traits::D_GROUPS; ++di) + for (int di = 0; di < Kernel_traits::DV_GROUPS; ++di) { int32_t const v_coords[4] = {di * Kernel_traits::D_PER_GROUP, - multi_query_attention_ ? h + h_kv + bidh / (h / h_kv) : v_coord_dim1, - multi_query_attention_ ? 0 : v_coord_dim2, sum_s_q_ + kv_step_idx * STEP_KV}; + multi_query_attention_ ? bidh / (h / h_kv) : bidh, 0, sum_s_q_ + kv_step_idx * STEP_KV}; - fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_kv, + fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_v, __cvta_generic_to_shared( &shared->smem_v_scratch[v_scratch_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP]), __cvta_generic_to_shared(cbw_v_scratch.barrier_ptr(v_scratch_barrier_id)), v_coords, elect_one_); @@ -1030,19 +1042,19 @@ struct DMA // Load k,v tiles from gmem to smem by TMA. template inline __device__ int load_kv(int bidh, int h, int h_kv, int kv_step_idx, cudaTmaDesc const* desc_kv, - Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v, BufferWriterScratch& cbw_v_scratch, - BufferReaderScratch& cbr_v_scratch) + cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v, + BufferWriterScratch& cbw_v_scratch, BufferReaderScratch& cbr_v_scratch) { if constexpr (DMA_GROUP_TRANSPOSE_V) { int v_scratch_barrier_id = load_kv_transpose_v_impl( - bidh, h, h_kv, kv_step_idx, desc_kv, shared, cbw_k, cbw_v, cbw_v_scratch, cbr_v_scratch); + bidh, h, h_kv, kv_step_idx, desc_kv, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch, cbr_v_scratch); return v_scratch_barrier_id; } else { - load_kv_impl(bidh, h, h_kv, kv_step_idx, desc_kv, shared, cbw_k, cbw_v); + load_kv_impl(bidh, h, h_kv, kv_step_idx, desc_kv, desc_v, shared, cbw_k, cbw_v); return 0; } } @@ -1071,9 +1083,9 @@ struct DMA template inline __device__ int load_paged_kv(int bidh, int kv_tile_start_offset, int num_valid_kv_blocks, int tokens_per_block_log2, int blocks_per_tma_load, int blocks_per_tma_load_log2, - int max_blocks_per_sequence, int32_t const* paged_block_offsets, cudaTmaDesc const* desc_kv, Shared* shared, - BufferWriter& cbw_k, BufferWriter& cbw_v, BufferWriterScratch& cbw_v_scratch, - BufferReaderScratch& cbr_v_scratch) + int max_blocks_per_sequence, int32_t const* paged_block_offsets, cudaTmaDesc const* desc_kv, + cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v, + BufferWriterScratch& cbw_v_scratch, BufferReaderScratch& cbr_v_scratch) { if constexpr (DMA_GROUP_TRANSPOSE_V) @@ -1088,7 +1100,7 @@ struct DMA { load_paged_kv_impl(bidh, kv_tile_start_offset, num_valid_kv_blocks, tokens_per_block_log2, blocks_per_tma_load, blocks_per_tma_load_log2, max_blocks_per_sequence, paged_block_offsets, - desc_kv, shared, cbw_k, cbw_v); + desc_kv, desc_v, shared, cbw_k, cbw_v); return 0; } } @@ -1141,32 +1153,46 @@ struct DMA // Per batch tensor size. uint32_t tensor_size_qkv[4]; + // Stride size in bytes. Assumes least significant dim is 1 (?) + uint64_t tensor_size_qk[3], tensor_size_v[3]; + uint32_t v_offset; // Total sequence length. int const total_seqlen = params.is_s_padded ? (params.b * params.s) : launch_params.total_q_seqlen; + tensor_size_qkv[0] = params.d; // params.d; tensor_size_qkv[3] = total_seqlen; + tensor_size_qk[0] = params.d * Kernel_traits::ELEMENT_BYTES; + tensor_size_qk[2] = params.qkv_stride_in_bytes; + tensor_size_v[1] = 0; + tensor_size_v[2] = params.qkv_stride_in_bytes; if (params.h_kv < params.h) { // Take MQA as non-heads-interleaved. + tensor_size_qkv[1] = params.h + params.h_kv; tensor_size_qkv[2] = 1; - tensor_size_qkv[1] = (params.h + 2 * params.h_kv); - tensor_size_qkv[0] = params.d; // params.d; + tensor_size_qk[1] = 0; + tensor_size_v[0] = params.dv * Kernel_traits::ELEMENT_BYTES; + v_offset = (params.h + params.h_kv) * params.d * Kernel_traits::ELEMENT_BYTES; } else if (HEADS_INTERLEAVED) { + tensor_size_qkv[1] = 2; tensor_size_qkv[2] = params.h; - tensor_size_qkv[1] = 3; - tensor_size_qkv[0] = params.d; // params.d; + tensor_size_qk[1] = (2 * params.d + params.dv) * Kernel_traits::ELEMENT_BYTES; + tensor_size_v[0] = tensor_size_qk[1]; + v_offset = 2 * params.d * Kernel_traits::ELEMENT_BYTES; } else { - tensor_size_qkv[2] = 3; tensor_size_qkv[1] = params.h; - tensor_size_qkv[0] = params.d; // params.d; + tensor_size_qkv[2] = 2; + tensor_size_qk[1] = params.h * tensor_size_qk[0]; + tensor_size_v[0] = params.dv * Kernel_traits::ELEMENT_BYTES; + v_offset = 2 * params.h * params.d * Kernel_traits::ELEMENT_BYTES; } // O : [TOTAL, 1, h, d] uint32_t tensor_size_o[4]; - tensor_size_o[0] = params.d; + tensor_size_o[0] = params.dv; tensor_size_o[1] = params.h; tensor_size_o[2] = 1; tensor_size_o[3] = total_seqlen; @@ -1178,16 +1204,10 @@ struct DMA box_size[1] = 1; box_size[0] = Kernel_traits::D_PER_GROUP; - // Stride size in bytes. Assumes least significant dim is 1 (?) - uint64_t tensor_stride_qkv[3]; - tensor_stride_qkv[0] = tensor_size_qkv[0] * Kernel_traits::ELEMENT_BYTES; // d - tensor_stride_qkv[1] = tensor_size_qkv[1] * tensor_stride_qkv[0]; // d*h - tensor_stride_qkv[2] = tensor_size_qkv[2] * tensor_stride_qkv[1]; // d*h*3 - uint64_t tensor_stride_o[3]; - tensor_stride_o[0] = tensor_size_o[0] * Kernel_traits::ELEMENT_BYTES; // d - tensor_stride_o[1] = tensor_size_o[1] * tensor_stride_o[0]; // d*h - tensor_stride_o[2] = tensor_size_o[2] * tensor_stride_o[1]; // d*h*1 + tensor_stride_o[0] = tensor_size_o[0] * Kernel_traits::ELEMENT_BYTES; // dv + tensor_stride_o[1] = tensor_size_o[1] * tensor_stride_o[0]; // dv*h + tensor_stride_o[2] = tensor_size_o[2] * tensor_stride_o[1]; // dv*h*1 // Traversal stride. uint32_t traversal_stride_qkv[4] = {1, 1, 1, 1}; @@ -1225,7 +1245,7 @@ struct DMA box_size[3] = STEP_Q; qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, - fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_stride_qkv, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_size_qk, traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, ¶ms.tma_desc_q); // O: 16 @@ -1242,8 +1262,18 @@ struct DMA box_size[3] = STEP_KV; qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, - fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_stride_qkv, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_size_qk, traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, ¶ms.tma_desc_kv); + + // V: STEP_KV. + tensor_size_qkv[0] = params.dv; + tensor_size_qkv[1] = params.h_kv; + tensor_size_qkv[2] = 1; + + qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr + v_offset, desc_format, + fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_size_v, + traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, ¶ms.tma_desc_v); } else { @@ -1353,7 +1383,16 @@ struct DMA // Paged KV: [UINT32_MAX, H, TokensPerBlock, D] // Per batch tensor size. uint32_t tensor_size_kv[4]; - tensor_size_kv[3] = params.b * 2 * params.paged_kv_cache.mMaxBlocksPerSeq; + // The original code is: + // tensor_size_kv[3] = params.b * 2 * params.paged_kv_cache.mMaxBlocksPerSeq; + // If d != dv and v is not padded, then the code should be: + // tensor_size_kv[3] = params.b * params.paged_kv_cache.mMaxBlocksPerSeq + // * ((params.d + params.dv) / std::gcd(params.d, params.dv)); + // TensorRT-LLM uses: + // tensor_size_kv[3] = mLaunchParams.total_device_memory / + // mKernelParams.paged_kv_cache.mBytesPerBlock; + // I think the simplest way is: + tensor_size_kv[3] = INT_MAX; tensor_size_kv[2] = params.h_kv; tensor_size_kv[1] = params.paged_kv_cache.mTokensPerBlock; tensor_size_kv[0] = params.d; // params.d; @@ -1373,14 +1412,28 @@ struct DMA // Stride size in bytes. Assumes least significant dim is 1 (?) uint64_t tensor_stride_kv[3]; tensor_stride_kv[0] = tensor_size_kv[0] * Kernel_traits::ELEMENT_BYTES; // d - tensor_stride_kv[1] = tensor_size_kv[1] * tensor_stride_kv[0]; // d*h - tensor_stride_kv[2] = tensor_size_kv[2] * tensor_stride_kv[1]; // d*h*3 + // The original code is: + // tensor_stride_kv[1] = tensor_size_kv[1] * tensor_stride_kv[0]; // d*mTokensPerBlock + // tensor_stride_kv[2] = tensor_size_kv[2] * tensor_stride_kv[1]; // d*mTokensPerBlock*h + // This can be simplified to: + tensor_stride_kv[1] = params.kv_stride_in_bytes; + tensor_stride_kv[2] = params.paged_kv_cache.mBytesPerBlock; // Paged KV pool tma descriptors. paged_kv_tma_descriptor.set_tma_desctriptor(reinterpret_cast(params.paged_kv_cache.mPoolPtr), desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_kv, tensor_stride_kv, traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, ¶ms.tma_desc_kv); +#ifndef GENERATE_CUBIN + tensor_size_kv[0] = params.dv; + tensor_stride_kv[0] = tensor_size_kv[0] * Kernel_traits::ELEMENT_BYTES; // dv + tensor_stride_kv[1] = params.v_stride_in_bytes; // dv*mTokensPerBlock + + paged_kv_tma_descriptor.set_tma_desctriptor(reinterpret_cast(params.paged_kv_cache.mPoolPtr), + desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, + fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_kv, tensor_stride_kv, + traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, ¶ms.tma_desc_v); +#endif } } } diff --git a/cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h b/cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h index 0e5c208b71..09c5b009d6 100644 --- a/cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h +++ b/cpp/kernels/fmha_v2/src/fmha/warpspec/kernel_traits.h @@ -36,6 +36,8 @@ template < int STEP_KV_, // The head dimension. int D_, + // The head dimension of V. + int DV_, // The number of smem buffers for Q tiles. int Q_BUFFERS_, // The number of smem buffers for K, and V tiles. @@ -83,18 +85,17 @@ struct Kernel_traits STEP_KV = STEP_KV_ }; - // The padded head dimension. - enum - { - D = Next_power_of_two::VALUE - }; - // The valid head dimension. enum { VALID_D = D_ }; + enum + { + VALID_DV = (DV_ == 0 ? D_ : DV_) + }; + // Bootstrap GMMA_K from dummy Instruction_traits where FP16/BF16 K = 16, FP8 K = 32. enum { @@ -113,6 +114,16 @@ struct Kernel_traits ELEMENT_BYTES = sizeof(Element_data_type) }; + enum + { + D = Next_power_of_two::VALUE + }; + + enum + { + DV = Next_power_of_two::VALUE + }; + // The number of smem buffers for Q tiles. enum { @@ -326,6 +337,18 @@ struct Kernel_traits D_BYTES_PER_GROUP = D_BYTES / D_GROUPS }; + // The bytes of head dimension of V. + enum + { + DV_BYTES = DV * ELEMENT_BYTES + }; + + // The number of head_dimension groups of V. + enum + { + DV_GROUPS = fmha::Div_up::VALUE + }; + // QGMMA: BMM2 will be split into multiple K groups as we explicitly transpose v (128 * D) in the smem. // HGMMA: BMM2 will load from row-major (K * N) smem_v, so we don't need to explicitly split K. static constexpr auto BMM2_LEADING_DIM_BYTES = ELEMENT_BYTES == 1 ? 128 : STEP_KV * ELEMENT_BYTES; @@ -364,7 +387,7 @@ struct Kernel_traits // The instruction traits for the BMM2. // FP16/BF16 K = 16, FP8 K = 32. - using Traits_o = Instruction_traits; + using Traits_o = Instruction_traits; // The CTA description for BMM1. using Cta_tile_p = @@ -375,7 +398,7 @@ struct Kernel_traits typename Traits_p::template Cta_tile; // The CTA description for BMM2. - using Cta_tile_o = typename Traits_o::template Cta_padded_tile; // The MMA tile for the 1st GEMM. @@ -415,9 +438,9 @@ struct Kernel_traits // The q, k, v tile buffer. using Buffer_q_t = cuda::std::array; using Buffer_k_t = cuda::std::array; - using Buffer_v_t = cuda::std::array; + using Buffer_v_t = cuda::std::array; // We need one kv buffer to explicitly transose fp8 smem_tile. - using Buffer_v_scratch_t = cuda::std::array; + using Buffer_v_scratch_t = cuda::std::array; // The smem bytes of q, k, v tiles. enum @@ -521,6 +544,8 @@ template < // The step size in query sequence dimension (M of BMM1 and BMM2). int STEP_KV_, // The head dimension. int D_, + // The head dimension of V. + int DV_, // The number of smem buffers for Q tiles. int Q_BUFFERS_, // The number of smem buffers for K, and V tiles. @@ -554,14 +579,14 @@ template < // The step size in query sequence dimension (M of BMM1 and BMM2). // The sage attention block size for Q, K and V int SAGE_BLOCK_SIZE_Q_ = 0, int SAGE_BLOCK_SIZE_K_ = 0, int SAGE_BLOCK_SIZE_V_ = 0> struct Kernel_traits_Hopper_qgmma_e4m3_fp32 - : public Kernel_traits { // Base class. - using Base = Kernel_traits; @@ -601,7 +626,7 @@ struct Kernel_traits_Hopper_qgmma_e4m3_fp32 using Buffer_v_scratch_t = typename Base::Buffer_v_scratch_t; // Extra O buffer if TMA is used for epilogue using Element_data_type = typename Base::Element_data_type; - using Buffer_o_t = cuda::std::array; + using Buffer_o_t = cuda::std::array; // The struct of shared memory buffers. struct __align__(128) Shared diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_attention.h b/cpp/kernels/fmha_v2/src/fused_multihead_attention.h index 33610dca78..fd1b935258 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_attention.h +++ b/cpp/kernels/fmha_v2/src/fused_multihead_attention.h @@ -208,6 +208,8 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba // Contiguous kv layout: [B, 2, H, S, D]. // Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D]. fmha::cudaTmaDesc tma_desc_kv; + // Tma descriptor for v if v_stride_in_bytes not in [0, kv_stride_in_bytes] + fmha::cudaTmaDesc tma_desc_v; // Tma descriptor for o fmha::cudaTmaDesc tma_desc_o; diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h b/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h index ce8522b52f..22730ef0b3 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h +++ b/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h @@ -111,6 +111,7 @@ struct Fused_multihead_attention_params_v2 // Contiguous kv layout: [B, 2, H, S, D]. // Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D]. fmha::cudaTmaDesc tma_desc_kv; + fmha::cudaTmaDesc tma_desc_v; // Tma descriptor for o fmha::cudaTmaDesc tma_desc_o;