[None][feat] Use Separate QKV Input Layout for Context MLA (#6538)

Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
This commit is contained in:
zhhuang-nv 2025-08-19 22:04:48 +08:00 committed by GitHub
parent 8f95f35503
commit 7e135d2ea7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
101 changed files with 854 additions and 1475 deletions

View File

@ -155,50 +155,41 @@ def test_trtllm_sage_attention_fmha(d, s):
@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
ids=["bf16", "e4m3", "e4m3-bf16"])
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
@pytest.mark.parametrize(
'input_layout', ["", "-paged-kv", "-contiguous-q-kv", "-separate-q-k-v"],
ids=["packed-qkv", "paged-kv", "q-contiguous-kv", "separate-q-k-v"])
def test_trtllm_context_mla_attention_fmha(dtype, s, input_layout):
def test_trtllm_context_mla_attention_fmha(dtype, s):
sm_version = getSMVersion()
if sm_version < 90:
pytest.skip("MLA kernels are only tested on sm90 and above currently.")
# use higher error tolerance for bf16 and s = 4096.
epsilon = ''
if dtype == "-bf16" and s == 4096:
epsilon += ' -epsilon 0.03'
sm_version = getSMVersion()
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89:
pytest.skip("FP8 MLAs only supported on sm89 currently.")
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120:
pytest.skip("FP8 MLAs are only supported on sm120 currently.")
# Context phase kernels.
# Context phase kernels, always use separate-q-k-v layout.
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
-force-non-warp-specialization -causal-mask {epsilon}",
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
f"-causal-mask {epsilon} -separate-q-k-v",
shell=True,
check=True)
if sm_version == 90:
# Now only hopper-style supports separate-q-k-v
# For chunked prefill, we need to enable -save-softmax (dtype: bf16, layout: separate-q-k-v).
# Currently fp8 kernel doesn't support saving softmax.
if dtype == "-bf16":
# padding mask
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
-causal-mask {epsilon} {input_layout}",
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
f"{epsilon} -separate-q-k-v -save-softmax",
shell=True,
check=True)
# causal mask
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
f"-causal-mask {epsilon} -separate-q-k-v -save-softmax",
shell=True,
check=True)
# For chunked prefill, we need to enable -save-softmax (dtype: bf16, sm90, layout: paged-kv or separate-q-k-v).
if dtype == "-bf16" and input_layout in [
"-paged-kv", "-separate-q-k-v"
]:
# padding mask
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
{epsilon} {input_layout} -save-softmax",
shell=True,
check=True)
# causal mask
subprocess.run(
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
-causal-mask {epsilon} {input_layout} -save-softmax",
shell=True,
check=True)
@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
@ -210,14 +201,17 @@ def test_trtllm_context_mla_attention_fmha(dtype, s, input_layout):
"num-grouped-heads-64", "num-grouped-heads-128"
])
def test_trtllm_gen_mla_attention_fmha(dtype, s, num_grouped_heads):
sm_version = getSMVersion()
if sm_version < 90:
pytest.skip("MLA kernels are only tested on sm90 and above currently.")
# use higher error tolerance for bf16 and s = 4096.
epsilon = ''
if dtype == "-bf16" and s == 4096:
epsilon += ' -epsilon 0.03'
sm_version = getSMVersion()
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89:
pytest.skip("FP8 MLAs only supported on sm89 currently.")
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120:
pytest.skip("FP8 MLAs are only supported on sm120 currently.")
# Generation phase kernels.
subprocess.run(

View File

@ -2075,6 +2075,8 @@ def get_kernel_code(kspec, kname, lname):
kernel_traits += '_paged_kv_cache'
elif kspec.input_layout == InputLayout.CONTIGUOUS_Q_KV:
kernel_traits += '_contiguous_kv_cache'
elif kspec.input_layout == InputLayout.SEPARATE_Q_K_V:
kernel_traits += '_q_k_v'
flags = 0
if kspec.ldgsts_q:
@ -3183,7 +3185,7 @@ def get_cubin_header(kernel_traits, specs_names):
attention_mask_type_value = attention_mask_type.value
# Attention input layout:
# packed_qkv (0), contiguous_q_kv (1), q_paged_kv (2).
# packed_qkv (0), contiguous_q_kv (1), q_paged_kv (2), separate_q_k_v (3).
attention_input_layout = InputLayout.PACKED_QKV
if '_q_kv' in kname:
attention_input_layout = InputLayout.CONTIGUOUS_Q_KV
@ -3652,12 +3654,9 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
if alibi and enable_attn_logit_softcapping:
continue
# for normal attention, we only need contiguous kv as input layout when returning softmax.
skip_combination = return_softmax and (input_layout
!= InputLayout.CONTIGUOUS_Q_KV)
# for context mla, we need paged kv or separate qkv as input layout when returning softmax.
skip_mla_combination = return_softmax and (
input_layout != InputLayout.Q_PAGED_KV
and input_layout != InputLayout.SEPARATE_Q_K_V)
skip_combination = return_softmax and input_layout != InputLayout.CONTIGUOUS_Q_KV
# for context mla, we need separate qkv as input layout when returning softmax.
skip_mla_combination = return_softmax and input_layout != InputLayout.SEPARATE_Q_K_V
if not skip_combination:
# only specify
specs.append(
@ -4702,9 +4701,16 @@ def enumerate_hmma_paged_kv_flash_kernels(specs, sm=80, dtype='fp16'):
def enumerate_hmma_flash_kernels(specs, sm=80, dtype='fp16', head_size_v=0):
for (input_layout, enable_attn_logit_softcapping) in \
product([InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV, InputLayout.Q_PAGED_KV], \
[False, True]):
input_layouts = [
InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV,
InputLayout.Q_PAGED_KV
]
# Deepseek MLA (context 192/128 separate-q-k-v)
if head_size_v == 128:
input_layouts.append(InputLayout.SEPARATE_Q_K_V)
for (input_layout,
enable_attn_logit_softcapping) in product(input_layouts,
[False, True]):
enumerate_hmma_flash_kernels_base(specs, sm, dtype, input_layout,
enable_attn_logit_softcapping,
head_size_v)
@ -5080,7 +5086,7 @@ def enumerate_qmma_flash_kernels(specs,
]
input_layouts = [
InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV,
InputLayout.Q_PAGED_KV
InputLayout.Q_PAGED_KV, InputLayout.SEPARATE_Q_K_V
]
for (head_size_params, (q_loop_step, kv_loop_step), tiled), input_layout in \
product(params_q_kv_step, input_layouts):
@ -5094,6 +5100,9 @@ def enumerate_qmma_flash_kernels(specs,
# skip if head_size is not in head_sizes
if head_sizes is not None and head_size not in head_sizes:
continue
# skip if head_size_v is not 128 for separate-q-k-v
if input_layout == InputLayout.SEPARATE_Q_K_V and head_size_v != 128:
continue
specs.append(
kernel_spec(sm=sm,
sm_mma=89,
@ -6354,28 +6363,30 @@ def enumerate_kernels():
and kspec.version == 2
and kspec.cross_mha == False
and kspec.flash_attention == False)
# Deepseek MLA (192/128 packed + 576/512 paged)
or (kspec.sm in [80, 86, 89, 90, 100, 120]
# Deepseek MLA (generation 576/512 paged)
or (kspec.sm in [90, 100, 120]
and kspec.dtype in ['bf16', 'e4m3_fp32']
and (((kspec.head_size, kspec.head_size_v) == (192, 128) and kspec.input_layout in [InputLayout.PACKED_QKV, InputLayout.Q_PAGED_KV])
or ((kspec.head_size, kspec.head_size_v) == (576, 512) and kspec.input_layout == InputLayout.Q_PAGED_KV))
and kspec.head_size == 576
and kspec.head_size_v == 512
and kspec.input_layout == InputLayout.Q_PAGED_KV
and kspec.sage_block_sizes is None
and kspec.version == 2
and kspec.cross_mha == False
and kspec.flash_attention == True
and kspec.warp_specialization == False
and kspec.tiled == True)
# Deepseek MLA (hopper-style context 192/128)
or (kspec.sm == 90
and kspec.dtype == 'bf16'
# Deepseek MLA (context 192/128 separate-q-k-v)
or (kspec.sm in [90, 100, 120]
and kspec.dtype in ['bf16', 'e4m3_fp32']
and kspec.head_size == 192
and kspec.head_size_v == 128
and kspec.input_layout == InputLayout.SEPARATE_Q_K_V
and kspec.sage_block_sizes is None
and kspec.version == 2
and kspec.cross_mha == False
and kspec.flash_attention == True
and kspec.warp_specialization == True
and kspec.alibi == False
and ((kspec.warp_specialization == True and kspec.alibi == False) # sm90
or (kspec.warp_specialization == False and kspec.tiled == True)) # non-sm90
and kspec.enable_attn_logit_softcapping == False)
# SageAttention (warp_spec, head_size in (80, 128), packed QKV, padding mask)
or (kspec.sm == 90

View File

@ -418,7 +418,7 @@ struct Gmem_tile_qkv
////////////////////////////////////////////////////////////////////////////////////////////////////
// We expect the Q layout to be [B, S, H, D] with variable sequence length support.
// We expect the Q/K/V layout to be [B, S, H, D] with variable sequence length support.
template <
// The instruction traits.
typename Traits,
@ -440,7 +440,7 @@ template <
int NUM_MATS = 1,
// Is sliding window attention used ?
bool SLIDING_WINDOW_ATTENTION = false>
struct Gmem_tile_q
struct Gmem_tile_q_k_v
{
// The size of each LDG.
@ -523,22 +523,38 @@ struct Gmem_tile_q
USE_LDGSTS = USE_LDGSTS_
};
// Ctor (keep qkv_offset for compatibility)
// Ctor
// qkv_offset: 0 for Q, 1 for K, 2 for V
template <typename Block_info>
inline __device__ Gmem_tile_q(bert::Fused_multihead_attention_params_v2 const& params, int qkv_offset,
inline __device__ Gmem_tile_q_k_v(bert::Fused_multihead_attention_params_v2 const& params, int qkv_offset,
Block_info const& binfo, int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0)
: Gmem_tile_q(params, binfo, tidx, cta_row_offset, cta_col_offset_in_bytes)
{
}
// Ctor.
template <typename Block_info>
inline __device__ Gmem_tile_q(bert::Fused_multihead_attention_params_v2 const& params, Block_info const& binfo,
int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0)
: params_q_stride_in_bytes_(params.q_stride_in_bytes)
, actual_seqlen_(binfo.actual_q_seqlen)
, q_ptr_(reinterpret_cast<char*>(params.q_ptr))
{
int seq_offset = 0;
if (qkv_offset == 0)
{
// Q tensor
params_q_k_v_stride_in_bytes_ = params.q_stride_in_bytes;
q_k_v_ptr_ = reinterpret_cast<char*>(params.q_ptr);
actual_seqlen_ = binfo.actual_q_seqlen;
seq_offset = binfo.sum_s;
}
else if (qkv_offset == 1)
{
// K tensor
params_q_k_v_stride_in_bytes_ = params.k_stride_in_bytes;
q_k_v_ptr_ = reinterpret_cast<char*>(params.k_ptr);
actual_seqlen_ = binfo.actual_kv_seqlen;
seq_offset = binfo.sum_s_kv;
}
else if (qkv_offset == 2)
{
// V tensor
params_q_k_v_stride_in_bytes_ = params.v_stride_in_bytes;
q_k_v_ptr_ = reinterpret_cast<char*>(params.v_ptr);
actual_seqlen_ = binfo.actual_kv_seqlen;
seq_offset = binfo.sum_s_kv;
}
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW;
@ -550,17 +566,20 @@ struct Gmem_tile_q
// Do not load/store if the thread is in the padded area
col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_LDG;
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
// We won't consider past_q_length when loading from gmem_q.
int64_t row_offset = (int64_t) (row + cta_row_offset) * params_q_stride_in_bytes_;
// Add the block index. (sum_s * h + hidx).
int64_t idx = binfo.bidx;
// The row offset in the batched GEMM, including the sequence offset.
int64_t row_offset = (int64_t) (row + cta_row_offset + seq_offset) * params_q_k_v_stride_in_bytes_;
// Add the head index.
int64_t idx = binfo.bidh;
// Assemble the final pointer.
q_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col_in_bytes_;
q_k_v_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col_in_bytes_;
// Take the CTA offset to modify the sequence length.
actual_seqlen_ -= cta_row_offset;
// Set the initial seq_len and qkv_offset in case of reinterating
actual_seqlen_init_ = actual_seqlen_;
q_k_v_ptr_init_ = q_k_v_ptr_;
}
// Store data to shared memory.
@ -590,7 +609,7 @@ struct Gmem_tile_q
#pragma unroll
for (int ii = 0; ii < LDGS; ++ii)
{
ptrs[ii] = q_ptr_ + (int64_t) ii * ROWS_PER_LDG * params_q_stride_in_bytes_;
ptrs[ii] = q_k_v_ptr_ + (int64_t) ii * ROWS_PER_LDG * params_q_k_v_stride_in_bytes_;
}
// Trigger LDGSTS or the LDGs.
@ -598,10 +617,24 @@ struct Gmem_tile_q
Ldgsts_helper<USE_LDGSTS>::load(this, smem_tile, ptrs, preds);
}
// Move the pointer to the next row location.
inline __device__ void move(int const steps = 1)
{
q_k_v_ptr_ += (int64_t) ROWS * params_q_k_v_stride_in_bytes_ * steps;
actual_seqlen_ -= (int) ROWS * steps;
}
// Move the pointer to the next row location by the offset (not step).
inline __device__ void move_by_offset(int const offset)
{
q_k_v_ptr_ = q_k_v_ptr_init_ + (int64_t) offset * params_q_k_v_stride_in_bytes_;
actual_seqlen_ = actual_seqlen_init_ - (int) offset;
}
// Move the pointer to the next column location
inline __device__ void move_col()
{
q_ptr_ += (int64_t) COLS * (BITS_PER_ELEMENT / 8);
q_k_v_ptr_ += (int64_t) COLS * (BITS_PER_ELEMENT / 8);
// Update col_in_bytes_ to ensure load predicates work
col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG;
}
@ -609,15 +642,29 @@ struct Gmem_tile_q
// Rewind the pointer back to previous column location
inline __device__ void rewind_col(int const steps)
{
q_ptr_ -= COLS * (BITS_PER_ELEMENT / 8) * steps;
q_k_v_ptr_ -= COLS * (BITS_PER_ELEMENT / 8) * steps;
// Update col_in_bytes_ to ensure load predicates work
col_in_bytes_ -= THREADS_PER_ROW * BYTES_PER_LDG * steps;
}
// The stride between rows for the QKV matrice.
int64_t params_q_stride_in_bytes_;
// Move the pointer to the specified step.
inline __device__ void move_to(int const step)
{
q_k_v_ptr_ = q_k_v_ptr_init_ + (int64_t) ROWS * params_q_k_v_stride_in_bytes_ * step;
actual_seqlen_ = actual_seqlen_init_ - (int) ROWS * step;
}
inline __device__ void reset()
{
q_k_v_ptr_ = q_k_v_ptr_init_;
actual_seqlen_ = actual_seqlen_init_;
}
// The stride between rows for the Q/K/V matrice.
int64_t params_q_k_v_stride_in_bytes_;
// The pointer.
char* q_ptr_;
char* q_k_v_ptr_;
char* q_k_v_ptr_init_;
// The register to store predicates.
uint32_t preds_[PRED_REGS];
// The fetch registers.
@ -627,6 +674,7 @@ struct Gmem_tile_q
int64_t col_in_bytes_;
// The sequence length.
int actual_seqlen_;
int actual_seqlen_init_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1015,8 +1015,8 @@ template <
typename OutputType = typename Traits::A_type,
// The sage attention block size for Q, K and V
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0>
using Kernel_traits_v2_paged_kv_cache
= Kernel_traits_<Traits, fmha::v2::Gmem_tile_q, fmha::v2::Gmem_tile_paged_kv, fmha::v2::Gmem_tile_paged_kv,
using Kernel_traits_v2_q_k_v
= Kernel_traits_<Traits, fmha::v2::Gmem_tile_q_k_v, fmha::v2::Gmem_tile_q_k_v, fmha::v2::Gmem_tile_q_k_v,
Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, DV, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS,
2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>;
@ -1049,7 +1049,41 @@ template <
typename OutputType = typename Traits::A_type,
// The sage attention block size for Q, K and V
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0>
using Kernel_traits_v2_contiguous_kv_cache = Kernel_traits_<Traits, fmha::v2::Gmem_tile_q,
using Kernel_traits_v2_paged_kv_cache
= Kernel_traits_<Traits, fmha::v2::Gmem_tile_q_k_v, fmha::v2::Gmem_tile_paged_kv, fmha::v2::Gmem_tile_paged_kv,
Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, DV, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS,
2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
// The instruction traits.
typename Traits,
// The sequence length.
int S,
// The hidden size per head.
int D,
// The hidden dimension of V.
int DV,
// The number of timesteps per iteration of the main loop.
int STEP,
// The number of vertical warps.
int WARPS_M,
// The number of horizontal warps.
int WARPS_N,
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
int CTAS_PER_HEAD,
// The flags.
uint32_t FLAGS = 0x8,
// The attention mask version (see src/mask.h).
int MASK_VERSION = 2,
// Do we use half epilogue for the 2nd GEMM (hmma_fp32)
bool BMM2_FP16_EPILOGUE = true,
// The output type.
typename OutputType = typename Traits::A_type,
// The sage attention block size for Q, K and V
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0>
using Kernel_traits_v2_contiguous_kv_cache = Kernel_traits_<Traits, fmha::v2::Gmem_tile_q_k_v,
fmha::v2::Gmem_tile_contiguous_kv, fmha::v2::Gmem_tile_contiguous_kv,
Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, 0, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS, 2,
MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>;

View File

@ -890,18 +890,17 @@ int main(int argc, char** argv)
{
bool is_MLA = (d == 192 && dv == 128);
if (((!is_MLA) && input_layout != Attention_input_layout::CONTIGUOUS_Q_KV)
|| (is_MLA && input_layout != Attention_input_layout::Q_PAGED_KV
&& input_layout != Attention_input_layout::SEPARATE_Q_K_V))
|| (is_MLA && input_layout != Attention_input_layout::SEPARATE_Q_K_V))
{
fprintf(stderr,
"For normal attention, Only '--contiguous-q-kv' layout supports "
"'-save-softmax'. For MLA only '-paged-kv' and '-separate-q-k-v' layout supports "
"'-save-softmax'. For MLA only '-separate-q-k-v' layout supports "
"'-save-softmax'.\n");
exit(1);
}
if (data_type == DATA_TYPE_E4M3)
{
fprintf(stderr, "Currently fp8 kernel doesn't support fp8.\n");
fprintf(stderr, "Currently fp8 kernel doesn't support saving softmax.\n");
exit(1);
}
}

View File

@ -747,17 +747,19 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
= mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
int const total_v_dim_all_heads
= mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
int const num_total_qkv_elements
= max_num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
// Packed fp8 qkv buffer size for normal fp8 context FMHA
size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
? max_num_tokens * size_t(local_hidden_units_qo + 2 * local_hidden_units_kv)
: 0;
if (mFP8ContextMLA)
// Separate fp8 q/k/v buffer size for fp8 context MLA
size_t fp8_q_buf_size = 0;
size_t fp8_k_buf_size = 0;
size_t fp8_v_buf_size = 0;
if (mEnableContextFMHA && mFP8ContextMLA && mFmhaDispatcher->isSeparateQAndKvInput())
{
fp8_qkv_buffer_size
= mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0;
fp8_q_buf_size = max_num_tokens * static_cast<size_t>(total_q_dim_all_heads);
fp8_k_buf_size = max_num_tokens * static_cast<size_t>(total_k_dim_all_heads);
fp8_v_buf_size = max_num_tokens * static_cast<size_t>(total_v_dim_all_heads);
}
size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
@ -774,7 +776,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
? 0
: (2 * size * cpMaxPaddedSequenceLength * getHeadSize() * (mNumHeads + 2 * mNumKVHeads) + cu_seqlens_size);
int const NUM_BUFFERS = 20;
int const NUM_BUFFERS = 23;
size_t workspaces[NUM_BUFFERS];
workspaces[0] = CUBLAS_WORKSPACE_SIZE;
workspaces[1] = attention_mask_size;
@ -789,13 +791,16 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
workspaces[10] = qkv_buf_2_size;
workspaces[11] = qk_buf_float_size;
workspaces[12] = fp8_qkv_buffer_size;
workspaces[13] = padding_offset_size;
workspaces[14] = encoder_padding_offset_size;
workspaces[15] = tokens_info_size;
workspaces[16] = fmha_scheduler_counter;
workspaces[17] = fmha_bmm1_scale_size;
workspaces[18] = fmha_bmm2_scale_size;
workspaces[19] = cpWorkspaceSize;
workspaces[13] = fp8_q_buf_size;
workspaces[14] = fp8_k_buf_size;
workspaces[15] = fp8_v_buf_size;
workspaces[16] = padding_offset_size;
workspaces[17] = encoder_padding_offset_size;
workspaces[18] = tokens_info_size;
workspaces[19] = fmha_scheduler_counter;
workspaces[20] = fmha_bmm1_scale_size;
workspaces[21] = fmha_bmm2_scale_size;
workspaces[22] = cpWorkspaceSize;
context_workspace_size = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
return context_workspace_size;
@ -979,7 +984,7 @@ int AttentionOp::mlaGeneration(
params.fmha_tile_counter = fmha_tile_counter_ptr;
params.bmm1_scale = mla_bmm1_scale_ptr;
params.bmm2_scale = mla_bmm2_scale_ptr;
params.quant_attention_input_buf = quant_q_buffer_ptr;
params.quant_q_buf = quant_q_buffer_ptr;
params.quant_scale_o = generation_params.attention_output_orig_quant;
params.quant_scale_q = generation_params.kv_scale_orig_quant;
@ -1020,8 +1025,8 @@ int AttentionOp::mlaGeneration(
tllmRunnerParams.mTileScheduler = mMultiBlockMode ? TileScheduler::Static : TileScheduler::Persistent;
// Q buffer.
tllmRunnerParams.qPtr = mFP8GenerationMLA ? reinterpret_cast<void const*>(params.quant_attention_input_buf)
: reinterpret_cast<void const*>(params.attention_input_buf);
tllmRunnerParams.qPtr = mFP8GenerationMLA ? reinterpret_cast<void const*>(params.quant_q_buf)
: reinterpret_cast<void const*>(params.q_buf);
// KV buffer
// Paged KV
@ -1146,9 +1151,8 @@ int AttentionOp::mlaGeneration(
flashMlaParams.scale_softmax = softmax_scale;
flashMlaParams.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
flashMlaParams.q_ptr = mFP8GenerationMLA
? const_cast<void*>(reinterpret_cast<void const*>(params.quant_attention_input_buf))
: const_cast<void*>(reinterpret_cast<void const*>(params.attention_input_buf));
flashMlaParams.q_ptr = mFP8GenerationMLA ? const_cast<void*>(reinterpret_cast<void const*>(params.quant_q_buf))
: const_cast<void*>(reinterpret_cast<void const*>(params.q_buf));
flashMlaParams.k_ptr = kv_cache_buffer.mPrimaryPoolPtr;
flashMlaParams.v_ptr = flashMlaParams.k_ptr;
flashMlaParams.o_ptr = reinterpret_cast<void*>(params.context_buf);
@ -1253,8 +1257,8 @@ int AttentionOp::mlaGeneration(
// fmhaParams.totalKvSeqLen = params.num_tokens;
// Device buffer pointers.
// fmhaParams.qkvPtr = reinterpret_cast<void const*>(params.attention_input);
fmhaParams.qPtr = mFP8GenerationMLA ? reinterpret_cast<void const*>(params.quant_attention_input_buf)
: reinterpret_cast<void const*>(params.attention_input_buf);
fmhaParams.qPtr = mFP8GenerationMLA ? reinterpret_cast<void const*>(params.quant_q_buf)
: reinterpret_cast<void const*>(params.q_buf);
// TODO: add contiguous kv buffer (cross-attention).
fmhaParams.kvPtr = nullptr;
@ -1379,15 +1383,19 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
= mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
int const total_v_dim_all_heads
= mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
int const num_total_qkv_elements
= params.num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
// Packed fp8 qkv buffer size for normal fp8 context FMHA
size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv)
: 0;
if (mFP8ContextMLA)
// Separate fp8 q/k/v buffer size for fp8 context MLA
size_t fp8_q_buf_size = 0;
size_t fp8_k_buf_size = 0;
size_t fp8_v_buf_size = 0;
if (mEnableContextFMHA && mFP8ContextMLA && mFmhaDispatcher->isSeparateQAndKvInput())
{
fp8_qkv_buffer_size
= mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0;
fp8_q_buf_size = params.num_tokens * static_cast<size_t>(total_q_dim_all_heads);
fp8_k_buf_size = params.total_kv_len * static_cast<size_t>(total_k_dim_all_heads);
fp8_v_buf_size = params.total_kv_len * static_cast<size_t>(total_v_dim_all_heads);
}
size_t const padding_offset_size
= mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.input_seq_length;
@ -1424,6 +1432,12 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
float* qk_buf_float_ = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, qk_buf_float_size));
__nv_fp8_e4m3* fp8_qkv_buffer
= reinterpret_cast<__nv_fp8_e4m3*>(nextWorkspacePtr(workspace_byte_ptr, offset, fp8_qkv_buffer_size));
__nv_fp8_e4m3* fp8_q_buf
= reinterpret_cast<__nv_fp8_e4m3*>(nextWorkspacePtr(workspace_byte_ptr, offset, fp8_q_buf_size));
__nv_fp8_e4m3* fp8_k_buf
= reinterpret_cast<__nv_fp8_e4m3*>(nextWorkspacePtr(workspace_byte_ptr, offset, fp8_k_buf_size));
__nv_fp8_e4m3* fp8_v_buf
= reinterpret_cast<__nv_fp8_e4m3*>(nextWorkspacePtr(workspace_byte_ptr, offset, fp8_v_buf_size));
int* padding_offset = mEnableContextFMHA
? nullptr
: reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, padding_offset_size));
@ -1638,16 +1652,18 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
"Found invalid number (NaN or Inf) in " + beforeRopeStr);
}
KVBlockArray mla_context_paged_kv_cache_buffer;
if (mIsMLAEnabled)
{
TLLM_CHECK_WITH_INFO(params.mla_param != nullptr, "MLA param is nullptr");
params.mla_param->cache_type = cache_type;
params.mla_param->cu_q_seqlens = cu_q_seqlens;
params.mla_param->quant_scale_kv = params.kv_scale_orig_quant;
// Set BMM scales for FP8 context computation
params.mla_param->bmm1_scale = fmha_bmm1_scale_ptr;
params.mla_param->bmm2_scale = fmha_bmm2_scale_ptr;
params.mla_param->quant_attention_input_buf = mFP8ContextMLA ? fp8_qkv_buffer : nullptr;
params.mla_param->quant_q_buf = mFP8ContextMLA ? fp8_q_buf : nullptr;
params.mla_param->quant_k_buf = mFP8ContextMLA ? fp8_k_buf : nullptr;
params.mla_param->quant_v_buf = mFP8ContextMLA ? fp8_v_buf : nullptr;
// Set additional scales for context phase
params.mla_param->quant_scale_o = params.attention_output_orig_quant;
params.mla_param->quant_scale_q = params.kv_scale_orig_quant;
@ -1656,37 +1672,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
params.mla_param->dequant_scale_kv = params.kv_scale_quant_orig;
params.mla_param->host_bmm1_scale
= 1 / (mQScaling * sqrt((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim)));
if (mPagedContextFMHA && mPagedKVCache)
if (params.mla_param->latent_cache != nullptr)
{
TLLM_CHECK_WITH_INFO(params.mla_param->context_paged_kv_ptr != nullptr,
"Paged kv cache is not set for MLA context kernel");
TLLM_CHECK_WITH_INFO(params.mla_param->context_kv_cache_block_offsets_ptr != nullptr,
"Paged kv cache block offsets is not set for MLA context kernel");
// build another KVBlockArray for MLA context kernel to read paged kv cache, which is built by the
// PyTorch backend assume the dtype of paged kv cache is the same as the T
auto const elemSize = sizeof(T);
auto const headSize = params.mla_param->meta.qk_nope_head_dim + params.mla_param->meta.qk_rope_head_dim;
// mNumKVHeads is 1 for writing, we use mNumHeads for reading paged kv cache
auto sizePerToken = mNumHeads * headSize * elemSize;
auto maxBlocksPerSeq = params.mla_param->context_paged_kv_max_blocks_per_seq;
TLLM_LOG_DEBUG(
"AttentionOp building KVBlockArray for MLA context kernel, elemSize: %d, headSize: %d, mNumHeads: "
"%d, sizePerToken: %d, batchSize: %d, maxBlocksPerSeq: %d, tokensPerBlock: %d, maxAttentionWindow: "
"%d, "
"sinkTokenLen: %d, canUseOneMoreBlock: %d",
elemSize, headSize, mNumHeads, sizePerToken, params.batch_size, maxBlocksPerSeq, mTokensPerBlock,
params.cyclic_attention_window_size, params.sink_token_length, params.can_use_one_more_block);
mla_context_paged_kv_cache_buffer = KVBlockArray(params.batch_size, maxBlocksPerSeq, mTokensPerBlock,
sizePerToken, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
params.sink_token_length, params.can_use_one_more_block, params.mla_param->context_paged_kv_ptr,
nullptr,
static_cast<KVBlockArray::DataType*>(params.mla_param->context_kv_cache_block_offsets_ptr));
}
else
{
// compute RoPE and set compressed_kv + k_pe by invokeMLARopeContext if not using paged context FMHA
// compute RoPE and set compressed_kv + k_pe by invokeMLARopeContext if latent_cache is not nullptr
invokeMLARopeContext<T, KVCacheBuffer>(*params.mla_param, kv_cache_buffer, stream);
}
if (mFP8ContextMLA)
{
invokeMLAContextFp8Quantize(*params.mla_param, params.total_kv_len, stream);
}
}
else
{
@ -1709,7 +1703,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
mFMHAForceFP32Acc = mFMHAForceFP32Acc || enable_context_fmha_fp32_acc_val == 1;
}
// Unified FMHA runner interface for both packed QKV FMHA, contiguous Q_KV and paged KV FMHA.
// Unified FMHA runner interface for both packed QKV FMHA, contiguous Q_KV, paged KV FMHA, and separate QKV
// FMHA.
// Page KV input layout:
// - q_ptr: [B, S, H, D], which supports variable sequence length
// - paged_kv_cache: paged kv buffer
@ -1721,6 +1716,14 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
// - kv_ptr: [B, S, 2, H, D], which supports variable sequence length
// - cu_q_seqlens: the cumulative query sequence lengths, needed for variable sequence length.
// - cu_kv_seqlens: the cumulative kv sequence lengths, needed for variable sequence length.
//
// Separate QKV input layout (only for context MLA now):
// - q_ptr: [B, S, H, D], which supports variable sequence length
// - k_ptr: [B, S, H_kv, D], which supports variable sequence length
// - v_ptr: [B, S, H_kv, D_v], which supports variable sequence length
// - cu_q_seqlens: the cumulative query sequence lengths, needed for variable sequence length.
// - cu_kv_seqlens: the cumulative kv sequence lengths, needed for variable sequence length.
// - total_kv_len: the total kv sequence length, needed for variable sequence length.
// Construct the fmha params for running kernels.
MHARunnerParams fmhaParams{};
@ -1732,11 +1735,35 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
= (mDenseContextFMHA || isCrossAttention()) ? max_kv_seq_len : params.cyclic_attention_window_size;
fmhaParams.totalQSeqLen = params.num_tokens;
// TODO: set it correctly for contiguous kv buffer (cross-attention).
fmhaParams.totalKvSeqLen = isCrossAttention() ? params.num_encoder_tokens : params.num_tokens;
fmhaParams.totalKvSeqLen = isCrossAttention() ? params.num_encoder_tokens : params.total_kv_len;
// Device buffer pointers.
fmhaParams.qkvPtr = (mFP8ContextFMHA || mFP8ContextMLA) ? reinterpret_cast<void const*>(fp8_qkv_buffer)
: reinterpret_cast<void const*>(attention_input);
fmhaParams.qPtr = reinterpret_cast<void const*>(q_buf_2_);
if (mIsMLAEnabled)
{
// separate QKV input for context MLA
if (mFP8ContextMLA)
{
TLLM_CHECK_WITH_INFO(
mFmhaDispatcher->isSeparateQAndKvInput(), "Separate QKV input is required for fp8 context MLA");
TLLM_CHECK_WITH_INFO(fp8_q_buf != nullptr, "FP8 q buffer is required for fp8 context MLA");
TLLM_CHECK_WITH_INFO(fp8_k_buf != nullptr, "FP8 k buffer is required for fp8 context MLA");
TLLM_CHECK_WITH_INFO(fp8_v_buf != nullptr, "FP8 v buffer is required for fp8 context MLA");
fmhaParams.qPtr = reinterpret_cast<void const*>(fp8_q_buf);
fmhaParams.kPtr = reinterpret_cast<void const*>(fp8_k_buf);
fmhaParams.vPtr = reinterpret_cast<void const*>(fp8_v_buf);
}
else
{
fmhaParams.qPtr = attention_input;
fmhaParams.kPtr = params.k_ptr;
fmhaParams.vPtr = params.v_ptr;
}
}
else
{
fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast<void const*>(fp8_qkv_buffer)
: reinterpret_cast<void const*>(attention_input);
fmhaParams.qPtr = reinterpret_cast<void const*>(q_buf_2_);
}
// TODO: add contiguous kv buffer (cross-attention).
fmhaParams.kvPtr = nullptr;
if (isCrossAttention() && !useKVCache())
@ -1750,15 +1777,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
fmhaParams.packedMaskPtr = params.attention_packed_mask;
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
{
if (mIsMLAEnabled && mPagedContextFMHA && mPagedKVCache)
{
fmhaParams.pagedKvCache = mla_context_paged_kv_cache_buffer;
fmhaParams.qPtr = reinterpret_cast<void const*>(attention_input);
}
else
{
fmhaParams.pagedKvCache = kv_cache_buffer;
}
fmhaParams.pagedKvCache = kv_cache_buffer;
}
fmhaParams.cuQSeqLenPtr = cu_q_seqlens;
fmhaParams.kvSeqLenPtr = decoder_params.seqKVLengths;
@ -2606,6 +2625,8 @@ int AttentionOp::initialize() noexcept
// the wrong kernel, no matter mIsGenerationMLA is true or false
if (mIsMLAEnabled)
{
// Context MLA always use separate_q_k_v layout
fmhaParams.attentionInputLayout = AttentionInputLayout::SEPARATE_Q_K_V;
// Context attention of MLA is different
fmhaParams.numKvHeads = mNumHeads;
fmhaParams.headSize = mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim;
@ -2614,6 +2635,7 @@ int AttentionOp::initialize() noexcept
// attention op and that could fail to create the FmhaDispatcher for context phase.
// Luckily, for deepseek, qk_nope_head_dim is the same as v_head_dim in context phase.
fmhaParams.headSizeV = mMLAParams.qk_nope_head_dim;
fmhaParams.headSizeQkNope = mMLAParams.qk_nope_head_dim;
}
fmhaParams.qScaling = mQScaling;
fmhaParams.attnLogitSoftcappingScale = mAttnLogitSoftcappingScale;

View File

@ -95,6 +95,7 @@ public:
void* host_primary_pool_pointer = nullptr;
void* host_secondary_pool_pointer = nullptr;
int32_t num_tokens = 0;
int32_t total_kv_len = 0;
int32_t max_blocks_per_sequence = 0;
int32_t const* sequence_lengths = nullptr;
int32_t const* context_lengths = nullptr;
@ -128,6 +129,9 @@ public:
// For MLA chunked prefill
void* softmaxStatsPtr = nullptr;
// optional for separate QKV input, currently only used for context MLA
T const* k_ptr = nullptr;
T const* v_ptr = nullptr;
std::string enqueueContextParamsToString() const
{
@ -169,6 +173,7 @@ public:
ss << "host_secondary_pool_pointer: " << this->host_secondary_pool_pointer << std::endl;
ss << "batch_size: " << this->batch_size << std::endl;
ss << "num_tokens: " << this->num_tokens << std::endl;
ss << "total_kv_len: " << this->total_kv_len << std::endl;
ss << "max_blocks_per_sequence: " << this->max_blocks_per_sequence << std::endl;
ss << "workspace: " << this->workspace << std::endl;
ss << "logn_scaling_ptr: " << this->logn_scaling_ptr << std::endl;
@ -179,6 +184,8 @@ public:
ss << "encoder_input_lengths: " << this->encoder_input_lengths << std::endl;
ss << "num_encoder_tokens: " << this->num_encoder_tokens << std::endl;
ss << "softmaxStatsPtr: " << this->softmaxStatsPtr << std::endl;
ss << "k_ptr: " << this->k_ptr << std::endl;
ss << "v_ptr: " << this->v_ptr << std::endl;
return ss.str();
}
};

View File

@ -193,12 +193,10 @@ extern void run_fmha_v2_flash_attention_bf16_64_128_S_qkv_104_tma_ws_sm90(Fused_
extern void run_fmha_v2_flash_attention_bf16_64_64_S_qkv_160_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_64_S_qkv_192_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_softcapping_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_64_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_40_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_48_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
@ -210,13 +208,12 @@ extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_104_tma_ws_sm90
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_160_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_192_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_softcapping_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_64_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_256_S_qkv_32_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_256_S_qkv_40_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_256_S_qkv_48_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
@ -450,14 +447,8 @@ extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm8
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89_cu_cubin[];
@ -666,9 +657,6 @@ extern void run_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softca
extern void run_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_256_softcapping_sm89_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm89_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm89_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm89_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm89_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm89_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
#endif
#ifndef EXCLUDE_SM_80
@ -880,9 +868,6 @@ extern void run_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softca
extern void run_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_256_softcapping_sm80_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm80_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm80_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm80_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm80_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm80_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
#endif
#ifndef EXCLUDE_SM_86
@ -1092,14 +1077,10 @@ extern void run_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softca
extern void run_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_256_softcapping_sm86_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm86_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm86_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm86_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm86_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm86_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
#endif
#ifndef EXCLUDE_SM_100
extern void run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm100_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm100_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm100_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
#endif
@ -1240,8 +1221,7 @@ extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_256_softcapping_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_128_softcapping_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_softcapping_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_32_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
@ -1271,13 +1251,11 @@ extern void run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm120_nl(Fused
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_16_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
extern void run_fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_32_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
@ -1426,14 +1404,8 @@ extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm89_cu_
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89_cu_cubin_len;
@ -1820,7 +1792,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_qkv_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_qkv_256_sliding_or_chunked_causal_tma_ws_sm90_kernel", 196864, 384, 64, 2, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_qkv_256_custom_mask_tma_ws_sm90_kernel", 196864, 384, 64, 3, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_128_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 1, 0, false, true, true, true, false, false, true, false, nullptr},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 2, 0, false, true, true, true, false, false, true, false, nullptr},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_qkv_256_causal_softcapping_tma_ws_sm90_kernel", 196864, 384, 64, 1, 0, false, true, true, true, false, false, true, false, run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_softcapping_tma_ws_sm90},
@ -1833,7 +1804,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_custom_mask_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_kernel", 164096, 384, 64, 0, 1, false, true, true, true, false, false, false, false, nullptr},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_custom_mask_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, false, nullptr},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 1, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90_kernel", 73984, 384, 64, 0, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, 1, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_sliding_or_chunked_causal_tma_ws_sm90_kernel", 73984, 384, 64, 2, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90},
@ -1873,11 +1843,11 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, 1, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_sliding_or_chunked_causal_tma_ws_sm90_kernel", 196864, 384, 64, 2, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_custom_mask_tma_ws_sm90_kernel", 196864, 384, 64, 3, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 1, 2, false, true, true, true, false, false, true, false, nullptr},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 2, 2, false, true, true, true, false, false, true, false, nullptr},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_causal_softcapping_tma_ws_sm90_kernel", 196864, 384, 64, 1, 2, false, true, true, true, false, false, true, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_softcapping_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 196864, 384, 64, 2, 2, false, true, true, true, false, false, true, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_softcapping_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90_kernel", 213248, 384, 64, 0, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90_kernel", 73984, 384, 64, 0, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_custom_mask_softmax_tma_ws_sm90_kernel", 73984, 384, 64, 3, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90},
@ -1887,7 +1857,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_custom_mask_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_softmax_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 0, 1, false, true, true, true, false, false, false, true, nullptr},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_custom_mask_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, true, nullptr},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_causal_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 1, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_softmax_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 0, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 1, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_qkv_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_qkv_32_alibi_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 40, 40, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_qkv_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_qkv_40_alibi_tma_ws_sm90},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 48, 48, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_qkv_48_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_qkv_48_alibi_tma_ws_sm90},
@ -2522,18 +2493,12 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_causal_sm89_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sliding_or_chunked_causal_sm89_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_custom_mask_sm89_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 576, 512, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_causal_output_bf16_sm89_kernel_nl", 65536, 128, 64, 1, 0, false, true, false, true, true, false, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sliding_or_chunked_causal_output_bf16_sm89_kernel_nl", 65536, 128, 64, 2, 0, false, true, false, true, true, false, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_custom_mask_output_bf16_sm89_kernel_nl", 65536, 128, 64, 3, 0, false, true, false, true, true, false, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_causal_output_bf16_sm89_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sliding_or_chunked_causal_output_bf16_sm89_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_custom_mask_output_bf16_sm89_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_output_bf16_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_output_bf16_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 64, 32, 32, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 128, 128, 64, 32, 32, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, nullptr},
{ DATA_TYPE_E4M3, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 64, 32, 32, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, nullptr},
@ -3144,9 +3109,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm89_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm89_nl},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_causal_softcapping_sm89_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm89_nl},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm89_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm89_nl},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm89_nl_tiled},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm89_nl_tiled},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm89_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm89_nl_tiled},
#endif
#ifndef EXCLUDE_SM_80
@ -3756,9 +3718,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm80_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm80_nl},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_causal_softcapping_sm80_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm80_nl},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm80_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm80_nl},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm80_nl_tiled},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm80_nl_tiled},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm80_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm80_nl_tiled},
#endif
#ifndef EXCLUDE_SM_86
@ -4368,14 +4327,11 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm86_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm86_nl},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_causal_softcapping_sm86_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm86_nl},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm86_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm86_nl},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm86_nl_tiled},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm86_nl_tiled},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm86_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm86_nl_tiled},
#endif
#ifndef EXCLUDE_SM_100
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm100_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm100_nl_tiled},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm100_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm100_nl_tiled},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_kernel_nl_tiled", 81920, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_sm100_kernel_nl_tiled", 81920, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm100_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm100_nl_tiled},
#endif
@ -4784,8 +4740,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm120_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_128_softcapping_sm120_nl},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_causal_softcapping_sm120_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_softcapping_sm120_nl},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm120_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_softcapping_sm120_nl},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm120_nl_tiled},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm120_nl_tiled},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_kernel_nl_tiled", 81920, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm120_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm120_nl_tiled},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 128, 128, 32, 32, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm120_kernel_nl", 12288, 128, 128, 0, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm120_nl},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 128, 128, 32, 32, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_causal_sm120_kernel_nl", 12288, 128, 128, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm120_nl},
@ -4874,8 +4830,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_causal_sm120_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm120_nl},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sliding_or_chunked_causal_sm120_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm120_nl},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_custom_mask_sm120_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm120_nl},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm120_nl_tiled},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm120_nl_tiled},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_causal_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled},
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 576, 512, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm120_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm120_nl_tiled},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_causal_output_bf16_sm120_kernel_nl", 65536, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm120_nl},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sliding_or_chunked_causal_output_bf16_sm120_kernel_nl", 65536, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm120_nl},
@ -4883,8 +4839,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_causal_output_bf16_sm120_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm120_nl},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sliding_or_chunked_causal_output_bf16_sm120_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm120_nl},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_custom_mask_output_bf16_sm120_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm120_nl},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm120_nl_tiled},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm120_nl_tiled},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_causal_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled},
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm120_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm120_nl_tiled},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 128, 128, 16, 16, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_16_causal_sm120_kernel_nl_tiled", 16384, 128, 128, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_16_sm120_nl_tiled},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 128, 128, 16, 16, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_16_sliding_or_chunked_causal_sm120_kernel_nl_tiled", 16384, 128, 128, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_16_sm120_nl_tiled},

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c709dce149c0f4500539e495c90d1da2d86cec28c4187ee9494b015642e158cf
size 363441

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b9170581da010aca67f4bafd9f6f59aaaf5fd1958a1fdd336aa208146599ac06
size 1094770

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2147a246067f7ea74ca382fbc8c02a26332479e5205ecfbe08fb84161a3a87ec
size 1483888

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:279bd48b8ac53690bb4e37dffbe9060428db80c1417ff29c6f4d4a10ab35a7c9
oid sha256:f7cd70cc37451a7b7a43679dad30ef15d1cd0017762cb716ec412a4ebe0c3e1a
size 700094

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:db5d186ce70d7a94cae2b6619b3449ca557903944beba1ee738d2ee425792d74
oid sha256:3d4f0a4e3d19dec07331ea48e38fc0f25beef3c0e29e4688dca5ba488c55ec54
size 652718

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:089a98cf8ab0bbd7530e69821c42220ea02578b740bff62a3e6e33de45209114
oid sha256:f2b44305b58da85faac69dd59921a8ff889174f690ae89b2dcddc7d704046a51
size 416335

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1f0cc486ec5e9c1720f495a2a5e7c26d42e737694d307d4746a08b6ead5cc225
oid sha256:48eed98ece216ad1e339949020d5d1e99af3ac4893ec6f502ed8f669fa91f88a
size 1197394

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:398965e34c1a4c747b42d8836c04934daaa43903b7931586ed12120e17a61f76
oid sha256:7a980c264dbab18c9b528d28e2d5887818aab94d1e1097fe0f56a411f41ff3a7
size 1672548

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:77cbd7d45164d24be73e021bc0a8745b4f021e4369a254e216ee00b36d3c7263
oid sha256:c15693202fa72a88bf2ee7a1fe742238909988e0d57744a67a513bd921506ac2
size 366593

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8e26f3b8cc173301b3cf07ba1ca7893b6f140432410b0b298361ecff597604c2
oid sha256:826b74c39f5e59e600caa36d926b6ace29a7e46ba2d1a8cf2fe153f993f80dba
size 1095556

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:32220d11bc3542e9edcc36d51b4866bf40044213114d7e237e003afc1fc7c464
oid sha256:232b414a2ae4a7db0eac36c90b2345c4a353cd34c798501b9407b926f2d356ec
size 1478358

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3ee5ae75df4866d848e90616562345d3740b17b68c90f06329dc074dba5217a9
size 482709
oid sha256:bb1231035d1664f4e297b4f9791e2faba45a72cd32395a43a6746b7122477f2c
size 480341

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3e1ecaa635067924b692b665241d86e1d8c1d60a19290de7adde1ff2ca7dbeb0
oid sha256:6471f6d9d5202376d80c8f7c4120a566957a5898b360ae547a77624fa7870251
size 956612

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d3018c622303f89c6f22f037ec99eaeaeea9cfe8911e22463b48a22c13116805
oid sha256:a9ec6512514e4ed352ded776ae591b8cfbfc23b6414eeedc3861b5a47141eb4d
size 592357

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a7a381f2855236f418a40124a5254401c95001d5e15c074a704e22cc7ed89aa2
oid sha256:d6c7c58e214c5dc789bfa5ab42846664723eccfe603a50961068c4c4db35d846
size 1818600

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9bb49ace4dedc4faa3de2b9c22e09db0f3990129ce7ab4afb6419c38a5d48a16
oid sha256:3dce2bfb8e79278b80f5e3f77dac6949b9f763c7dc5a80910fd3ef361aba5955
size 2427152

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9769d7cb9754718798be515c84c45ff48e43322573f3f12e31c2e42e99d8dbd4
oid sha256:a0bcf2b7464ea2873c9d8e74884df7977eda39e2df6acc203534290cdd7e0892
size 557613

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:134f4a73e0e6b02b717319ec49e3b3ea0a585cad385a1f300e6c5761f12de9d7
oid sha256:40b93e65748ccf03c381dab7844480e2b89b1c3c808a7c33e94cf2842c432256
size 671320

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7935b0f053a79a7e620c0efe274fa5b4c840fc9c6e439a381c4d380446e1cb68
oid sha256:d51fcaad4d1f2d094baf94f66c0cd1e4d24322da546e328436c7a60e0dc37823
size 1744388

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:74ecbbaa19b2efe97a3b12c488f0e03c2102f16c460239df4bfc19976fc4365e
oid sha256:5277b7be251586814d4a2cd9e1de1619279e7668d0081faf153056267cd0f350
size 2266902

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:813265d25709bd2d39982efbaf092c9163b124bd990fccab505b3c22134522aa
size 595585
oid sha256:5cd5e5880a553637230aeb78eedc765afb4f8cd8abd05703579e2545452f80c9
size 593217

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:dd36195c01bf7c2a2013d5f31d2e74c2579c471385d7b45be7e35ea2f0652608
size 908162
oid sha256:bcd4344770e379f65fde20d25d20d3f7854aaf968b9a299b33658a1995ea5e32
size 905004

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:31d4d6dca68c4632d1f435e9179582cfe2ad7a75ee0f7625ee67b0044c914f10
size 1371512
oid sha256:315cabc45bc7a8290c6a8a12d7b750154e94eee920811401f215763e8ce719eb
size 1366776

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6570d3ee7b651dec797e82b31eb21fd3261c6e2639fb7c9b157f251bf98bb3bf
size 1419662
oid sha256:f20388cef55675a265e790be3cded63d8999373d8eb19386d0c7bbea432381da
size 1417294

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:88b972677c5436b90fe85870278e3b23d6f709608f99295bddf0be3861d95d1a
size 1419662
oid sha256:dde99f0026396063b68e94063c69e7fd799284da02152d308f68e2e728c46e8f
size 1417294

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d975f605d62c3070d6cf72f6114d98642c520e66989ed2d2845c3213e921ebf7
size 1965880
oid sha256:87f8d1d345231bf20d7b4553e92fa1f52b8ca1694c0da7535867b5105aa2c063
size 1961144

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7c2d7ab0692de5405b26d19a0c57d720285366ac12a8550bbabca1613cce7f0c
size 305897
oid sha256:b1290d40043da35b674f832685cf9f4c0c0534002298b5187c14bf7d614ecd24
size 302741

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:91a26adfddc0bcaf8b42249f59f1a0b9f74be0f82c7378fe4b56f3a2fa3d4bf1
size 290109
oid sha256:67695ec794f5746b1757e273f8eaceb974d57e91ed3063e3873ddc0d144d46f1
size 288531

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6ef79c9e2e2d8bba55d7803dc8dc147b5d8babc29e906a43407a8722bbd8d939
size 498507
oid sha256:a8c253eafb26d52f79de54a9856be97b988e4adc8e6824ecb50b4999ff3f9607
size 496139

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0eef025f8e8581868b02bcea37ff225afebcbb2966450fb29fb0e32ac54eccd4
size 668214
oid sha256:6f45790b8f859c6ccdc1c848f6321b583bc06e3a3b93681e066edc588b990170
size 667426

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:abb2857ffb85cc36aae90ebb674635dffee2b2c5f7ad1ea81bb8002b65d5a0f8
size 711628
oid sha256:465a23bd4c7604cfd8d8a78b1f117e1d45172d3fe9e0d59804b3c82ed1283ebf
size 703734

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:49a3661535314b139e2794fe16f6f3e0a8d45742b68ea59ba99a9113068adf2c
size 752698
oid sha256:cb808cb241cb58f5c98a2f3de87797799a44021020a5537c1fbb1a3c84f7f416
size 749540

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d76fb6c4f8bb2de687bc5f9f275389356934119c1f0db9983dcf0ec7b68c6197
size 748726
oid sha256:c17d374897fd92df92adcc717b2b17b2781ea0cfc8f3be63f160aff078ab5ca3
size 746358

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:be8ee89f4489c430d0ff6e9c6cf4e07379ac05abf468d47e34e084ad594b2037
size 946060
oid sha256:0adf8ae7688e2613eef57d110d7185fd0267d9d97d93b57f3bb9f67dcacf2127
size 943692

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:aa4be8ca2dd52e56c9a6af76b90ac353d217fad5fa931b21129ac5a811b5283a
size 489823
oid sha256:8ee19a3d57b2795b547c5f5e0220313f3b8a59afaa56a29610c4a444f106ece3
size 487455

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cb0482b768a40bc7f8a86fa23a84bab62fb82c205f3237ff60becda50cbafc90
size 489823
oid sha256:c8b1ebcea7fcf90c2a48ef118cb9c58294aded13d38ec682acefb414e107b99e
size 487455

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:95b1796f4e7c905eca82ed3691427025f68e765797440b962b0114a5ab32b1d7
size 500083
oid sha256:63b1556854d992884134d26dbdfb717661ce85056f51387b9ada2ecb325bd578
size 497715

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8f685b6b2a0a573953f31fad89fa37e949361db245de69c0c06ce0bbb14eacef
size 443285

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:834f0f3601c589893a21b957be2864df594f96b34b2cfd6018ada8319986aa21
size 441683

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3d81a070e7ed49f1e1a322d38a757a3505186cf5cbded99814e950e07229a46a
size 298049

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b9de5bc49d888699da1880d24ccf6a9cb6c0049d7a244d1ae9ab64b7365ecd5a
size 296445

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e30ed0df4b0d0b1da1ace5831dc0a7a526e04001b25860f862345c78acff5a43
size 427485

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:030015dc1811e3dc2ae36ed770f51063a3f46deae42ead5e1523c977b438a133
size 425883

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6921a204892e1336cef2a308be38855f3c888e56bd6a16752d2806aa9e93c431
size 1524634
oid sha256:b74e330f275a99c8ba94c5eaa600c24b5c8beb589bf95c242b81dab04a49db98
size 1523844

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:200df98fb2fcc734e8fc012c98c5d78c2061e5718eef6ffd50c2358a3d664197
size 406065

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:430194fe07e526ad01a1e0fb43273b240c269215b132c9af248ba386dcbda23e
size 1124766

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:53a07904a7bfbf82380c96af99c5e24bc86f77906c5d6fdc85ef9720639d76d2
size 1569136

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1ce4d27b11fee3e5f6489510b55613177e174660b6c7a6fb4efed862b62c50d7
oid sha256:3956c73db35ea3988aa0bdf3798c388fe35448918c7a1ae5f2b7783b8cdf17f3
size 731668

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3992d7bd34e72089c5cffc4fc6de3f70a3995145b989811f83b00b47c96b5159
oid sha256:6d7f601b937eb5007c507655fd0a5e0e3788d230f35219c792d1f35580c29e97
size 681924

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:521417177fc0447809c07ff86b58725fedbf1a6b9412ace4c50268a20bc2680d
oid sha256:d2f062ee799ae89f394ed5092d9adb557510e51b9474a1535a3cd2548f32f923
size 447119

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cb063c946558e6928faabb85df9775fecd2b9444b40b3e06cf0f863db80a5ad8
size 1242842
oid sha256:5350484be4826fdc7da6bb03d96421158afa7423bca7569234bd887564bee003
size 1240474

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:31e6b7442b277f5206cc1d70fa6021f36170265b311106281e88b4611d1a5b6b
oid sha256:78698afcaaf4eb325f240ef4ff512798c321394793b444e636a316a2dad496bc
size 1220284

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c1342769efa91794d5bd35ac623b3014738b075b2671441668e2f0d5c1eef78a
oid sha256:f542e4eb88c6040c96d23d6e1ab50b9a2d6da5eab64c9fa20e792b09bf4ac951
size 1739642

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a49dd8abcca57a64eb2ab4e00e4e0d26edf68488fb67086a4b466f8e6651522e
oid sha256:b2d4aec095c9e9763484987e868acf2182cd35fd6ec8254acc90057ebcf028fa
size 410007

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a7d4526887fe860e0d9c482fc7fe2cfe646c7a20bc8a0813ce33a01fd9cc733c
oid sha256:4048a1adcd670df5dea695a0b1a09e73629ecf0e430fab9a4529f4cc5695869b
size 1125550

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b880e78ffc354edb541bd612e543dd894843fc4163f7bd65ce53282892381b8a
oid sha256:85cf87e375d3b05f47f57b67f48c35bf516c60286a16ddac32fd1b914b74ba27
size 1566764

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b332d4c6047c98b504cd3be72cc5028d240621c8e0a3260d64c17804982104db
size 365029

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a16c23767a2e5efbd7330728ed87af2ec62a7731debe1da557705c6db6d3268e
size 1096360

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:66950bc137b734d509f0574152bcf9cf7efcb17a7483450d5fdbf480e9f83001
size 1486266

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bba586d9fe487c49cef2abfbfb0a078dde907d28e04b4d2335018cdb7031879c
oid sha256:6a30185801336d52c40d06b41d631ed6651d1db563cd06edd7534deedb78e3f0
size 701682

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d3e45ab30e471f4649807f5b7640512e2c6678cf623cadfcb26c93eb4ad60ec0
oid sha256:bf706aecad7cd6177ae318723b1c55f4f9108e960f50540e3538eaaf24218633
size 654306

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1932937b7f4ad0370341c77a03db133dd676bdf844b13eb45ec10243d1dfd16b
oid sha256:5ba643582110007f29bbb03fd2bc34243255b4bc0d24355448249ae7fe7374ba
size 417135

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c11f5d464b0486023b78babfdfe9d2768e4b0d13caeb436d6f73110ede72498c
oid sha256:e2468e449b0361230e724b5551bd1c6d899bdd748438e7d47a3007dc369ce383
size 1198982

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3bac9b40302bbfc6ee5a49e5c45d3238f46cff45619acd1b098d90e758d3ce30
oid sha256:014e24c9f00859db417ed48d9372fde79a191d559268d069a9c0dfe4b44e15ec
size 1675716

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:26f09ab86b52c40b283652e555f677850f00902151d17e375e016b9a99a97794
oid sha256:ca2568bf3ac5fd23c74d739cb948a465d0f7d8cacd40e880b6f3c51f4f7ee30f
size 368183

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9d0cf59a8114940070448d87d02d9e83d53bb371ca9915c3983e03626d17024e
oid sha256:b400cd55b4ac4832a4160d3f51fe42d21ebf0a840b99ff937ed11fcf0e2994e5
size 1097144

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ff1449b6795f5beda0b6a62e8a1171ce952b07c4e63b607c06f5fedddb2debe9
oid sha256:3ae97a5d9070592673ffcb80f2e710f00edc6c89564086814dc53739fd6395c0
size 1480736

