mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
update fmha_v2 (#4895)
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
This commit is contained in:
parent
51652b9b2b
commit
180b91f957
@ -157,7 +157,7 @@ def test_trtllm_context_mla_attention_fmha(dtype, s):
|
|||||||
epsilon += ' -epsilon 0.03'
|
epsilon += ' -epsilon 0.03'
|
||||||
|
|
||||||
sm_version = getSMVersion()
|
sm_version = getSMVersion()
|
||||||
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89:
|
if sm_version != 89:
|
||||||
pytest.skip("FP8 MLAs only supported on sm89 currently.")
|
pytest.skip("FP8 MLAs only supported on sm89 currently.")
|
||||||
|
|
||||||
# Context phase kernels.
|
# Context phase kernels.
|
||||||
|
|||||||
@ -189,8 +189,7 @@ namespace kernels
|
|||||||
ns_close = r"""
|
ns_close = r"""
|
||||||
// clang-format on
|
// clang-format on
|
||||||
} // namespace kernels
|
} // namespace kernels
|
||||||
} // namespace tensorrt_llm
|
} // namespace tensorrt_llm""" if generate_cu_trtllm else ""
|
||||||
""" if generate_cu_trtllm else ""
|
|
||||||
|
|
||||||
copyright = '''\
|
copyright = '''\
|
||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
@ -1344,7 +1343,7 @@ void {sliding_or_chunked_causal_kernel_name}_nl({params_type} params){{
|
|||||||
|
|
||||||
#endif // sliding_or_chunked_causal_mask
|
#endif // sliding_or_chunked_causal_mask
|
||||||
|
|
||||||
void {launcher_name}_nl({params_type} ¶ms,
|
void {launcher_name}_nl({fused_multihead_attention_params_v2_str} ¶ms,
|
||||||
const Launch_params& launch_params, cudaStream_t stream){{
|
const Launch_params& launch_params, cudaStream_t stream){{
|
||||||
constexpr int loop_iters = {seq_len} / {noloop_step};
|
constexpr int loop_iters = {seq_len} / {noloop_step};
|
||||||
static_assert(loop_iters * {noloop_step} == {seq_len}, "");
|
static_assert(loop_iters * {noloop_step} == {seq_len}, "");
|
||||||
@ -1431,6 +1430,7 @@ using Ktraits = {kernel_traits_header}
|
|||||||
{loop_step},
|
{loop_step},
|
||||||
{kv_loop_step},
|
{kv_loop_step},
|
||||||
{head_size},
|
{head_size},
|
||||||
|
{head_size_v},
|
||||||
{q_tile_buffers},
|
{q_tile_buffers},
|
||||||
{kv_tile_buffers},
|
{kv_tile_buffers},
|
||||||
NUM_COMPUTE_GROUPS,
|
NUM_COMPUTE_GROUPS,
|
||||||
@ -1453,6 +1453,7 @@ using Ktraits_causal = {kernel_traits_header}
|
|||||||
{loop_step},
|
{loop_step},
|
||||||
{kv_loop_step},
|
{kv_loop_step},
|
||||||
{head_size},
|
{head_size},
|
||||||
|
{head_size_v},
|
||||||
{q_tile_buffers},
|
{q_tile_buffers},
|
||||||
{kv_tile_buffers},
|
{kv_tile_buffers},
|
||||||
NUM_COMPUTE_GROUPS,
|
NUM_COMPUTE_GROUPS,
|
||||||
@ -1472,6 +1473,7 @@ using Ktraits_sliding_or_chunked_causal = {kernel_traits_header}
|
|||||||
{loop_step},
|
{loop_step},
|
||||||
{kv_loop_step},
|
{kv_loop_step},
|
||||||
{head_size},
|
{head_size},
|
||||||
|
{head_size_v},
|
||||||
{q_tile_buffers},
|
{q_tile_buffers},
|
||||||
{kv_tile_buffers},
|
{kv_tile_buffers},
|
||||||
NUM_COMPUTE_GROUPS,
|
NUM_COMPUTE_GROUPS,
|
||||||
@ -1491,6 +1493,7 @@ using Ktraits_custom_mask = {kernel_traits_header}
|
|||||||
{loop_step},
|
{loop_step},
|
||||||
{kv_loop_step},
|
{kv_loop_step},
|
||||||
{head_size},
|
{head_size},
|
||||||
|
{head_size_v},
|
||||||
{q_tile_buffers},
|
{q_tile_buffers},
|
||||||
{kv_tile_buffers},
|
{kv_tile_buffers},
|
||||||
NUM_COMPUTE_GROUPS,
|
NUM_COMPUTE_GROUPS,
|
||||||
@ -2881,6 +2884,7 @@ def get_kernel_traits_code(specs_names):
|
|||||||
{loop_step},
|
{loop_step},
|
||||||
{kv_loop_step},
|
{kv_loop_step},
|
||||||
{head_size},
|
{head_size},
|
||||||
|
{head_size_v},
|
||||||
{q_tile_buffers},
|
{q_tile_buffers},
|
||||||
{kv_tile_buffers},
|
{kv_tile_buffers},
|
||||||
NUM_COMPUTE_GROUPS,
|
NUM_COMPUTE_GROUPS,
|
||||||
@ -3213,7 +3217,7 @@ def get_cubin_header(kernel_traits, specs_names):
|
|||||||
return 'nullptr'
|
return 'nullptr'
|
||||||
lname = kname.replace('_kernel', '')
|
lname = kname.replace('_kernel', '')
|
||||||
mask_types = [
|
mask_types = [
|
||||||
'_sliding_window_causal', '_custom_mask', '_causal'
|
'_sliding_or_chunked_causal', '_custom_mask', '_causal'
|
||||||
]
|
]
|
||||||
for mask_type in mask_types:
|
for mask_type in mask_types:
|
||||||
lname = lname.replace(mask_type, '')
|
lname = lname.replace(mask_type, '')
|
||||||
@ -3228,6 +3232,12 @@ def get_cubin_header(kernel_traits, specs_names):
|
|||||||
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
{cubin_name}_len, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
||||||
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
||||||
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
|
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
|
||||||
|
'''.format(**locals()) if 'sage' in kname and 'sm90' in kname else '''\
|
||||||
|
{{ DATA_TYPE_{prec}, DATA_TYPE_{output_prec}, {seq_len}, {q_step}, {kv_step}, {head_size}, {head_size_v}, \
|
||||||
|
{sage_block_sizes[0]}, {sage_block_sizes[1]}, {sage_block_sizes[2]}, kSM_{sm}, nullptr, \
|
||||||
|
0, \"{kname}\", {smem}, {threads}, {meta_unroll_step}, {attention_mask_type_value}, \
|
||||||
|
{attention_input_layout_value}, {is_il}, {is_flash_atten}, {is_warp_specialization}, {is_fp32_accu}, \
|
||||||
|
{is_alibi_supported}, {is_tiled}, {has_softcapping_scale}, {return_softmax_stats_flag}, {lname}}}\
|
||||||
'''.format(**locals())
|
'''.format(**locals())
|
||||||
else:
|
else:
|
||||||
code = '''\
|
code = '''\
|
||||||
@ -3332,7 +3342,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
|||||||
{metadata_v2}
|
{metadata_v2}
|
||||||
}};
|
}};
|
||||||
{local_ns_close}
|
{local_ns_close}
|
||||||
|
|
||||||
'''.format(**locals(), copyright=copyright)
|
'''.format(**locals(), copyright=copyright)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -3540,7 +3549,10 @@ def enumerate_hgmma_ldgsts_kernels(specs, sm=90, dtype='fp16'):
|
|||||||
|
|
||||||
|
|
||||||
# Note this will be used in TRT-LLM.
|
# Note this will be used in TRT-LLM.
|
||||||
def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
def enumerate_hgmma_flash_warpspec_kernels(specs,
|
||||||
|
sm=90,
|
||||||
|
dtype='fp16',
|
||||||
|
head_size_v=0):
|
||||||
|
|
||||||
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
|
scheduling_mode = int(os.getenv('SCHEDULING_MODE', '1'))
|
||||||
|
|
||||||
@ -3563,6 +3575,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
seq_len=0, # support any sequence length
|
seq_len=0, # support any sequence length
|
||||||
head_size=[32, 40, 48, 64],
|
head_size=[32, 40, 48, 64],
|
||||||
|
head_size_v=head_size_v,
|
||||||
warps_m=4, #4x1 warpgroups
|
warps_m=4, #4x1 warpgroups
|
||||||
warps_n=1,
|
warps_n=1,
|
||||||
version=2,
|
version=2,
|
||||||
@ -3595,6 +3608,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
seq_len=0, # support any sequence length
|
seq_len=0, # support any sequence length
|
||||||
head_size=[72, 80, 96, 104, 128],
|
head_size=[72, 80, 96, 104, 128],
|
||||||
|
head_size_v=head_size_v,
|
||||||
warps_m=4, #4x1 warpgroups
|
warps_m=4, #4x1 warpgroups
|
||||||
warps_n=1,
|
warps_n=1,
|
||||||
version=2,
|
version=2,
|
||||||
@ -3627,6 +3641,7 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
seq_len=0, # support any sequence length
|
seq_len=0, # support any sequence length
|
||||||
head_size=[160, 192, 256],
|
head_size=[160, 192, 256],
|
||||||
|
head_size_v=head_size_v,
|
||||||
warps_m=4, #4x1 warpgroups
|
warps_m=4, #4x1 warpgroups
|
||||||
warps_n=1,
|
warps_n=1,
|
||||||
version=2,
|
version=2,
|
||||||
@ -3652,6 +3667,40 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
|||||||
scheduling_mode=scheduling_mode,
|
scheduling_mode=scheduling_mode,
|
||||||
input_layout=input_layout))
|
input_layout=input_layout))
|
||||||
|
|
||||||
|
# for deepseek context 192/128, kv_step=128
|
||||||
|
specs.append(
|
||||||
|
kernel_spec(
|
||||||
|
sm=sm,
|
||||||
|
sm_mma=90,
|
||||||
|
dtype=dtype,
|
||||||
|
seq_len=0, # support any sequence length
|
||||||
|
head_size=192,
|
||||||
|
head_size_v=128,
|
||||||
|
warps_m=4, #4x1 warpgroups
|
||||||
|
warps_n=1,
|
||||||
|
version=2,
|
||||||
|
interleaved=False,
|
||||||
|
ldgsts_q=
|
||||||
|
False, # for Hopper kernels, ldgsts = False signals TMA usage.
|
||||||
|
ldgsts_k=False,
|
||||||
|
ldgsts_v=False,
|
||||||
|
share_smem_k_v=False,
|
||||||
|
loop_step=64,
|
||||||
|
q_tile_buffers=1, # only used by warp specialized kernels
|
||||||
|
has_noloop=0,
|
||||||
|
noloop_step=64,
|
||||||
|
kv_loop_step=128,
|
||||||
|
kv_tile_buffers=2, # only used by warp specialized kernels
|
||||||
|
unroll_threshold=1,
|
||||||
|
has_scale_max=False,
|
||||||
|
flash_attention=True,
|
||||||
|
warp_specialization=True,
|
||||||
|
alibi=alibi,
|
||||||
|
enable_attn_logit_softcapping=enable_attn_logit_softcapping,
|
||||||
|
return_softmax_stats=return_softmax,
|
||||||
|
scheduling_mode=scheduling_mode,
|
||||||
|
input_layout=input_layout))
|
||||||
|
|
||||||
|
|
||||||
# Note this will be used in TRT-LLM.
|
# Note this will be used in TRT-LLM.
|
||||||
def enumerate_qgmma_flash_warpspec_kernels(specs,
|
def enumerate_qgmma_flash_warpspec_kernels(specs,
|
||||||
@ -6215,7 +6264,21 @@ def enumerate_kernels():
|
|||||||
and kspec.cross_mha == False
|
and kspec.cross_mha == False
|
||||||
and kspec.flash_attention == True
|
and kspec.flash_attention == True
|
||||||
and kspec.warp_specialization == False
|
and kspec.warp_specialization == False
|
||||||
and kspec.tiled == True)
|
and kspec.tiled == True
|
||||||
|
and not (kspec.sm == 90 and (kspec.head_size, kspec.head_size_v) == (192, 128)))
|
||||||
|
# Deepseek MLA (hopper-style context 192/128 packed + paged)
|
||||||
|
or (kspec.sm == 90
|
||||||
|
and kspec.dtype == 'bf16'
|
||||||
|
and kspec.head_size == 192
|
||||||
|
and kspec.head_size_v == 128
|
||||||
|
and kspec.sage_block_sizes is None
|
||||||
|
and kspec.version == 2
|
||||||
|
and kspec.cross_mha == False
|
||||||
|
and kspec.flash_attention == True
|
||||||
|
and kspec.warp_specialization == True
|
||||||
|
and kspec.input_layout in [InputLayout.PACKED_QKV, InputLayout.Q_PAGED_KV]
|
||||||
|
and kspec.alibi == False
|
||||||
|
and kspec.enable_attn_logit_softcapping == False)
|
||||||
# SageAttention (warp_spec, head_size in (80, 128), packed QKV, padding mask)
|
# SageAttention (warp_spec, head_size in (80, 128), packed QKV, padding mask)
|
||||||
or (kspec.sm == 90
|
or (kspec.sm == 90
|
||||||
and kspec.head_size in [80, 128]
|
and kspec.head_size in [80, 128]
|
||||||
|
|||||||
@ -173,7 +173,7 @@ struct Compute
|
|||||||
|
|
||||||
enum
|
enum
|
||||||
{
|
{
|
||||||
TILE_SIZE_V = STEP_KV * Kernel_traits::D
|
TILE_SIZE_V = STEP_KV * Kernel_traits::DV
|
||||||
};
|
};
|
||||||
|
|
||||||
enum
|
enum
|
||||||
|
|||||||
@ -76,7 +76,7 @@ struct DMA
|
|||||||
// The tile size of V.
|
// The tile size of V.
|
||||||
enum
|
enum
|
||||||
{
|
{
|
||||||
TILE_SIZE_V = TILE_SIZE_K
|
TILE_SIZE_V = STEP_KV * Kernel_traits::DV
|
||||||
};
|
};
|
||||||
|
|
||||||
// The tile size of V after head_dimension split.
|
// The tile size of V after head_dimension split.
|
||||||
@ -280,6 +280,7 @@ struct DMA
|
|||||||
|
|
||||||
cudaTmaDesc const* desc_q = ¶ms.tma_desc_q;
|
cudaTmaDesc const* desc_q = ¶ms.tma_desc_q;
|
||||||
cudaTmaDesc const* desc_kv = ¶ms.tma_desc_kv;
|
cudaTmaDesc const* desc_kv = ¶ms.tma_desc_kv;
|
||||||
|
cudaTmaDesc const* desc_v = ¶ms.tma_desc_v;
|
||||||
int actual_seqlen;
|
int actual_seqlen;
|
||||||
if (params.is_s_padded)
|
if (params.is_s_padded)
|
||||||
{
|
{
|
||||||
@ -342,8 +343,8 @@ struct DMA
|
|||||||
// Iterate over the kv tiles for this q step.
|
// Iterate over the kv tiles for this q step.
|
||||||
for (int kv_step_idx = kv_idx_start; kv_step_idx < kv_idx_end; kv_step_idx++)
|
for (int kv_step_idx = kv_idx_start; kv_step_idx < kv_idx_end; kv_step_idx++)
|
||||||
{
|
{
|
||||||
int bar_id = load_kv(bidh, params.h, params.h_kv, kv_step_idx, desc_kv, shared, cbw_k, cbw_v,
|
int bar_id = load_kv(bidh, params.h, params.h_kv, kv_step_idx, desc_kv, desc_v, shared, cbw_k,
|
||||||
cbw_v_scratch, cbr_v_scratch);
|
cbw_v, cbw_v_scratch, cbr_v_scratch);
|
||||||
|
|
||||||
// Opportunistically hide headinfo in the shadow of UTMALDGs of the QKV tensor
|
// Opportunistically hide headinfo in the shadow of UTMALDGs of the QKV tensor
|
||||||
if (q_step_idx == 0 && kv_step_idx == kv_idx_start)
|
if (q_step_idx == 0 && kv_step_idx == kv_idx_start)
|
||||||
@ -511,7 +512,17 @@ struct DMA
|
|||||||
int32_t const* paged_block_offsets
|
int32_t const* paged_block_offsets
|
||||||
= params.paged_kv_cache.mBlockOffsets + bidb * 2 * params.paged_kv_cache.mMaxBlocksPerSeq;
|
= params.paged_kv_cache.mBlockOffsets + bidb * 2 * params.paged_kv_cache.mMaxBlocksPerSeq;
|
||||||
cudaTmaDesc const* desc_kv = ¶ms.tma_desc_kv;
|
cudaTmaDesc const* desc_kv = ¶ms.tma_desc_kv;
|
||||||
|
// If a separate v_stride_in_bytes is set, we have to use separate tma_desc_v,
|
||||||
|
// otherwise share with tma_desc_kv.
|
||||||
|
// This is for the compatibility that TensorRT-LLM needs no modification if padding V to 192.
|
||||||
|
#ifndef GENERATE_CUBIN
|
||||||
|
cudaTmaDesc const* desc_v
|
||||||
|
= (params.v_stride_in_bytes == 0 || params.v_stride_in_bytes == params.kv_stride_in_bytes)
|
||||||
|
? desc_kv
|
||||||
|
: ¶ms.tma_desc_v;
|
||||||
|
#else
|
||||||
|
cudaTmaDesc const* desc_v = desc_kv;
|
||||||
|
#endif
|
||||||
if (SCHEDULING_MODE == 0)
|
if (SCHEDULING_MODE == 0)
|
||||||
{
|
{
|
||||||
// split work across M
|
// split work across M
|
||||||
@ -575,7 +586,8 @@ struct DMA
|
|||||||
bar_id = load_paged_kv(bidh_kv, remapped_kv_step_idx * STEP_KV, num_valid_kv_blocks,
|
bar_id = load_paged_kv(bidh_kv, remapped_kv_step_idx * STEP_KV, num_valid_kv_blocks,
|
||||||
params.paged_kv_cache.mTokensPerBlockLog2, params.blocks_per_tma_load,
|
params.paged_kv_cache.mTokensPerBlockLog2, params.blocks_per_tma_load,
|
||||||
params.blocks_per_tma_load_log2, params.paged_kv_cache.mMaxBlocksPerSeq,
|
params.blocks_per_tma_load_log2, params.paged_kv_cache.mMaxBlocksPerSeq,
|
||||||
paged_block_offsets, desc_kv, shared, cbw_k, cbw_v, cbw_v_scratch, cbr_v_scratch);
|
paged_block_offsets, desc_kv, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch,
|
||||||
|
cbr_v_scratch);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -670,7 +682,7 @@ struct DMA
|
|||||||
// Load k,v tiles from gmem to smem by TMA.
|
// Load k,v tiles from gmem to smem by TMA.
|
||||||
template <typename BufferWriter>
|
template <typename BufferWriter>
|
||||||
inline __device__ void load_kv_impl(int bidh, int h, int h_kv, int kv_step_idx, cudaTmaDesc const* desc_kv,
|
inline __device__ void load_kv_impl(int bidh, int h, int h_kv, int kv_step_idx, cudaTmaDesc const* desc_kv,
|
||||||
Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v)
|
cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v)
|
||||||
{
|
{
|
||||||
|
|
||||||
int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES);
|
int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES);
|
||||||
@ -685,8 +697,6 @@ struct DMA
|
|||||||
// split D into multiple groups in order to satisfy the TMA 128B sizzle mode
|
// split D into multiple groups in order to satisfy the TMA 128B sizzle mode
|
||||||
int32_t const k_coord_dim1 = HEADS_INTERLEAVED ? 1 : bidh;
|
int32_t const k_coord_dim1 = HEADS_INTERLEAVED ? 1 : bidh;
|
||||||
int32_t const k_coord_dim2 = HEADS_INTERLEAVED ? bidh : 1;
|
int32_t const k_coord_dim2 = HEADS_INTERLEAVED ? bidh : 1;
|
||||||
int32_t const v_coord_dim1 = HEADS_INTERLEAVED ? 2 : bidh;
|
|
||||||
int32_t const v_coord_dim2 = HEADS_INTERLEAVED ? bidh : 2;
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int di = 0; di < Kernel_traits::D_GROUPS; ++di)
|
for (int di = 0; di < Kernel_traits::D_GROUPS; ++di)
|
||||||
@ -699,12 +709,14 @@ struct DMA
|
|||||||
__cvta_generic_to_shared(
|
__cvta_generic_to_shared(
|
||||||
&shared->smem_k[k_barrier_id * TILE_SIZE_K + di * TILE_SIZE_K_PER_D_GROUP]),
|
&shared->smem_k[k_barrier_id * TILE_SIZE_K + di * TILE_SIZE_K_PER_D_GROUP]),
|
||||||
__cvta_generic_to_shared(cbw_k.barrier_ptr(k_barrier_id)), k_coords, elect_one_);
|
__cvta_generic_to_shared(cbw_k.barrier_ptr(k_barrier_id)), k_coords, elect_one_);
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int di = 0; di < Kernel_traits::DV_GROUPS; ++di)
|
||||||
|
{
|
||||||
int32_t const v_coords[4] = {di * Kernel_traits::D_PER_GROUP,
|
int32_t const v_coords[4] = {di * Kernel_traits::D_PER_GROUP,
|
||||||
multi_query_attention_ ? h + h_kv + bidh / (h / h_kv) : v_coord_dim1,
|
multi_query_attention_ ? bidh / (h / h_kv) : bidh, 0, sum_s_q_ + kv_step_idx * STEP_KV};
|
||||||
multi_query_attention_ ? 0 : v_coord_dim2, sum_s_q_ + kv_step_idx * STEP_KV};
|
|
||||||
|
|
||||||
fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_kv,
|
fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_v,
|
||||||
__cvta_generic_to_shared(
|
__cvta_generic_to_shared(
|
||||||
&shared->smem_v[v_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP]),
|
&shared->smem_v[v_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP]),
|
||||||
__cvta_generic_to_shared(cbw_v.barrier_ptr(v_barrier_id)), v_coords, elect_one_);
|
__cvta_generic_to_shared(cbw_v.barrier_ptr(v_barrier_id)), v_coords, elect_one_);
|
||||||
@ -748,8 +760,8 @@ struct DMA
|
|||||||
template <typename BufferWriter>
|
template <typename BufferWriter>
|
||||||
inline __device__ void load_paged_kv_impl(int bidh, int kv_tile_start_offset, int num_valid_kv_blocks,
|
inline __device__ void load_paged_kv_impl(int bidh, int kv_tile_start_offset, int num_valid_kv_blocks,
|
||||||
int tokens_per_block_log2, int blocks_per_tma_load, int blocks_per_tma_load_log2,
|
int tokens_per_block_log2, int blocks_per_tma_load, int blocks_per_tma_load_log2,
|
||||||
int max_blocks_per_sequence, int32_t const* paged_block_offsets, cudaTmaDesc const* desc_kv, Shared* shared,
|
int max_blocks_per_sequence, int32_t const* paged_block_offsets, cudaTmaDesc const* desc_kv,
|
||||||
BufferWriter& cbw_k, BufferWriter& cbw_v)
|
cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v)
|
||||||
{
|
{
|
||||||
|
|
||||||
int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES);
|
int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES);
|
||||||
@ -783,11 +795,14 @@ struct DMA
|
|||||||
__cvta_generic_to_shared(&shared->smem_k[k_barrier_id * TILE_SIZE_K
|
__cvta_generic_to_shared(&shared->smem_k[k_barrier_id * TILE_SIZE_K
|
||||||
+ di * TILE_SIZE_K_PER_D_GROUP + bi * tile_size_k_per_block]),
|
+ di * TILE_SIZE_K_PER_D_GROUP + bi * tile_size_k_per_block]),
|
||||||
__cvta_generic_to_shared(cbw_k.barrier_ptr(k_barrier_id)), k_coords, elect_one_);
|
__cvta_generic_to_shared(cbw_k.barrier_ptr(k_barrier_id)), k_coords, elect_one_);
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int di = 0; di < Kernel_traits::DV_GROUPS; ++di)
|
||||||
|
{
|
||||||
int32_t const v_coords[4]
|
int32_t const v_coords[4]
|
||||||
= {di * Kernel_traits::D_PER_GROUP, kv_offset_in_block, bidh, v_paged_block_offset};
|
= {di * Kernel_traits::D_PER_GROUP, kv_offset_in_block, bidh, v_paged_block_offset};
|
||||||
|
|
||||||
fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_kv,
|
fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_v,
|
||||||
__cvta_generic_to_shared(&shared->smem_v[v_barrier_id * TILE_SIZE_V
|
__cvta_generic_to_shared(&shared->smem_v[v_barrier_id * TILE_SIZE_V
|
||||||
+ di * TILE_SIZE_V_PER_D_GROUP + bi * tile_size_k_per_block]),
|
+ di * TILE_SIZE_V_PER_D_GROUP + bi * tile_size_k_per_block]),
|
||||||
__cvta_generic_to_shared(cbw_v.barrier_ptr(v_barrier_id)), v_coords, elect_one_);
|
__cvta_generic_to_shared(cbw_v.barrier_ptr(v_barrier_id)), v_coords, elect_one_);
|
||||||
@ -877,8 +892,8 @@ struct DMA
|
|||||||
// Load k,v tiles from gmem to smem by TMA.
|
// Load k,v tiles from gmem to smem by TMA.
|
||||||
template <typename BufferWriter, typename BufferWriterScratch, typename BufferReaderScratch>
|
template <typename BufferWriter, typename BufferWriterScratch, typename BufferReaderScratch>
|
||||||
inline __device__ int load_kv_transpose_v_impl(int bidh, int h, int h_kv, int kv_step_idx,
|
inline __device__ int load_kv_transpose_v_impl(int bidh, int h, int h_kv, int kv_step_idx,
|
||||||
cudaTmaDesc const* desc_kv, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v,
|
cudaTmaDesc const* desc_kv, cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k,
|
||||||
BufferWriterScratch& cbw_v_scratch, BufferReaderScratch& cbr_v_scratch)
|
BufferWriter& cbw_v, BufferWriterScratch& cbw_v_scratch, BufferReaderScratch& cbr_v_scratch)
|
||||||
{
|
{
|
||||||
int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES);
|
int k_barrier_id = cbw_k.tmaReserve(elect_one_, (TILE_SIZE_K) *Kernel_traits::ELEMENT_BYTES);
|
||||||
|
|
||||||
@ -890,8 +905,6 @@ struct DMA
|
|||||||
// split D into multiple groups in order to satisfy the TMA 128B sizzle mode
|
// split D into multiple groups in order to satisfy the TMA 128B sizzle mode
|
||||||
int32_t const k_coord_dim1 = HEADS_INTERLEAVED ? 1 : bidh;
|
int32_t const k_coord_dim1 = HEADS_INTERLEAVED ? 1 : bidh;
|
||||||
int32_t const k_coord_dim2 = HEADS_INTERLEAVED ? bidh : 1;
|
int32_t const k_coord_dim2 = HEADS_INTERLEAVED ? bidh : 1;
|
||||||
int32_t const v_coord_dim1 = HEADS_INTERLEAVED ? 2 : bidh;
|
|
||||||
int32_t const v_coord_dim2 = HEADS_INTERLEAVED ? bidh : 2;
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int di = 0; di < Kernel_traits::D_GROUPS; ++di)
|
for (int di = 0; di < Kernel_traits::D_GROUPS; ++di)
|
||||||
@ -910,13 +923,12 @@ struct DMA
|
|||||||
= cbw_v_scratch.tmaReserve(elect_one_, (TILE_SIZE_V) *Kernel_traits::ELEMENT_BYTES);
|
= cbw_v_scratch.tmaReserve(elect_one_, (TILE_SIZE_V) *Kernel_traits::ELEMENT_BYTES);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int di = 0; di < Kernel_traits::D_GROUPS; ++di)
|
for (int di = 0; di < Kernel_traits::DV_GROUPS; ++di)
|
||||||
{
|
{
|
||||||
int32_t const v_coords[4] = {di * Kernel_traits::D_PER_GROUP,
|
int32_t const v_coords[4] = {di * Kernel_traits::D_PER_GROUP,
|
||||||
multi_query_attention_ ? h + h_kv + bidh / (h / h_kv) : v_coord_dim1,
|
multi_query_attention_ ? bidh / (h / h_kv) : bidh, 0, sum_s_q_ + kv_step_idx * STEP_KV};
|
||||||
multi_query_attention_ ? 0 : v_coord_dim2, sum_s_q_ + kv_step_idx * STEP_KV};
|
|
||||||
|
|
||||||
fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_kv,
|
fmha::utmaldg<4, fmha::cudaTmaDescType::TILED, false>(desc_v,
|
||||||
__cvta_generic_to_shared(
|
__cvta_generic_to_shared(
|
||||||
&shared->smem_v_scratch[v_scratch_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP]),
|
&shared->smem_v_scratch[v_scratch_barrier_id * TILE_SIZE_V + di * TILE_SIZE_V_PER_D_GROUP]),
|
||||||
__cvta_generic_to_shared(cbw_v_scratch.barrier_ptr(v_scratch_barrier_id)), v_coords, elect_one_);
|
__cvta_generic_to_shared(cbw_v_scratch.barrier_ptr(v_scratch_barrier_id)), v_coords, elect_one_);
|
||||||
@ -1030,19 +1042,19 @@ struct DMA
|
|||||||
// Load k,v tiles from gmem to smem by TMA.
|
// Load k,v tiles from gmem to smem by TMA.
|
||||||
template <typename BufferWriter, typename BufferWriterScratch, typename BufferReaderScratch>
|
template <typename BufferWriter, typename BufferWriterScratch, typename BufferReaderScratch>
|
||||||
inline __device__ int load_kv(int bidh, int h, int h_kv, int kv_step_idx, cudaTmaDesc const* desc_kv,
|
inline __device__ int load_kv(int bidh, int h, int h_kv, int kv_step_idx, cudaTmaDesc const* desc_kv,
|
||||||
Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v, BufferWriterScratch& cbw_v_scratch,
|
cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v,
|
||||||
BufferReaderScratch& cbr_v_scratch)
|
BufferWriterScratch& cbw_v_scratch, BufferReaderScratch& cbr_v_scratch)
|
||||||
{
|
{
|
||||||
|
|
||||||
if constexpr (DMA_GROUP_TRANSPOSE_V)
|
if constexpr (DMA_GROUP_TRANSPOSE_V)
|
||||||
{
|
{
|
||||||
int v_scratch_barrier_id = load_kv_transpose_v_impl(
|
int v_scratch_barrier_id = load_kv_transpose_v_impl(
|
||||||
bidh, h, h_kv, kv_step_idx, desc_kv, shared, cbw_k, cbw_v, cbw_v_scratch, cbr_v_scratch);
|
bidh, h, h_kv, kv_step_idx, desc_kv, desc_v, shared, cbw_k, cbw_v, cbw_v_scratch, cbr_v_scratch);
|
||||||
return v_scratch_barrier_id;
|
return v_scratch_barrier_id;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
load_kv_impl(bidh, h, h_kv, kv_step_idx, desc_kv, shared, cbw_k, cbw_v);
|
load_kv_impl(bidh, h, h_kv, kv_step_idx, desc_kv, desc_v, shared, cbw_k, cbw_v);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1071,9 +1083,9 @@ struct DMA
|
|||||||
template <typename BufferWriter, typename BufferWriterScratch, typename BufferReaderScratch>
|
template <typename BufferWriter, typename BufferWriterScratch, typename BufferReaderScratch>
|
||||||
inline __device__ int load_paged_kv(int bidh, int kv_tile_start_offset, int num_valid_kv_blocks,
|
inline __device__ int load_paged_kv(int bidh, int kv_tile_start_offset, int num_valid_kv_blocks,
|
||||||
int tokens_per_block_log2, int blocks_per_tma_load, int blocks_per_tma_load_log2,
|
int tokens_per_block_log2, int blocks_per_tma_load, int blocks_per_tma_load_log2,
|
||||||
int max_blocks_per_sequence, int32_t const* paged_block_offsets, cudaTmaDesc const* desc_kv, Shared* shared,
|
int max_blocks_per_sequence, int32_t const* paged_block_offsets, cudaTmaDesc const* desc_kv,
|
||||||
BufferWriter& cbw_k, BufferWriter& cbw_v, BufferWriterScratch& cbw_v_scratch,
|
cudaTmaDesc const* desc_v, Shared* shared, BufferWriter& cbw_k, BufferWriter& cbw_v,
|
||||||
BufferReaderScratch& cbr_v_scratch)
|
BufferWriterScratch& cbw_v_scratch, BufferReaderScratch& cbr_v_scratch)
|
||||||
{
|
{
|
||||||
|
|
||||||
if constexpr (DMA_GROUP_TRANSPOSE_V)
|
if constexpr (DMA_GROUP_TRANSPOSE_V)
|
||||||
@ -1088,7 +1100,7 @@ struct DMA
|
|||||||
{
|
{
|
||||||
load_paged_kv_impl(bidh, kv_tile_start_offset, num_valid_kv_blocks, tokens_per_block_log2,
|
load_paged_kv_impl(bidh, kv_tile_start_offset, num_valid_kv_blocks, tokens_per_block_log2,
|
||||||
blocks_per_tma_load, blocks_per_tma_load_log2, max_blocks_per_sequence, paged_block_offsets,
|
blocks_per_tma_load, blocks_per_tma_load_log2, max_blocks_per_sequence, paged_block_offsets,
|
||||||
desc_kv, shared, cbw_k, cbw_v);
|
desc_kv, desc_v, shared, cbw_k, cbw_v);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1141,32 +1153,46 @@ struct DMA
|
|||||||
|
|
||||||
// Per batch tensor size.
|
// Per batch tensor size.
|
||||||
uint32_t tensor_size_qkv[4];
|
uint32_t tensor_size_qkv[4];
|
||||||
|
// Stride size in bytes. Assumes least significant dim is 1 (?)
|
||||||
|
uint64_t tensor_size_qk[3], tensor_size_v[3];
|
||||||
|
uint32_t v_offset;
|
||||||
// Total sequence length.
|
// Total sequence length.
|
||||||
int const total_seqlen = params.is_s_padded ? (params.b * params.s) : launch_params.total_q_seqlen;
|
int const total_seqlen = params.is_s_padded ? (params.b * params.s) : launch_params.total_q_seqlen;
|
||||||
|
tensor_size_qkv[0] = params.d; // params.d;
|
||||||
tensor_size_qkv[3] = total_seqlen;
|
tensor_size_qkv[3] = total_seqlen;
|
||||||
|
tensor_size_qk[0] = params.d * Kernel_traits::ELEMENT_BYTES;
|
||||||
|
tensor_size_qk[2] = params.qkv_stride_in_bytes;
|
||||||
|
tensor_size_v[1] = 0;
|
||||||
|
tensor_size_v[2] = params.qkv_stride_in_bytes;
|
||||||
if (params.h_kv < params.h)
|
if (params.h_kv < params.h)
|
||||||
{
|
{
|
||||||
// Take MQA as non-heads-interleaved.
|
// Take MQA as non-heads-interleaved.
|
||||||
|
tensor_size_qkv[1] = params.h + params.h_kv;
|
||||||
tensor_size_qkv[2] = 1;
|
tensor_size_qkv[2] = 1;
|
||||||
tensor_size_qkv[1] = (params.h + 2 * params.h_kv);
|
tensor_size_qk[1] = 0;
|
||||||
tensor_size_qkv[0] = params.d; // params.d;
|
tensor_size_v[0] = params.dv * Kernel_traits::ELEMENT_BYTES;
|
||||||
|
v_offset = (params.h + params.h_kv) * params.d * Kernel_traits::ELEMENT_BYTES;
|
||||||
}
|
}
|
||||||
else if (HEADS_INTERLEAVED)
|
else if (HEADS_INTERLEAVED)
|
||||||
{
|
{
|
||||||
|
tensor_size_qkv[1] = 2;
|
||||||
tensor_size_qkv[2] = params.h;
|
tensor_size_qkv[2] = params.h;
|
||||||
tensor_size_qkv[1] = 3;
|
tensor_size_qk[1] = (2 * params.d + params.dv) * Kernel_traits::ELEMENT_BYTES;
|
||||||
tensor_size_qkv[0] = params.d; // params.d;
|
tensor_size_v[0] = tensor_size_qk[1];
|
||||||
|
v_offset = 2 * params.d * Kernel_traits::ELEMENT_BYTES;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
tensor_size_qkv[2] = 3;
|
|
||||||
tensor_size_qkv[1] = params.h;
|
tensor_size_qkv[1] = params.h;
|
||||||
tensor_size_qkv[0] = params.d; // params.d;
|
tensor_size_qkv[2] = 2;
|
||||||
|
tensor_size_qk[1] = params.h * tensor_size_qk[0];
|
||||||
|
tensor_size_v[0] = params.dv * Kernel_traits::ELEMENT_BYTES;
|
||||||
|
v_offset = 2 * params.h * params.d * Kernel_traits::ELEMENT_BYTES;
|
||||||
}
|
}
|
||||||
|
|
||||||
// O : [TOTAL, 1, h, d]
|
// O : [TOTAL, 1, h, d]
|
||||||
uint32_t tensor_size_o[4];
|
uint32_t tensor_size_o[4];
|
||||||
tensor_size_o[0] = params.d;
|
tensor_size_o[0] = params.dv;
|
||||||
tensor_size_o[1] = params.h;
|
tensor_size_o[1] = params.h;
|
||||||
tensor_size_o[2] = 1;
|
tensor_size_o[2] = 1;
|
||||||
tensor_size_o[3] = total_seqlen;
|
tensor_size_o[3] = total_seqlen;
|
||||||
@ -1178,16 +1204,10 @@ struct DMA
|
|||||||
box_size[1] = 1;
|
box_size[1] = 1;
|
||||||
box_size[0] = Kernel_traits::D_PER_GROUP;
|
box_size[0] = Kernel_traits::D_PER_GROUP;
|
||||||
|
|
||||||
// Stride size in bytes. Assumes least significant dim is 1 (?)
|
|
||||||
uint64_t tensor_stride_qkv[3];
|
|
||||||
tensor_stride_qkv[0] = tensor_size_qkv[0] * Kernel_traits::ELEMENT_BYTES; // d
|
|
||||||
tensor_stride_qkv[1] = tensor_size_qkv[1] * tensor_stride_qkv[0]; // d*h
|
|
||||||
tensor_stride_qkv[2] = tensor_size_qkv[2] * tensor_stride_qkv[1]; // d*h*3
|
|
||||||
|
|
||||||
uint64_t tensor_stride_o[3];
|
uint64_t tensor_stride_o[3];
|
||||||
tensor_stride_o[0] = tensor_size_o[0] * Kernel_traits::ELEMENT_BYTES; // d
|
tensor_stride_o[0] = tensor_size_o[0] * Kernel_traits::ELEMENT_BYTES; // dv
|
||||||
tensor_stride_o[1] = tensor_size_o[1] * tensor_stride_o[0]; // d*h
|
tensor_stride_o[1] = tensor_size_o[1] * tensor_stride_o[0]; // dv*h
|
||||||
tensor_stride_o[2] = tensor_size_o[2] * tensor_stride_o[1]; // d*h*1
|
tensor_stride_o[2] = tensor_size_o[2] * tensor_stride_o[1]; // dv*h*1
|
||||||
|
|
||||||
// Traversal stride.
|
// Traversal stride.
|
||||||
uint32_t traversal_stride_qkv[4] = {1, 1, 1, 1};
|
uint32_t traversal_stride_qkv[4] = {1, 1, 1, 1};
|
||||||
@ -1225,7 +1245,7 @@ struct DMA
|
|||||||
box_size[3] = STEP_Q;
|
box_size[3] = STEP_Q;
|
||||||
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format,
|
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format,
|
||||||
fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode,
|
fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode,
|
||||||
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_stride_qkv,
|
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_size_qk,
|
||||||
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, ¶ms.tma_desc_q);
|
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, ¶ms.tma_desc_q);
|
||||||
|
|
||||||
// O: 16
|
// O: 16
|
||||||
@ -1242,8 +1262,18 @@ struct DMA
|
|||||||
box_size[3] = STEP_KV;
|
box_size[3] = STEP_KV;
|
||||||
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format,
|
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format,
|
||||||
fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode,
|
fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode,
|
||||||
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_stride_qkv,
|
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_size_qk,
|
||||||
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, ¶ms.tma_desc_kv);
|
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, ¶ms.tma_desc_kv);
|
||||||
|
|
||||||
|
// V: STEP_KV.
|
||||||
|
tensor_size_qkv[0] = params.dv;
|
||||||
|
tensor_size_qkv[1] = params.h_kv;
|
||||||
|
tensor_size_qkv[2] = 1;
|
||||||
|
|
||||||
|
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr + v_offset, desc_format,
|
||||||
|
fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode,
|
||||||
|
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_size_v,
|
||||||
|
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, ¶ms.tma_desc_v);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -1353,7 +1383,16 @@ struct DMA
|
|||||||
// Paged KV: [UINT32_MAX, H, TokensPerBlock, D]
|
// Paged KV: [UINT32_MAX, H, TokensPerBlock, D]
|
||||||
// Per batch tensor size.
|
// Per batch tensor size.
|
||||||
uint32_t tensor_size_kv[4];
|
uint32_t tensor_size_kv[4];
|
||||||
tensor_size_kv[3] = params.b * 2 * params.paged_kv_cache.mMaxBlocksPerSeq;
|
// The original code is:
|
||||||
|
// tensor_size_kv[3] = params.b * 2 * params.paged_kv_cache.mMaxBlocksPerSeq;
|
||||||
|
// If d != dv and v is not padded, then the code should be:
|
||||||
|
// tensor_size_kv[3] = params.b * params.paged_kv_cache.mMaxBlocksPerSeq
|
||||||
|
// * ((params.d + params.dv) / std::gcd(params.d, params.dv));
|
||||||
|
// TensorRT-LLM uses:
|
||||||
|
// tensor_size_kv[3] = mLaunchParams.total_device_memory /
|
||||||
|
// mKernelParams.paged_kv_cache.mBytesPerBlock;
|
||||||
|
// I think the simplest way is:
|
||||||
|
tensor_size_kv[3] = INT_MAX;
|
||||||
tensor_size_kv[2] = params.h_kv;
|
tensor_size_kv[2] = params.h_kv;
|
||||||
tensor_size_kv[1] = params.paged_kv_cache.mTokensPerBlock;
|
tensor_size_kv[1] = params.paged_kv_cache.mTokensPerBlock;
|
||||||
tensor_size_kv[0] = params.d; // params.d;
|
tensor_size_kv[0] = params.d; // params.d;
|
||||||
@ -1373,14 +1412,28 @@ struct DMA
|
|||||||
// Stride size in bytes. Assumes least significant dim is 1 (?)
|
// Stride size in bytes. Assumes least significant dim is 1 (?)
|
||||||
uint64_t tensor_stride_kv[3];
|
uint64_t tensor_stride_kv[3];
|
||||||
tensor_stride_kv[0] = tensor_size_kv[0] * Kernel_traits::ELEMENT_BYTES; // d
|
tensor_stride_kv[0] = tensor_size_kv[0] * Kernel_traits::ELEMENT_BYTES; // d
|
||||||
tensor_stride_kv[1] = tensor_size_kv[1] * tensor_stride_kv[0]; // d*h
|
// The original code is:
|
||||||
tensor_stride_kv[2] = tensor_size_kv[2] * tensor_stride_kv[1]; // d*h*3
|
// tensor_stride_kv[1] = tensor_size_kv[1] * tensor_stride_kv[0]; // d*mTokensPerBlock
|
||||||
|
// tensor_stride_kv[2] = tensor_size_kv[2] * tensor_stride_kv[1]; // d*mTokensPerBlock*h
|
||||||
|
// This can be simplified to:
|
||||||
|
tensor_stride_kv[1] = params.kv_stride_in_bytes;
|
||||||
|
tensor_stride_kv[2] = params.paged_kv_cache.mBytesPerBlock;
|
||||||
|
|
||||||
// Paged KV pool tma descriptors.
|
// Paged KV pool tma descriptors.
|
||||||
paged_kv_tma_descriptor.set_tma_desctriptor(reinterpret_cast<char*>(params.paged_kv_cache.mPoolPtr),
|
paged_kv_tma_descriptor.set_tma_desctriptor(reinterpret_cast<char*>(params.paged_kv_cache.mPoolPtr),
|
||||||
desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode,
|
desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode,
|
||||||
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_kv, tensor_stride_kv,
|
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_kv, tensor_stride_kv,
|
||||||
traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, ¶ms.tma_desc_kv);
|
traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, ¶ms.tma_desc_kv);
|
||||||
|
#ifndef GENERATE_CUBIN
|
||||||
|
tensor_size_kv[0] = params.dv;
|
||||||
|
tensor_stride_kv[0] = tensor_size_kv[0] * Kernel_traits::ELEMENT_BYTES; // dv
|
||||||
|
tensor_stride_kv[1] = params.v_stride_in_bytes; // dv*mTokensPerBlock
|
||||||
|
|
||||||
|
paged_kv_tma_descriptor.set_tma_desctriptor(reinterpret_cast<char*>(params.paged_kv_cache.mPoolPtr),
|
||||||
|
desc_format, fmha::cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode,
|
||||||
|
fmha::cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_kv, tensor_stride_kv,
|
||||||
|
traversal_stride, box_size_kv, oob_fill, fp32_to_tf32, ¶ms.tma_desc_v);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -36,6 +36,8 @@ template <
|
|||||||
int STEP_KV_,
|
int STEP_KV_,
|
||||||
// The head dimension.
|
// The head dimension.
|
||||||
int D_,
|
int D_,
|
||||||
|
// The head dimension of V.
|
||||||
|
int DV_,
|
||||||
// The number of smem buffers for Q tiles.
|
// The number of smem buffers for Q tiles.
|
||||||
int Q_BUFFERS_,
|
int Q_BUFFERS_,
|
||||||
// The number of smem buffers for K, and V tiles.
|
// The number of smem buffers for K, and V tiles.
|
||||||
@ -83,18 +85,17 @@ struct Kernel_traits
|
|||||||
STEP_KV = STEP_KV_
|
STEP_KV = STEP_KV_
|
||||||
};
|
};
|
||||||
|
|
||||||
// The padded head dimension.
|
|
||||||
enum
|
|
||||||
{
|
|
||||||
D = Next_power_of_two<D_>::VALUE
|
|
||||||
};
|
|
||||||
|
|
||||||
// The valid head dimension.
|
// The valid head dimension.
|
||||||
enum
|
enum
|
||||||
{
|
{
|
||||||
VALID_D = D_
|
VALID_D = D_
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum
|
||||||
|
{
|
||||||
|
VALID_DV = (DV_ == 0 ? D_ : DV_)
|
||||||
|
};
|
||||||
|
|
||||||
// Bootstrap GMMA_K from dummy Instruction_traits where FP16/BF16 K = 16, FP8 K = 32.
|
// Bootstrap GMMA_K from dummy Instruction_traits where FP16/BF16 K = 16, FP8 K = 32.
|
||||||
enum
|
enum
|
||||||
{
|
{
|
||||||
@ -113,6 +114,16 @@ struct Kernel_traits
|
|||||||
ELEMENT_BYTES = sizeof(Element_data_type)
|
ELEMENT_BYTES = sizeof(Element_data_type)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum
|
||||||
|
{
|
||||||
|
D = Next_power_of_two<VALID_D>::VALUE
|
||||||
|
};
|
||||||
|
|
||||||
|
enum
|
||||||
|
{
|
||||||
|
DV = Next_power_of_two<VALID_DV>::VALUE
|
||||||
|
};
|
||||||
|
|
||||||
// The number of smem buffers for Q tiles.
|
// The number of smem buffers for Q tiles.
|
||||||
enum
|
enum
|
||||||
{
|
{
|
||||||
@ -326,6 +337,18 @@ struct Kernel_traits
|
|||||||
D_BYTES_PER_GROUP = D_BYTES / D_GROUPS
|
D_BYTES_PER_GROUP = D_BYTES / D_GROUPS
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// The bytes of head dimension of V.
|
||||||
|
enum
|
||||||
|
{
|
||||||
|
DV_BYTES = DV * ELEMENT_BYTES
|
||||||
|
};
|
||||||
|
|
||||||
|
// The number of head_dimension groups of V.
|
||||||
|
enum
|
||||||
|
{
|
||||||
|
DV_GROUPS = fmha::Div_up<DV_BYTES, 128>::VALUE
|
||||||
|
};
|
||||||
|
|
||||||
// QGMMA: BMM2 will be split into multiple K groups as we explicitly transpose v (128 * D) in the smem.
|
// QGMMA: BMM2 will be split into multiple K groups as we explicitly transpose v (128 * D) in the smem.
|
||||||
// HGMMA: BMM2 will load from row-major (K * N) smem_v, so we don't need to explicitly split K.
|
// HGMMA: BMM2 will load from row-major (K * N) smem_v, so we don't need to explicitly split K.
|
||||||
static constexpr auto BMM2_LEADING_DIM_BYTES = ELEMENT_BYTES == 1 ? 128 : STEP_KV * ELEMENT_BYTES;
|
static constexpr auto BMM2_LEADING_DIM_BYTES = ELEMENT_BYTES == 1 ? 128 : STEP_KV * ELEMENT_BYTES;
|
||||||
@ -364,7 +387,7 @@ struct Kernel_traits
|
|||||||
|
|
||||||
// The instruction traits for the BMM2.
|
// The instruction traits for the BMM2.
|
||||||
// FP16/BF16 K = 16, FP8 K = 32.
|
// FP16/BF16 K = 16, FP8 K = 32.
|
||||||
using Traits_o = Instruction_traits<STEP_Q, D, GMMA_K, true, false>;
|
using Traits_o = Instruction_traits<STEP_Q, DV, GMMA_K, true, false>;
|
||||||
|
|
||||||
// The CTA description for BMM1.
|
// The CTA description for BMM1.
|
||||||
using Cta_tile_p =
|
using Cta_tile_p =
|
||||||
@ -375,7 +398,7 @@ struct Kernel_traits
|
|||||||
typename Traits_p::template Cta_tile<STEP_Q, STEP_KV, D_PER_GROUP, WARP_GROUP_M, WARP_GROUP_N, WARP_GROUP_K>;
|
typename Traits_p::template Cta_tile<STEP_Q, STEP_KV, D_PER_GROUP, WARP_GROUP_M, WARP_GROUP_N, WARP_GROUP_K>;
|
||||||
|
|
||||||
// The CTA description for BMM2.
|
// The CTA description for BMM2.
|
||||||
using Cta_tile_o = typename Traits_o::template Cta_padded_tile<STEP_Q, D, STEP_KV, VALID_D, STEP_KV, WARP_GROUP_M,
|
using Cta_tile_o = typename Traits_o::template Cta_padded_tile<STEP_Q, DV, STEP_KV, VALID_DV, STEP_KV, WARP_GROUP_M,
|
||||||
WARP_GROUP_K, WARP_GROUP_N>;
|
WARP_GROUP_K, WARP_GROUP_N>;
|
||||||
|
|
||||||
// The MMA tile for the 1st GEMM.
|
// The MMA tile for the 1st GEMM.
|
||||||
@ -415,9 +438,9 @@ struct Kernel_traits
|
|||||||
// The q, k, v tile buffer.
|
// The q, k, v tile buffer.
|
||||||
using Buffer_q_t = cuda::std::array<Element_data_type, D * STEP_Q * Q_BUFFERS>;
|
using Buffer_q_t = cuda::std::array<Element_data_type, D * STEP_Q * Q_BUFFERS>;
|
||||||
using Buffer_k_t = cuda::std::array<Element_data_type, D * STEP_KV * KV_BUFFERS>;
|
using Buffer_k_t = cuda::std::array<Element_data_type, D * STEP_KV * KV_BUFFERS>;
|
||||||
using Buffer_v_t = cuda::std::array<Element_data_type, D * STEP_KV * KV_BUFFERS>;
|
using Buffer_v_t = cuda::std::array<Element_data_type, DV * STEP_KV * KV_BUFFERS>;
|
||||||
// We need one kv buffer to explicitly transose fp8 smem_tile.
|
// We need one kv buffer to explicitly transose fp8 smem_tile.
|
||||||
using Buffer_v_scratch_t = cuda::std::array<Element_data_type, D * STEP_KV * V_SCRATCH_BUFFERS>;
|
using Buffer_v_scratch_t = cuda::std::array<Element_data_type, DV * STEP_KV * V_SCRATCH_BUFFERS>;
|
||||||
|
|
||||||
// The smem bytes of q, k, v tiles.
|
// The smem bytes of q, k, v tiles.
|
||||||
enum
|
enum
|
||||||
@ -521,6 +544,8 @@ template < // The step size in query sequence dimension (M of BMM1 and BMM2).
|
|||||||
int STEP_KV_,
|
int STEP_KV_,
|
||||||
// The head dimension.
|
// The head dimension.
|
||||||
int D_,
|
int D_,
|
||||||
|
// The head dimension of V.
|
||||||
|
int DV_,
|
||||||
// The number of smem buffers for Q tiles.
|
// The number of smem buffers for Q tiles.
|
||||||
int Q_BUFFERS_,
|
int Q_BUFFERS_,
|
||||||
// The number of smem buffers for K, and V tiles.
|
// The number of smem buffers for K, and V tiles.
|
||||||
@ -554,14 +579,14 @@ template < // The step size in query sequence dimension (M of BMM1 and BMM2).
|
|||||||
// The sage attention block size for Q, K and V
|
// The sage attention block size for Q, K and V
|
||||||
int SAGE_BLOCK_SIZE_Q_ = 0, int SAGE_BLOCK_SIZE_K_ = 0, int SAGE_BLOCK_SIZE_V_ = 0>
|
int SAGE_BLOCK_SIZE_Q_ = 0, int SAGE_BLOCK_SIZE_K_ = 0, int SAGE_BLOCK_SIZE_V_ = 0>
|
||||||
struct Kernel_traits_Hopper_qgmma_e4m3_fp32
|
struct Kernel_traits_Hopper_qgmma_e4m3_fp32
|
||||||
: public Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, Q_BUFFERS_, KV_BUFFERS_,
|
: public Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
|
||||||
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_,
|
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_,
|
||||||
ENABLE_MUTEX_, SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_,
|
ENABLE_MUTEX_, SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_,
|
||||||
RETURN_SOFTMAX_STATS_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>
|
RETURN_SOFTMAX_STATS_, OutputType, SAGE_BLOCK_SIZE_Q_, SAGE_BLOCK_SIZE_K_, SAGE_BLOCK_SIZE_V_>
|
||||||
{
|
{
|
||||||
|
|
||||||
// Base class.
|
// Base class.
|
||||||
using Base = Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, Q_BUFFERS_, KV_BUFFERS_,
|
using Base = Kernel_traits<Hopper_qgmma_e4m3_fp32_traits, STEP_Q_, STEP_KV_, D_, DV_, Q_BUFFERS_, KV_BUFFERS_,
|
||||||
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_, ENABLE_MUTEX_,
|
NUM_COMPUTE_GROUPS_, DMA2COMPUTE_DEPTH_, ATTENTION_MASK_TYPE_, HEADS_INTERLEAVED_, APPLY_ALIBI_, ENABLE_MUTEX_,
|
||||||
SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_>;
|
SCHEDULING_MODE_, INPUT_LAYOUT_, USE_TMA_STORE_, ENABLE_BMM1_SOFTCAPPING_SCALE_>;
|
||||||
|
|
||||||
@ -601,7 +626,7 @@ struct Kernel_traits_Hopper_qgmma_e4m3_fp32
|
|||||||
using Buffer_v_scratch_t = typename Base::Buffer_v_scratch_t;
|
using Buffer_v_scratch_t = typename Base::Buffer_v_scratch_t;
|
||||||
// Extra O buffer if TMA is used for epilogue
|
// Extra O buffer if TMA is used for epilogue
|
||||||
using Element_data_type = typename Base::Element_data_type;
|
using Element_data_type = typename Base::Element_data_type;
|
||||||
using Buffer_o_t = cuda::std::array<Element_data_type, Base::D * Base::STEP_Q * O_BUFFERS>;
|
using Buffer_o_t = cuda::std::array<Element_data_type, Base::DV * Base::STEP_Q * O_BUFFERS>;
|
||||||
|
|
||||||
// The struct of shared memory buffers.
|
// The struct of shared memory buffers.
|
||||||
struct __align__(128) Shared
|
struct __align__(128) Shared
|
||||||
|
|||||||
@ -208,6 +208,8 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba
|
|||||||
// Contiguous kv layout: [B, 2, H, S, D].
|
// Contiguous kv layout: [B, 2, H, S, D].
|
||||||
// Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D].
|
// Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D].
|
||||||
fmha::cudaTmaDesc tma_desc_kv;
|
fmha::cudaTmaDesc tma_desc_kv;
|
||||||
|
// Tma descriptor for v if v_stride_in_bytes not in [0, kv_stride_in_bytes]
|
||||||
|
fmha::cudaTmaDesc tma_desc_v;
|
||||||
// Tma descriptor for o
|
// Tma descriptor for o
|
||||||
fmha::cudaTmaDesc tma_desc_o;
|
fmha::cudaTmaDesc tma_desc_o;
|
||||||
|
|
||||||
|
|||||||
@ -111,6 +111,7 @@ struct Fused_multihead_attention_params_v2
|
|||||||
// Contiguous kv layout: [B, 2, H, S, D].
|
// Contiguous kv layout: [B, 2, H, S, D].
|
||||||
// Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D].
|
// Paged kv layout: [UINT32_MAX, H, Tokens_per_block, D].
|
||||||
fmha::cudaTmaDesc tma_desc_kv;
|
fmha::cudaTmaDesc tma_desc_kv;
|
||||||
|
fmha::cudaTmaDesc tma_desc_v;
|
||||||
// Tma descriptor for o
|
// Tma descriptor for o
|
||||||
fmha::cudaTmaDesc tma_desc_o;
|
fmha::cudaTmaDesc tma_desc_o;
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user