View File

@ -179,6 +179,31 @@ void FusedMHARunnerV2::setupKernelParams(MHARunnerParams runnerParams)
// Thus, v_stride_in_bytes always equals to k_stride_in_bytes so far.
mKernelParams.v_stride_in_bytes = mKernelParams.k_stride_in_bytes;
}
else if (mFixedParams.attentionInputLayout == AttentionInputLayout::SEPARATE_Q_K_V)
{
// Separate QKV input layout, [total_kv_seqlen, H_KV, D] + [total_kv_seqlen, H_KV, DV]
TLLM_CHECK_WITH_INFO(runnerParams.kPtr != nullptr && runnerParams.vPtr != nullptr,
"SEPARATE_Q_K_V requires valid K and V pointers.");
mKernelParams.k_ptr = runnerParams.kPtr;
mKernelParams.v_ptr = runnerParams.vPtr;
// Tensor K is contiguous.
mKernelParams.k_stride_in_bytes
= get_size_in_bytes(mFixedParams.numKvHeads * mFixedParams.headSize, mFixedParams.dataType);
if (mFixedParams.headSizeQkNope > 0 && mFixedParams.dataType != DATA_TYPE_E4M3)
{
// Non-FP8 context MLA: tensor V is not contiguous. The token stride is numKvHeads * (headSizeQkNope +
// headSizeV).
mKernelParams.v_stride_in_bytes = get_size_in_bytes(
mFixedParams.numKvHeads * (mFixedParams.headSizeQkNope + mFixedParams.headSizeV),
mFixedParams.dataType);
}
else
{
// Tensor V is contiguous for other cases.
mKernelParams.v_stride_in_bytes
= get_size_in_bytes(mFixedParams.numKvHeads * mFixedParams.headSizeV, mFixedParams.dataType);
}
}
}
mKernelParams.o_ptr = runnerParams.outputPtr;
@ -464,13 +489,12 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams)
{
bool isHopperBF16ContextMLA = (mFixedParams.headSize == mFixedParams.headSizeV + 64) && isSm90
&& mFixedParams.dataType == DATA_TYPE_BF16 && mFixedParams.headSizeV == 128;
// TODO: add support for separate QKV input layout
mLaunchParams.supportReturnSoftmaxStats = (runnerParams.softmaxStatsPtr != nullptr
&& mLaunchParams.flash_attention && mLaunchParams.warp_specialization
&& ((!isHopperBF16ContextMLA
&& mLaunchParams.attention_input_layout == AttentionInputLayout::Q_CONTIGUOUS_KV)
|| (isHopperBF16ContextMLA
&& (mLaunchParams.attention_input_layout == AttentionInputLayout::Q_PAGED_KV))));
&& (mLaunchParams.attention_input_layout == AttentionInputLayout::SEPARATE_Q_K_V))));
}
}
@ -623,6 +647,12 @@ void FusedMHARunnerV2::setTmaDescriptors(MHARunnerParams runnerParams)
k_ptr = reinterpret_cast<char const*>(mKernelParams.kv_ptr);
v_ptr = k_ptr + h_kv * d_in_bytes;
}
else if (layout == AttentionInputLayout::SEPARATE_Q_K_V)
{
// Layout: [total_kv_seqlen, H_KV, D] + [total_kv_seqlen, H_KV, DV]
k_ptr = reinterpret_cast<char const*>(mKernelParams.k_ptr);
v_ptr = reinterpret_cast<char const*>(mKernelParams.v_ptr);
}
Multiple_tma_descriptor<3> kv_tma_descriptor;
// K

View File

@ -81,7 +81,10 @@ enum class AttentionInputLayout
// Q has contiguous [B, S, H, D] layout, while paged KV has [B, 2, Max_blocks_per_seq] layout
// that contains paged block indices. The indices indicate the block offset to the pool ptr in
// global memory
Q_PAGED_KV
Q_PAGED_KV,
// Q has contiguous [B, S, H, D] layout, while K has contiguous [B, S, H_kv, D] layout, and V has
// contiguous [B, S, H_kv, D_v] layout. Only used for context MLA now.
SEPARATE_Q_K_V,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -115,6 +118,8 @@ struct MHARunnerFixedParams
int headSize;
// The head size of V.
int headSizeV = 0;
// The head size of Q/K non-RoPE part, only used for MLA now.
int headSizeQkNope = 0;
// The scaling applied to bmm1_scale.
float qScaling;
// The attention logit softcapping scale.
@ -166,6 +171,7 @@ struct MHARunnerFixedParams
case AttentionInputLayout::PACKED_QKV: output += "packed_qkv"; break;
case AttentionInputLayout::Q_CONTIGUOUS_KV: output += "q_contiguous_kv"; break;
case AttentionInputLayout::Q_PAGED_KV: output += "q_paged_kv"; break;
case AttentionInputLayout::SEPARATE_Q_K_V: output += "separate_q_k_v"; break;
default: output += std::to_string(static_cast<int>(attentionInputLayout)) + " (unknown)"; break;
}
@ -255,6 +261,10 @@ struct MHARunnerParams
void const* qPtr;
// The contiguous Kv buffer ptr;
void const* kvPtr;
// The K buffer ptr (for separate K input).
void const* kPtr;
// The V buffer ptr (for separate V input).
void const* vPtr;
// The paged kv cache array.
KVBlockArray pagedKvCache;
// The output buffer ptr.

View File

@ -36,6 +36,10 @@ QkvLayout AttentionInputLayoutToQkvLayout(AttentionInputLayout layout)
{
return QkvLayout::PagedKv;
}
else if (layout == AttentionInputLayout::SEPARATE_Q_K_V)
{
return QkvLayout::SeparateQkv;
}
TLLM_CHECK_WITH_INFO(false, "Unexpected AttentionInputLayout");
return QkvLayout::SeparateQkv;
}
@ -148,6 +152,10 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
maxBlocksPerSeq = pagedKvCache.mMaxBlocksPerSeq;
numTokensPerBlock = pagedKvCache.mTokensPerBlock;
}
else if (mFixedParams.attentionInputLayout == AttentionInputLayout::SEPARATE_Q_K_V)
{
qkvLayout = kernels::QkvLayout::SeparateQkv;
}
TllmGenFmhaRunnerParams tllmRunnerParams;
memset(&tllmRunnerParams, 0, sizeof(tllmRunnerParams));
@ -161,8 +169,8 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
tllmRunnerParams.mMultiCtasKvMode = false;
tllmRunnerParams.qPtr = runnerParams.qPtr;
tllmRunnerParams.kPtr = nullptr;
tllmRunnerParams.vPtr = nullptr;
tllmRunnerParams.kPtr = runnerParams.kPtr;
tllmRunnerParams.vPtr = runnerParams.vPtr;
tllmRunnerParams.kvPtr = kvPoolPtr;
tllmRunnerParams.qkvPtr = runnerParams.qkvPtr;
tllmRunnerParams.attentionSinksPtr = runnerParams.attentionSinksPtr;
@ -181,6 +189,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
// Assume same headDim for Qk and V here.
tllmRunnerParams.mHeadDimQk = mFixedParams.headSize;
tllmRunnerParams.mHeadDimV = mFixedParams.headSizeV;
tllmRunnerParams.mHeadDimQkNope = mFixedParams.headSizeQkNope;
tllmRunnerParams.mNumHeadsQ = mFixedParams.numQHeads;
tllmRunnerParams.mNumHeadsKv = mFixedParams.numKvHeads;
tllmRunnerParams.mNumHeadsQPerKv = tllmRunnerParams.mNumHeadsQ / tllmRunnerParams.mNumHeadsKv;
@ -202,6 +211,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
// For mla chunked prefill
tllmRunnerParams.softmaxStatsPtr = reinterpret_cast<float2*>(runnerParams.softmaxStatsPtr);
tllmRunnerParams.stream = runnerParams.stream;
mTllmGenFMHARunner->run(tllmRunnerParams);
}
else

View File

@ -71,26 +71,6 @@ struct loadChunkedKVKernelTraits
static constexpr int kKVThreadPerHead = (kLoraSize * kBytesPerElem) / kBytesPerLoad;
};
template <typename T>
struct setChunkedKVKernelTraits
{
using VecT = uint4;
static constexpr int kQKNopeSize = 128;
static constexpr int kVHeadSize = 128;
static_assert(kQKNopeSize == kVHeadSize);
static constexpr int kRopeSize = 64;
static constexpr int kHeadSize = kQKNopeSize + kRopeSize;
static constexpr int kBytesPerElem = sizeof(T);
static constexpr int kBytesPerLoad = 16;
static constexpr int kElemPerLoad = kBytesPerLoad / kBytesPerElem;
static_assert((kHeadSize * kBytesPerElem) % kBytesPerLoad == 0,
"kHeadSize * kBytesPerElem must be multiple of kBytesPerLoad (16Bytes)");
static constexpr int kThreadPerHead = (kHeadSize * kBytesPerElem) / kBytesPerLoad;
static constexpr int kKVThreadPerHead = (kQKNopeSize * kBytesPerElem) / kBytesPerLoad;
static constexpr int kCpTokenPerBlock = 16;
static constexpr int kBlockSize = kThreadPerHead * kCpTokenPerBlock;
};
template <typename SrcType, int NUM>
inline __device__ void quantCopy(
__nv_fp8_e4m3* dst_global_ptr, SrcType const* src_fragment_ptr, float const scale_val = 1.f)
@ -311,76 +291,6 @@ __global__ void loadChunkedKVCacheForMLAKernel(T* output_kv_ptr, T* output_k_pe_
}
}
// in the most of cases, chunk_size = max_seq_len
// output_kv {B, 2, ceil(max_seq_len / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}, padding with
// zero
// kv {token_size = B*chunked_unit_size, 2, H=128, uncompressed_h=128}, k_pe {token_size = B*chunked_unit_size, h=1,
// rope_h}
// cu_seq_lens {batch + 1}, fake cu_seq_len, for chunked prefill is {0, chunk_size, chunk_size * 2 ....}
template <typename T>
__global__ void setChunkedKVCacheForMLAKernel(T* output_kv, T const* kv, T const* k_pe, int const max_seq_len,
int const num_heads, int uncompressed_head_size, int rope_size, int64_t const* cu_seq_lens,
int kv_cache_tokens_per_block)
{
using KT = setChunkedKVKernelTraits<T>;
int const batch_idx = static_cast<int>(blockIdx.y);
int const head_idx = static_cast<int>(blockIdx.z);
int const head_dim_vec_idx = (threadIdx.x % KT::kThreadPerHead);
int const head_dim_idx = head_dim_vec_idx * KT::kElemPerLoad;
bool const is_valid_kv = head_dim_idx < KT::kQKNopeSize;
int64_t const global_token_offset = cu_seq_lens[batch_idx];
int64_t const cache_kv_len = cu_seq_lens[batch_idx + 1] - cu_seq_lens[batch_idx];
int const kv_cache_block_num = (max_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
int const kv_cache_block_size = num_heads * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size);
int64_t const offset_for_kv_in_mem_pool = kv_cache_block_num * kv_cache_block_size;
int64_t const kv_offset = num_heads * uncompressed_head_size;
size_t const seq_len_loop_end = cache_kv_len;
for (int local_token_idx = (threadIdx.x / KT::kThreadPerHead) + blockIdx.x * KT::kCpTokenPerBlock;
local_token_idx < seq_len_loop_end; local_token_idx += gridDim.x * KT::kCpTokenPerBlock)
{
if (local_token_idx >= cache_kv_len)
{
break;
}
if (is_valid_kv)
{
int64_t ld_kv_global_offset
= int64_t(global_token_offset + local_token_idx) * 2 * num_heads * uncompressed_head_size
+ head_idx * uncompressed_head_size;
int64_t ld_kv_local_offset = head_dim_vec_idx;
auto k_data = (reinterpret_cast<typename KT::VecT const*>(kv + ld_kv_global_offset))[ld_kv_local_offset];
auto v_data = (reinterpret_cast<typename KT::VecT const*>(
kv + kv_offset + ld_kv_global_offset))[ld_kv_local_offset];
int64_t st_k_global_offset = int64_t(batch_idx) * 2 * offset_for_kv_in_mem_pool
+ local_token_idx / kv_cache_tokens_per_block * kv_cache_block_size
+ head_idx * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
+ (local_token_idx % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size);
int64_t st_v_global_offset = st_k_global_offset + offset_for_kv_in_mem_pool;
int64_t st_k_local_offset = head_dim_vec_idx;
int64_t st_v_local_offset = head_dim_vec_idx;
(reinterpret_cast<typename KT::VecT*>(output_kv + st_k_global_offset))[st_k_local_offset] = k_data;
(reinterpret_cast<typename KT::VecT*>(output_kv + st_v_global_offset))[st_v_local_offset] = v_data;
}
else
{
// rope h = 1
int64_t ld_rope_global_offset = int64_t(global_token_offset + local_token_idx) * rope_size;
int64_t ld_rope_local_offset = head_dim_vec_idx - KT::kKVThreadPerHead;
auto rope_data
= (reinterpret_cast<typename KT::VecT const*>(k_pe + ld_rope_global_offset))[ld_rope_local_offset];
int64_t st_rope_global_offset = int64_t(batch_idx) * 2 * offset_for_kv_in_mem_pool
+ local_token_idx / kv_cache_tokens_per_block * kv_cache_block_size
+ head_idx * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
+ (local_token_idx % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size);
int64_t st_rope_local_offset = head_dim_vec_idx;
(reinterpret_cast<typename KT::VecT*>(output_kv + st_rope_global_offset))[st_rope_local_offset] = rope_data;
}
}
}
} // namespace
namespace tensorrt_llm
@ -427,26 +337,6 @@ void invokeMLALoadChunkedKV(T* output_kv_ptr, T* output_k_pe_ptr, KVBlockArray c
kv_cache, cu_ctx_chunked_len, chunked_size, chunked_idx, kv_scale_quant_orig_ptr);
}
// output_kv {B, 2, ceil(chunked_size / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}, padding with
// zero
// kv {total_token, 2, H, uncompressed_h=128} 0 for k and 1 for v, k_pe {total_token, h=1, rope_h}
// input kv and k_pe can be cached tokens or uncached tokens
template <typename T>
void invokeMLASetChunkedKV(T* output_kv, T const* kv, T const* k_pe, int const batch_size, int const max_seq_len,
int const num_heads, int uncompressed_head_size, int rope_size, int64_t const* cu_seq_lens,
int const kv_cache_tokens_per_block, cudaStream_t stream)
{
using KT = setChunkedKVKernelTraits<T>;
TLLM_CHECK_WITH_INFO(
uncompressed_head_size + rope_size == KT::kHeadSize, "head dim should be equal to %d", KT::kHeadSize);
TLLM_CHECK_WITH_INFO(kv_cache_tokens_per_block % KT::kCpTokenPerBlock == 0,
"kv_cache_tokens_per_block should be multiple of %d", KT::kCpTokenPerBlock);
dim3 grid(tensorrt_llm::common::divUp(max_seq_len, KT::kCpTokenPerBlock), batch_size, num_heads);
setChunkedKVCacheForMLAKernel<T><<<grid, KT::kBlockSize, 0, stream>>>(output_kv, kv, k_pe, max_seq_len, num_heads,
uncompressed_head_size, rope_size, cu_seq_lens, kv_cache_tokens_per_block);
}
#define INSTANTIATE_MLA_CHUNKED_PREFILL_KERNEL(T) \
template void invokeMergeAttnWithSoftmax<T>(T * merged_attn, float* merged_softmax_stats, T const* pre_attn, \
float const* pre_softmax_stats, T const* curr_attn, float const* curr_softmax_stats, int const batch_size, \
@ -457,10 +347,7 @@ void invokeMLASetChunkedKV(T* output_kv, T const* kv, T const* k_pe, int const b
int chunked_idx, float const* kv_scale_quant_orig_ptr, cudaStream_t stream); \
template void invokeMLALoadChunkedKV<T, __nv_fp8_e4m3>(T * output_kv_ptr, T * output_k_pe_ptr, \
KVBlockArray const& kv_cache, int const num_contexts, int64_t const* cu_ctx_chunked_len, int lora_size, \
int rope_size, int chunked_size, int chunked_idx, float const* kv_scale_quant_orig_ptr, cudaStream_t stream); \
template void invokeMLASetChunkedKV<T>(T * output_kv, T const* kv, T const* k_pe, int const batch_size, \
int const max_seq_len, int const num_heads, int uncompressed_head_size, int rope_size, \
int64_t const* cu_seq_lens, int const kv_cache_tokens_per_block, cudaStream_t stream);
int rope_size, int chunked_size, int chunked_idx, float const* kv_scale_quant_orig_ptr, cudaStream_t stream);
INSTANTIATE_MLA_CHUNKED_PREFILL_KERNEL(half);
INSTANTIATE_MLA_CHUNKED_PREFILL_KERNEL(float);

View File

@ -37,14 +37,5 @@ template <typename T, typename TCache>
void invokeMLALoadChunkedKV(T* output_kv_ptr, T* output_k_pe_ptr, KVBlockArray const& kv_cache, int const num_contexts,
int64_t const* cu_ctx_chunked_len, int lora_size, int rope_size, int chunked_size, int chunked_idx,
float const* kv_scale_quant_orig_ptr, cudaStream_t stream);
// output_kv {B, 2, ceil(chunked_size / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}, padding with
// zero
// kv {total_token, 2, H, uncompressed_h=128} 0 for k and 1 for v, k_pe {total_token, h=1, rope_h}
// input kv and k_pe can be cached tokens or uncached tokens
template <typename T>
void invokeMLASetChunkedKV(T* output_kv, T const* kv, T const* k_pe, int const batch_size, int const max_seq_len,
int const num_heads, int uncompressed_head_size, int rope_size, int64_t const* cu_seq_lens,
int const kv_cache_tokens_per_block, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -121,27 +121,6 @@ struct loadPagedKVKernelTraits
static constexpr int kKVThreadPerHead = (kLoraSize * kBytesPerElem) / kBytesPerLoad;
};
template <typename T>
struct setPagedKVKernelTraits
{
static constexpr int kQKNopeSize = 128;
static constexpr int kVHeadSize = 128;
static_assert(kQKNopeSize == kVHeadSize);
static constexpr int kRopeSize = 64;
static constexpr int kHeadSize = kQKNopeSize + kRopeSize;
using VecT = typename VecType<T>::Type;
static constexpr int kBytesPerElem = sizeof(T);
static constexpr int kBytesPerLoad = 16;
static constexpr int kElemPerLoad = kBytesPerLoad / kBytesPerElem;
static_assert((kHeadSize * kBytesPerElem) % kBytesPerLoad == 0,
"kHeadSize * kBytesPerElem must be multiple of kBytesPerLoad (16Bytes)");
static constexpr int kNumHeads = 128;
static constexpr int kThreadPerHead = (kHeadSize * kBytesPerElem) / kBytesPerLoad;
static constexpr int kKVThreadPerHead = (kQKNopeSize * kBytesPerElem) / kBytesPerLoad;
static constexpr int kCpTokenPerBlock = 16;
static constexpr int kBlockSize = kThreadPerHead * kCpTokenPerBlock;
};
template <typename SrcType, int NUM>
inline __device__ void quantCopy(
__nv_fp8_e4m3* dst_global_ptr, SrcType const* src_fragment_ptr, float const scale_val = 1.f)
@ -205,11 +184,10 @@ inline __device__ void dequantCopy(
}
template <typename T, int BLOCK_SIZE, int K_DIM, int ROPE_DIM, typename KVCacheBuffer>
__global__ void applyMLARopeAndAssignQKVKernelOptContext(T* qkv_output, T const* fuse_buf, KVCacheBuffer kv_cache,
__global__ void applyMLARopeAndAssignQKVKernelOptContext(T* q_ptr, T* k_ptr, T const* fuse_buf, KVCacheBuffer kv_cache,
float2 const* cos_sin_cache, size_t head_num, int head_size, int c_k, int* cu_q_seqlens,
int32_t const* kv_cache_lengths, uint32_t max_input_seq_len, KvCacheDataType cache_type, float* bmm1_scale,
float* bmm2_scale, float const* quant_scale_o, float const* quant_scale_kv, float const* dequant_scale_q,
float const* dequant_scale_kv, float host_bmm1_scale)
int32_t const* kv_cache_lengths, uint32_t max_input_seq_len, KvCacheDataType cache_type,
float const* quant_scale_kv)
{
// Constants.
@ -232,32 +210,6 @@ __global__ void applyMLARopeAndAssignQKVKernelOptContext(T* qkv_output, T const*
size_t const batch_idx = blockIdx.y;
size_t const head_idx = blockIdx.z;
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0)
{
// Calculate bmm scale for FP8 MLA
if (cache_type == KvCacheDataType::FP8)
{
float dequant_scale_q_val = dequant_scale_q ? dequant_scale_q[0] : 1.f;
float dequant_scale_kv_val = dequant_scale_kv ? dequant_scale_kv[0] : 1.f;
float quant_scale_o_val = quant_scale_o ? quant_scale_o[0] : 1.f;
if (bmm1_scale)
{
// The scale prepared for log2 optimization.
constexpr float kLog2e = 1.4426950408889634074f;
// The scale after fmha bmm1.
float bmm1_scale_val = dequant_scale_q_val * dequant_scale_kv_val * host_bmm1_scale;
bmm1_scale[0] = bmm1_scale_val;
bmm1_scale[1] = bmm1_scale_val * kLog2e;
}
if (bmm2_scale)
{
// The scale after fmha bmm2.
bmm2_scale[0] = quant_scale_o_val * dequant_scale_kv_val;
}
}
}
if (head_idx < head_num)
{
size_t const head_dim_vec_idx = (threadIdx.x % VECS_PER_HEAD);
@ -287,11 +239,10 @@ __global__ void applyMLARopeAndAssignQKVKernelOptContext(T* qkv_output, T const*
VecT q, k;
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM) + c_k;
auto const src_q_global_offset
= static_cast<size_t>(global_token_idx) * head_num * ((head_size + ROPE_DIM) * 2 + head_size)
auto const src_q_global_offset = static_cast<size_t>(global_token_idx) * head_num * (head_size + ROPE_DIM)
+ (head_size + ROPE_DIM) * head_idx + head_size;
q = *reinterpret_cast<VecT const*>(&qkv_output[src_q_global_offset + head_dim_idx]);
q = *reinterpret_cast<VecT const*>(&q_ptr[src_q_global_offset + head_dim_idx]);
k = *reinterpret_cast<VecT const*>(&fuse_buf[src_k_global_offset + head_dim_idx]);
// Pack two elements into one for gptj rotary embedding.
@ -322,14 +273,12 @@ __global__ void applyMLARopeAndAssignQKVKernelOptContext(T* qkv_output, T const*
else
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = k;
}
auto const dst_q_idx
= static_cast<size_t>(global_token_idx) * head_num * ((head_size + ROPE_DIM) * 2 + head_size)
auto const dst_q_idx = static_cast<size_t>(global_token_idx) * head_num * (head_size + ROPE_DIM)
+ head_idx * (head_size + ROPE_DIM) + head_size + head_dim_idx;
auto const dst_k_idx
= static_cast<size_t>(global_token_idx) * head_num * ((head_size + ROPE_DIM) * 2 + head_size)
+ head_num * (head_size + ROPE_DIM) + head_idx * (head_size + ROPE_DIM) + head_size + head_dim_idx;
reinterpret_cast<VecT*>(qkv_output)[dst_q_idx / ELTS_PER_VEC] = q;
reinterpret_cast<VecT*>(qkv_output)[dst_k_idx / ELTS_PER_VEC] = k;
auto const dst_k_idx = static_cast<size_t>(global_token_idx) * head_num * (head_size + ROPE_DIM)
+ head_idx * (head_size + ROPE_DIM) + head_size + head_dim_idx;
reinterpret_cast<VecT*>(q_ptr)[dst_q_idx / ELTS_PER_VEC] = q;
reinterpret_cast<VecT*>(k_ptr)[dst_k_idx / ELTS_PER_VEC] = k;
}
}
}
@ -712,79 +661,6 @@ __global__ void loadPagedKVCacheForMLAKernel(T* compressed_kv_ptr, T* k_pe_ptr,
}
}
// k {total_token, h, d}, v {total_token, h, d}, k_pe {total_token, h=1, d_rope}
// output {b, 2, ceil(max_seq / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}
template <typename T>
__global__ void setPagedKVCacheForMLAKernel(T* output, T const* k_ptr, T const* v_ptr, T const* k_pe_ptr,
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, int rope_dim,
int kv_cache_tokens_per_block, int64_t kv_token_stride)
{
using KT = typename tensorrt_llm::kernels::setPagedKVKernelTraits<T>;
int const batch_idx = static_cast<int>(blockIdx.y);
int const head_idx = static_cast<int>(blockIdx.z);
int const head_dim_vec_idx = (threadIdx.x % KT::kThreadPerHead);
int const head_dim_idx = head_dim_vec_idx * KT::kElemPerLoad;
bool const is_valid_v = head_dim_idx < KT::kVHeadSize;
size_t const seq_len_loop_end
= (max_input_seq_len + KT::kCpTokenPerBlock - 1) / KT::kCpTokenPerBlock * KT::kCpTokenPerBlock;
size_t const kv_cache_block_size = num_heads * kv_cache_tokens_per_block * (kv_dim + rope_dim);
size_t const kv_cache_block_num = (max_input_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
int64_t const global_token_offset = cu_seq_lens[batch_idx];
int64_t const cache_kv_len = cu_seq_lens[batch_idx + 1] - cu_seq_lens[batch_idx];
for (int local_token_idx = (threadIdx.x / KT::kThreadPerHead) + blockIdx.x * KT::kCpTokenPerBlock;
local_token_idx < seq_len_loop_end; local_token_idx += KT::kCpTokenPerBlock * gridDim.x)
{
int token_idx_in_kv_cache = local_token_idx;
bool const valid_token = token_idx_in_kv_cache < cache_kv_len;
if (valid_token)
{
// copy k and v
if (is_valid_v)
{
int ld_kv_global_offset = (global_token_offset + local_token_idx) * kv_token_stride + head_idx * kv_dim;
int ld_kv_local_offset = head_dim_vec_idx;
auto k_data
= (reinterpret_cast<typename KT::VecT const*>(k_ptr + ld_kv_global_offset))[ld_kv_local_offset];
auto v_data
= (reinterpret_cast<typename KT::VecT const*>(v_ptr + ld_kv_global_offset))[ld_kv_local_offset];
// {b, 0, token / kv_cache_tokens_per_block, h, token % kv_cache_tokens_per_block, ...}
int st_k_global_offset = batch_idx * 2 * kv_cache_block_num * kv_cache_block_size
+ local_token_idx / kv_cache_tokens_per_block * kv_cache_block_size
+ head_idx * kv_cache_tokens_per_block * (kv_dim + rope_dim)
+ (local_token_idx % kv_cache_tokens_per_block) * (kv_dim + rope_dim);
// {b, 1, token / kv_cache_tokens_per_block, h, token % kv_cache_tokens_per_block, ...}
int st_v_global_offset = st_k_global_offset + kv_cache_block_num * kv_cache_block_size;
int st_k_local_offset = head_dim_vec_idx;
int st_v_local_offset = head_dim_vec_idx;
(reinterpret_cast<typename KT::VecT*>(output + st_k_global_offset))[st_k_local_offset] = k_data;
(reinterpret_cast<typename KT::VecT*>(output + st_v_global_offset))[st_v_local_offset] = v_data;
}
// copy k_pe, only 1 head
else
{
int ld_rope_global_offset = (global_token_offset + local_token_idx) * rope_dim;
int ld_rope_local_offset = head_dim_vec_idx - KT::kKVThreadPerHead;
auto rope_data = (reinterpret_cast<typename KT::VecT const*>(
k_pe_ptr + ld_rope_global_offset))[ld_rope_local_offset];
// {b, 0, token / kv_cache_tokens_per_block, h, token % kv_cache_tokens_per_block, ...}
int st_rope_global_offset = batch_idx * 2 * kv_cache_block_num * kv_cache_block_size
+ local_token_idx / kv_cache_tokens_per_block * kv_cache_block_size
+ head_idx * kv_cache_tokens_per_block * (kv_dim + rope_dim)
+ (local_token_idx % kv_cache_tokens_per_block) * (kv_dim + rope_dim);
int st_rope_local_offset = head_dim_vec_idx;
(reinterpret_cast<typename KT::VecT*>(output + st_rope_global_offset))[st_rope_local_offset]
= rope_data;
}
}
else
{
break;
}
}
}
// q {total_uncached_tokens, h, d_nope + d_rope}
// latent_cache {total_uncached_tokens, d_k + d_rope}
template <typename T, typename TCache, int BLOCK_SIZE, int K_DIM, int ROPE_DIM>
@ -941,58 +817,152 @@ __global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T*
}
}
template <typename T, int BLOCK_SIZE, int QK_NOPE_HEAD_DIM, int QK_ROPE_HEAD_DIM, int V_HEAD_DIM>
__global__ void quantizeCopyInputToFp8Kernel(T const* q_buf, __nv_fp8_e4m3* quant_q_buf, T const* k_buf,
__nv_fp8_e4m3* quant_k_buf, T const* v_buf, __nv_fp8_e4m3* quant_v_buf, int total_q_len, int total_kv_len,
float const* quant_scale_qkv_ptr, float* bmm1_scale, float* bmm2_scale, float const* quant_scale_o,
float const* dequant_scale_q, float const* dequant_scale_kv, float host_bmm1_scale)
{
// Constants.
using VecT = typename VecType<T>::Type;
constexpr auto BYTES_PER_ELT = sizeof(T);
constexpr auto BYTES_PER_LOAD = 16;
constexpr auto ELTS_PER_VEC = BYTES_PER_LOAD / BYTES_PER_ELT;
constexpr auto QK_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM;
static_assert(
(QK_HEAD_DIM * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "QK head size needs to be multiple of 16 bytes.");
static_assert((V_HEAD_DIM * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "V head size needs to be multiple of 16 bytes.");
constexpr auto QK_VECS_PER_HEAD = QK_HEAD_DIM * BYTES_PER_ELT / BYTES_PER_LOAD;
constexpr auto V_VECS_PER_HEAD = V_HEAD_DIM * BYTES_PER_ELT / BYTES_PER_LOAD;
static_assert(BLOCK_SIZE % QK_VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
static_assert(BLOCK_SIZE % V_VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
constexpr auto QK_TOKENS_PER_BLOCK = BLOCK_SIZE / QK_VECS_PER_HEAD;
constexpr auto V_TOKENS_PER_BLOCK = BLOCK_SIZE / V_VECS_PER_HEAD;
size_t const head_idx = blockIdx.z;
size_t const head_num = gridDim.z;
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0)
{
// Calculate bmm scale for FP8 MLA
float dequant_scale_q_val = dequant_scale_q ? dequant_scale_q[0] : 1.f;
float dequant_scale_kv_val = dequant_scale_kv ? dequant_scale_kv[0] : 1.f;
float quant_scale_o_val = quant_scale_o ? quant_scale_o[0] : 1.f;
if (bmm1_scale)
{
// The scale prepared for log2 optimization.
constexpr float kLog2e = 1.4426950408889634074f;
// The scale after fmha bmm1.
float bmm1_scale_val = dequant_scale_q_val * dequant_scale_kv_val * host_bmm1_scale;
bmm1_scale[0] = bmm1_scale_val;
bmm1_scale[1] = bmm1_scale_val * kLog2e;
}
if (bmm2_scale)
{
// The scale after fmha bmm2.
bmm2_scale[0] = quant_scale_o_val * dequant_scale_kv_val;
}
}
size_t const qk_head_dim_vec_idx = (threadIdx.x % QK_VECS_PER_HEAD);
size_t const v_head_dim_vec_idx = (threadIdx.x % V_VECS_PER_HEAD);
size_t const qk_head_dim_idx = qk_head_dim_vec_idx * ELTS_PER_VEC;
size_t const v_head_dim_idx = v_head_dim_vec_idx * ELTS_PER_VEC;
size_t const q_len_loop_end
= size_t((total_q_len + QK_TOKENS_PER_BLOCK - 1) / QK_TOKENS_PER_BLOCK) * QK_TOKENS_PER_BLOCK;
size_t const k_len_loop_end
= size_t((total_kv_len + QK_TOKENS_PER_BLOCK - 1) / QK_TOKENS_PER_BLOCK) * QK_TOKENS_PER_BLOCK;
size_t const v_len_loop_end
= size_t((total_kv_len + V_TOKENS_PER_BLOCK - 1) / V_TOKENS_PER_BLOCK) * V_TOKENS_PER_BLOCK;
float quant_scale_qkv_val = quant_scale_qkv_ptr ? quant_scale_qkv_ptr[0] : 1.f;
// Quantize Q, both src and dst are contiguous
for (int q_token_idx = (threadIdx.x / QK_VECS_PER_HEAD) + blockIdx.x * QK_TOKENS_PER_BLOCK;
q_token_idx < q_len_loop_end; q_token_idx += QK_TOKENS_PER_BLOCK * gridDim.x)
{
if (q_token_idx < total_q_len)
{
auto const src_q_idx
= static_cast<size_t>(q_token_idx) * QK_HEAD_DIM * head_num + head_idx * QK_HEAD_DIM + qk_head_dim_idx;
auto const dst_q_idx = src_q_idx;
quantCopy<T, ELTS_PER_VEC>(quant_q_buf + dst_q_idx, &q_buf[src_q_idx], quant_scale_qkv_val);
}
}
// Quantize K, both src and dst are contiguous
for (int k_token_idx = (threadIdx.x / QK_VECS_PER_HEAD) + blockIdx.x * QK_TOKENS_PER_BLOCK;
k_token_idx < k_len_loop_end; k_token_idx += QK_TOKENS_PER_BLOCK * gridDim.x)
{
if (k_token_idx < total_kv_len)
{
auto const src_k_idx
= static_cast<size_t>(k_token_idx) * QK_HEAD_DIM * head_num + head_idx * QK_HEAD_DIM + qk_head_dim_idx;
auto const dst_k_idx = src_k_idx;
quantCopy<T, ELTS_PER_VEC>(quant_k_buf + dst_k_idx, &k_buf[src_k_idx], quant_scale_qkv_val);
}
}
// Quantize V, dst V is contiguous, but src V is not contiguous, so we need to calculate the stride
size_t const src_v_token_stride = (QK_NOPE_HEAD_DIM + V_HEAD_DIM) * head_num;
for (int v_token_idx = (threadIdx.x / V_VECS_PER_HEAD) + blockIdx.x * V_TOKENS_PER_BLOCK;
v_token_idx < v_len_loop_end; v_token_idx += V_TOKENS_PER_BLOCK * gridDim.x)
{
if (v_token_idx < total_kv_len)
{
auto const src_v_idx
= static_cast<size_t>(v_token_idx) * src_v_token_stride + head_idx * V_HEAD_DIM + v_head_dim_idx;
auto const dst_v_idx
= static_cast<size_t>(v_token_idx) * V_HEAD_DIM * head_num + head_idx * V_HEAD_DIM + v_head_dim_idx;
quantCopy<T, ELTS_PER_VEC>(quant_v_buf + dst_v_idx, &v_buf[src_v_idx], quant_scale_qkv_val);
}
}
}
template <typename T, typename KVCacheBuffer>
void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream)
{
dim3 grid(int(tensorrt_llm::common::divUp(params.max_input_seq_len, 32)), params.batch_size, params.head_num + 8);
auto head_size = params.meta.qk_nope_head_dim;
applyMLARopeAndAssignQKVKernelOptContext<T, 256, 512, 64, KVCacheBuffer><<<grid, 256, 0, stream>>>(
params.attention_input_buf, params.latent_cache, kv_cache_buffer, params.cos_sin_cache, params.head_num,
head_size, params.meta.kv_lora_rank, params.cu_q_seqlens, params.cache_seq_lens, params.max_input_seq_len,
params.cache_type, params.bmm1_scale, params.bmm2_scale, params.quant_scale_o, params.quant_scale_kv,
params.dequant_scale_q, params.dequant_scale_kv, params.host_bmm1_scale);
if (params.attention_input_buf != nullptr && params.quant_attention_input_buf != nullptr
&& params.cache_type == KvCacheDataType::FP8)
applyMLARopeAndAssignQKVKernelOptContext<T, 256, 512, 64, KVCacheBuffer><<<grid, 256, 0, stream>>>(params.q_buf,
params.k_buf, params.latent_cache, kv_cache_buffer, params.cos_sin_cache, params.head_num, head_size,
params.meta.kv_lora_rank, params.cu_q_seqlens, params.cache_seq_lens, params.max_input_seq_len,
params.cache_type, params.quant_scale_kv);
}
template <typename T>
void invokeMLAContextFp8Quantize(MlaParams<T>& params, int total_kv_len, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(params.cache_type == KvCacheDataType::FP8, "MLA Context: cache_type must be FP8");
TLLM_CHECK_WITH_INFO(params.q_buf != nullptr, "MLA Context: q_buf must be non-null");
TLLM_CHECK_WITH_INFO(params.k_buf != nullptr, "MLA Context: k_buf must be non-null");
TLLM_CHECK_WITH_INFO(params.v_buf != nullptr, "MLA Context: v_buf must be non-null");
TLLM_CHECK_WITH_INFO(params.quant_q_buf != nullptr, "MLA Context: quant_q_buf must be non-null");
TLLM_CHECK_WITH_INFO(params.quant_k_buf != nullptr, "MLA Context: quant_k_buf must be non-null");
TLLM_CHECK_WITH_INFO(params.quant_v_buf != nullptr, "MLA Context: quant_v_buf must be non-null");
TLLM_LOG_DEBUG("MLA RoPE Context: Quantizing separate qkv to FP8");
if (params.acc_q_len > 0)
{
TLLM_LOG_DEBUG("MLA RoPE Context: Quantizing attention_input_buf to FP8");
constexpr int threads_per_block = 384;
dim3 grid(int(tensorrt_llm::common::divUp(total_kv_len, 48)), 1, params.head_num);
int const dim_q_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim);
int const dim_k_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim);
int const dim_v_per_head = (params.meta.v_head_dim);
TLLM_LOG_DEBUG(
"Launching quantizeCopyInputToFp8Kernel with grid_size: (%d, %d, %d), threads_per_block: %d, "
"total_kv_len: %d, acc_q_len: %d",
grid.x, grid.y, grid.z, threads_per_block, total_kv_len, params.acc_q_len);
// Total dimension per token across all heads for Q, K, and V components respectively
int const total_q_dim_all_heads = params.head_num * dim_q_per_head;
int const total_k_dim_all_heads
= params.head_num * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
int const total_v_dim_all_heads
= params.head_num * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
int const num_total_qkv_elements
= params.acc_q_len * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
size_t headDim = params.meta.kv_lora_rank + params.meta.qk_rope_head_dim;
float const* device_qkv_scale_ptr = params.quant_scale_qkv;
if (num_total_qkv_elements > 0)
{
int const threads_per_block = 256;
int const num_blocks = (num_total_qkv_elements + threads_per_block - 1) / threads_per_block;
TLLM_LOG_DEBUG(
"Launching QuantizeCopyInputToFp8Kernel with num_blocks: %d, threads_per_block: %d, elements: %d",
num_blocks, threads_per_block, num_total_qkv_elements);
tensorrt_llm::kernels::QuantizeCopyInputToFp8Kernel<T><<<num_blocks, threads_per_block, 0, stream>>>(
static_cast<T const*>(params.attention_input_buf), // Source
static_cast<__nv_fp8_e4m3*>(params.quant_attention_input_buf), // Destination
num_total_qkv_elements, device_qkv_scale_ptr);
sync_check_cuda_error(stream);
cudaStreamSynchronize(stream);
}
else
{
TLLM_LOG_WARNING("MLA RoPE Context: num_total_qkv_elements is 0, skipping quantization.");
}
quantizeCopyInputToFp8Kernel<T, threads_per_block, 128, 64, 128>
<<<grid, threads_per_block, 0, stream>>>(params.q_buf, static_cast<__nv_fp8_e4m3*>(params.quant_q_buf),
params.k_buf, static_cast<__nv_fp8_e4m3*>(params.quant_k_buf), params.v_buf,
static_cast<__nv_fp8_e4m3*>(params.quant_v_buf), params.acc_q_len, total_kv_len, params.quant_scale_qkv,
params.bmm1_scale, params.bmm2_scale, params.quant_scale_o, params.dequant_scale_q,
params.dequant_scale_kv, params.host_bmm1_scale);
}
else
{
TLLM_LOG_WARNING("MLA RoPE Context: acc_q_len is 0, skipping quantization.");
}
}
@ -1017,12 +987,12 @@ void invokeMLARopeGeneration(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance, params.attention_input_buf, params.q_pe, params.latent_cache,
params.quant_attention_input_buf, kv_cache_buffer, params.cos_sin_cache, params.head_num,
params.meta.kv_lora_rank, params.acc_q_len, seq_len, params.seqQOffset, params.fmha_tile_counter,
params.cache_seq_lens, params.cu_kv_seqlens, params.q_pe_ld, params.q_pe_stride, params.cache_type,
params.bmm1_scale, params.bmm2_scale, params.quant_scale_o, params.quant_scale_q, params.quant_scale_kv,
params.dequant_scale_q, params.dequant_scale_kv, params.host_bmm1_scale);
cudaLaunchKernelEx(&config, kernel_instance, params.q_buf, params.q_pe, params.latent_cache, params.quant_q_buf,
kv_cache_buffer, params.cos_sin_cache, params.head_num, params.meta.kv_lora_rank, params.acc_q_len, seq_len,
params.seqQOffset, params.fmha_tile_counter, params.cache_seq_lens, params.cu_kv_seqlens, params.q_pe_ld,
params.q_pe_stride, params.cache_type, params.bmm1_scale, params.bmm2_scale, params.quant_scale_o,
params.quant_scale_q, params.quant_scale_kv, params.dequant_scale_q, params.dequant_scale_kv,
params.host_bmm1_scale);
}
template <typename T, typename TCache>
@ -1040,20 +1010,6 @@ void invokeMLALoadPagedKV(T* compressed_kv_ptr, T* k_pe_ptr, KVBlockArray& kv_ca
compressed_kv_ptr, k_pe_ptr, kv_cache, cu_ctx_cached_kv_lens, max_input_seq_len, kv_scale_quant_orig_ptr);
}
template <typename T>
void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_pe_ptr, int const num_requests,
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, int rope_dim,
int kv_cache_tokens_per_block, int64_t kv_token_stride, cudaStream_t stream)
{
using KT = typename tensorrt_llm::kernels::setPagedKVKernelTraits<T>;
TLLM_CHECK_WITH_INFO(kv_dim + rope_dim == KT::kHeadSize, "head dim should be equal to %d", KT::kHeadSize);
TLLM_CHECK_WITH_INFO(kv_cache_tokens_per_block % KT::kCpTokenPerBlock == 0,
"kv_cache_tokens_per_block should be multiple of %d", KT::kCpTokenPerBlock);
dim3 grid(tensorrt_llm::common::divUp(max_input_seq_len, KT::kCpTokenPerBlock), num_requests, num_heads);
setPagedKVCacheForMLAKernel<T><<<grid, KT::kBlockSize, 0, stream>>>(output, k_ptr, v_ptr, k_pe_ptr, cu_seq_lens,
max_input_seq_len, num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, kv_token_stride);
}
template <typename T, typename TCache>
void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* latent_cache_ptr, int const num_requests,
int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len,
@ -1076,11 +1032,15 @@ INSTANTIATE_MLA_ROPE(float, KVBlockArray);
INSTANTIATE_MLA_ROPE(half, KVBlockArray);
INSTANTIATE_MLA_ROPE(float, KVLinearBuffer);
INSTANTIATE_MLA_ROPE(half, KVLinearBuffer);
#ifdef ENABLE_BF16
INSTANTIATE_MLA_ROPE(__nv_bfloat16, KVBlockArray);
INSTANTIATE_MLA_ROPE(__nv_bfloat16, KVLinearBuffer);
#endif
#define INSTANTIATE_MLA_QUANTIZE(T) \
template void invokeMLAContextFp8Quantize<T>(MlaParams<T> & params, int total_kv_len, cudaStream_t stream);
INSTANTIATE_MLA_QUANTIZE(float);
INSTANTIATE_MLA_QUANTIZE(half);
INSTANTIATE_MLA_QUANTIZE(__nv_bfloat16);
#define INSTANTIATE_RW_KVCACHE_MLA(T, TCache) \
template void invokeMLALoadPagedKV<T, TCache>(T * compressed_kv_ptr, T * k_pe_ptr, KVBlockArray & kv_cache, \
@ -1099,26 +1059,6 @@ INSTANTIATE_RW_KVCACHE_MLA(half, __nv_fp8_e4m3);
INSTANTIATE_RW_KVCACHE_MLA(__nv_bfloat16, __nv_bfloat16);
INSTANTIATE_RW_KVCACHE_MLA(__nv_bfloat16, __nv_fp8_e4m3);
#define INSTANTIATE_SET_KVCACHE_MLA(T) \
template void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_pe_ptr, \
int const num_requests, int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, \
int rope_dim, int kv_cache_tokens_per_block, int64_t kv_token_stride, cudaStream_t stream);
INSTANTIATE_SET_KVCACHE_MLA(float);
INSTANTIATE_SET_KVCACHE_MLA(half);
INSTANTIATE_SET_KVCACHE_MLA(__nv_bfloat16);
template <typename T_IN>
__global__ void QuantizeCopyInputToFp8Kernel(
T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr)
{
uint element_idx = threadIdx.x + blockDim.x * blockIdx.x;
if (element_idx < num_total_elements)
{
float scale_factor = (device_scale_ptr != nullptr) ? *device_scale_ptr : 1.0f;
output_fp8_buffer[element_idx] = __nv_fp8_e4m3(static_cast<float>(input_buffer[element_idx]) * scale_factor);
}
}
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -51,9 +51,25 @@ struct MlaMetaParams
template <typename T>
struct MlaParams
{
T const* latent_cache; // cKV + k_pe
T* attention_input_buf; // [b, s, 3, h, d_h + r]
void* quant_attention_input_buf;
T const* latent_cache; // cKV + k_pe
// Tensor Q for both context and generation MLA, contiguous. Pre-process kernel will apply RoPE and modify it
// in-place. For context MLA, shape: [total_q_len, h * (d_nope + d_rope)], stride: [h * (d_nope + d_rope), 1]
T* q_buf;
// Separate tensor K for context MLA, contiguous. Pre-process kernel will apply RoPE and modify it in-place.
// shape: [total_kv_len, h * (d_nope + d_rope)], stride: [h * (d_nope + d_rope), 1]
T* k_buf = nullptr;
// Separate tensor V for context MLA, NOT contiguous,
// shape: [total_kv_len, h * d_v], stride: [h * (d_nope + d_v), 1]
T const* v_buf = nullptr;
// Tensor quantized Q for both context and generation MLA.
// For context MLA, shape: [total_q_len, h * (d_nope + d_rope)], stride: [h * (d_nope + d_rope), 1]
void* quant_q_buf = nullptr;
// Tensor quantized K for context MLA, contiguous
// shape: [total_kv_len, h * (d_nope + d_rope)], stride: [h * (d_nope + d_rope), 1]
void* quant_k_buf = nullptr;
// Tensor quantized V for context MLA, contiguous
// shape: [total_kv_len, h * d_v], stride: [h * d_v, 1]
void* quant_v_buf = nullptr;
T* context_buf;
T* q_pe; // [b, h, d_r], strided
@ -83,17 +99,16 @@ struct MlaParams
float const* dequant_scale_kv;
float host_bmm1_scale;
// for kv cache reuse/chunked context
void* context_paged_kv_ptr = nullptr;
void* context_kv_cache_block_offsets_ptr = nullptr;
int32_t context_paged_kv_max_blocks_per_seq = 0;
// for FP8 context qkv quantization
// For FP8 context qkv quantization
float const* quant_scale_qkv = nullptr;
};
template <typename T, typename KVCacheBuffer>
void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream);
template <typename T>
void invokeMLAContextFp8Quantize(MlaParams<T>& params, int total_kv_len, cudaStream_t stream);
template <typename T, typename KVCacheBuffer>
void invokeMLARopeGeneration(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream);
@ -102,20 +117,10 @@ void invokeMLALoadPagedKV(T* compressed_kv_ptr, T* k_pe_ptr, KVBlockArray& kv_ca
int64_t const* cu_ctx_cached_kv_lens, int const max_input_seq_len, int const lora_size, int const rope_size,
float const* kv_scale_quant_orig_ptr, cudaStream_t stream);
template <typename T>
void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_pe_ptr, int const num_requests,
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, int rope_dim,
int kv_cache_tokens_per_block, int64_t kv_token_stride, cudaStream_t stream);
template <typename T, typename TCache>
void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* latent_cache_ptr, int const num_requests,
int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len,
float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, int lora_size,
float const* kv_scale_orig_quant_ptr, cudaStream_t stream);
template <typename T_IN>
__global__ void QuantizeCopyInputToFp8Kernel(
T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -238,6 +238,8 @@ struct TllmGenFmhaRunnerParams
int mHeadDimQk;
// Head dimension for V.
int mHeadDimV;
// Head dimension for Q/K non-RoPE part, only used for MLA now.
int mHeadDimQkNope;
// Number of heads for Q and K/V.
int mNumHeadsQ, mNumHeadsKv, mNumHeadsQPerKv;
// The batch size.

View File

@ -329,7 +329,7 @@ struct KernelParams
// Compute the strides for K and V.
template <class FmhaOptions>
static auto makeStrideKv(FmhaOptions const& options, bool isK)
static auto makeStrideKv(FmhaOptions const& options, Data_type dtypeKv, bool isK)
{
// The maximum headDim of K and V.
@ -357,6 +357,12 @@ struct KernelParams
{
strideKeysVals = maxHeadDimKv;
}
else if (isSeparateQkv(options.mQkvLayout) && !isK && options.mHeadDimQkNope > 0 && dtypeKv != DATA_TYPE_E4M3)
{
// Non-FP8 context MLA: tensor V is not contiguous. The token stride is mNumHeadsKv * (mHeadDimQkNope +
// mHeadDimV).
strideKeysVals = options.mNumHeadsKv * (options.mHeadDimQkNope + options.mHeadDimV);
}
// The stride between heads.
int32_t strideHeads{isK ? options.mHeadDimQk : options.mHeadDimV};
@ -397,7 +403,7 @@ struct KernelParams
// The shape elements.
auto [numKeys, numHeadsQPerKv, batchSize] = makeShapeKv(options, params);
// The stride elements.
auto [strideKeys, strideHeads, strideBatch] = makeStrideKv(options, isK);
auto [strideKeys, strideHeads, strideBatch] = makeStrideKv(options, dtypeKv, isK);
// The headDim.
// Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv.
@ -430,12 +436,13 @@ struct KernelParams
// Create the TMA shape/stride for KV scaling factors.
template <class FmhaOptions>
static auto makeTmaShapeStrideKvSf(FmhaOptions const& options, KernelParams const& params, bool isK)
static auto makeTmaShapeStrideKvSf(
FmhaOptions const& options, KernelParams const& params, Data_type dtypeKv, bool isK)
{
// The shape elements.
auto [numKeys, numHeadsQPerKv, batchSize] = makeShapeKv(options, params);
// The stride elements.
auto [strideKeys, strideHeads, strideBatch] = makeStrideKv(options, isK);
auto [strideKeys, strideHeads, strideBatch] = makeStrideKv(options, dtypeKv, isK);
// The headDim.
// Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv.
@ -685,7 +692,8 @@ struct KernelParams
int32_t NumEltsPerSf = 16;
// Compute the shape and stride for SF tensor.
// FIXME: assume K and V uses the same shape.
auto [shapeKvSf, strideKvSf] = makeTmaShapeStrideKvSf(options, params, /*isK*/ true);
auto [shapeKvSf, strideKvSf]
= makeTmaShapeStrideKvSf(options, params, kernelMeta.mDataTypeKv, /*isK*/ true);
// The tileShapes for K/V.
std::vector<uint32_t> tileShapeKvSf(shapeKvSf.size(), 1);

View File

@ -32,8 +32,8 @@ void initBindings(nb::module_& m)
// Parameters with default values using std::nullopt for optional arguments
nb::arg("q"), nb::arg("k") = std::nullopt, nb::arg("v") = std::nullopt, nb::arg("output"),
nb::arg("output_sf") = std::nullopt, nb::arg("out_dtype") = std::nullopt, nb::arg("workspace_") = std::nullopt,
nb::arg("sequence_length"), nb::arg("host_past_key_value_lengths"), nb::arg("context_lengths"),
nb::arg("host_context_lengths"), nb::arg("host_request_types"),
nb::arg("sequence_length"), nb::arg("host_past_key_value_lengths"), nb::arg("host_total_kv_lens"),
nb::arg("context_lengths"), nb::arg("host_context_lengths"), nb::arg("host_request_types"),
nb::arg("kv_cache_block_offsets") = std::nullopt, nb::arg("host_kv_cache_block_offsets") = std::nullopt,
nb::arg("host_kv_cache_pool_pointers") = std::nullopt, nb::arg("host_kv_cache_pool_mapping") = std::nullopt,
nb::arg("cache_indirection") = std::nullopt, nb::arg("kv_scale_orig_quant") = std::nullopt,
@ -52,7 +52,6 @@ void initBindings(nb::module_& m)
nb::arg("kv_lora_rank") = std::nullopt, nb::arg("qk_nope_head_dim") = std::nullopt,
nb::arg("qk_rope_head_dim") = std::nullopt, nb::arg("v_head_dim") = std::nullopt,
nb::arg("mrope_rotary_cos_sin") = std::nullopt, nb::arg("mrope_position_deltas") = std::nullopt,
nb::arg("mla_context_paged_kv") = std::nullopt, nb::arg("mla_context_kv_cache_block_offsets") = std::nullopt,
nb::arg("attention_chunk_size") = std::nullopt, nb::arg("softmax_stats_tensor") = std::nullopt,
nb::arg("spec_decoding_bool_params"), nb::arg("spec_decoding_tensor_params"), "Multi-head attention operation");
}

View File

@ -1046,6 +1046,7 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
enqueue_params.host_block_offsets = host_block_offsets;
enqueue_params.batch_size = batch_size;
enqueue_params.mrope_rotary_cos_sin = mrope_rotary_cos_sin;
enqueue_params.total_kv_len = enqueue_params.num_tokens;
if (isCrossAttention())
{

View File

@ -32,8 +32,8 @@ void initBindings(pybind11::module_& m)
// Parameters with default values using std::nullopt for optional arguments
py::arg("q"), py::arg("k") = std::nullopt, py::arg("v") = std::nullopt, py::arg("output"),
py::arg("output_sf") = std::nullopt, py::arg("out_dtype") = std::nullopt, py::arg("workspace_") = std::nullopt,
py::arg("sequence_length"), py::arg("host_past_key_value_lengths"), py::arg("context_lengths"),
py::arg("host_context_lengths"), py::arg("host_request_types"),
py::arg("sequence_length"), py::arg("host_past_key_value_lengths"), py::arg("host_total_kv_lens"),
py::arg("context_lengths"), py::arg("host_context_lengths"), py::arg("host_request_types"),
py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_block_offsets") = std::nullopt,
py::arg("host_kv_cache_pool_pointers") = std::nullopt, py::arg("host_kv_cache_pool_mapping") = std::nullopt,
py::arg("cache_indirection") = std::nullopt, py::arg("kv_scale_orig_quant") = std::nullopt,
@ -52,7 +52,6 @@ void initBindings(pybind11::module_& m)
py::arg("kv_lora_rank") = std::nullopt, py::arg("qk_nope_head_dim") = std::nullopt,
py::arg("qk_rope_head_dim") = std::nullopt, py::arg("v_head_dim") = std::nullopt,
py::arg("mrope_rotary_cos_sin") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt,
py::arg("mla_context_paged_kv") = std::nullopt, py::arg("mla_context_kv_cache_block_offsets") = std::nullopt,
py::arg("attention_chunk_size") = std::nullopt, py::arg("softmax_stats_tensor") = std::nullopt,
py::arg("spec_decoding_bool_params"), py::arg("spec_decoding_tensor_params"), "Multi-head attention operation");
}

View File

@ -64,10 +64,12 @@ public:
virtual int64_t getWorkspaceSize(AttentionOp const& op, int const num_tokens, int const max_attention_window_size,
int const num_gen_tokens) const
= 0;
// typically, we use single qkv input, but for context MLA, we use separate qkv inputs
virtual void run(AttentionOp& op, bool const is_context, int32_t const seq_offset, int32_t const num_seqs,
int32_t const token_offset, int32_t const num_tokens, int32_t const predicted_tokens_per_seq,
torch::Tensor workspace, torch::Tensor output, torch::optional<torch::Tensor> output_sf, torch::Tensor qkv,
torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, torch::Tensor context_lengths,
torch::Tensor workspace, torch::Tensor output, torch::optional<torch::Tensor> output_sf, torch::Tensor qkv_or_q,
torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v, torch::Tensor sequence_length,
torch::Tensor host_past_key_value_lengths, int32_t const total_kv_len, torch::Tensor context_lengths,
torch::Tensor host_context_lengths, torch::optional<torch::Tensor> kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
@ -77,8 +79,6 @@ public:
torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache,
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
torch::optional<torch::Tensor> mla_context_paged_kv,
torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
torch::optional<torch::Tensor> softmax_stats_tensor,
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
torch::optional<torch::Tensor> attention_sinks) const
@ -121,8 +121,9 @@ public:
void run(AttentionOp& op, bool const is_context, int32_t const seq_offset, int32_t const num_seqs,
int32_t const token_offset, int32_t const num_tokens, int32_t const predicted_tokens_per_seq,
torch::Tensor workspace, torch::Tensor output, torch::optional<torch::Tensor> output_sf, torch::Tensor qkv,
torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, torch::Tensor context_lengths,
torch::Tensor workspace, torch::Tensor output, torch::optional<torch::Tensor> output_sf, torch::Tensor qkv_or_q,
torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v, torch::Tensor sequence_length,
torch::Tensor host_past_key_value_lengths, int32_t const total_kv_len, torch::Tensor context_lengths,
torch::Tensor host_context_lengths, torch::optional<torch::Tensor> kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
@ -132,14 +133,14 @@ public:
torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache,
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
torch::optional<torch::Tensor> mla_context_paged_kv,
torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
torch::optional<torch::Tensor> softmax_stats_tensor,
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
torch::optional<torch::Tensor> attention_sinks) const override
{
auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
T* attention_input = static_cast<T*>(qkv.slice(0, token_offset).data_ptr());
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
T* k_ptr = nullptr;
T* v_ptr = nullptr;
AttentionOutT* context_buf = static_cast<AttentionOutT*>(output.slice(0, token_offset).data_ptr());
TORCH_CHECK(!op.mFuseFp4Quant || output_sf.has_value());
void* context_buf_sf = op.mFuseFp4Quant ? output_sf->data_ptr() : nullptr;
@ -164,23 +165,33 @@ public:
[[maybe_unused]] MlaParams<T> mla_params;
if (op.isMLAEnabled())
{
if (is_context && op.mPagedContextFMHA && op.mPagedKVCache)
if (is_context)
{
TORCH_CHECK(mla_context_paged_kv.has_value());
TORCH_CHECK(mla_context_kv_cache_block_offsets.has_value());
if (latent_cache.has_value())
{
mla_params.latent_cache = static_cast<T const*>(latent_cache->data_ptr());
}
else
{
// kv cache reuse / chunked context cases, latent_cache is not used
mla_params.latent_cache = nullptr;
}
TORCH_CHECK(k.has_value());
TORCH_CHECK(v.has_value());
TORCH_CHECK(k->dim() == 2);
TORCH_CHECK(v->dim() == 2);
TORCH_CHECK(k->strides()[1] == 1);
TORCH_CHECK(v->strides()[1] == 1);
mla_params.context_paged_kv_ptr = mla_context_paged_kv->data_ptr();
mla_params.context_kv_cache_block_offsets_ptr = mla_context_kv_cache_block_offsets->data_ptr();
mla_params.context_paged_kv_max_blocks_per_seq = mla_context_kv_cache_block_offsets->size(-1);
k_ptr = static_cast<T*>(k->slice(0, token_offset).data_ptr());
v_ptr = static_cast<T*>(v->slice(0, token_offset).data_ptr());
mla_params.k_buf = k_ptr;
mla_params.v_buf = v_ptr;
}
else
{
// assume latent_cache has been written to paged kv cache by the PyTorch backend
TORCH_CHECK(latent_cache.has_value());
mla_params.latent_cache = static_cast<T const*>(latent_cache->data_ptr());
}
if (!is_context)
{
TORCH_CHECK(q_pe.has_value());
TORCH_CHECK(q_pe->dim() == 3);
TORCH_CHECK(q_pe->strides()[2] == 1);
@ -189,7 +200,7 @@ public:
mla_params.q_pe_ld = q_pe->strides()[1];
mla_params.q_pe_stride = q_pe->strides()[0];
}
mla_params.attention_input_buf = attention_input;
mla_params.q_buf = attention_input;
mla_params.context_buf = reinterpret_cast<T*>(context_buf);
mla_params.cos_sin_cache = rotary_cos_sin_ptr;
@ -300,6 +311,7 @@ public:
common_enqueue_params.host_primary_pool_pointer = host_primary_pool_pointer;
common_enqueue_params.host_secondary_pool_pointer = host_secondary_pool_pointer;
common_enqueue_params.num_tokens = num_tokens;
common_enqueue_params.total_kv_len = total_kv_len;
common_enqueue_params.max_blocks_per_sequence = max_blocks_per_sequence;
common_enqueue_params.sequence_lengths = sequence_lengths_ptr;
common_enqueue_params.context_lengths = context_lengths_ptr;
@ -316,6 +328,8 @@ public:
{
enqueue_params.softmaxStatsPtr = static_cast<float2*>(softmax_stats_tensor.value().data_ptr());
}
enqueue_params.k_ptr = k_ptr;
enqueue_params.v_ptr = v_ptr;
if (op.isMLAEnabled())
{
@ -424,26 +438,26 @@ using torch_ext::trtllm::attention::AttentionInputType;
void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output,
std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::Tensor host_request_types,
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
torch::Tensor host_total_kv_lens, torch::Tensor context_lengths, torch::Tensor host_context_lengths,
torch::Tensor host_request_types, std::optional<torch::Tensor> kv_cache_block_offsets,
std::optional<torch::Tensor> host_kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_pool_pointers,
std::optional<torch::Tensor> host_kv_cache_pool_mapping, std::optional<torch::Tensor> cache_indirection,
std::optional<torch::Tensor> kv_scale_orig_quant, std::optional<torch::Tensor> kv_scale_quant_orig,
std::optional<torch::Tensor> out_scale, std::optional<torch::Tensor> rotary_inv_freq,
std::optional<torch::Tensor> rotary_cos_sin, std::optional<torch::Tensor> latent_cache,
std::optional<torch::Tensor> q_pe, std::optional<torch::Tensor> block_ids_per_seq,
std::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads,
int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block,
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode,
double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim,
double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
std::optional<int64_t> q_lora_rank, std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas,
std::optional<torch::Tensor> mla_context_paged_kv, std::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor,
std::vector<bool> spec_decoding_bool_params, std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params)
{
@ -452,9 +466,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_block_offsets.has_value()
&& host_kv_cache_pool_pointers.has_value() && host_kv_cache_pool_mapping.has_value();
TLLM_CHECK_WITH_INFO(is_fused_qkv, "Only fused QKV is supported now");
TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv, "Only fused QKV is supported for non-MLA attention now");
TLLM_CHECK_WITH_INFO(update_kv_cache, "KV cache update cannot be disabled now");
auto qkv = q;
auto qkv_or_q = q;
if (is_fused_qkv)
{
TLLM_CHECK_WITH_INFO(!k.has_value(), "The k tensor should be null if using fused QKV");
@ -466,7 +480,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
TLLM_CHECK_WITH_INFO(v.has_value(), "The v tensor should be provided if updating KV cache with unfused K/V");
}
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(qkv.scalar_type());
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(qkv_or_q.scalar_type());
bool const is_fp8_out = out_dtype.has_value() && out_dtype.value() == torch::kFloat8_e4m3fn;
bool const is_fp4_out = out_dtype.has_value() && out_dtype.value() == torch::kUInt8;
@ -624,9 +638,11 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
++num_contexts;
}
int32_t const num_generations = num_seqs - num_contexts;
int32_t const num_tokens = qkv.size(0);
int32_t const num_tokens = qkv_or_q.size(0);
int32_t const num_ctx_tokens = host_context_lengths.slice(0, 0, num_contexts).sum().item<int32_t>();
int32_t const num_gen_tokens = is_gen_only ? num_tokens : num_tokens - num_ctx_tokens;
auto const ctx_total_kv_len = host_total_kv_lens.index({0}).item<int32_t>();
auto const gen_total_kv_len = host_total_kv_lens.index({1}).item<int32_t>();
for (int32_t idx = num_contexts; idx < num_seqs; idx++)
{
@ -661,7 +677,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
}
else
{
workspace = torch::empty({workspace_size}, torch::dtype(torch::kByte).device(qkv.device()));
workspace = torch::empty({workspace_size}, torch::dtype(torch::kByte).device(qkv_or_q.device()));
}
if ((num_contexts > 0) && (attn_input_type != AttentionInputType::GenerationOnly))
@ -671,12 +687,12 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
runner->run(*op,
/*is_context=*/true, seq_offset,
/*num_seqs=*/num_contexts, token_offset,
/*num_tokens=*/num_ctx_tokens, predicted_tokens_per_seq, workspace, output, output_sf, qkv, sequence_length,
host_past_key_value_lengths, context_lengths, host_context_lengths, kv_cache_block_offsets,
host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv,
mla_context_kv_cache_block_offsets, softmax_stats_tensor, spec_decoding_tensor_params, attention_sinks);
/*num_tokens=*/num_ctx_tokens, predicted_tokens_per_seq, workspace, output, output_sf, qkv_or_q, k, v,
sequence_length, host_past_key_value_lengths, ctx_total_kv_len, context_lengths, host_context_lengths,
kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers,
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
mrope_position_deltas, softmax_stats_tensor, spec_decoding_tensor_params, attention_sinks);
}
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
@ -687,12 +703,12 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
runner->run(*op,
/*is_context=*/false, seq_offset,
/*num_seqs=*/num_generations, token_offset,
/*num_tokens=*/num_gen_tokens, predicted_tokens_per_seq, workspace, output, output_sf, qkv, sequence_length,
host_past_key_value_lengths, context_lengths, host_context_lengths, kv_cache_block_offsets,
host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv,
mla_context_kv_cache_block_offsets, softmax_stats_tensor, spec_decoding_tensor_params, attention_sinks);
/*num_tokens=*/num_gen_tokens, predicted_tokens_per_seq, workspace, output, output_sf, qkv_or_q, k, v,
sequence_length, host_past_key_value_lengths, gen_total_kv_len, context_lengths, host_context_lengths,
kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers,
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
mrope_position_deltas, softmax_stats_tensor, spec_decoding_tensor_params, attention_sinks);
}
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);

View File

@ -37,26 +37,26 @@ namespace torch_ext
void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output,
std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::Tensor host_request_types,
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
torch::Tensor host_total_kv_lens, torch::Tensor context_lengths, torch::Tensor host_context_lengths,
torch::Tensor host_request_types, std::optional<torch::Tensor> kv_cache_block_offsets,
std::optional<torch::Tensor> host_kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_pool_pointers,
std::optional<torch::Tensor> host_kv_cache_pool_mapping, std::optional<torch::Tensor> cache_indirection,
std::optional<torch::Tensor> kv_scale_orig_quant, std::optional<torch::Tensor> kv_scale_quant_orig,
std::optional<torch::Tensor> out_scale, std::optional<torch::Tensor> rotary_inv_freq,
std::optional<torch::Tensor> rotary_cos_sin, std::optional<torch::Tensor> latent_cache,
std::optional<torch::Tensor> q_pe, std::optional<torch::Tensor> block_ids_per_seq,
std::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads,
int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block,
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode,
double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim,
double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
std::optional<int64_t> q_lora_rank, std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
std::optional<torch::Tensor> mla_context_paged_kv, std::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor,
std::vector<bool> spec_decoding_bool_params, std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params);

View File

@ -62,39 +62,6 @@ void loadChunkedKVCacheForMLAHelper(torch::Tensor& output_kv, torch::Tensor& out
kv_scale_quant_orig_ptr, stream);
}
template <typename T>
void setPagedKVCacheForMLAHelper(torch::Tensor& output, torch::Tensor const& k, torch::Tensor const& v,
torch::Tensor const& k_pe, int const num_requests, torch::Tensor const& cu_seq_lens, int const max_input_seq_len,
int num_heads, int kv_dim, int rope_dim, int kv_cache_tokens_per_block, int64_t kv_token_stride)
{
auto stream = at::cuda::getCurrentCUDAStream(output.get_device());
auto* output_ptr = static_cast<T*>(output.data_ptr());
auto const* k_ptr = static_cast<T const*>(k.data_ptr());
auto const* v_ptr = static_cast<T const*>(v.data_ptr());
auto const* k_pe_ptr = static_cast<T const*>(k_pe.data_ptr());
auto const* cu_seq_lens_ptr = cu_seq_lens.data_ptr<int64_t>();
// cudaMemset is faster than torch::zeros
TLLM_CUDA_CHECK(cudaMemsetAsync(output_ptr, 0, output.numel() * torch::elementSize(output.scalar_type()), stream));
tensorrt_llm::kernels::invokeMLASetPagedKV<T>(output_ptr, k_ptr, v_ptr, k_pe_ptr, num_requests, cu_seq_lens_ptr,
max_input_seq_len, num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, kv_token_stride, stream);
}
template <typename T>
void setChunkedKVCacheForMLAHelper(torch::Tensor& output, torch::Tensor const& kv, torch::Tensor const& k_pe,
int const num_requests, torch::Tensor const& cu_seq_lens, int num_heads, int kv_dim, int rope_dim,
int kv_cache_tokens_per_block, int max_seq_len)
{
auto stream = at::cuda::getCurrentCUDAStream(output.get_device());
T* output_ptr = static_cast<T*>(output.data_ptr());
T* kv_ptr = static_cast<T*>(kv.data_ptr());
T* k_pe_ptr = static_cast<T*>(k_pe.data_ptr());
auto* cu_seq_lens_ptr = cu_seq_lens.data_ptr<int64_t>();
tensorrt_llm::kernels::invokeMLASetChunkedKV<T>(output_ptr, kv_ptr, k_pe_ptr, num_requests, max_seq_len, num_heads,
kv_dim, rope_dim, cu_seq_lens_ptr, kv_cache_tokens_per_block, stream);
}
template <typename T, typename TCache>
void invokeMLARopeAppendPagedKVAssignQHelper(KVBlockArray& kv_cache, torch::Tensor& q, torch::Tensor& latent_cache,
int const num_requests, torch::Tensor const& cu_ctx_cached_kv_lens, torch::Tensor const& cu_seq_lens,
@ -371,111 +338,6 @@ std::vector<torch::Tensor> loadChunkedKVCacheForMLA(torch::ScalarType out_dtype,
return outputs;
}
torch::Tensor setPagedKVCacheForMLA(torch::Tensor& output, torch::Tensor const& k, torch::Tensor const& v,
torch::Tensor const& k_pe, int64_t const num_requests, torch::Tensor const& cu_seq_lens,
int64_t const max_input_seq_len, int64_t const num_heads, int64_t const kv_dim, int64_t const rope_dim,
int64_t const kv_cache_tokens_per_block)
{
TORCH_CHECK(output.numel() > 0);
auto output_dtype = output.scalar_type();
TORCH_CHECK(output_dtype == torch::kFloat16 || output_dtype == torch::kFloat32 || output_dtype == torch::kBFloat16);
CHECK_TH_CUDA(output);
CHECK_CONTIGUOUS(output);
// k and v can be non-contiguous
CHECK_TH_CUDA(k);
CHECK_TYPE(k, output_dtype);
CHECK_TH_CUDA(v);
CHECK_TYPE(v, output_dtype);
TORCH_CHECK(k.dim() == 3);
TORCH_CHECK(v.dim() == 3);
TORCH_CHECK(k.size(0) == v.size(0));
TORCH_CHECK(k.size(1) == v.size(1));
TORCH_CHECK(k.size(2) == v.size(2));
TORCH_CHECK(k.stride(1) == k.size(2));
TORCH_CHECK(v.stride(1) == v.size(2));
TORCH_CHECK(k.stride(2) == 1);
TORCH_CHECK(v.stride(2) == 1);
// k and v should have the same token stride
int64_t k_token_stride = k.stride(0);
int64_t v_token_stride = v.stride(0);
TORCH_CHECK(k_token_stride == v_token_stride);
// k_pe should be contiguous
CHECK_INPUT(k_pe, output_dtype);
CHECK_INPUT(cu_seq_lens, torch::kInt64);
TORCH_CHECK(cu_seq_lens.dim() == 1);
TORCH_CHECK(cu_seq_lens.size(0) >= num_requests + 1);
if (output_dtype == torch::kFloat16)
{
setPagedKVCacheForMLAHelper<half>(output, k, v, k_pe, num_requests, cu_seq_lens, max_input_seq_len, num_heads,
kv_dim, rope_dim, kv_cache_tokens_per_block, k_token_stride);
}
else if (output_dtype == torch::kFloat32)
{
setPagedKVCacheForMLAHelper<float>(output, k, v, k_pe, num_requests, cu_seq_lens, max_input_seq_len, num_heads,
kv_dim, rope_dim, kv_cache_tokens_per_block, k_token_stride);
}
else if (output_dtype == torch::kBFloat16)
{
setPagedKVCacheForMLAHelper<__nv_bfloat16>(output, k, v, k_pe, num_requests, cu_seq_lens, max_input_seq_len,
num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, k_token_stride);
}
int64_t max_block_num = (max_input_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
torch::Tensor faked_kv_cache_block_offsets = torch::arange(
0, num_requests * 2 * max_block_num, torch::TensorOptions().dtype(torch::kInt32).device(output.device()));
faked_kv_cache_block_offsets = faked_kv_cache_block_offsets.view({num_requests, 2, max_block_num});
return faked_kv_cache_block_offsets;
}
torch::Tensor setChunkedKVCacheForMLA(torch::Tensor& output, torch::Tensor const& kv, torch::Tensor const& k_pe,
int64_t const num_requests, torch::Tensor const& cu_seq_lens, int64_t const num_heads, int64_t const kv_dim,
int64_t const rope_dim, int64_t const kv_cache_tokens_per_block, int64_t const max_seq_len)
{
TORCH_CHECK(output.numel() > 0);
TORCH_CHECK(output.scalar_type() == torch::kFloat16 || output.scalar_type() == torch::kFloat32
|| output.scalar_type() == torch::kBFloat16);
CHECK_TH_CUDA(output);
CHECK_CONTIGUOUS(output);
CHECK_INPUT(kv, output.scalar_type());
CHECK_INPUT(k_pe, output.scalar_type());
CHECK_INPUT(cu_seq_lens, torch::kInt64);
TORCH_CHECK(cu_seq_lens.dim() == 1);
TORCH_CHECK(cu_seq_lens.size(0) >= num_requests + 1);
if (output.scalar_type() == torch::kFloat16)
{
setChunkedKVCacheForMLAHelper<half>(output, kv, k_pe, num_requests, cu_seq_lens, num_heads, kv_dim, rope_dim,
kv_cache_tokens_per_block, max_seq_len);
}
else if (output.scalar_type() == torch::kFloat32)
{
setChunkedKVCacheForMLAHelper<float>(output, kv, k_pe, num_requests, cu_seq_lens, num_heads, kv_dim, rope_dim,
kv_cache_tokens_per_block, max_seq_len);
}
else if (output.scalar_type() == torch::kBFloat16)
{
setChunkedKVCacheForMLAHelper<__nv_bfloat16>(output, kv, k_pe, num_requests, cu_seq_lens, num_heads, kv_dim,
rope_dim, kv_cache_tokens_per_block, max_seq_len);
}
int64_t max_block_num = (max_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
// TODO: actually this offset is always the same for all requests and all layers.
torch::Tensor faked_kv_cache_block_offsets = torch::arange(
0, num_requests * 2 * max_block_num, torch::TensorOptions().dtype(torch::kInt32).device(output.device()));
faked_kv_cache_block_offsets = faked_kv_cache_block_offsets.view({num_requests, 2, max_block_num});
return faked_kv_cache_block_offsets;
}
void MLARopeAppendPagedKVAssignQ(torch::Tensor& q, torch::Tensor& latent_cache, int64_t const num_contexts,
torch::Tensor const& cu_ctx_cached_kv_lens, torch::Tensor const& cu_seq_lens,
int64_t const max_input_uncached_seq_len, torch::Tensor const& cos_sin_cache, int64_t const head_num,
@ -664,51 +526,6 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
m.impl("load_chunked_kv_cache_for_mla", &torch_ext::loadChunkedKVCacheForMLA);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"set_paged_kv_cache_for_mla("
"Tensor output"
", Tensor k"
", Tensor v"
", Tensor k_pe"
", int num_requests"
", Tensor cu_seq_lens"
", int max_input_seq_len"
", int num_heads"
", int kv_dim"
", int rope_dim"
", int kv_cache_tokens_per_block"
") -> Tensor");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("set_paged_kv_cache_for_mla", &torch_ext::setPagedKVCacheForMLA);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"set_chunked_kv_cache_for_mla("
"Tensor output"
", Tensor kv"
", Tensor k_pe"
", int num_requests"
", Tensor cu_seq_lens"
", int num_heads"
", int kv_dim"
", int rope_dim"
", int kv_cache_tokens_per_block"
", int max_seq_len"
") -> Tensor");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("set_chunked_kv_cache_for_mla", &torch_ext::setChunkedKVCacheForMLA);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(

View File

@ -63,59 +63,6 @@ void loadChunkedKVKernelRef(T* kv_output, T* k_pe_output, tensorrt_llm::kernels:
}
}
// kv {total_tokens, 2, h, nope_size}
// k_pe {total_tokens, h=1, rope_size}
// output {b, 2, ceil(max_seq / cache_tokens_per_block), h, cache_tokens_per_block, (nope_size + rope_size)}
// max_seq <= chunk_size
template <typename T>
void setChunkedKVCacheForMLAKernelRef(T* output, T* kv_ptr, T* k_pe_ptr, int num_contexts, int64_t const* cu_seq_len,
int const max_input_seq_len, int num_heads, int nope_size, int rope_size, int cache_tokens_per_block)
{
int head_size = nope_size + rope_size;
int const kv_cache_size_per_block = num_heads * cache_tokens_per_block * head_size;
int const kv_cache_block_num_per_seq = (max_input_seq_len + cache_tokens_per_block - 1) / cache_tokens_per_block;
for (int b = 0; b < num_contexts; b++)
{
int const global_token_offset = cu_seq_len[b];
int const current_seq_len = cu_seq_len[b + 1] - cu_seq_len[b];
for (int s = 0; s < current_seq_len; s++)
{
int const global_token_idx = global_token_offset + s;
int const kv_cache_block_offset_for_k
= (b * 2 * kv_cache_block_num_per_seq + s / cache_tokens_per_block) * kv_cache_size_per_block;
int const kv_cache_block_offset_for_v
= kv_cache_block_offset_for_k + (kv_cache_block_num_per_seq * kv_cache_size_per_block);
for (int h = 0; h < num_heads; h++)
{
int const ld_k_head_offset = (global_token_idx * 2 * num_heads * nope_size) + h * nope_size;
int const ld_v_head_offset = ld_k_head_offset + num_heads * nope_size;
int const ld_k_pe_head_offset = global_token_idx * rope_size;
// copy kv
for (int d = 0; d < nope_size; d++)
{
int const ld_k_idx = ld_k_head_offset + d;
int const ld_v_idx = ld_v_head_offset + d;
int const st_k_idx = kv_cache_block_offset_for_k + h * cache_tokens_per_block * head_size
+ (s % cache_tokens_per_block) * head_size + d;
int const st_v_idx = kv_cache_block_offset_for_v + h * cache_tokens_per_block * head_size
+ (s % cache_tokens_per_block) * head_size + d;
output[st_k_idx] = kv_ptr[ld_k_idx];
output[st_v_idx] = kv_ptr[ld_v_idx];
}
// copy k_pe
for (int d = 0; d < rope_size; d++)
{
int const ld_k_pe_idx = ld_k_pe_head_offset + d;
int const st_k_pe_idx = kv_cache_block_offset_for_k + h * cache_tokens_per_block * head_size
+ (s % cache_tokens_per_block) * head_size + (nope_size + d);
output[st_k_pe_idx] = k_pe_ptr[ld_k_pe_idx];
}
}
}
}
}
// Q {total_q, H, D}
// KV {total_kv, 2, H, D}
// softmax_sum {total_q, H, 2} // {max/sum}
@ -322,9 +269,6 @@ protected:
d_k_pe_output{nullptr}, h_compressed_kv_output_ref{nullptr}, h_k_pe_output_ref{nullptr},
h_kv_scale_quant_orig{nullptr}, d_kv_scale_quant_orig{nullptr},
// for kernel 2
h_kv_tensor{nullptr}, d_kv_tensor{nullptr}, h_k_pe_tensor{nullptr}, d_k_pe_tensor{nullptr},
// for merge attn {kv_full_tensor = kv + k_pe}
m_h_q_tensor{nullptr}, m_h_kv_full_tensor{nullptr}, m_h_chunked_kv_tensor{nullptr}, m_h_output_tensor{nullptr},
m_h_softmax_sum_tensor{nullptr}, m_h_softmax_sum_accum_tensor{nullptr}, m_h_output_tensor_ref{nullptr},
@ -621,27 +565,6 @@ protected:
cudaMemcpyHostToDevice);
}
// kv, k_pe for invokeMLASetChunkedKV (kernel 2)
this->h_kv_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mBatchSize * this->mChunkSize, 2, this->mNumHeads, this->mNopeSize}), dtype);
this->h_k_pe_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mRopeSize}), dtype);
this->d_kv_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_kv_tensor->getShape(), dtype);
this->d_k_pe_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_k_pe_tensor->getShape(), dtype);
{
auto* kv_ptr = bufferCast<DataType>(*(this->h_kv_tensor));
auto* k_pe_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor));
fillArrayDataWithMod<DataType>(kv_ptr, h_kv_tensor->getSize());
fillArrayDataWithMod<DataType>(k_pe_ptr, h_k_pe_tensor->getSize());
cudaMemcpyAsync(d_kv_tensor->data(), h_kv_tensor->data(), h_kv_tensor->getSizeInBytes(),
cudaMemcpyHostToDevice, mStream->get());
cudaMemcpyAsync(d_k_pe_tensor->data(), h_k_pe_tensor->data(), h_k_pe_tensor->getSizeInBytes(),
cudaMemcpyHostToDevice, mStream->get());
cudaStreamSynchronize(mStream->get());
}
// invokeMergeAttnWithSoftmax, we just ignore rope_size here for simplicity
this->m_h_q_tensor = tensorrt_llm::runtime::BufferManager::pinned(
@ -915,39 +838,6 @@ protected:
cudaMemcpyDeviceToHost);
sync_check_cuda_error(this->mStream->get());
}
void PerformSetChunkedKVRef()
{
using tensorrt_llm::runtime::bufferCast;
auto* kv_ptr = bufferCast<DataType>(*(this->h_kv_tensor));
auto* k_pe_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor));
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
auto* cu_chunked_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_chunk_lens));
this->PrepareChunkedLen(0);
setChunkedKVCacheForMLAKernelRef(kv_cache_ptr, kv_ptr, k_pe_ptr, this->mBatchSize, cu_chunked_seq_lens_ptr,
this->mChunkSize, this->mNumHeads, this->mNopeSize, this->mRopeSize, this->mTokensPerBlock);
}
void PerformSetChunkedKV()
{
using tensorrt_llm::runtime::bufferCast;
auto* kv_ptr = bufferCast<DataType>(*(this->d_kv_tensor));
auto* k_pe_ptr = bufferCast<DataType>(*(this->d_k_pe_tensor));
auto* kv_cache_ptr = bufferCast<DataType>(*(this->d_kv_cache_tensor));
auto* cu_chunked_seq_lens_ptr = bufferCast<int64_t>(*(this->d_cu_chunk_lens));
this->PrepareChunkedLen(0);
// copy cu chunk lens to device
cudaMemcpy(this->d_cu_chunk_lens->data(), this->h_cu_chunk_lens->data(),
this->h_cu_chunk_lens->getSizeInBytes(), cudaMemcpyHostToDevice);
tensorrt_llm::kernels::invokeMLASetChunkedKV(kv_cache_ptr, kv_ptr, k_pe_ptr, this->mBatchSize, this->mChunkSize,
this->mNumHeads, this->mNopeSize, this->mRopeSize, cu_chunked_seq_lens_ptr, this->mTokensPerBlock,
mStream->get());
cudaStreamSynchronize(this->mStream->get());
// copy result back to host
cudaMemcpy(this->h_kv_cache_tensor->data(), kv_cache_ptr, this->h_kv_cache_tensor->getSizeInBytes(),
cudaMemcpyDeviceToHost);
sync_check_cuda_error(this->mStream->get());
}
};
using MLATypes
@ -1086,41 +976,3 @@ TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedLoad)
}
ASSERT_TRUE(allEqual);
}
TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedSet)
{
using tensorrt_llm::runtime::bufferCast;
using DataType = typename TestFixture::DataType;
using TCache = typename TestFixture::TCache;
if constexpr (std::is_same_v<DataType, TCache>)
{
this->setDefaultParams();
this->allocateBuffers();
sync_check_cuda_error(this->mStream->get());
bool allEqual{true};
this->PerformSetChunkedKVRef();
sync_check_cuda_error(this->mStream->get());
this->PerformSetChunkedKV();
sync_check_cuda_error(this->mStream->get());
// check result
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor));
auto* kv_cache_ptr_ref = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
for (int i = 0; i < this->h_kv_cache_tensor->getSize(); i++)
{
if (std::abs(static_cast<float>(kv_cache_ptr[i]) - static_cast<float>(kv_cache_ptr_ref[i]))
> getTolerance<DataType>(kv_cache_ptr[i]))
{
std::cout << "KV cache mismatch at index " << i << ": "
<< "expected " << static_cast<float>(kv_cache_ptr_ref[i]) << ", got "
<< static_cast<float>(kv_cache_ptr[i]) << std::endl;
allEqual = false;
break;
}
}
ASSERT_TRUE(allEqual);
}
}

View File

@ -79,60 +79,6 @@ void loadPagedKvKernelRef(T* compressed_kv_output, T* k_pe_output,
}
}
// k {total_token, h, uncompressed_h=128}, v {total_token, h, uncompressed_h}, k_pe {total_token, h=1, rope_h}
// output {b, 2, ceil(max_seq / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, (uncompressed_h + rope_h)}
// copy k, v, k_pe to a continuous memory space (then it will be packed to kv_cache)
template <typename T>
void setPagedKvCacheForMLAKernelRef(T* output, T* const k_ptr, T* const v_ptr, T* const k_pe_ptr, int num_requests,
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int uncompressed_head_size, int rope_size,
int kv_cache_tokens_per_block, int64_t kv_token_stride)
{
int const kv_cache_size_per_block = num_heads * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size);
int const kv_cache_block_num_per_seq
= (max_input_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
for (int b = 0; b < num_requests; b++)
{
int const global_token_offset = cu_seq_lens[b];
int const current_token_len = cu_seq_lens[b + 1] - cu_seq_lens[b];
for (int s = 0; s < current_token_len; s++)
{
int const global_token_idx = global_token_offset + s;
int const kv_cache_block_offset_for_k
= ((b * 2 * kv_cache_block_num_per_seq) + (s / kv_cache_tokens_per_block)) * kv_cache_size_per_block;
int const kv_cache_block_offset_for_v
= kv_cache_block_offset_for_k + (kv_cache_block_num_per_seq * kv_cache_size_per_block);
for (int h = 0; h < num_heads; h++)
{
// copy k, v
int const ld_kv_head_offset = (global_token_idx * kv_token_stride) + (h * uncompressed_head_size);
int const ld_k_pe_head_offset = (global_token_idx * rope_size);
for (int d = 0; d < uncompressed_head_size; d++)
{
int const ld_kv_idx = ld_kv_head_offset + d;
int const st_k_idx = kv_cache_block_offset_for_k
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d;
int const st_v_idx = kv_cache_block_offset_for_v
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d;
output[st_k_idx] = k_ptr[ld_kv_idx];
output[st_v_idx] = v_ptr[ld_kv_idx];
}
// copy k_pe, head_num = 1
for (int d = 0; d < rope_size; d++)
{
int const ld_k_pe_idx = ld_k_pe_head_offset + d;
int const st_k_pe_idx = kv_cache_block_offset_for_k
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d
+ uncompressed_head_size;
output[st_k_pe_idx] = k_pe_ptr[ld_k_pe_idx];
}
}
}
}
}
inline bool almostEqual(float a, float b, float atol = 1e-2, float rtol = 1e-3)
{
if (isnan(a) || isnan(b))
@ -168,10 +114,7 @@ protected:
h_kv_scale_quant_orig{nullptr}, d_kv_scale_quant_orig{nullptr},
// for kernel 1
d_compressed_kv_output{nullptr}, h_compressed_kv_output{nullptr}, h_compressed_kv_output_ref{nullptr},
d_k_pe_output{nullptr}, h_k_pe_output{nullptr}, h_k_pe_output_ref{nullptr},
// for kernel 2
d_k_tensor{nullptr}, d_v_tensor{nullptr}, d_k_pe_tensor{nullptr}, h_k_tensor{nullptr}, h_v_tensor{nullptr},
h_k_pe_tensor{nullptr};
d_k_pe_output{nullptr}, h_k_pe_output{nullptr}, h_k_pe_output_ref{nullptr};
int mNumRequests{};
int mMaxSeqLen{};
@ -463,33 +406,6 @@ protected:
cudaMemcpy(this->d_k_pe_output->data(), this->h_k_pe_output->data(), this->h_k_pe_output->getSizeInBytes(),
cudaMemcpyHostToDevice);
}
// k, v, k_pe for setPagedKvCacheForMLAKernel (kernel 2)
this->h_k_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
this->h_v_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
this->h_k_pe_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
this->d_k_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
this->d_v_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
this->d_k_pe_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
{
auto* k_ptr = bufferCast<DataType>(*(this->h_k_tensor));
auto* v_ptr = bufferCast<DataType>(*(this->h_v_tensor));
auto* k_pe_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor));
fillArrayDataWithMod(k_ptr, this->h_k_tensor->getSize());
fillArrayDataWithMod(v_ptr, this->h_v_tensor->getSize());
fillArrayDataWithMod(k_pe_ptr, this->h_k_pe_tensor->getSize());
cudaMemcpy(this->d_k_tensor->data(), this->h_k_tensor->data(), this->h_k_tensor->getSizeInBytes(),
cudaMemcpyHostToDevice);
cudaMemcpy(this->d_v_tensor->data(), this->h_v_tensor->data(), this->h_v_tensor->getSizeInBytes(),
cudaMemcpyHostToDevice);
cudaMemcpy(this->d_k_pe_tensor->data(), this->h_k_pe_tensor->data(), this->h_k_pe_tensor->getSizeInBytes(),
cudaMemcpyHostToDevice);
}
return true;
}
@ -539,35 +455,6 @@ protected:
cu_ctx_cached_kv_lens_ptr, this->mLoraSize, this->mRopeSize, kv_scale_quant_orig_ptr);
}
void PerformSetPagedKV()
{
using tensorrt_llm::runtime::bufferCast;
auto* k_ptr = bufferCast<DataType>(*(this->d_k_tensor));
auto* v_ptr = bufferCast<DataType>(*(this->d_v_tensor));
auto* k_pe_ptr = bufferCast<DataType>(*(this->d_k_pe_tensor));
auto* kv_cache_ptr = bufferCast<DataType>(*(this->d_kv_cache_tensor));
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->d_cu_seq_lens));
tensorrt_llm::kernels::invokeMLASetPagedKV<DataType>(kv_cache_ptr, k_ptr, v_ptr, k_pe_ptr, this->mNumRequests,
cu_seq_lens_ptr, this->mMaxSeqLen, this->mNumHeadsUncompressed, this->mUncompressedHeadSize,
this->mRopeSize, this->mTokensPerBlock, this->mKvTokenStride, this->mStream->get());
cudaStreamSynchronize(this->mStream->get());
cudaMemcpy(this->h_kv_cache_tensor->data(), this->d_kv_cache_tensor->data(),
this->d_kv_cache_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost);
}
void PerformSetPagedKVRef()
{
using tensorrt_llm::runtime::bufferCast;
auto* k_ptr = bufferCast<DataType>(*(this->h_k_tensor));
auto* v_ptr = bufferCast<DataType>(*(this->h_v_tensor));
auto* k_pe_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor));
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_seq_lens));
setPagedKvCacheForMLAKernelRef(kv_cache_ptr, k_ptr, v_ptr, k_pe_ptr, this->mNumRequests, cu_seq_lens_ptr,
this->mMaxSeqLen, this->mNumHeadsUncompressed, this->mUncompressedHeadSize, this->mRopeSize,
this->mTokensPerBlock, this->mKvTokenStride);
}
template <typename T>
bool CheckEqual(T const* expected, T const* output, size_t size)
{
@ -617,14 +504,4 @@ TYPED_TEST(MlaPreprocessTest, MLAPreprocessDefault)
allEqual = this->CheckEqual(k_pe_output_ref_ptr, k_pe_output_ptr, this->h_k_pe_output->getSize());
EXPECT_TRUE(allEqual);
}
{
this->PerformSetPagedKV();
sync_check_cuda_error(this->mStream->get());
this->PerformSetPagedKVRef();
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor));
auto* kv_cache_ref_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
allEqual = this->CheckEqual(kv_cache_ref_ptr, kv_cache_ptr, this->h_kv_cache_tensor->getSize());
EXPECT_TRUE(allEqual);
}
}

View File

@ -783,7 +783,7 @@ echo "All processes completed!"
The converted checkpoint could be used as `<YOUR_MODEL_DIR>` and consumed by other commands.
### KV Cache Reuse
KV cache reuse is supported for MLA on SM90 and SM100. It is enabled by default. Due to extra operations like memcpy and GEMMs, GPU memory consumption may be higher and the E2E performance may have regression in some cases. Users could pass `KvCacheConfig(enable_block_reuse=False)` to LLM API to disable it.
KV cache reuse is supported for MLA on SM90, SM100 and SM120. It is enabled by default. Due to extra operations like memcpy and GEMMs, GPU memory consumption may be higher and the E2E performance may have regression in some cases. Users could pass `KvCacheConfig(enable_block_reuse=False)` to LLM API to disable it.
### Chunked Prefill
Chunked Prefill is supported for MLA only on SM90 and SM100 currently. You should add `--enable_chunked_prefill` to enable it. The GPU memory consumption is highly correlated with `max_num_tokens` and `max_batch_size`. If encountering out-of-memory errors, you may make these values smaller. (`max_num_tokens` must be divisible by kv cache's `tokens_per_block`)

View File

@ -52,7 +52,7 @@ class AttentionMetadata:
mapping: Optional[Mapping] = None
enable_flash_mla: bool = False
enable_paged_context_mla: bool = False
enable_context_mla_with_cached_kv: bool = False
# Whether CUDA graph is enabled.
is_cuda_graph: bool = field(default=False, repr=False)

View File

@ -24,6 +24,7 @@ from .interface import (AttentionBackend, AttentionInputType, AttentionMask,
class TrtllmAttentionWrapper:
sequence_length: torch.Tensor
host_past_key_value_lengths: torch.Tensor
host_total_kv_lens: torch.Tensor
context_lengths: torch.Tensor
host_context_lengths: torch.Tensor
host_request_types: torch.Tensor
@ -67,6 +68,7 @@ class TrtllmAttentionWrapper:
qk_nope_head_dim: Optional[int]
v_head_dim: Optional[int]
attention_chunk_size: Optional[int]
softmax_stats_tensor: Optional[torch.Tensor]
use_spec_decoding: bool
is_spec_dec_tree: bool
spec_decoding_position_offsets: Optional[torch.Tensor]
@ -154,6 +156,7 @@ class TrtllmAttentionWrapper:
beam_width: int = 1,
sequence_length: torch.Tensor = ...,
host_past_key_value_lengths: torch.Tensor = ...,
host_total_kv_lens: torch.Tensor = ...,
context_lengths: torch.Tensor = ...,
host_context_lengths: torch.Tensor = ...,
host_request_types: torch.Tensor = ...,
@ -174,8 +177,6 @@ class TrtllmAttentionWrapper:
latent_cache: Optional[torch.Tensor] = None,
q_pe: Optional[torch.Tensor] = None,
mrope_config: Optional[dict] = None,
mla_context_paged_kv: Optional[torch.Tensor] = None,
mla_context_kv_cache_block_offsets: Optional[torch.Tensor] = None,
softmax_stats_tensor: Optional[torch.Tensor] = None,
is_spec_decoding_enabled: bool = False,
use_spec_decoding: bool = False,
@ -201,6 +202,7 @@ class TrtllmAttentionWrapper:
beam_width (int): Beam width in beam search.
sequence_length (torch.Tensor): The length of each sequence with shape (batch_size) on GPU.
host_past_key_value_lengths (torch.Tensor): Same as sequence_length, but on CPU.
host_total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and generation requests, with shape (2) on CPU.
context_lengths (torch.Tensor): The context-phase sequence length of each request with shape (batch_size) on GPU.
host_context_lengths (torch.Tensor): Same as context_lengths, but on CPU.
host_request_types (torch.Tensor): The tensor that indicates whether a request is in context or generation phase, with shape (batch_size) on CPU.
@ -216,8 +218,6 @@ class TrtllmAttentionWrapper:
out_scale_sf (torch.Tensor): The tensor to store the global scale for NVFP4 scaling factors, with shape (1) on GPU.
use_paged_context_fmha (bool): Sets the mPagedContextFMHA attribute in the op runner.
mrope_config (dict): The dictionary containing the mRope configuration.
mla_context_paged_kv (torch.Tensor): The paged KV cache for MLA context, for kv cache reuse/chunked context.
mla_context_kv_cache_block_offsets (torch.Tensor): The block offsets for the paged KV cache for MLA context, for kv cache reuse/chunked context.
softmax_stats_tensor (torch.Tensor): The tensor to store the softmax statistics (max/sum)
attention_sinks (torch.Tensor): The attention sinks (additional value in the denominator of the softmax) with shape of (num_heads_q) on GPU.
"""
@ -230,6 +230,7 @@ class TrtllmAttentionWrapper:
self.beam_width = beam_width
self.sequence_length = sequence_length
self.host_past_key_value_lengths = host_past_key_value_lengths
self.host_total_kv_lens = host_total_kv_lens
self.context_lengths = context_lengths
self.host_context_lengths = host_context_lengths
self.host_request_types = host_request_types
@ -253,8 +254,6 @@ class TrtllmAttentionWrapper:
self.mrope_position_deltas = mrope_config.get(
'mrope_position_deltas') if mrope_config is not None else None
self.block_ids_per_seq = block_ids_per_seq
self.mla_context_paged_kv = mla_context_paged_kv
self.mla_context_kv_cache_block_offsets = mla_context_kv_cache_block_offsets
self.softmax_stats_tensor = softmax_stats_tensor
self.attention_sinks = attention_sinks
@ -361,18 +360,12 @@ class TrtllmAttentionWrapper:
else:
raise ValueError("Unexpected attention mask type")
else:
assert is_fused_qkv
if self.attention_input_type == AttentionInputType.context_only:
if self.use_paged_context_fmha:
assert self.mla_context_paged_kv is not None
assert self.mla_context_kv_cache_block_offsets is not None
qkv_hidden_size = self.num_heads * (self.qk_nope_head_dim +
self.qk_rope_head_dim)
else:
qkv_hidden_size = self.num_heads * (
2 * (self.qk_nope_head_dim + self.qk_rope_head_dim)
) + self.num_kv_heads * self.v_head_dim
assert not is_fused_qkv
qkv_hidden_size = self.num_heads * (self.qk_nope_head_dim +
self.qk_rope_head_dim)
elif self.attention_input_type == AttentionInputType.generation_only:
assert is_fused_qkv
qkv_hidden_size = self.num_heads * (self.kv_lora_rank +
self.qk_rope_head_dim)
else:
@ -430,6 +423,7 @@ class TrtllmAttentionWrapper:
self.workspace,
self.sequence_length,
self.host_past_key_value_lengths,
self.host_total_kv_lens,
self.context_lengths,
self.host_context_lengths,
self.host_request_types,
@ -479,8 +473,6 @@ class TrtllmAttentionWrapper:
self.v_head_dim,
self.mrope_rotary_cos_sin,
self.mrope_position_deltas,
self.mla_context_paged_kv,
self.mla_context_kv_cache_block_offsets,
self.attention_chunk_size,
self.softmax_stats_tensor,
spec_decoding_bool_params,
@ -623,6 +615,7 @@ class TrtllmAttentionMetadata(AttentionMetadata):
self.kv_lens = torch.empty_like(self.kv_lens_cuda,
device='cpu',
pin_memory=True)
self.host_total_kv_lens = torch.empty(2, device='cpu', dtype=torch.int)
self.host_request_types = torch.empty_like(self.prompt_lens_cpu)
# For debugging, can use it to call the wrapper's plan function
@ -665,7 +658,7 @@ class TrtllmAttentionMetadata(AttentionMetadata):
dtype=torch.int32,
device='cuda',
)
if self.enable_paged_context_mla:
if self.enable_context_mla_with_cached_kv:
# for kv cache reuse/chunked context in MLA
self.ctx_cached_token_indptr = torch.zeros(
(self.max_num_requests + 1, ),
@ -746,12 +739,16 @@ class TrtllmAttentionMetadata(AttentionMetadata):
kv_lens + self.kv_cache_params.num_extra_kv_tokens)
self.kv_lens_cuda[:self.num_seqs].copy_(
kv_lens[:self.num_seqs].pin_memory(), non_blocking=True)
# total kv lens for context requests and generation requests, without extra tokens
self.host_total_kv_lens[0] = kv_lens[:self.num_contexts].sum().item()
self.host_total_kv_lens[1] = kv_lens[self.num_contexts:self.
num_seqs].sum().item()
self.host_request_types[:self.num_contexts].fill_(0)
self.host_request_types[self.num_contexts:self.num_seqs].fill_(1)
# prepare for kv cache reuse/chunked context in MLA
if self.enable_paged_context_mla:
self.prepare_paged_context_mla(cached_token_lens, kv_lens)
if self.enable_context_mla_with_cached_kv:
self.prepare_context_mla_with_cached_kv(cached_token_lens, kv_lens)
# kv block offsets
assert self.request_ids is not None
@ -838,8 +835,9 @@ class TrtllmAttentionMetadata(AttentionMetadata):
else:
merge_op_tensor[chunked_loop_num, s] = 1 # merge
def prepare_paged_context_mla(self, cached_token_lens: torch.Tensor,
kv_lens: torch.Tensor) -> None:
def prepare_context_mla_with_cached_kv(self,
cached_token_lens: torch.Tensor,
kv_lens: torch.Tensor) -> None:
if self.num_contexts > 0:
self.num_ctx_cached_tokens = cached_token_lens[:self.
num_contexts].sum(
@ -1101,8 +1099,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
q_pe: Optional[torch.Tensor] = None,
mrope_config: Optional[dict] = None,
attention_window_size: Optional[int] = None,
mla_context_paged_kv: Optional[torch.Tensor] = None,
mla_context_kv_cache_block_offsets: Optional[torch.Tensor] = None,
softmax_stats_tensor: Optional[torch.Tensor] = None,
enable_attn_nvfp4_output: bool = True,
output: Optional[torch.Tensor] = None,
@ -1123,9 +1119,8 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
) if metadata.runtime_features else False
if self.is_mla_enable:
# for MLA, we only use paged_context_fmha when there is cached kv
use_paged_context_fmha = use_paged_context_fmha and self.has_cached_kv_for_mla_context(
metadata)
# Context MLA uses separate qkv instead of paged_context_fmha
use_paged_context_fmha = False
use_nvfp4_output = False
if enable_attn_nvfp4_output and self.has_nvfp4 and self.support_nvfp4_output(
@ -1149,6 +1144,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
beam_width=metadata.beam_width,
sequence_length=metadata.kv_lens_cuda_runtime,
host_past_key_value_lengths=metadata.kv_lens_runtime,
host_total_kv_lens=metadata.host_total_kv_lens,
context_lengths=metadata.prompt_lens_cuda_runtime,
host_context_lengths=metadata.prompt_lens_cpu_runtime,
host_request_types=metadata.host_request_types_runtime,
@ -1169,9 +1165,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
latent_cache=latent_cache,
q_pe=q_pe,
mrope_config=mrope_config,
mla_context_paged_kv=mla_context_paged_kv,
mla_context_kv_cache_block_offsets=
mla_context_kv_cache_block_offsets,
softmax_stats_tensor=softmax_stats_tensor,
is_spec_decoding_enabled=metadata.is_spec_decoding_enabled,
use_spec_decoding=metadata.use_spec_decoding,
@ -1232,7 +1225,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
metadata: TrtllmAttentionMetadata,
) -> bool:
return (self.is_mla_enable and metadata.kv_cache_manager is not None
and metadata.enable_paged_context_mla
and metadata.enable_context_mla_with_cached_kv
and metadata.num_ctx_cached_tokens > 0)
def is_chunked_prefill_for_mla_context(
@ -1240,7 +1233,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
metadata: TrtllmAttentionMetadata,
) -> bool:
return (self.is_mla_enable and metadata.kv_cache_manager is not None
and metadata.enable_paged_context_mla
and metadata.enable_context_mla_with_cached_kv
and metadata.num_ctx_cached_tokens > 0
and metadata.runtime_features.chunked_prefill)
@ -1248,7 +1241,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
self,
metadata: TrtllmAttentionMetadata,
out_dtype: torch.dtype,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
assert out_dtype in [torch.float16, torch.bfloat16, torch.float32]
assert self.is_mla_enable and self.mla_params is not None
assert metadata.kv_cache_manager is not None
@ -1290,15 +1283,19 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
num_ctx_cached_tokens: int,
cu_chunked_seq_len: torch.Tensor,
out_dtype: torch.dtype,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
assert out_dtype in [torch.float16, torch.bfloat16, torch.float32]
assert self.is_mla_enable and self.mla_params is not None
assert metadata.kv_cache_manager is not None
if metadata.max_ctx_cached_token_len == 0:
return torch.empty((0, metadata.kv_cache_manager.head_dim),
dtype=out_dtype,
device=cu_chunked_seq_len.device)
empty_kv = torch.empty((0, self.mla_params.kv_lora_rank),
dtype=out_dtype,
device=cu_chunked_seq_len.device)
empty_k_pe = torch.empty((0, self.mla_params.qk_rope_head_dim),
dtype=out_dtype,
device=cu_chunked_seq_len.device)
return empty_kv, empty_k_pe
sink_token_length = 0
beam_width = 1
@ -1326,86 +1323,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
)
return output_kv, output_k_pe
def set_paged_kv_cache_for_mla(
self,
paged_kv: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
k_pe: torch.Tensor,
metadata: TrtllmAttentionMetadata,
) -> torch.Tensor:
assert self.is_mla_enable and self.mla_params is not None
assert self.mla_params.qk_nope_head_dim == self.mla_params.v_head_dim
assert metadata.kv_cache_manager is not None
assert paged_kv.shape[0] == metadata.num_contexts
assert paged_kv.is_contiguous()
num_contexts = metadata.num_contexts
max_seq_len = metadata.max_ctx_kv_len
tokens_per_block = metadata.kv_cache_manager.tokens_per_block
paged_kv_offsets = torch.ops.trtllm.set_paged_kv_cache_for_mla(
paged_kv,
k,
v,
k_pe,
num_contexts,
metadata.ctx_kv_indptr,
max_seq_len,
self.num_heads,
self.mla_params.qk_nope_head_dim,
self.mla_params.qk_rope_head_dim,
tokens_per_block,
)
max_block_num = (max_seq_len + tokens_per_block - 1) // tokens_per_block
assert paged_kv_offsets.shape == (num_contexts, 2, max_block_num)
return paged_kv_offsets
def set_chunked_kv_cache_for_mla(
self,
paged_kv: torch.Tensor,
kv: torch.Tensor,
k_pe: torch.Tensor,
cu_chunked_seq_len: torch.Tensor,
cached: bool,
metadata: TrtllmAttentionMetadata,
) -> torch.Tensor:
assert self.is_mla_enable and self.mla_params is not None
assert self.mla_params.qk_nope_head_dim == self.mla_params.v_head_dim
assert metadata.kv_cache_manager is not None
assert paged_kv.shape[0] == metadata.num_contexts
assert paged_kv.is_contiguous()
kv = kv.contiguous()
k_pe = k_pe.contiguous()
num_contexts = metadata.num_contexts
tokens_per_block = metadata.kv_cache_manager.tokens_per_block
if cached:
# this indptr is the fake.
cu_seq_len = cu_chunked_seq_len
max_seq_len = metadata.runtime_features.chunk_size
else:
cu_seq_len = metadata.ctx_uncached_token_indptr
max_seq_len = metadata.max_ctx_seq_len
paged_kv_offsets = torch.ops.trtllm.set_chunked_kv_cache_for_mla(
paged_kv,
kv,
k_pe,
num_contexts,
cu_seq_len,
self.num_heads,
self.mla_params.qk_nope_head_dim,
self.mla_params.qk_rope_head_dim,
metadata.kv_cache_manager.tokens_per_block,
max_seq_len,
)
max_block_num = (max_seq_len + tokens_per_block - 1) // tokens_per_block
assert paged_kv_offsets.shape == (num_contexts, 2, max_block_num)
return paged_kv_offsets
def mla_rope_append_paged_kv_assign_q(
self,
q: torch.Tensor,

View File

@ -28,7 +28,7 @@ class KVCacheParams:
# The number of sink tokens for each layer.
host_sink_token_length: Optional[torch.Tensor] = None
# The number of extra kv for draft tokens
num_extra_kv_tokens: Optional[List[int]] = 0
num_extra_kv_tokens: Optional[int] = 0
class CacheType(Enum):

View File

@ -764,7 +764,6 @@ class DeepseekV3DecoderLayer(DecoderLayer):
enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs,
)
if isinstance(self.mlp, Deepseekv3MoE):
return self.forward_MoE(
hidden_states=hidden_states,

View File

@ -835,7 +835,6 @@ class MLA(nn.Module):
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
self.rope_fusion = self.mha.support_fused_rope()
self.support_fused_qkv = self.mha.support_fused_qkv()
self.rotary_emb = None
self.apply_rotary_emb = not self.rope_fusion
if self.apply_rotary_emb:
@ -1054,12 +1053,6 @@ class MLA(nn.Module):
attn_output_gen = None
return output
def _maybe_concat_qkv(self, q, k, v):
if k is not None and v is not None and self.support_fused_qkv:
qkv = torch.concat([q, k, v], dim=-1)
q, k, v = qkv, None, None
return q, k, v
def forward_context_default(
self,
q: torch.Tensor,
@ -1085,9 +1078,6 @@ class MLA(nn.Module):
self.qk_rope_head_dim)
k = k.view(-1, self.num_heads * self.qk_head_dim)
# May concat q(including q_pe), k + k_pe, v together
q, k, v = self._maybe_concat_qkv(q, k, v)
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
out_scale = None # Currently we use BF16 MHA for context phase
@ -1139,50 +1129,33 @@ class MLA(nn.Module):
],
-1,
)
full_k_nope = full_k_nope.view(-1, self.num_heads,
self.qk_nope_head_dim)
full_v = full_v.view(-1, self.num_heads, self.v_head_dim)
# build paged_full_kv
tokens_per_block = attn_metadata.kv_cache_manager.tokens_per_block
# paged_full_kv will be initialized to 0 in the kernel to avoid NaN
paged_full_kv = torch.empty([
attn_metadata.num_contexts, 2,
(attn_metadata.max_ctx_kv_len + tokens_per_block - 1) //
tokens_per_block, self.num_heads, tokens_per_block,
max(self.qk_nope_head_dim + self.qk_rope_head_dim, self.v_head_dim)
],
dtype=q.dtype,
device=q.device)
mla_context_kv_cache_block_offsets = trtllm_attention.set_paged_kv_cache_for_mla(
paged_full_kv,
full_k_nope,
full_v,
full_k_pe,
attn_metadata,
)
full_k_pe = full_k_pe.view(-1, 1, self.qk_rope_head_dim)
full_k = torch.cat(
(full_k_nope, full_k_pe.expand(-1, self.num_heads, -1)), dim=-1)
full_k = full_k.view(-1, self.num_heads * self.qk_head_dim)
# release pytorch activation memory
full_compressed_kv = None
full_k_pe = None
full_kv = None
full_k_nope = None
full_v = None
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
out_scale = None # Currently we use BF16 MHA for context phase
# latent_cache must be None to differentiate from normal context phase,
# so that we can skip applying RoPE and appending KV cache inside attention op
attn_output = self.mha.forward(
q,
None,
None,
full_k,
full_v,
attn_metadata,
attention_input_type=AttentionInputType.context_only,
latent_cache=None,
out_scale=out_scale,
mla_context_paged_kv=paged_full_kv,
mla_context_kv_cache_block_offsets=
mla_context_kv_cache_block_offsets,
output=output,
)
@ -1204,7 +1177,6 @@ class MLA(nn.Module):
# determine the number of loop
# currently we assume that the chunk size is the same as the max_num_tokens
chunk_size = attn_metadata.runtime_features.chunk_size
chunked_loop_num = attn_metadata.chunked_loop_num
# [toal_token_q, num_heads, 2] -> [toal_token_q, num_heads] float2
@ -1229,6 +1201,7 @@ class MLA(nn.Module):
# use fake cached_cu_seq_len for chunked loop
origin_kv_lens_cuda_runtime = attn_metadata.kv_lens_cuda_runtime
origin_kv_lens_runtime = attn_metadata.kv_lens_runtime
origin_ctx_total_kv_len = attn_metadata.host_total_kv_lens[0]
for loop_idx in range(chunked_loop_num):
# {b, chunked_unit_size, h, kv_lora_rank + qk_rope_head_dim} zero padded
@ -1246,45 +1219,48 @@ class MLA(nn.Module):
# up proj to uncompressed kv
# [tokens, 2, h, kv_dim], without rope_dim
chunked_kv = self.kv_b_proj(chunked_compressed_kv)
chunked_k_nope, chunked_v = chunked_kv.split(
[
self.num_heads * self.qk_nope_head_dim,
self.num_heads * self.v_head_dim
],
-1,
)
# build full_kv
# full_kv {B, 2, chunk_size / tokens_per_block, h, tokens_per_block, kv_dim + rope_dim}
tokens_per_block = attn_metadata.kv_cache_manager.tokens_per_block
full_kv = torch.zeros([
attn_metadata.num_contexts, 2,
(chunk_size + tokens_per_block - 1) // tokens_per_block,
self.num_heads, tokens_per_block,
max(self.qk_nope_head_dim + self.qk_rope_head_dim,
self.v_head_dim)
],
dtype=q.dtype,
device=q.device)
mla_kv_cache_block_offsets = trtllm_attention.set_chunked_kv_cache_for_mla(
full_kv,
chunked_kv,
chunked_k_pe,
cu_chunked_seq_len=temp_cu_chunked_seq_len,
cached=True,
metadata=attn_metadata)
chunked_k_nope = chunked_k_nope.view(-1, self.num_heads,
self.qk_nope_head_dim)
chunked_k_pe = chunked_k_pe.view(-1, 1, self.qk_rope_head_dim)
chunked_k = torch.cat(
(chunked_k_nope, chunked_k_pe.expand(-1, self.num_heads, -1)),
dim=-1)
chunked_k = chunked_k.view(-1, self.num_heads * self.qk_head_dim)
# release pytorch activation memory
chunked_compressed_kv = None
chunked_k_pe = None
chunked_kv = None
chunked_k_nope = None
# copy chunked_seq_len to replace kv_lens_runtime
attn_metadata.kv_lens_runtime = attn_metadata.host_chunked_seq_len[
loop_idx]
attn_metadata.kv_lens_cuda_runtime = attn_metadata.chunked_seq_len[
loop_idx]
attn_metadata.host_total_kv_lens[0] = total_ctx_chunked_tokens
out_scale = None
# do not apply mask for attention within loop
# latent_cache must be None to differentiate from normal context phase,
# so that we can skip applying RoPE and appending KV cache inside attention op
temp_attn_output = self.mha.forward(
q,
None,
None,
chunked_k,
chunked_v,
attn_metadata,
attention_input_type=AttentionInputType.context_only,
latent_cache=None,
out_scale=out_scale,
attention_mask=PredefinedAttentionMask.FULL,
mla_context_paged_kv=full_kv,
mla_context_kv_cache_block_offsets=mla_kv_cache_block_offsets,
softmax_stats_tensor=self.temp_softmax_stats_tensor,
output=temp_attn_output,
)
@ -1299,41 +1275,42 @@ class MLA(nn.Module):
_, k_pe = latent_cache.view([
-1, self.kv_lora_rank + self.qk_rope_head_dim
]).split([self.kv_lora_rank, self.qk_rope_head_dim], -1)
k_pe = k_pe.contiguous()
# final round of attention
k_nope, v = kv.split(
[
self.num_heads * self.qk_nope_head_dim,
self.num_heads * self.v_head_dim
],
-1,
)
k_nope = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
k = torch.cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), dim=-1)
k = k.view(-1, self.num_heads * self.qk_head_dim)
# copy q_lens to replace kv_lens_runtime
attn_metadata.kv_lens_runtime = attn_metadata.prompt_lens_cpu_runtime
attn_metadata.kv_lens_cuda_runtime = attn_metadata.prompt_lens_cuda_runtime
attn_metadata.host_total_kv_lens[
0] = attn_metadata.prompt_lens_cpu_runtime[:attn_metadata.
num_contexts].sum().item(
)
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
out_scale = None # Currently we use BF16 MHA for context phase
tokens_per_block = attn_metadata.kv_cache_manager.tokens_per_block
full_kv = torch.zeros([
attn_metadata.num_contexts, 2,
(attn_metadata.max_ctx_seq_len + tokens_per_block - 1) //
tokens_per_block, self.num_heads, tokens_per_block,
max(self.qk_nope_head_dim + self.qk_rope_head_dim, self.v_head_dim)
],
dtype=q.dtype,
device=q.device)
mla_kv_cache_block_offsets = trtllm_attention.set_chunked_kv_cache_for_mla(
full_kv,
kv,
k_pe,
cu_chunked_seq_len=None,
cached=False,
metadata=attn_metadata)
# copy q_lens to replace kv_lens_runtime
attn_metadata.kv_lens_runtime = attn_metadata.prompt_lens_cpu_runtime
attn_metadata.kv_lens_cuda_runtime = attn_metadata.prompt_lens_cuda_runtime
# latent_cache must be None to differentiate from normal context phase,
# so that we can skip applying RoPE and appending KV cache inside attention op
temp_attn_output = self.mha.forward(
q,
None,
None,
k,
v,
attn_metadata,
attention_input_type=AttentionInputType.context_only,
latent_cache=None,
out_scale=out_scale,
mla_context_paged_kv=full_kv,
mla_context_kv_cache_block_offsets=mla_kv_cache_block_offsets,
softmax_stats_tensor=self.temp_softmax_stats_tensor,
output=temp_attn_output,
)
@ -1345,6 +1322,7 @@ class MLA(nn.Module):
# copy back kv_lens_runtime and kv_lens_cuda_runtime
attn_metadata.kv_lens_runtime = origin_kv_lens_runtime
attn_metadata.kv_lens_cuda_runtime = origin_kv_lens_cuda_runtime
attn_metadata.host_total_kv_lens[0] = origin_ctx_total_kv_len
return attn_output

View File

@ -823,7 +823,7 @@ class PyTorchModelEngine(ModelEngine):
self.enable_spec_decode = self.is_spec_decode
def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
enable_paged_context_mla = is_mla(
enable_context_mla_with_cached_kv = is_mla(
self.model.model_config.pretrained_config) and (
self.attn_runtime_features.cache_reuse
or self.attn_runtime_features.chunked_prefill)
@ -837,7 +837,8 @@ class PyTorchModelEngine(ModelEngine):
mapping=self.mapping,
runtime_features=self.attn_runtime_features,
enable_flash_mla=self.model.model_config.enable_flash_mla,
enable_paged_context_mla=enable_paged_context_mla,
enable_context_mla_with_cached_kv=
enable_context_mla_with_cached_kv,
cache_indirection=cache_indirection)
if self.attn_metadata is not None:
@ -854,7 +855,7 @@ class PyTorchModelEngine(ModelEngine):
mapping=self.mapping,
runtime_features=self.attn_runtime_features,
enable_flash_mla=self.model.model_config.enable_flash_mla,
enable_paged_context_mla=enable_paged_context_mla,
enable_context_mla_with_cached_kv=enable_context_mla_with_cached_kv,
cache_indirection=cache_indirection)
return self.attn_metadata

View File

@ -290,11 +290,13 @@ def create_py_executor(
f"Change tokens_per_block to: {executor_config.tokens_per_block} for using FlashMLA"
)
if executor_config.kv_cache_config.enable_block_reuse and not (
get_sm_version() >= 90 and get_sm_version() <= 100):
sm_version = get_sm_version()
if executor_config.kv_cache_config.enable_block_reuse and sm_version not in [
90, 100, 120
]:
logger.warning(
f"KV cache reuse for MLA can only be enabled on SM90/SM100, "
f"disable enable_block_reuse for SM{get_sm_version()}")
f"KV cache reuse for MLA can only be enabled on SM90/SM100/SM120, "
f"disable enable_block_reuse for SM{sm_version}")
executor_config.kv_cache_config.enable_block_reuse = False
kv_cache_quant_algo = model_engine.model.model_config.quant_config.kv_cache_quant_algo
@ -306,11 +308,12 @@ def create_py_executor(
f"disable enable_block_reuse for KV cache quant algorithm: {kv_cache_quant_algo}"
)
executor_config.kv_cache_config.enable_block_reuse = False
if executor_config.enable_chunked_context and not (
get_sm_version() == 100 or get_sm_version() == 90):
if executor_config.enable_chunked_context and sm_version not in [
90, 100
]:
logger.warning(
"Chunked Prefill for MLA can only be enabled on SM90/100, "
f"disable enable_block_reuse for SM{get_sm_version()}")
"Chunked Prefill for MLA can only be enabled on SM90/SM100, "
f"disable enable_chunked_context for SM{sm_version}")
executor_config.enable_chunked_context = False
model_engine.attn_runtime_features.chunked_prefill = False
if draft_model_engine is not None:

Some files were not shown because too many files have changed in this diff Show More