mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Use Separate QKV Input Layout for Context MLA (#6538)
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
This commit is contained in:
parent
8f95f35503
commit
7e135d2ea7
@ -155,50 +155,41 @@ def test_trtllm_sage_attention_fmha(d, s):
|
||||
@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
|
||||
ids=["bf16", "e4m3", "e4m3-bf16"])
|
||||
@pytest.mark.parametrize('s', [1024, 4096], ids=["seqlen-1024", "seqlen-4096"])
|
||||
@pytest.mark.parametrize(
|
||||
'input_layout', ["", "-paged-kv", "-contiguous-q-kv", "-separate-q-k-v"],
|
||||
ids=["packed-qkv", "paged-kv", "q-contiguous-kv", "separate-q-k-v"])
|
||||
def test_trtllm_context_mla_attention_fmha(dtype, s, input_layout):
|
||||
def test_trtllm_context_mla_attention_fmha(dtype, s):
|
||||
sm_version = getSMVersion()
|
||||
if sm_version < 90:
|
||||
pytest.skip("MLA kernels are only tested on sm90 and above currently.")
|
||||
|
||||
# use higher error tolerance for bf16 and s = 4096.
|
||||
epsilon = ''
|
||||
if dtype == "-bf16" and s == 4096:
|
||||
epsilon += ' -epsilon 0.03'
|
||||
|
||||
sm_version = getSMVersion()
|
||||
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89:
|
||||
pytest.skip("FP8 MLAs only supported on sm89 currently.")
|
||||
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120:
|
||||
pytest.skip("FP8 MLAs are only supported on sm120 currently.")
|
||||
|
||||
# Context phase kernels.
|
||||
# Context phase kernels, always use separate-q-k-v layout.
|
||||
subprocess.run(
|
||||
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
|
||||
-force-non-warp-specialization -causal-mask {epsilon}",
|
||||
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
|
||||
f"-causal-mask {epsilon} -separate-q-k-v",
|
||||
shell=True,
|
||||
check=True)
|
||||
|
||||
if sm_version == 90:
|
||||
# Now only hopper-style supports separate-q-k-v
|
||||
# For chunked prefill, we need to enable -save-softmax (dtype: bf16, layout: separate-q-k-v).
|
||||
# Currently fp8 kernel doesn't support saving softmax.
|
||||
if dtype == "-bf16":
|
||||
# padding mask
|
||||
subprocess.run(
|
||||
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
|
||||
-causal-mask {epsilon} {input_layout}",
|
||||
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
|
||||
f"{epsilon} -separate-q-k-v -save-softmax",
|
||||
shell=True,
|
||||
check=True)
|
||||
# causal mask
|
||||
subprocess.run(
|
||||
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} "
|
||||
f"-causal-mask {epsilon} -separate-q-k-v -save-softmax",
|
||||
shell=True,
|
||||
check=True)
|
||||
|
||||
# For chunked prefill, we need to enable -save-softmax (dtype: bf16, sm90, layout: paged-kv or separate-q-k-v).
|
||||
if dtype == "-bf16" and input_layout in [
|
||||
"-paged-kv", "-separate-q-k-v"
|
||||
]:
|
||||
# padding mask
|
||||
subprocess.run(
|
||||
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
|
||||
{epsilon} {input_layout} -save-softmax",
|
||||
shell=True,
|
||||
check=True)
|
||||
# causal mask
|
||||
subprocess.run(
|
||||
f"bin/fmha.exe -v 0 -runs 1 -min-s 1024 -s {s} -b 8 -h 8 -d 192 -dv 128 {dtype} \
|
||||
-causal-mask {epsilon} {input_layout} -save-softmax",
|
||||
shell=True,
|
||||
check=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', ["-bf16", "-e4m3", "-e4m3 -bf16-output"],
|
||||
@ -210,14 +201,17 @@ def test_trtllm_context_mla_attention_fmha(dtype, s, input_layout):
|
||||
"num-grouped-heads-64", "num-grouped-heads-128"
|
||||
])
|
||||
def test_trtllm_gen_mla_attention_fmha(dtype, s, num_grouped_heads):
|
||||
sm_version = getSMVersion()
|
||||
if sm_version < 90:
|
||||
pytest.skip("MLA kernels are only tested on sm90 and above currently.")
|
||||
|
||||
# use higher error tolerance for bf16 and s = 4096.
|
||||
epsilon = ''
|
||||
if dtype == "-bf16" and s == 4096:
|
||||
epsilon += ' -epsilon 0.03'
|
||||
|
||||
sm_version = getSMVersion()
|
||||
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 89:
|
||||
pytest.skip("FP8 MLAs only supported on sm89 currently.")
|
||||
if dtype in ["-e4m3", "-e4m3 -bf16-output"] and sm_version != 120:
|
||||
pytest.skip("FP8 MLAs are only supported on sm120 currently.")
|
||||
|
||||
# Generation phase kernels.
|
||||
subprocess.run(
|
||||
|
||||
@ -2075,6 +2075,8 @@ def get_kernel_code(kspec, kname, lname):
|
||||
kernel_traits += '_paged_kv_cache'
|
||||
elif kspec.input_layout == InputLayout.CONTIGUOUS_Q_KV:
|
||||
kernel_traits += '_contiguous_kv_cache'
|
||||
elif kspec.input_layout == InputLayout.SEPARATE_Q_K_V:
|
||||
kernel_traits += '_q_k_v'
|
||||
|
||||
flags = 0
|
||||
if kspec.ldgsts_q:
|
||||
@ -3183,7 +3185,7 @@ def get_cubin_header(kernel_traits, specs_names):
|
||||
attention_mask_type_value = attention_mask_type.value
|
||||
|
||||
# Attention input layout:
|
||||
# packed_qkv (0), contiguous_q_kv (1), q_paged_kv (2).
|
||||
# packed_qkv (0), contiguous_q_kv (1), q_paged_kv (2), separate_q_k_v (3).
|
||||
attention_input_layout = InputLayout.PACKED_QKV
|
||||
if '_q_kv' in kname:
|
||||
attention_input_layout = InputLayout.CONTIGUOUS_Q_KV
|
||||
@ -3652,12 +3654,9 @@ def enumerate_hgmma_flash_warpspec_kernels(specs, sm=90, dtype='fp16'):
|
||||
if alibi and enable_attn_logit_softcapping:
|
||||
continue
|
||||
# for normal attention, we only need contiguous kv as input layout when returning softmax.
|
||||
skip_combination = return_softmax and (input_layout
|
||||
!= InputLayout.CONTIGUOUS_Q_KV)
|
||||
# for context mla, we need paged kv or separate qkv as input layout when returning softmax.
|
||||
skip_mla_combination = return_softmax and (
|
||||
input_layout != InputLayout.Q_PAGED_KV
|
||||
and input_layout != InputLayout.SEPARATE_Q_K_V)
|
||||
skip_combination = return_softmax and input_layout != InputLayout.CONTIGUOUS_Q_KV
|
||||
# for context mla, we need separate qkv as input layout when returning softmax.
|
||||
skip_mla_combination = return_softmax and input_layout != InputLayout.SEPARATE_Q_K_V
|
||||
if not skip_combination:
|
||||
# only specify
|
||||
specs.append(
|
||||
@ -4702,9 +4701,16 @@ def enumerate_hmma_paged_kv_flash_kernels(specs, sm=80, dtype='fp16'):
|
||||
|
||||
|
||||
def enumerate_hmma_flash_kernels(specs, sm=80, dtype='fp16', head_size_v=0):
|
||||
for (input_layout, enable_attn_logit_softcapping) in \
|
||||
product([InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV, InputLayout.Q_PAGED_KV], \
|
||||
[False, True]):
|
||||
input_layouts = [
|
||||
InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV,
|
||||
InputLayout.Q_PAGED_KV
|
||||
]
|
||||
# Deepseek MLA (context 192/128 separate-q-k-v)
|
||||
if head_size_v == 128:
|
||||
input_layouts.append(InputLayout.SEPARATE_Q_K_V)
|
||||
for (input_layout,
|
||||
enable_attn_logit_softcapping) in product(input_layouts,
|
||||
[False, True]):
|
||||
enumerate_hmma_flash_kernels_base(specs, sm, dtype, input_layout,
|
||||
enable_attn_logit_softcapping,
|
||||
head_size_v)
|
||||
@ -5080,7 +5086,7 @@ def enumerate_qmma_flash_kernels(specs,
|
||||
]
|
||||
input_layouts = [
|
||||
InputLayout.PACKED_QKV, InputLayout.CONTIGUOUS_Q_KV,
|
||||
InputLayout.Q_PAGED_KV
|
||||
InputLayout.Q_PAGED_KV, InputLayout.SEPARATE_Q_K_V
|
||||
]
|
||||
for (head_size_params, (q_loop_step, kv_loop_step), tiled), input_layout in \
|
||||
product(params_q_kv_step, input_layouts):
|
||||
@ -5094,6 +5100,9 @@ def enumerate_qmma_flash_kernels(specs,
|
||||
# skip if head_size is not in head_sizes
|
||||
if head_sizes is not None and head_size not in head_sizes:
|
||||
continue
|
||||
# skip if head_size_v is not 128 for separate-q-k-v
|
||||
if input_layout == InputLayout.SEPARATE_Q_K_V and head_size_v != 128:
|
||||
continue
|
||||
specs.append(
|
||||
kernel_spec(sm=sm,
|
||||
sm_mma=89,
|
||||
@ -6354,28 +6363,30 @@ def enumerate_kernels():
|
||||
and kspec.version == 2
|
||||
and kspec.cross_mha == False
|
||||
and kspec.flash_attention == False)
|
||||
# Deepseek MLA (192/128 packed + 576/512 paged)
|
||||
or (kspec.sm in [80, 86, 89, 90, 100, 120]
|
||||
# Deepseek MLA (generation 576/512 paged)
|
||||
or (kspec.sm in [90, 100, 120]
|
||||
and kspec.dtype in ['bf16', 'e4m3_fp32']
|
||||
and (((kspec.head_size, kspec.head_size_v) == (192, 128) and kspec.input_layout in [InputLayout.PACKED_QKV, InputLayout.Q_PAGED_KV])
|
||||
or ((kspec.head_size, kspec.head_size_v) == (576, 512) and kspec.input_layout == InputLayout.Q_PAGED_KV))
|
||||
and kspec.head_size == 576
|
||||
and kspec.head_size_v == 512
|
||||
and kspec.input_layout == InputLayout.Q_PAGED_KV
|
||||
and kspec.sage_block_sizes is None
|
||||
and kspec.version == 2
|
||||
and kspec.cross_mha == False
|
||||
and kspec.flash_attention == True
|
||||
and kspec.warp_specialization == False
|
||||
and kspec.tiled == True)
|
||||
# Deepseek MLA (hopper-style context 192/128)
|
||||
or (kspec.sm == 90
|
||||
and kspec.dtype == 'bf16'
|
||||
# Deepseek MLA (context 192/128 separate-q-k-v)
|
||||
or (kspec.sm in [90, 100, 120]
|
||||
and kspec.dtype in ['bf16', 'e4m3_fp32']
|
||||
and kspec.head_size == 192
|
||||
and kspec.head_size_v == 128
|
||||
and kspec.input_layout == InputLayout.SEPARATE_Q_K_V
|
||||
and kspec.sage_block_sizes is None
|
||||
and kspec.version == 2
|
||||
and kspec.cross_mha == False
|
||||
and kspec.flash_attention == True
|
||||
and kspec.warp_specialization == True
|
||||
and kspec.alibi == False
|
||||
and ((kspec.warp_specialization == True and kspec.alibi == False) # sm90
|
||||
or (kspec.warp_specialization == False and kspec.tiled == True)) # non-sm90
|
||||
and kspec.enable_attn_logit_softcapping == False)
|
||||
# SageAttention (warp_spec, head_size in (80, 128), packed QKV, padding mask)
|
||||
or (kspec.sm == 90
|
||||
|
||||
@ -418,7 +418,7 @@ struct Gmem_tile_qkv
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// We expect the Q layout to be [B, S, H, D] with variable sequence length support.
|
||||
// We expect the Q/K/V layout to be [B, S, H, D] with variable sequence length support.
|
||||
template <
|
||||
// The instruction traits.
|
||||
typename Traits,
|
||||
@ -440,7 +440,7 @@ template <
|
||||
int NUM_MATS = 1,
|
||||
// Is sliding window attention used ?
|
||||
bool SLIDING_WINDOW_ATTENTION = false>
|
||||
struct Gmem_tile_q
|
||||
struct Gmem_tile_q_k_v
|
||||
{
|
||||
|
||||
// The size of each LDG.
|
||||
@ -523,22 +523,38 @@ struct Gmem_tile_q
|
||||
USE_LDGSTS = USE_LDGSTS_
|
||||
};
|
||||
|
||||
// Ctor (keep qkv_offset for compatibility)
|
||||
// Ctor
|
||||
// qkv_offset: 0 for Q, 1 for K, 2 for V
|
||||
template <typename Block_info>
|
||||
inline __device__ Gmem_tile_q(bert::Fused_multihead_attention_params_v2 const& params, int qkv_offset,
|
||||
inline __device__ Gmem_tile_q_k_v(bert::Fused_multihead_attention_params_v2 const& params, int qkv_offset,
|
||||
Block_info const& binfo, int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0)
|
||||
: Gmem_tile_q(params, binfo, tidx, cta_row_offset, cta_col_offset_in_bytes)
|
||||
{
|
||||
}
|
||||
|
||||
// Ctor.
|
||||
template <typename Block_info>
|
||||
inline __device__ Gmem_tile_q(bert::Fused_multihead_attention_params_v2 const& params, Block_info const& binfo,
|
||||
int tidx, int cta_row_offset = 0, int cta_col_offset_in_bytes = 0)
|
||||
: params_q_stride_in_bytes_(params.q_stride_in_bytes)
|
||||
, actual_seqlen_(binfo.actual_q_seqlen)
|
||||
, q_ptr_(reinterpret_cast<char*>(params.q_ptr))
|
||||
{
|
||||
int seq_offset = 0;
|
||||
if (qkv_offset == 0)
|
||||
{
|
||||
// Q tensor
|
||||
params_q_k_v_stride_in_bytes_ = params.q_stride_in_bytes;
|
||||
q_k_v_ptr_ = reinterpret_cast<char*>(params.q_ptr);
|
||||
actual_seqlen_ = binfo.actual_q_seqlen;
|
||||
seq_offset = binfo.sum_s;
|
||||
}
|
||||
else if (qkv_offset == 1)
|
||||
{
|
||||
// K tensor
|
||||
params_q_k_v_stride_in_bytes_ = params.k_stride_in_bytes;
|
||||
q_k_v_ptr_ = reinterpret_cast<char*>(params.k_ptr);
|
||||
actual_seqlen_ = binfo.actual_kv_seqlen;
|
||||
seq_offset = binfo.sum_s_kv;
|
||||
}
|
||||
else if (qkv_offset == 2)
|
||||
{
|
||||
// V tensor
|
||||
params_q_k_v_stride_in_bytes_ = params.v_stride_in_bytes;
|
||||
q_k_v_ptr_ = reinterpret_cast<char*>(params.v_ptr);
|
||||
actual_seqlen_ = binfo.actual_kv_seqlen;
|
||||
seq_offset = binfo.sum_s_kv;
|
||||
}
|
||||
|
||||
// Compute the position in the sequence (within the CTA for the moment).
|
||||
int row = tidx / THREADS_PER_ROW;
|
||||
@ -550,17 +566,20 @@ struct Gmem_tile_q
|
||||
// Do not load/store if the thread is in the padded area
|
||||
col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_LDG;
|
||||
|
||||
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
|
||||
// We won't consider past_q_length when loading from gmem_q.
|
||||
int64_t row_offset = (int64_t) (row + cta_row_offset) * params_q_stride_in_bytes_;
|
||||
// Add the block index. (sum_s * h + hidx).
|
||||
int64_t idx = binfo.bidx;
|
||||
// The row offset in the batched GEMM, including the sequence offset.
|
||||
int64_t row_offset = (int64_t) (row + cta_row_offset + seq_offset) * params_q_k_v_stride_in_bytes_;
|
||||
// Add the head index.
|
||||
int64_t idx = binfo.bidh;
|
||||
|
||||
// Assemble the final pointer.
|
||||
q_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col_in_bytes_;
|
||||
q_k_v_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col_in_bytes_;
|
||||
|
||||
// Take the CTA offset to modify the sequence length.
|
||||
actual_seqlen_ -= cta_row_offset;
|
||||
|
||||
// Set the initial seq_len and qkv_offset in case of reinterating
|
||||
actual_seqlen_init_ = actual_seqlen_;
|
||||
q_k_v_ptr_init_ = q_k_v_ptr_;
|
||||
}
|
||||
|
||||
// Store data to shared memory.
|
||||
@ -590,7 +609,7 @@ struct Gmem_tile_q
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDGS; ++ii)
|
||||
{
|
||||
ptrs[ii] = q_ptr_ + (int64_t) ii * ROWS_PER_LDG * params_q_stride_in_bytes_;
|
||||
ptrs[ii] = q_k_v_ptr_ + (int64_t) ii * ROWS_PER_LDG * params_q_k_v_stride_in_bytes_;
|
||||
}
|
||||
|
||||
// Trigger LDGSTS or the LDGs.
|
||||
@ -598,10 +617,24 @@ struct Gmem_tile_q
|
||||
Ldgsts_helper<USE_LDGSTS>::load(this, smem_tile, ptrs, preds);
|
||||
}
|
||||
|
||||
// Move the pointer to the next row location.
|
||||
inline __device__ void move(int const steps = 1)
|
||||
{
|
||||
q_k_v_ptr_ += (int64_t) ROWS * params_q_k_v_stride_in_bytes_ * steps;
|
||||
actual_seqlen_ -= (int) ROWS * steps;
|
||||
}
|
||||
|
||||
// Move the pointer to the next row location by the offset (not step).
|
||||
inline __device__ void move_by_offset(int const offset)
|
||||
{
|
||||
q_k_v_ptr_ = q_k_v_ptr_init_ + (int64_t) offset * params_q_k_v_stride_in_bytes_;
|
||||
actual_seqlen_ = actual_seqlen_init_ - (int) offset;
|
||||
}
|
||||
|
||||
// Move the pointer to the next column location
|
||||
inline __device__ void move_col()
|
||||
{
|
||||
q_ptr_ += (int64_t) COLS * (BITS_PER_ELEMENT / 8);
|
||||
q_k_v_ptr_ += (int64_t) COLS * (BITS_PER_ELEMENT / 8);
|
||||
// Update col_in_bytes_ to ensure load predicates work
|
||||
col_in_bytes_ += THREADS_PER_ROW * BYTES_PER_LDG;
|
||||
}
|
||||
@ -609,15 +642,29 @@ struct Gmem_tile_q
|
||||
// Rewind the pointer back to previous column location
|
||||
inline __device__ void rewind_col(int const steps)
|
||||
{
|
||||
q_ptr_ -= COLS * (BITS_PER_ELEMENT / 8) * steps;
|
||||
q_k_v_ptr_ -= COLS * (BITS_PER_ELEMENT / 8) * steps;
|
||||
// Update col_in_bytes_ to ensure load predicates work
|
||||
col_in_bytes_ -= THREADS_PER_ROW * BYTES_PER_LDG * steps;
|
||||
}
|
||||
|
||||
// The stride between rows for the QKV matrice.
|
||||
int64_t params_q_stride_in_bytes_;
|
||||
// Move the pointer to the specified step.
|
||||
inline __device__ void move_to(int const step)
|
||||
{
|
||||
q_k_v_ptr_ = q_k_v_ptr_init_ + (int64_t) ROWS * params_q_k_v_stride_in_bytes_ * step;
|
||||
actual_seqlen_ = actual_seqlen_init_ - (int) ROWS * step;
|
||||
}
|
||||
|
||||
inline __device__ void reset()
|
||||
{
|
||||
q_k_v_ptr_ = q_k_v_ptr_init_;
|
||||
actual_seqlen_ = actual_seqlen_init_;
|
||||
}
|
||||
|
||||
// The stride between rows for the Q/K/V matrice.
|
||||
int64_t params_q_k_v_stride_in_bytes_;
|
||||
// The pointer.
|
||||
char* q_ptr_;
|
||||
char* q_k_v_ptr_;
|
||||
char* q_k_v_ptr_init_;
|
||||
// The register to store predicates.
|
||||
uint32_t preds_[PRED_REGS];
|
||||
// The fetch registers.
|
||||
@ -627,6 +674,7 @@ struct Gmem_tile_q
|
||||
int64_t col_in_bytes_;
|
||||
// The sequence length.
|
||||
int actual_seqlen_;
|
||||
int actual_seqlen_init_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1015,8 +1015,8 @@ template <
|
||||
typename OutputType = typename Traits::A_type,
|
||||
// The sage attention block size for Q, K and V
|
||||
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0>
|
||||
using Kernel_traits_v2_paged_kv_cache
|
||||
= Kernel_traits_<Traits, fmha::v2::Gmem_tile_q, fmha::v2::Gmem_tile_paged_kv, fmha::v2::Gmem_tile_paged_kv,
|
||||
using Kernel_traits_v2_q_k_v
|
||||
= Kernel_traits_<Traits, fmha::v2::Gmem_tile_q_k_v, fmha::v2::Gmem_tile_q_k_v, fmha::v2::Gmem_tile_q_k_v,
|
||||
Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, DV, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS,
|
||||
2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>;
|
||||
|
||||
@ -1049,7 +1049,41 @@ template <
|
||||
typename OutputType = typename Traits::A_type,
|
||||
// The sage attention block size for Q, K and V
|
||||
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0>
|
||||
using Kernel_traits_v2_contiguous_kv_cache = Kernel_traits_<Traits, fmha::v2::Gmem_tile_q,
|
||||
using Kernel_traits_v2_paged_kv_cache
|
||||
= Kernel_traits_<Traits, fmha::v2::Gmem_tile_q_k_v, fmha::v2::Gmem_tile_paged_kv, fmha::v2::Gmem_tile_paged_kv,
|
||||
Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, DV, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS,
|
||||
2, MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
// The instruction traits.
|
||||
typename Traits,
|
||||
// The sequence length.
|
||||
int S,
|
||||
// The hidden size per head.
|
||||
int D,
|
||||
// The hidden dimension of V.
|
||||
int DV,
|
||||
// The number of timesteps per iteration of the main loop.
|
||||
int STEP,
|
||||
// The number of vertical warps.
|
||||
int WARPS_M,
|
||||
// The number of horizontal warps.
|
||||
int WARPS_N,
|
||||
// The number of CTAs per head for Cta_tile_p; equivalent to BMM1 split-K
|
||||
int CTAS_PER_HEAD,
|
||||
// The flags.
|
||||
uint32_t FLAGS = 0x8,
|
||||
// The attention mask version (see src/mask.h).
|
||||
int MASK_VERSION = 2,
|
||||
// Do we use half epilogue for the 2nd GEMM (hmma_fp32)
|
||||
bool BMM2_FP16_EPILOGUE = true,
|
||||
// The output type.
|
||||
typename OutputType = typename Traits::A_type,
|
||||
// The sage attention block size for Q, K and V
|
||||
int SAGE_BLOCK_SIZE_Q = 0, int SAGE_BLOCK_SIZE_K = 0, int SAGE_BLOCK_SIZE_V = 0>
|
||||
using Kernel_traits_v2_contiguous_kv_cache = Kernel_traits_<Traits, fmha::v2::Gmem_tile_q_k_v,
|
||||
fmha::v2::Gmem_tile_contiguous_kv, fmha::v2::Gmem_tile_contiguous_kv,
|
||||
Gmem_tile_o_dispatcher<Traits, OutputType>::Gmem_tile_o, S, D, 0, STEP, WARPS_M, WARPS_N, CTAS_PER_HEAD, FLAGS, 2,
|
||||
MASK_VERSION, BMM2_FP16_EPILOGUE, SAGE_BLOCK_SIZE_Q, SAGE_BLOCK_SIZE_K, SAGE_BLOCK_SIZE_V>;
|
||||
|
||||
@ -890,18 +890,17 @@ int main(int argc, char** argv)
|
||||
{
|
||||
bool is_MLA = (d == 192 && dv == 128);
|
||||
if (((!is_MLA) && input_layout != Attention_input_layout::CONTIGUOUS_Q_KV)
|
||||
|| (is_MLA && input_layout != Attention_input_layout::Q_PAGED_KV
|
||||
&& input_layout != Attention_input_layout::SEPARATE_Q_K_V))
|
||||
|| (is_MLA && input_layout != Attention_input_layout::SEPARATE_Q_K_V))
|
||||
{
|
||||
fprintf(stderr,
|
||||
"For normal attention, Only '--contiguous-q-kv' layout supports "
|
||||
"'-save-softmax'. For MLA only '-paged-kv' and '-separate-q-k-v' layout supports "
|
||||
"'-save-softmax'. For MLA only '-separate-q-k-v' layout supports "
|
||||
"'-save-softmax'.\n");
|
||||
exit(1);
|
||||
}
|
||||
if (data_type == DATA_TYPE_E4M3)
|
||||
{
|
||||
fprintf(stderr, "Currently fp8 kernel doesn't support fp8.\n");
|
||||
fprintf(stderr, "Currently fp8 kernel doesn't support saving softmax.\n");
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
@ -747,17 +747,19 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
|
||||
= mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
|
||||
int const total_v_dim_all_heads
|
||||
= mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
|
||||
|
||||
int const num_total_qkv_elements
|
||||
= max_num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
|
||||
|
||||
// Packed fp8 qkv buffer size for normal fp8 context FMHA
|
||||
size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
|
||||
? max_num_tokens * size_t(local_hidden_units_qo + 2 * local_hidden_units_kv)
|
||||
: 0;
|
||||
if (mFP8ContextMLA)
|
||||
// Separate fp8 q/k/v buffer size for fp8 context MLA
|
||||
size_t fp8_q_buf_size = 0;
|
||||
size_t fp8_k_buf_size = 0;
|
||||
size_t fp8_v_buf_size = 0;
|
||||
if (mEnableContextFMHA && mFP8ContextMLA && mFmhaDispatcher->isSeparateQAndKvInput())
|
||||
{
|
||||
fp8_qkv_buffer_size
|
||||
= mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0;
|
||||
fp8_q_buf_size = max_num_tokens * static_cast<size_t>(total_q_dim_all_heads);
|
||||
fp8_k_buf_size = max_num_tokens * static_cast<size_t>(total_k_dim_all_heads);
|
||||
fp8_v_buf_size = max_num_tokens * static_cast<size_t>(total_v_dim_all_heads);
|
||||
}
|
||||
|
||||
size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
|
||||
@ -774,7 +776,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
|
||||
? 0
|
||||
: (2 * size * cpMaxPaddedSequenceLength * getHeadSize() * (mNumHeads + 2 * mNumKVHeads) + cu_seqlens_size);
|
||||
|
||||
int const NUM_BUFFERS = 20;
|
||||
int const NUM_BUFFERS = 23;
|
||||
size_t workspaces[NUM_BUFFERS];
|
||||
workspaces[0] = CUBLAS_WORKSPACE_SIZE;
|
||||
workspaces[1] = attention_mask_size;
|
||||
@ -789,13 +791,16 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
|
||||
workspaces[10] = qkv_buf_2_size;
|
||||
workspaces[11] = qk_buf_float_size;
|
||||
workspaces[12] = fp8_qkv_buffer_size;
|
||||
workspaces[13] = padding_offset_size;
|
||||
workspaces[14] = encoder_padding_offset_size;
|
||||
workspaces[15] = tokens_info_size;
|
||||
workspaces[16] = fmha_scheduler_counter;
|
||||
workspaces[17] = fmha_bmm1_scale_size;
|
||||
workspaces[18] = fmha_bmm2_scale_size;
|
||||
workspaces[19] = cpWorkspaceSize;
|
||||
workspaces[13] = fp8_q_buf_size;
|
||||
workspaces[14] = fp8_k_buf_size;
|
||||
workspaces[15] = fp8_v_buf_size;
|
||||
workspaces[16] = padding_offset_size;
|
||||
workspaces[17] = encoder_padding_offset_size;
|
||||
workspaces[18] = tokens_info_size;
|
||||
workspaces[19] = fmha_scheduler_counter;
|
||||
workspaces[20] = fmha_bmm1_scale_size;
|
||||
workspaces[21] = fmha_bmm2_scale_size;
|
||||
workspaces[22] = cpWorkspaceSize;
|
||||
context_workspace_size = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
|
||||
|
||||
return context_workspace_size;
|
||||
@ -979,7 +984,7 @@ int AttentionOp::mlaGeneration(
|
||||
params.fmha_tile_counter = fmha_tile_counter_ptr;
|
||||
params.bmm1_scale = mla_bmm1_scale_ptr;
|
||||
params.bmm2_scale = mla_bmm2_scale_ptr;
|
||||
params.quant_attention_input_buf = quant_q_buffer_ptr;
|
||||
params.quant_q_buf = quant_q_buffer_ptr;
|
||||
|
||||
params.quant_scale_o = generation_params.attention_output_orig_quant;
|
||||
params.quant_scale_q = generation_params.kv_scale_orig_quant;
|
||||
@ -1020,8 +1025,8 @@ int AttentionOp::mlaGeneration(
|
||||
tllmRunnerParams.mTileScheduler = mMultiBlockMode ? TileScheduler::Static : TileScheduler::Persistent;
|
||||
|
||||
// Q buffer.
|
||||
tllmRunnerParams.qPtr = mFP8GenerationMLA ? reinterpret_cast<void const*>(params.quant_attention_input_buf)
|
||||
: reinterpret_cast<void const*>(params.attention_input_buf);
|
||||
tllmRunnerParams.qPtr = mFP8GenerationMLA ? reinterpret_cast<void const*>(params.quant_q_buf)
|
||||
: reinterpret_cast<void const*>(params.q_buf);
|
||||
|
||||
// KV buffer
|
||||
// Paged KV
|
||||
@ -1146,9 +1151,8 @@ int AttentionOp::mlaGeneration(
|
||||
flashMlaParams.scale_softmax = softmax_scale;
|
||||
flashMlaParams.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
|
||||
|
||||
flashMlaParams.q_ptr = mFP8GenerationMLA
|
||||
? const_cast<void*>(reinterpret_cast<void const*>(params.quant_attention_input_buf))
|
||||
: const_cast<void*>(reinterpret_cast<void const*>(params.attention_input_buf));
|
||||
flashMlaParams.q_ptr = mFP8GenerationMLA ? const_cast<void*>(reinterpret_cast<void const*>(params.quant_q_buf))
|
||||
: const_cast<void*>(reinterpret_cast<void const*>(params.q_buf));
|
||||
flashMlaParams.k_ptr = kv_cache_buffer.mPrimaryPoolPtr;
|
||||
flashMlaParams.v_ptr = flashMlaParams.k_ptr;
|
||||
flashMlaParams.o_ptr = reinterpret_cast<void*>(params.context_buf);
|
||||
@ -1253,8 +1257,8 @@ int AttentionOp::mlaGeneration(
|
||||
// fmhaParams.totalKvSeqLen = params.num_tokens;
|
||||
// Device buffer pointers.
|
||||
// fmhaParams.qkvPtr = reinterpret_cast<void const*>(params.attention_input);
|
||||
fmhaParams.qPtr = mFP8GenerationMLA ? reinterpret_cast<void const*>(params.quant_attention_input_buf)
|
||||
: reinterpret_cast<void const*>(params.attention_input_buf);
|
||||
fmhaParams.qPtr = mFP8GenerationMLA ? reinterpret_cast<void const*>(params.quant_q_buf)
|
||||
: reinterpret_cast<void const*>(params.q_buf);
|
||||
// TODO: add contiguous kv buffer (cross-attention).
|
||||
fmhaParams.kvPtr = nullptr;
|
||||
|
||||
@ -1379,15 +1383,19 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
= mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
|
||||
int const total_v_dim_all_heads
|
||||
= mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
|
||||
int const num_total_qkv_elements
|
||||
= params.num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
|
||||
// Packed fp8 qkv buffer size for normal fp8 context FMHA
|
||||
size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
|
||||
? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv)
|
||||
: 0;
|
||||
if (mFP8ContextMLA)
|
||||
// Separate fp8 q/k/v buffer size for fp8 context MLA
|
||||
size_t fp8_q_buf_size = 0;
|
||||
size_t fp8_k_buf_size = 0;
|
||||
size_t fp8_v_buf_size = 0;
|
||||
if (mEnableContextFMHA && mFP8ContextMLA && mFmhaDispatcher->isSeparateQAndKvInput())
|
||||
{
|
||||
fp8_qkv_buffer_size
|
||||
= mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0;
|
||||
fp8_q_buf_size = params.num_tokens * static_cast<size_t>(total_q_dim_all_heads);
|
||||
fp8_k_buf_size = params.total_kv_len * static_cast<size_t>(total_k_dim_all_heads);
|
||||
fp8_v_buf_size = params.total_kv_len * static_cast<size_t>(total_v_dim_all_heads);
|
||||
}
|
||||
size_t const padding_offset_size
|
||||
= mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.input_seq_length;
|
||||
@ -1424,6 +1432,12 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
float* qk_buf_float_ = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, qk_buf_float_size));
|
||||
__nv_fp8_e4m3* fp8_qkv_buffer
|
||||
= reinterpret_cast<__nv_fp8_e4m3*>(nextWorkspacePtr(workspace_byte_ptr, offset, fp8_qkv_buffer_size));
|
||||
__nv_fp8_e4m3* fp8_q_buf
|
||||
= reinterpret_cast<__nv_fp8_e4m3*>(nextWorkspacePtr(workspace_byte_ptr, offset, fp8_q_buf_size));
|
||||
__nv_fp8_e4m3* fp8_k_buf
|
||||
= reinterpret_cast<__nv_fp8_e4m3*>(nextWorkspacePtr(workspace_byte_ptr, offset, fp8_k_buf_size));
|
||||
__nv_fp8_e4m3* fp8_v_buf
|
||||
= reinterpret_cast<__nv_fp8_e4m3*>(nextWorkspacePtr(workspace_byte_ptr, offset, fp8_v_buf_size));
|
||||
int* padding_offset = mEnableContextFMHA
|
||||
? nullptr
|
||||
: reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, padding_offset_size));
|
||||
@ -1638,16 +1652,18 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
"Found invalid number (NaN or Inf) in " + beforeRopeStr);
|
||||
}
|
||||
|
||||
KVBlockArray mla_context_paged_kv_cache_buffer;
|
||||
if (mIsMLAEnabled)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(params.mla_param != nullptr, "MLA param is nullptr");
|
||||
params.mla_param->cache_type = cache_type;
|
||||
params.mla_param->cu_q_seqlens = cu_q_seqlens;
|
||||
params.mla_param->quant_scale_kv = params.kv_scale_orig_quant;
|
||||
// Set BMM scales for FP8 context computation
|
||||
params.mla_param->bmm1_scale = fmha_bmm1_scale_ptr;
|
||||
params.mla_param->bmm2_scale = fmha_bmm2_scale_ptr;
|
||||
params.mla_param->quant_attention_input_buf = mFP8ContextMLA ? fp8_qkv_buffer : nullptr;
|
||||
params.mla_param->quant_q_buf = mFP8ContextMLA ? fp8_q_buf : nullptr;
|
||||
params.mla_param->quant_k_buf = mFP8ContextMLA ? fp8_k_buf : nullptr;
|
||||
params.mla_param->quant_v_buf = mFP8ContextMLA ? fp8_v_buf : nullptr;
|
||||
// Set additional scales for context phase
|
||||
params.mla_param->quant_scale_o = params.attention_output_orig_quant;
|
||||
params.mla_param->quant_scale_q = params.kv_scale_orig_quant;
|
||||
@ -1656,37 +1672,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
params.mla_param->dequant_scale_kv = params.kv_scale_quant_orig;
|
||||
params.mla_param->host_bmm1_scale
|
||||
= 1 / (mQScaling * sqrt((float) (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim)));
|
||||
if (mPagedContextFMHA && mPagedKVCache)
|
||||
if (params.mla_param->latent_cache != nullptr)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(params.mla_param->context_paged_kv_ptr != nullptr,
|
||||
"Paged kv cache is not set for MLA context kernel");
|
||||
TLLM_CHECK_WITH_INFO(params.mla_param->context_kv_cache_block_offsets_ptr != nullptr,
|
||||
"Paged kv cache block offsets is not set for MLA context kernel");
|
||||
// build another KVBlockArray for MLA context kernel to read paged kv cache, which is built by the
|
||||
// PyTorch backend assume the dtype of paged kv cache is the same as the T
|
||||
auto const elemSize = sizeof(T);
|
||||
auto const headSize = params.mla_param->meta.qk_nope_head_dim + params.mla_param->meta.qk_rope_head_dim;
|
||||
// mNumKVHeads is 1 for writing, we use mNumHeads for reading paged kv cache
|
||||
auto sizePerToken = mNumHeads * headSize * elemSize;
|
||||
auto maxBlocksPerSeq = params.mla_param->context_paged_kv_max_blocks_per_seq;
|
||||
TLLM_LOG_DEBUG(
|
||||
"AttentionOp building KVBlockArray for MLA context kernel, elemSize: %d, headSize: %d, mNumHeads: "
|
||||
"%d, sizePerToken: %d, batchSize: %d, maxBlocksPerSeq: %d, tokensPerBlock: %d, maxAttentionWindow: "
|
||||
"%d, "
|
||||
"sinkTokenLen: %d, canUseOneMoreBlock: %d",
|
||||
elemSize, headSize, mNumHeads, sizePerToken, params.batch_size, maxBlocksPerSeq, mTokensPerBlock,
|
||||
params.cyclic_attention_window_size, params.sink_token_length, params.can_use_one_more_block);
|
||||
mla_context_paged_kv_cache_buffer = KVBlockArray(params.batch_size, maxBlocksPerSeq, mTokensPerBlock,
|
||||
sizePerToken, params.cyclic_attention_window_size, params.max_cyclic_attention_window_size,
|
||||
params.sink_token_length, params.can_use_one_more_block, params.mla_param->context_paged_kv_ptr,
|
||||
nullptr,
|
||||
static_cast<KVBlockArray::DataType*>(params.mla_param->context_kv_cache_block_offsets_ptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
// compute RoPE and set compressed_kv + k_pe by invokeMLARopeContext if not using paged context FMHA
|
||||
// compute RoPE and set compressed_kv + k_pe by invokeMLARopeContext if latent_cache is not nullptr
|
||||
invokeMLARopeContext<T, KVCacheBuffer>(*params.mla_param, kv_cache_buffer, stream);
|
||||
}
|
||||
if (mFP8ContextMLA)
|
||||
{
|
||||
invokeMLAContextFp8Quantize(*params.mla_param, params.total_kv_len, stream);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -1709,7 +1703,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
mFMHAForceFP32Acc = mFMHAForceFP32Acc || enable_context_fmha_fp32_acc_val == 1;
|
||||
}
|
||||
|
||||
// Unified FMHA runner interface for both packed QKV FMHA, contiguous Q_KV and paged KV FMHA.
|
||||
// Unified FMHA runner interface for both packed QKV FMHA, contiguous Q_KV, paged KV FMHA, and separate QKV
|
||||
// FMHA.
|
||||
// Page KV input layout:
|
||||
// - q_ptr: [B, S, H, D], which supports variable sequence length
|
||||
// - paged_kv_cache: paged kv buffer
|
||||
@ -1721,6 +1716,14 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
// - kv_ptr: [B, S, 2, H, D], which supports variable sequence length
|
||||
// - cu_q_seqlens: the cumulative query sequence lengths, needed for variable sequence length.
|
||||
// - cu_kv_seqlens: the cumulative kv sequence lengths, needed for variable sequence length.
|
||||
//
|
||||
// Separate QKV input layout (only for context MLA now):
|
||||
// - q_ptr: [B, S, H, D], which supports variable sequence length
|
||||
// - k_ptr: [B, S, H_kv, D], which supports variable sequence length
|
||||
// - v_ptr: [B, S, H_kv, D_v], which supports variable sequence length
|
||||
// - cu_q_seqlens: the cumulative query sequence lengths, needed for variable sequence length.
|
||||
// - cu_kv_seqlens: the cumulative kv sequence lengths, needed for variable sequence length.
|
||||
// - total_kv_len: the total kv sequence length, needed for variable sequence length.
|
||||
|
||||
// Construct the fmha params for running kernels.
|
||||
MHARunnerParams fmhaParams{};
|
||||
@ -1732,11 +1735,35 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
= (mDenseContextFMHA || isCrossAttention()) ? max_kv_seq_len : params.cyclic_attention_window_size;
|
||||
fmhaParams.totalQSeqLen = params.num_tokens;
|
||||
// TODO: set it correctly for contiguous kv buffer (cross-attention).
|
||||
fmhaParams.totalKvSeqLen = isCrossAttention() ? params.num_encoder_tokens : params.num_tokens;
|
||||
fmhaParams.totalKvSeqLen = isCrossAttention() ? params.num_encoder_tokens : params.total_kv_len;
|
||||
// Device buffer pointers.
|
||||
fmhaParams.qkvPtr = (mFP8ContextFMHA || mFP8ContextMLA) ? reinterpret_cast<void const*>(fp8_qkv_buffer)
|
||||
: reinterpret_cast<void const*>(attention_input);
|
||||
fmhaParams.qPtr = reinterpret_cast<void const*>(q_buf_2_);
|
||||
if (mIsMLAEnabled)
|
||||
{
|
||||
// separate QKV input for context MLA
|
||||
if (mFP8ContextMLA)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mFmhaDispatcher->isSeparateQAndKvInput(), "Separate QKV input is required for fp8 context MLA");
|
||||
TLLM_CHECK_WITH_INFO(fp8_q_buf != nullptr, "FP8 q buffer is required for fp8 context MLA");
|
||||
TLLM_CHECK_WITH_INFO(fp8_k_buf != nullptr, "FP8 k buffer is required for fp8 context MLA");
|
||||
TLLM_CHECK_WITH_INFO(fp8_v_buf != nullptr, "FP8 v buffer is required for fp8 context MLA");
|
||||
fmhaParams.qPtr = reinterpret_cast<void const*>(fp8_q_buf);
|
||||
fmhaParams.kPtr = reinterpret_cast<void const*>(fp8_k_buf);
|
||||
fmhaParams.vPtr = reinterpret_cast<void const*>(fp8_v_buf);
|
||||
}
|
||||
else
|
||||
{
|
||||
fmhaParams.qPtr = attention_input;
|
||||
fmhaParams.kPtr = params.k_ptr;
|
||||
fmhaParams.vPtr = params.v_ptr;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast<void const*>(fp8_qkv_buffer)
|
||||
: reinterpret_cast<void const*>(attention_input);
|
||||
fmhaParams.qPtr = reinterpret_cast<void const*>(q_buf_2_);
|
||||
}
|
||||
// TODO: add contiguous kv buffer (cross-attention).
|
||||
fmhaParams.kvPtr = nullptr;
|
||||
if (isCrossAttention() && !useKVCache())
|
||||
@ -1750,15 +1777,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
fmhaParams.packedMaskPtr = params.attention_packed_mask;
|
||||
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
|
||||
{
|
||||
if (mIsMLAEnabled && mPagedContextFMHA && mPagedKVCache)
|
||||
{
|
||||
fmhaParams.pagedKvCache = mla_context_paged_kv_cache_buffer;
|
||||
fmhaParams.qPtr = reinterpret_cast<void const*>(attention_input);
|
||||
}
|
||||
else
|
||||
{
|
||||
fmhaParams.pagedKvCache = kv_cache_buffer;
|
||||
}
|
||||
fmhaParams.pagedKvCache = kv_cache_buffer;
|
||||
}
|
||||
fmhaParams.cuQSeqLenPtr = cu_q_seqlens;
|
||||
fmhaParams.kvSeqLenPtr = decoder_params.seqKVLengths;
|
||||
@ -2606,6 +2625,8 @@ int AttentionOp::initialize() noexcept
|
||||
// the wrong kernel, no matter mIsGenerationMLA is true or false
|
||||
if (mIsMLAEnabled)
|
||||
{
|
||||
// Context MLA always use separate_q_k_v layout
|
||||
fmhaParams.attentionInputLayout = AttentionInputLayout::SEPARATE_Q_K_V;
|
||||
// Context attention of MLA is different
|
||||
fmhaParams.numKvHeads = mNumHeads;
|
||||
fmhaParams.headSize = mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim;
|
||||
@ -2614,6 +2635,7 @@ int AttentionOp::initialize() noexcept
|
||||
// attention op and that could fail to create the FmhaDispatcher for context phase.
|
||||
// Luckily, for deepseek, qk_nope_head_dim is the same as v_head_dim in context phase.
|
||||
fmhaParams.headSizeV = mMLAParams.qk_nope_head_dim;
|
||||
fmhaParams.headSizeQkNope = mMLAParams.qk_nope_head_dim;
|
||||
}
|
||||
fmhaParams.qScaling = mQScaling;
|
||||
fmhaParams.attnLogitSoftcappingScale = mAttnLogitSoftcappingScale;
|
||||
|
||||
@ -95,6 +95,7 @@ public:
|
||||
void* host_primary_pool_pointer = nullptr;
|
||||
void* host_secondary_pool_pointer = nullptr;
|
||||
int32_t num_tokens = 0;
|
||||
int32_t total_kv_len = 0;
|
||||
int32_t max_blocks_per_sequence = 0;
|
||||
int32_t const* sequence_lengths = nullptr;
|
||||
int32_t const* context_lengths = nullptr;
|
||||
@ -128,6 +129,9 @@ public:
|
||||
|
||||
// For MLA chunked prefill
|
||||
void* softmaxStatsPtr = nullptr;
|
||||
// optional for separate QKV input, currently only used for context MLA
|
||||
T const* k_ptr = nullptr;
|
||||
T const* v_ptr = nullptr;
|
||||
|
||||
std::string enqueueContextParamsToString() const
|
||||
{
|
||||
@ -169,6 +173,7 @@ public:
|
||||
ss << "host_secondary_pool_pointer: " << this->host_secondary_pool_pointer << std::endl;
|
||||
ss << "batch_size: " << this->batch_size << std::endl;
|
||||
ss << "num_tokens: " << this->num_tokens << std::endl;
|
||||
ss << "total_kv_len: " << this->total_kv_len << std::endl;
|
||||
ss << "max_blocks_per_sequence: " << this->max_blocks_per_sequence << std::endl;
|
||||
ss << "workspace: " << this->workspace << std::endl;
|
||||
ss << "logn_scaling_ptr: " << this->logn_scaling_ptr << std::endl;
|
||||
@ -179,6 +184,8 @@ public:
|
||||
ss << "encoder_input_lengths: " << this->encoder_input_lengths << std::endl;
|
||||
ss << "num_encoder_tokens: " << this->num_encoder_tokens << std::endl;
|
||||
ss << "softmaxStatsPtr: " << this->softmaxStatsPtr << std::endl;
|
||||
ss << "k_ptr: " << this->k_ptr << std::endl;
|
||||
ss << "v_ptr: " << this->v_ptr << std::endl;
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
@ -193,12 +193,10 @@ extern void run_fmha_v2_flash_attention_bf16_64_128_S_qkv_104_tma_ws_sm90(Fused_
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_64_S_qkv_160_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_64_S_qkv_192_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_softcapping_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_64_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_40_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_48_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
@ -210,13 +208,12 @@ extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_104_tma_ws_sm90
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_160_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_192_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_softcapping_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_64_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_qkv_32_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_qkv_40_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_256_S_qkv_48_alibi_tma_ws_sm90(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
@ -450,14 +447,8 @@ extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm8
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89_cu_cubin[];
|
||||
@ -666,9 +657,6 @@ extern void run_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softca
|
||||
extern void run_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_256_softcapping_sm89_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm89_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm89_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm89_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm89_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm89_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
#endif
|
||||
|
||||
#ifndef EXCLUDE_SM_80
|
||||
@ -880,9 +868,6 @@ extern void run_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softca
|
||||
extern void run_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_256_softcapping_sm80_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm80_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm80_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm80_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm80_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm80_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
#endif
|
||||
|
||||
#ifndef EXCLUDE_SM_86
|
||||
@ -1092,14 +1077,10 @@ extern void run_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softca
|
||||
extern void run_fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_256_softcapping_sm86_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm86_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm86_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm86_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm86_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm86_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
#endif
|
||||
|
||||
#ifndef EXCLUDE_SM_100
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm100_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm100_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm100_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
#endif
|
||||
|
||||
@ -1240,8 +1221,7 @@ extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_256_softcapping_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_128_softcapping_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_softcapping_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_32_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
@ -1271,13 +1251,11 @@ extern void run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm120_nl(Fused
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm120_nl(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_16_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
extern void run_fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_32_sm120_nl_tiled(Fused_multihead_attention_params_v2& params, const Launch_params& launch_params, cudaStream_t stream);
|
||||
@ -1426,14 +1404,8 @@ extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm89_cu_
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89_cu_cubin_len;
|
||||
@ -1820,7 +1792,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_qkv_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_qkv_256_sliding_or_chunked_causal_tma_ws_sm90_kernel", 196864, 384, 64, 2, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_qkv_256_custom_mask_tma_ws_sm90_kernel", 196864, 384, 64, 3, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 0, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_128_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 1, 0, false, true, true, true, false, false, true, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 2, 0, false, true, true, true, false, false, true, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_qkv_256_causal_softcapping_tma_ws_sm90_kernel", 196864, 384, 64, 1, 0, false, true, true, true, false, false, true, false, run_fmha_v2_flash_attention_bf16_64_64_S_qkv_256_softcapping_tma_ws_sm90},
|
||||
@ -1833,7 +1804,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_custom_mask_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_kernel", 164096, 384, 64, 0, 1, false, true, true, true, false, false, false, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_custom_mask_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 1, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90_kernel", 73984, 384, 64, 0, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_causal_tma_ws_sm90_kernel", 73984, 384, 64, 1, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_sliding_or_chunked_causal_tma_ws_sm90_kernel", 73984, 384, 64, 2, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_q_paged_kv_32_tma_ws_sm90},
|
||||
@ -1873,11 +1843,11 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, 1, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_sliding_or_chunked_causal_tma_ws_sm90_kernel", 196864, 384, 64, 2, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_custom_mask_tma_ws_sm90_kernel", 196864, 384, 64, 3, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 2, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 1, 2, false, true, true, true, false, false, true, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 164096, 384, 64, 2, 2, false, true, true, true, false, false, true, false, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_causal_softcapping_tma_ws_sm90_kernel", 196864, 384, 64, 1, 2, false, true, true, true, false, false, true, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_softcapping_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 256, 256, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_tma_ws_sm90_kernel", 196864, 384, 64, 2, 2, false, true, true, true, false, false, true, false, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_256_softcapping_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90_kernel", 213248, 384, 64, 0, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_tma_ws_sm90_kernel", 213248, 384, 64, 1, 3, false, true, true, true, false, false, false, false, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90_kernel", 73984, 384, 64, 0, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_custom_mask_softmax_tma_ws_sm90_kernel", 73984, 384, 64, 3, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_256_S_q_kv_32_softmax_tma_ws_sm90},
|
||||
@ -1887,7 +1857,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 72, 72, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_custom_mask_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_72_softmax_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 0, 1, false, true, true, true, false, false, false, true, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 128, 128, 0, 0, 0, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_custom_mask_softmax_tma_ws_sm90_kernel", 164096, 384, 64, 3, 1, false, true, true, true, false, false, false, true, nullptr},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_causal_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 1, 1, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_kv_192x128_softmax_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 0, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_softmax_tma_ws_sm90_kernel", 213248, 384, 64, 1, 3, false, true, true, true, false, false, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_softmax_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 32, 32, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_qkv_32_causal_alibi_tma_ws_sm90_kernel", 73984, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_qkv_32_alibi_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 40, 40, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_qkv_40_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_qkv_40_alibi_tma_ws_sm90},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 256, 48, 48, 0, 0, 0, kSM_90, nullptr, 0, "fmha_v2_flash_attention_bf16_64_256_S_qkv_48_causal_alibi_tma_ws_sm90_kernel", 147712, 384, 64, 1, 0, false, true, true, true, true, false, false, false, run_fmha_v2_flash_attention_bf16_64_256_S_qkv_48_alibi_tma_ws_sm90},
|
||||
@ -2522,18 +2493,12 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_causal_sm89_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sliding_or_chunked_causal_sm89_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_custom_mask_sm89_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 576, 512, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_causal_output_bf16_sm89_kernel_nl", 65536, 128, 64, 1, 0, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sliding_or_chunked_causal_output_bf16_sm89_kernel_nl", 65536, 128, 64, 2, 0, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_custom_mask_output_bf16_sm89_kernel_nl", 65536, 128, 64, 3, 0, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_causal_output_bf16_sm89_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sliding_or_chunked_causal_output_bf16_sm89_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_custom_mask_output_bf16_sm89_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_output_bf16_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_output_bf16_sm89_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 80, 80, 64, 32, 32, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 128, 128, 64, 32, 32, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, nullptr},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_FP16, 0, 64, 32, 80, 80, 64, 32, 32, kSM_89, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89_cu_cubin_len, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89_kernel_nl", 32768, 128, 64, 0, 0, false, true, false, true, true, false, false, true, nullptr},
|
||||
@ -3144,9 +3109,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm89_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm89_nl},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_causal_softcapping_sm89_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm89_nl},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm89_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm89_nl},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm89_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm89_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm89_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_89, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm89_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm89_nl_tiled},
|
||||
#endif
|
||||
|
||||
#ifndef EXCLUDE_SM_80
|
||||
@ -3756,9 +3718,6 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm80_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm80_nl},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_causal_softcapping_sm80_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm80_nl},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm80_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm80_nl},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm80_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm80_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm80_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_80, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm80_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm80_nl_tiled},
|
||||
#endif
|
||||
|
||||
#ifndef EXCLUDE_SM_86
|
||||
@ -4368,14 +4327,11 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm86_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_32_S_q_paged_kv_128_softcapping_sm86_nl},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_causal_softcapping_sm86_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm86_nl},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm86_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_fp16_fp32_64_16_S_q_paged_kv_256_softcapping_sm86_nl},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm86_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm86_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm86_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_86, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm86_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm86_nl_tiled},
|
||||
#endif
|
||||
|
||||
#ifndef EXCLUDE_SM_100
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm100_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm100_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm100_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm100_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_kernel_nl_tiled", 81920, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_sm100_kernel_nl_tiled", 81920, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm100_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_100, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm100_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm100_nl_tiled},
|
||||
#endif
|
||||
|
||||
@ -4784,8 +4740,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 32, 128, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_128_sliding_or_chunked_causal_softcapping_sm120_kernel_nl", 32768, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_128_softcapping_sm120_nl},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_causal_softcapping_sm120_kernel_nl", 49152, 128, 64, 1, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_softcapping_sm120_nl},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 16, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_sliding_or_chunked_causal_softcapping_sm120_kernel_nl", 49152, 128, 64, 2, 2, false, true, false, true, true, false, true, true, run_fmha_v2_flash_attention_bf16_64_16_S_q_paged_kv_256_softcapping_sm120_nl},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_qkv_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_kernel_nl_tiled", 81920, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 128, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_causal_sm120_kernel_nl_tiled", 81920, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_128_S_q_k_v_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm120_kernel_nl_tiled", 49152, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_bf16_64_64_S_q_paged_kv_576x512_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 128, 128, 32, 32, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm120_kernel_nl", 12288, 128, 128, 0, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm120_nl},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 128, 128, 32, 32, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_causal_sm120_kernel_nl", 12288, 128, 128, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm120_nl},
|
||||
@ -4874,8 +4830,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_causal_sm120_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm120_nl},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sliding_or_chunked_causal_sm120_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm120_nl},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 32, 256, 256, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_custom_mask_sm120_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm120_nl},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_causal_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_E4M3, 0, 64, 64, 576, 512, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm120_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_causal_output_bf16_sm120_kernel_nl", 65536, 128, 64, 1, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm120_nl},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sliding_or_chunked_causal_output_bf16_sm120_kernel_nl", 65536, 128, 64, 2, 0, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm120_nl},
|
||||
@ -4883,8 +4839,8 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_causal_output_bf16_sm120_kernel_nl", 65536, 128, 64, 1, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm120_nl},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sliding_or_chunked_causal_output_bf16_sm120_kernel_nl", 65536, 128, 64, 2, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm120_nl},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 32, 192, 192, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_custom_mask_output_bf16_sm120_kernel_nl", 65536, 128, 64, 3, 2, false, true, false, true, true, false, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm120_nl},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_causal_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_causal_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 0, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 192, 128, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_causal_output_bf16_sm120_kernel_nl_tiled", 32768, 128, 64, 1, 3, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_k_v_192x128_output_bf16_sm120_nl_tiled},
|
||||
{ DATA_TYPE_E4M3, DATA_TYPE_BF16, 0, 64, 64, 576, 512, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm120_kernel_nl_tiled", 65536, 128, 64, 0, 2, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm120_nl_tiled},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 128, 128, 16, 16, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_16_causal_sm120_kernel_nl_tiled", 16384, 128, 128, 1, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_16_sm120_nl_tiled},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 0, 128, 128, 16, 16, 0, 0, 0, kSM_120, nullptr, 0, "fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_16_sliding_or_chunked_causal_sm120_kernel_nl_tiled", 16384, 128, 128, 2, 0, false, true, false, true, true, true, false, true, run_fmha_v2_flash_attention_fp16_fp32_128_128_S_qkv_16_sm120_nl_tiled},
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c709dce149c0f4500539e495c90d1da2d86cec28c4187ee9494b015642e158cf
|
||||
size 363441
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b9170581da010aca67f4bafd9f6f59aaaf5fd1958a1fdd336aa208146599ac06
|
||||
size 1094770
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2147a246067f7ea74ca382fbc8c02a26332479e5205ecfbe08fb84161a3a87ec
|
||||
size 1483888
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:279bd48b8ac53690bb4e37dffbe9060428db80c1417ff29c6f4d4a10ab35a7c9
|
||||
oid sha256:f7cd70cc37451a7b7a43679dad30ef15d1cd0017762cb716ec412a4ebe0c3e1a
|
||||
size 700094
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:db5d186ce70d7a94cae2b6619b3449ca557903944beba1ee738d2ee425792d74
|
||||
oid sha256:3d4f0a4e3d19dec07331ea48e38fc0f25beef3c0e29e4688dca5ba488c55ec54
|
||||
size 652718
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:089a98cf8ab0bbd7530e69821c42220ea02578b740bff62a3e6e33de45209114
|
||||
oid sha256:f2b44305b58da85faac69dd59921a8ff889174f690ae89b2dcddc7d704046a51
|
||||
size 416335
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1f0cc486ec5e9c1720f495a2a5e7c26d42e737694d307d4746a08b6ead5cc225
|
||||
oid sha256:48eed98ece216ad1e339949020d5d1e99af3ac4893ec6f502ed8f669fa91f88a
|
||||
size 1197394
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:398965e34c1a4c747b42d8836c04934daaa43903b7931586ed12120e17a61f76
|
||||
oid sha256:7a980c264dbab18c9b528d28e2d5887818aab94d1e1097fe0f56a411f41ff3a7
|
||||
size 1672548
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:77cbd7d45164d24be73e021bc0a8745b4f021e4369a254e216ee00b36d3c7263
|
||||
oid sha256:c15693202fa72a88bf2ee7a1fe742238909988e0d57744a67a513bd921506ac2
|
||||
size 366593
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8e26f3b8cc173301b3cf07ba1ca7893b6f140432410b0b298361ecff597604c2
|
||||
oid sha256:826b74c39f5e59e600caa36d926b6ace29a7e46ba2d1a8cf2fe153f993f80dba
|
||||
size 1095556
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:32220d11bc3542e9edcc36d51b4866bf40044213114d7e237e003afc1fc7c464
|
||||
oid sha256:232b414a2ae4a7db0eac36c90b2345c4a353cd34c798501b9407b926f2d356ec
|
||||
size 1478358
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3ee5ae75df4866d848e90616562345d3740b17b68c90f06329dc074dba5217a9
|
||||
size 482709
|
||||
oid sha256:bb1231035d1664f4e297b4f9791e2faba45a72cd32395a43a6746b7122477f2c
|
||||
size 480341
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3e1ecaa635067924b692b665241d86e1d8c1d60a19290de7adde1ff2ca7dbeb0
|
||||
oid sha256:6471f6d9d5202376d80c8f7c4120a566957a5898b360ae547a77624fa7870251
|
||||
size 956612
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d3018c622303f89c6f22f037ec99eaeaeea9cfe8911e22463b48a22c13116805
|
||||
oid sha256:a9ec6512514e4ed352ded776ae591b8cfbfc23b6414eeedc3861b5a47141eb4d
|
||||
size 592357
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a7a381f2855236f418a40124a5254401c95001d5e15c074a704e22cc7ed89aa2
|
||||
oid sha256:d6c7c58e214c5dc789bfa5ab42846664723eccfe603a50961068c4c4db35d846
|
||||
size 1818600
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9bb49ace4dedc4faa3de2b9c22e09db0f3990129ce7ab4afb6419c38a5d48a16
|
||||
oid sha256:3dce2bfb8e79278b80f5e3f77dac6949b9f763c7dc5a80910fd3ef361aba5955
|
||||
size 2427152
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9769d7cb9754718798be515c84c45ff48e43322573f3f12e31c2e42e99d8dbd4
|
||||
oid sha256:a0bcf2b7464ea2873c9d8e74884df7977eda39e2df6acc203534290cdd7e0892
|
||||
size 557613
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:134f4a73e0e6b02b717319ec49e3b3ea0a585cad385a1f300e6c5761f12de9d7
|
||||
oid sha256:40b93e65748ccf03c381dab7844480e2b89b1c3c808a7c33e94cf2842c432256
|
||||
size 671320
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7935b0f053a79a7e620c0efe274fa5b4c840fc9c6e439a381c4d380446e1cb68
|
||||
oid sha256:d51fcaad4d1f2d094baf94f66c0cd1e4d24322da546e328436c7a60e0dc37823
|
||||
size 1744388
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:74ecbbaa19b2efe97a3b12c488f0e03c2102f16c460239df4bfc19976fc4365e
|
||||
oid sha256:5277b7be251586814d4a2cd9e1de1619279e7668d0081faf153056267cd0f350
|
||||
size 2266902
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:813265d25709bd2d39982efbaf092c9163b124bd990fccab505b3c22134522aa
|
||||
size 595585
|
||||
oid sha256:5cd5e5880a553637230aeb78eedc765afb4f8cd8abd05703579e2545452f80c9
|
||||
size 593217
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:dd36195c01bf7c2a2013d5f31d2e74c2579c471385d7b45be7e35ea2f0652608
|
||||
size 908162
|
||||
oid sha256:bcd4344770e379f65fde20d25d20d3f7854aaf968b9a299b33658a1995ea5e32
|
||||
size 905004
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:31d4d6dca68c4632d1f435e9179582cfe2ad7a75ee0f7625ee67b0044c914f10
|
||||
size 1371512
|
||||
oid sha256:315cabc45bc7a8290c6a8a12d7b750154e94eee920811401f215763e8ce719eb
|
||||
size 1366776
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6570d3ee7b651dec797e82b31eb21fd3261c6e2639fb7c9b157f251bf98bb3bf
|
||||
size 1419662
|
||||
oid sha256:f20388cef55675a265e790be3cded63d8999373d8eb19386d0c7bbea432381da
|
||||
size 1417294
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:88b972677c5436b90fe85870278e3b23d6f709608f99295bddf0be3861d95d1a
|
||||
size 1419662
|
||||
oid sha256:dde99f0026396063b68e94063c69e7fd799284da02152d308f68e2e728c46e8f
|
||||
size 1417294
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d975f605d62c3070d6cf72f6114d98642c520e66989ed2d2845c3213e921ebf7
|
||||
size 1965880
|
||||
oid sha256:87f8d1d345231bf20d7b4553e92fa1f52b8ca1694c0da7535867b5105aa2c063
|
||||
size 1961144
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7c2d7ab0692de5405b26d19a0c57d720285366ac12a8550bbabca1613cce7f0c
|
||||
size 305897
|
||||
oid sha256:b1290d40043da35b674f832685cf9f4c0c0534002298b5187c14bf7d614ecd24
|
||||
size 302741
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:91a26adfddc0bcaf8b42249f59f1a0b9f74be0f82c7378fe4b56f3a2fa3d4bf1
|
||||
size 290109
|
||||
oid sha256:67695ec794f5746b1757e273f8eaceb974d57e91ed3063e3873ddc0d144d46f1
|
||||
size 288531
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6ef79c9e2e2d8bba55d7803dc8dc147b5d8babc29e906a43407a8722bbd8d939
|
||||
size 498507
|
||||
oid sha256:a8c253eafb26d52f79de54a9856be97b988e4adc8e6824ecb50b4999ff3f9607
|
||||
size 496139
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0eef025f8e8581868b02bcea37ff225afebcbb2966450fb29fb0e32ac54eccd4
|
||||
size 668214
|
||||
oid sha256:6f45790b8f859c6ccdc1c848f6321b583bc06e3a3b93681e066edc588b990170
|
||||
size 667426
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:abb2857ffb85cc36aae90ebb674635dffee2b2c5f7ad1ea81bb8002b65d5a0f8
|
||||
size 711628
|
||||
oid sha256:465a23bd4c7604cfd8d8a78b1f117e1d45172d3fe9e0d59804b3c82ed1283ebf
|
||||
size 703734
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:49a3661535314b139e2794fe16f6f3e0a8d45742b68ea59ba99a9113068adf2c
|
||||
size 752698
|
||||
oid sha256:cb808cb241cb58f5c98a2f3de87797799a44021020a5537c1fbb1a3c84f7f416
|
||||
size 749540
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d76fb6c4f8bb2de687bc5f9f275389356934119c1f0db9983dcf0ec7b68c6197
|
||||
size 748726
|
||||
oid sha256:c17d374897fd92df92adcc717b2b17b2781ea0cfc8f3be63f160aff078ab5ca3
|
||||
size 746358
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:be8ee89f4489c430d0ff6e9c6cf4e07379ac05abf468d47e34e084ad594b2037
|
||||
size 946060
|
||||
oid sha256:0adf8ae7688e2613eef57d110d7185fd0267d9d97d93b57f3bb9f67dcacf2127
|
||||
size 943692
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:aa4be8ca2dd52e56c9a6af76b90ac353d217fad5fa931b21129ac5a811b5283a
|
||||
size 489823
|
||||
oid sha256:8ee19a3d57b2795b547c5f5e0220313f3b8a59afaa56a29610c4a444f106ece3
|
||||
size 487455
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cb0482b768a40bc7f8a86fa23a84bab62fb82c205f3237ff60becda50cbafc90
|
||||
size 489823
|
||||
oid sha256:c8b1ebcea7fcf90c2a48ef118cb9c58294aded13d38ec682acefb414e107b99e
|
||||
size 487455
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:95b1796f4e7c905eca82ed3691427025f68e765797440b962b0114a5ab32b1d7
|
||||
size 500083
|
||||
oid sha256:63b1556854d992884134d26dbdfb717661ce85056f51387b9ada2ecb325bd578
|
||||
size 497715
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8f685b6b2a0a573953f31fad89fa37e949361db245de69c0c06ce0bbb14eacef
|
||||
size 443285
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:834f0f3601c589893a21b957be2864df594f96b34b2cfd6018ada8319986aa21
|
||||
size 441683
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3d81a070e7ed49f1e1a322d38a757a3505186cf5cbded99814e950e07229a46a
|
||||
size 298049
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b9de5bc49d888699da1880d24ccf6a9cb6c0049d7a244d1ae9ab64b7365ecd5a
|
||||
size 296445
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e30ed0df4b0d0b1da1ace5831dc0a7a526e04001b25860f862345c78acff5a43
|
||||
size 427485
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:030015dc1811e3dc2ae36ed770f51063a3f46deae42ead5e1523c977b438a133
|
||||
size 425883
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6921a204892e1336cef2a308be38855f3c888e56bd6a16752d2806aa9e93c431
|
||||
size 1524634
|
||||
oid sha256:b74e330f275a99c8ba94c5eaa600c24b5c8beb589bf95c242b81dab04a49db98
|
||||
size 1523844
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:200df98fb2fcc734e8fc012c98c5d78c2061e5718eef6ffd50c2358a3d664197
|
||||
size 406065
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:430194fe07e526ad01a1e0fb43273b240c269215b132c9af248ba386dcbda23e
|
||||
size 1124766
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:53a07904a7bfbf82380c96af99c5e24bc86f77906c5d6fdc85ef9720639d76d2
|
||||
size 1569136
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1ce4d27b11fee3e5f6489510b55613177e174660b6c7a6fb4efed862b62c50d7
|
||||
oid sha256:3956c73db35ea3988aa0bdf3798c388fe35448918c7a1ae5f2b7783b8cdf17f3
|
||||
size 731668
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3992d7bd34e72089c5cffc4fc6de3f70a3995145b989811f83b00b47c96b5159
|
||||
oid sha256:6d7f601b937eb5007c507655fd0a5e0e3788d230f35219c792d1f35580c29e97
|
||||
size 681924
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:521417177fc0447809c07ff86b58725fedbf1a6b9412ace4c50268a20bc2680d
|
||||
oid sha256:d2f062ee799ae89f394ed5092d9adb557510e51b9474a1535a3cd2548f32f923
|
||||
size 447119
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cb063c946558e6928faabb85df9775fecd2b9444b40b3e06cf0f863db80a5ad8
|
||||
size 1242842
|
||||
oid sha256:5350484be4826fdc7da6bb03d96421158afa7423bca7569234bd887564bee003
|
||||
size 1240474
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:31e6b7442b277f5206cc1d70fa6021f36170265b311106281e88b4611d1a5b6b
|
||||
oid sha256:78698afcaaf4eb325f240ef4ff512798c321394793b444e636a316a2dad496bc
|
||||
size 1220284
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c1342769efa91794d5bd35ac623b3014738b075b2671441668e2f0d5c1eef78a
|
||||
oid sha256:f542e4eb88c6040c96d23d6e1ab50b9a2d6da5eab64c9fa20e792b09bf4ac951
|
||||
size 1739642
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a49dd8abcca57a64eb2ab4e00e4e0d26edf68488fb67086a4b466f8e6651522e
|
||||
oid sha256:b2d4aec095c9e9763484987e868acf2182cd35fd6ec8254acc90057ebcf028fa
|
||||
size 410007
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a7d4526887fe860e0d9c482fc7fe2cfe646c7a20bc8a0813ce33a01fd9cc733c
|
||||
oid sha256:4048a1adcd670df5dea695a0b1a09e73629ecf0e430fab9a4529f4cc5695869b
|
||||
size 1125550
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b880e78ffc354edb541bd612e543dd894843fc4163f7bd65ce53282892381b8a
|
||||
oid sha256:85cf87e375d3b05f47f57b67f48c35bf516c60286a16ddac32fd1b914b74ba27
|
||||
size 1566764
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b332d4c6047c98b504cd3be72cc5028d240621c8e0a3260d64c17804982104db
|
||||
size 365029
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a16c23767a2e5efbd7330728ed87af2ec62a7731debe1da557705c6db6d3268e
|
||||
size 1096360
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:66950bc137b734d509f0574152bcf9cf7efcb17a7483450d5fdbf480e9f83001
|
||||
size 1486266
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bba586d9fe487c49cef2abfbfb0a078dde907d28e04b4d2335018cdb7031879c
|
||||
oid sha256:6a30185801336d52c40d06b41d631ed6651d1db563cd06edd7534deedb78e3f0
|
||||
size 701682
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d3e45ab30e471f4649807f5b7640512e2c6678cf623cadfcb26c93eb4ad60ec0
|
||||
oid sha256:bf706aecad7cd6177ae318723b1c55f4f9108e960f50540e3538eaaf24218633
|
||||
size 654306
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1932937b7f4ad0370341c77a03db133dd676bdf844b13eb45ec10243d1dfd16b
|
||||
oid sha256:5ba643582110007f29bbb03fd2bc34243255b4bc0d24355448249ae7fe7374ba
|
||||
size 417135
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c11f5d464b0486023b78babfdfe9d2768e4b0d13caeb436d6f73110ede72498c
|
||||
oid sha256:e2468e449b0361230e724b5551bd1c6d899bdd748438e7d47a3007dc369ce383
|
||||
size 1198982
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3bac9b40302bbfc6ee5a49e5c45d3238f46cff45619acd1b098d90e758d3ce30
|
||||
oid sha256:014e24c9f00859db417ed48d9372fde79a191d559268d069a9c0dfe4b44e15ec
|
||||
size 1675716
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:26f09ab86b52c40b283652e555f677850f00902151d17e375e016b9a99a97794
|
||||
oid sha256:ca2568bf3ac5fd23c74d739cb948a465d0f7d8cacd40e880b6f3c51f4f7ee30f
|
||||
size 368183
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9d0cf59a8114940070448d87d02d9e83d53bb371ca9915c3983e03626d17024e
|
||||
oid sha256:b400cd55b4ac4832a4160d3f51fe42d21ebf0a840b99ff937ed11fcf0e2994e5
|
||||
size 1097144
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ff1449b6795f5beda0b6a62e8a1171ce952b07c4e63b607c06f5fedddb2debe9
|
||||
oid sha256:3ae97a5d9070592673ffcb80f2e710f00edc6c89564086814dc53739fd6395c0
|
||||
size 1480736
|
||||
|
||||
@ -179,6 +179,31 @@ void FusedMHARunnerV2::setupKernelParams(MHARunnerParams runnerParams)
|
||||
// Thus, v_stride_in_bytes always equals to k_stride_in_bytes so far.
|
||||
mKernelParams.v_stride_in_bytes = mKernelParams.k_stride_in_bytes;
|
||||
}
|
||||
else if (mFixedParams.attentionInputLayout == AttentionInputLayout::SEPARATE_Q_K_V)
|
||||
{
|
||||
// Separate QKV input layout, [total_kv_seqlen, H_KV, D] + [total_kv_seqlen, H_KV, DV]
|
||||
TLLM_CHECK_WITH_INFO(runnerParams.kPtr != nullptr && runnerParams.vPtr != nullptr,
|
||||
"SEPARATE_Q_K_V requires valid K and V pointers.");
|
||||
mKernelParams.k_ptr = runnerParams.kPtr;
|
||||
mKernelParams.v_ptr = runnerParams.vPtr;
|
||||
// Tensor K is contiguous.
|
||||
mKernelParams.k_stride_in_bytes
|
||||
= get_size_in_bytes(mFixedParams.numKvHeads * mFixedParams.headSize, mFixedParams.dataType);
|
||||
if (mFixedParams.headSizeQkNope > 0 && mFixedParams.dataType != DATA_TYPE_E4M3)
|
||||
{
|
||||
// Non-FP8 context MLA: tensor V is not contiguous. The token stride is numKvHeads * (headSizeQkNope +
|
||||
// headSizeV).
|
||||
mKernelParams.v_stride_in_bytes = get_size_in_bytes(
|
||||
mFixedParams.numKvHeads * (mFixedParams.headSizeQkNope + mFixedParams.headSizeV),
|
||||
mFixedParams.dataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tensor V is contiguous for other cases.
|
||||
mKernelParams.v_stride_in_bytes
|
||||
= get_size_in_bytes(mFixedParams.numKvHeads * mFixedParams.headSizeV, mFixedParams.dataType);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mKernelParams.o_ptr = runnerParams.outputPtr;
|
||||
@ -464,13 +489,12 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams)
|
||||
{
|
||||
bool isHopperBF16ContextMLA = (mFixedParams.headSize == mFixedParams.headSizeV + 64) && isSm90
|
||||
&& mFixedParams.dataType == DATA_TYPE_BF16 && mFixedParams.headSizeV == 128;
|
||||
// TODO: add support for separate QKV input layout
|
||||
mLaunchParams.supportReturnSoftmaxStats = (runnerParams.softmaxStatsPtr != nullptr
|
||||
&& mLaunchParams.flash_attention && mLaunchParams.warp_specialization
|
||||
&& ((!isHopperBF16ContextMLA
|
||||
&& mLaunchParams.attention_input_layout == AttentionInputLayout::Q_CONTIGUOUS_KV)
|
||||
|| (isHopperBF16ContextMLA
|
||||
&& (mLaunchParams.attention_input_layout == AttentionInputLayout::Q_PAGED_KV))));
|
||||
&& (mLaunchParams.attention_input_layout == AttentionInputLayout::SEPARATE_Q_K_V))));
|
||||
}
|
||||
}
|
||||
|
||||
@ -623,6 +647,12 @@ void FusedMHARunnerV2::setTmaDescriptors(MHARunnerParams runnerParams)
|
||||
k_ptr = reinterpret_cast<char const*>(mKernelParams.kv_ptr);
|
||||
v_ptr = k_ptr + h_kv * d_in_bytes;
|
||||
}
|
||||
else if (layout == AttentionInputLayout::SEPARATE_Q_K_V)
|
||||
{
|
||||
// Layout: [total_kv_seqlen, H_KV, D] + [total_kv_seqlen, H_KV, DV]
|
||||
k_ptr = reinterpret_cast<char const*>(mKernelParams.k_ptr);
|
||||
v_ptr = reinterpret_cast<char const*>(mKernelParams.v_ptr);
|
||||
}
|
||||
|
||||
Multiple_tma_descriptor<3> kv_tma_descriptor;
|
||||
// K
|
||||
|
||||
@ -81,7 +81,10 @@ enum class AttentionInputLayout
|
||||
// Q has contiguous [B, S, H, D] layout, while paged KV has [B, 2, Max_blocks_per_seq] layout
|
||||
// that contains paged block indices. The indices indicate the block offset to the pool ptr in
|
||||
// global memory
|
||||
Q_PAGED_KV
|
||||
Q_PAGED_KV,
|
||||
// Q has contiguous [B, S, H, D] layout, while K has contiguous [B, S, H_kv, D] layout, and V has
|
||||
// contiguous [B, S, H_kv, D_v] layout. Only used for context MLA now.
|
||||
SEPARATE_Q_K_V,
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -115,6 +118,8 @@ struct MHARunnerFixedParams
|
||||
int headSize;
|
||||
// The head size of V.
|
||||
int headSizeV = 0;
|
||||
// The head size of Q/K non-RoPE part, only used for MLA now.
|
||||
int headSizeQkNope = 0;
|
||||
// The scaling applied to bmm1_scale.
|
||||
float qScaling;
|
||||
// The attention logit softcapping scale.
|
||||
@ -166,6 +171,7 @@ struct MHARunnerFixedParams
|
||||
case AttentionInputLayout::PACKED_QKV: output += "packed_qkv"; break;
|
||||
case AttentionInputLayout::Q_CONTIGUOUS_KV: output += "q_contiguous_kv"; break;
|
||||
case AttentionInputLayout::Q_PAGED_KV: output += "q_paged_kv"; break;
|
||||
case AttentionInputLayout::SEPARATE_Q_K_V: output += "separate_q_k_v"; break;
|
||||
default: output += std::to_string(static_cast<int>(attentionInputLayout)) + " (unknown)"; break;
|
||||
}
|
||||
|
||||
@ -255,6 +261,10 @@ struct MHARunnerParams
|
||||
void const* qPtr;
|
||||
// The contiguous Kv buffer ptr;
|
||||
void const* kvPtr;
|
||||
// The K buffer ptr (for separate K input).
|
||||
void const* kPtr;
|
||||
// The V buffer ptr (for separate V input).
|
||||
void const* vPtr;
|
||||
// The paged kv cache array.
|
||||
KVBlockArray pagedKvCache;
|
||||
// The output buffer ptr.
|
||||
|
||||
@ -36,6 +36,10 @@ QkvLayout AttentionInputLayoutToQkvLayout(AttentionInputLayout layout)
|
||||
{
|
||||
return QkvLayout::PagedKv;
|
||||
}
|
||||
else if (layout == AttentionInputLayout::SEPARATE_Q_K_V)
|
||||
{
|
||||
return QkvLayout::SeparateQkv;
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(false, "Unexpected AttentionInputLayout");
|
||||
return QkvLayout::SeparateQkv;
|
||||
}
|
||||
@ -148,6 +152,10 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
|
||||
maxBlocksPerSeq = pagedKvCache.mMaxBlocksPerSeq;
|
||||
numTokensPerBlock = pagedKvCache.mTokensPerBlock;
|
||||
}
|
||||
else if (mFixedParams.attentionInputLayout == AttentionInputLayout::SEPARATE_Q_K_V)
|
||||
{
|
||||
qkvLayout = kernels::QkvLayout::SeparateQkv;
|
||||
}
|
||||
|
||||
TllmGenFmhaRunnerParams tllmRunnerParams;
|
||||
memset(&tllmRunnerParams, 0, sizeof(tllmRunnerParams));
|
||||
@ -161,8 +169,8 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
|
||||
tllmRunnerParams.mMultiCtasKvMode = false;
|
||||
|
||||
tllmRunnerParams.qPtr = runnerParams.qPtr;
|
||||
tllmRunnerParams.kPtr = nullptr;
|
||||
tllmRunnerParams.vPtr = nullptr;
|
||||
tllmRunnerParams.kPtr = runnerParams.kPtr;
|
||||
tllmRunnerParams.vPtr = runnerParams.vPtr;
|
||||
tllmRunnerParams.kvPtr = kvPoolPtr;
|
||||
tllmRunnerParams.qkvPtr = runnerParams.qkvPtr;
|
||||
tllmRunnerParams.attentionSinksPtr = runnerParams.attentionSinksPtr;
|
||||
@ -181,6 +189,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
|
||||
// Assume same headDim for Qk and V here.
|
||||
tllmRunnerParams.mHeadDimQk = mFixedParams.headSize;
|
||||
tllmRunnerParams.mHeadDimV = mFixedParams.headSizeV;
|
||||
tllmRunnerParams.mHeadDimQkNope = mFixedParams.headSizeQkNope;
|
||||
tllmRunnerParams.mNumHeadsQ = mFixedParams.numQHeads;
|
||||
tllmRunnerParams.mNumHeadsKv = mFixedParams.numKvHeads;
|
||||
tllmRunnerParams.mNumHeadsQPerKv = tllmRunnerParams.mNumHeadsQ / tllmRunnerParams.mNumHeadsKv;
|
||||
@ -202,6 +211,7 @@ void FmhaDispatcher::run(MHARunnerParams runnerParams)
|
||||
// For mla chunked prefill
|
||||
tllmRunnerParams.softmaxStatsPtr = reinterpret_cast<float2*>(runnerParams.softmaxStatsPtr);
|
||||
tllmRunnerParams.stream = runnerParams.stream;
|
||||
|
||||
mTllmGenFMHARunner->run(tllmRunnerParams);
|
||||
}
|
||||
else
|
||||
|
||||
@ -71,26 +71,6 @@ struct loadChunkedKVKernelTraits
|
||||
static constexpr int kKVThreadPerHead = (kLoraSize * kBytesPerElem) / kBytesPerLoad;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct setChunkedKVKernelTraits
|
||||
{
|
||||
using VecT = uint4;
|
||||
static constexpr int kQKNopeSize = 128;
|
||||
static constexpr int kVHeadSize = 128;
|
||||
static_assert(kQKNopeSize == kVHeadSize);
|
||||
static constexpr int kRopeSize = 64;
|
||||
static constexpr int kHeadSize = kQKNopeSize + kRopeSize;
|
||||
static constexpr int kBytesPerElem = sizeof(T);
|
||||
static constexpr int kBytesPerLoad = 16;
|
||||
static constexpr int kElemPerLoad = kBytesPerLoad / kBytesPerElem;
|
||||
static_assert((kHeadSize * kBytesPerElem) % kBytesPerLoad == 0,
|
||||
"kHeadSize * kBytesPerElem must be multiple of kBytesPerLoad (16Bytes)");
|
||||
static constexpr int kThreadPerHead = (kHeadSize * kBytesPerElem) / kBytesPerLoad;
|
||||
static constexpr int kKVThreadPerHead = (kQKNopeSize * kBytesPerElem) / kBytesPerLoad;
|
||||
static constexpr int kCpTokenPerBlock = 16;
|
||||
static constexpr int kBlockSize = kThreadPerHead * kCpTokenPerBlock;
|
||||
};
|
||||
|
||||
template <typename SrcType, int NUM>
|
||||
inline __device__ void quantCopy(
|
||||
__nv_fp8_e4m3* dst_global_ptr, SrcType const* src_fragment_ptr, float const scale_val = 1.f)
|
||||
@ -311,76 +291,6 @@ __global__ void loadChunkedKVCacheForMLAKernel(T* output_kv_ptr, T* output_k_pe_
|
||||
}
|
||||
}
|
||||
|
||||
// in the most of cases, chunk_size = max_seq_len
|
||||
// output_kv {B, 2, ceil(max_seq_len / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}, padding with
|
||||
// zero
|
||||
// kv {token_size = B*chunked_unit_size, 2, H=128, uncompressed_h=128}, k_pe {token_size = B*chunked_unit_size, h=1,
|
||||
// rope_h}
|
||||
// cu_seq_lens {batch + 1}, fake cu_seq_len, for chunked prefill is {0, chunk_size, chunk_size * 2 ....}
|
||||
template <typename T>
|
||||
__global__ void setChunkedKVCacheForMLAKernel(T* output_kv, T const* kv, T const* k_pe, int const max_seq_len,
|
||||
int const num_heads, int uncompressed_head_size, int rope_size, int64_t const* cu_seq_lens,
|
||||
int kv_cache_tokens_per_block)
|
||||
{
|
||||
using KT = setChunkedKVKernelTraits<T>;
|
||||
int const batch_idx = static_cast<int>(blockIdx.y);
|
||||
int const head_idx = static_cast<int>(blockIdx.z);
|
||||
int const head_dim_vec_idx = (threadIdx.x % KT::kThreadPerHead);
|
||||
int const head_dim_idx = head_dim_vec_idx * KT::kElemPerLoad;
|
||||
bool const is_valid_kv = head_dim_idx < KT::kQKNopeSize;
|
||||
|
||||
int64_t const global_token_offset = cu_seq_lens[batch_idx];
|
||||
int64_t const cache_kv_len = cu_seq_lens[batch_idx + 1] - cu_seq_lens[batch_idx];
|
||||
int const kv_cache_block_num = (max_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
|
||||
int const kv_cache_block_size = num_heads * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size);
|
||||
int64_t const offset_for_kv_in_mem_pool = kv_cache_block_num * kv_cache_block_size;
|
||||
int64_t const kv_offset = num_heads * uncompressed_head_size;
|
||||
size_t const seq_len_loop_end = cache_kv_len;
|
||||
for (int local_token_idx = (threadIdx.x / KT::kThreadPerHead) + blockIdx.x * KT::kCpTokenPerBlock;
|
||||
local_token_idx < seq_len_loop_end; local_token_idx += gridDim.x * KT::kCpTokenPerBlock)
|
||||
{
|
||||
if (local_token_idx >= cache_kv_len)
|
||||
{
|
||||
break;
|
||||
}
|
||||
if (is_valid_kv)
|
||||
{
|
||||
|
||||
int64_t ld_kv_global_offset
|
||||
= int64_t(global_token_offset + local_token_idx) * 2 * num_heads * uncompressed_head_size
|
||||
+ head_idx * uncompressed_head_size;
|
||||
int64_t ld_kv_local_offset = head_dim_vec_idx;
|
||||
auto k_data = (reinterpret_cast<typename KT::VecT const*>(kv + ld_kv_global_offset))[ld_kv_local_offset];
|
||||
auto v_data = (reinterpret_cast<typename KT::VecT const*>(
|
||||
kv + kv_offset + ld_kv_global_offset))[ld_kv_local_offset];
|
||||
|
||||
int64_t st_k_global_offset = int64_t(batch_idx) * 2 * offset_for_kv_in_mem_pool
|
||||
+ local_token_idx / kv_cache_tokens_per_block * kv_cache_block_size
|
||||
+ head_idx * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
|
||||
+ (local_token_idx % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size);
|
||||
int64_t st_v_global_offset = st_k_global_offset + offset_for_kv_in_mem_pool;
|
||||
int64_t st_k_local_offset = head_dim_vec_idx;
|
||||
int64_t st_v_local_offset = head_dim_vec_idx;
|
||||
(reinterpret_cast<typename KT::VecT*>(output_kv + st_k_global_offset))[st_k_local_offset] = k_data;
|
||||
(reinterpret_cast<typename KT::VecT*>(output_kv + st_v_global_offset))[st_v_local_offset] = v_data;
|
||||
}
|
||||
else
|
||||
{
|
||||
// rope h = 1
|
||||
int64_t ld_rope_global_offset = int64_t(global_token_offset + local_token_idx) * rope_size;
|
||||
int64_t ld_rope_local_offset = head_dim_vec_idx - KT::kKVThreadPerHead;
|
||||
auto rope_data
|
||||
= (reinterpret_cast<typename KT::VecT const*>(k_pe + ld_rope_global_offset))[ld_rope_local_offset];
|
||||
int64_t st_rope_global_offset = int64_t(batch_idx) * 2 * offset_for_kv_in_mem_pool
|
||||
+ local_token_idx / kv_cache_tokens_per_block * kv_cache_block_size
|
||||
+ head_idx * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
|
||||
+ (local_token_idx % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size);
|
||||
int64_t st_rope_local_offset = head_dim_vec_idx;
|
||||
(reinterpret_cast<typename KT::VecT*>(output_kv + st_rope_global_offset))[st_rope_local_offset] = rope_data;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace tensorrt_llm
|
||||
@ -427,26 +337,6 @@ void invokeMLALoadChunkedKV(T* output_kv_ptr, T* output_k_pe_ptr, KVBlockArray c
|
||||
kv_cache, cu_ctx_chunked_len, chunked_size, chunked_idx, kv_scale_quant_orig_ptr);
|
||||
}
|
||||
|
||||
// output_kv {B, 2, ceil(chunked_size / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}, padding with
|
||||
// zero
|
||||
// kv {total_token, 2, H, uncompressed_h=128} 0 for k and 1 for v, k_pe {total_token, h=1, rope_h}
|
||||
// input kv and k_pe can be cached tokens or uncached tokens
|
||||
template <typename T>
|
||||
void invokeMLASetChunkedKV(T* output_kv, T const* kv, T const* k_pe, int const batch_size, int const max_seq_len,
|
||||
int const num_heads, int uncompressed_head_size, int rope_size, int64_t const* cu_seq_lens,
|
||||
int const kv_cache_tokens_per_block, cudaStream_t stream)
|
||||
{
|
||||
using KT = setChunkedKVKernelTraits<T>;
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
uncompressed_head_size + rope_size == KT::kHeadSize, "head dim should be equal to %d", KT::kHeadSize);
|
||||
TLLM_CHECK_WITH_INFO(kv_cache_tokens_per_block % KT::kCpTokenPerBlock == 0,
|
||||
"kv_cache_tokens_per_block should be multiple of %d", KT::kCpTokenPerBlock);
|
||||
|
||||
dim3 grid(tensorrt_llm::common::divUp(max_seq_len, KT::kCpTokenPerBlock), batch_size, num_heads);
|
||||
setChunkedKVCacheForMLAKernel<T><<<grid, KT::kBlockSize, 0, stream>>>(output_kv, kv, k_pe, max_seq_len, num_heads,
|
||||
uncompressed_head_size, rope_size, cu_seq_lens, kv_cache_tokens_per_block);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MLA_CHUNKED_PREFILL_KERNEL(T) \
|
||||
template void invokeMergeAttnWithSoftmax<T>(T * merged_attn, float* merged_softmax_stats, T const* pre_attn, \
|
||||
float const* pre_softmax_stats, T const* curr_attn, float const* curr_softmax_stats, int const batch_size, \
|
||||
@ -457,10 +347,7 @@ void invokeMLASetChunkedKV(T* output_kv, T const* kv, T const* k_pe, int const b
|
||||
int chunked_idx, float const* kv_scale_quant_orig_ptr, cudaStream_t stream); \
|
||||
template void invokeMLALoadChunkedKV<T, __nv_fp8_e4m3>(T * output_kv_ptr, T * output_k_pe_ptr, \
|
||||
KVBlockArray const& kv_cache, int const num_contexts, int64_t const* cu_ctx_chunked_len, int lora_size, \
|
||||
int rope_size, int chunked_size, int chunked_idx, float const* kv_scale_quant_orig_ptr, cudaStream_t stream); \
|
||||
template void invokeMLASetChunkedKV<T>(T * output_kv, T const* kv, T const* k_pe, int const batch_size, \
|
||||
int const max_seq_len, int const num_heads, int uncompressed_head_size, int rope_size, \
|
||||
int64_t const* cu_seq_lens, int const kv_cache_tokens_per_block, cudaStream_t stream);
|
||||
int rope_size, int chunked_size, int chunked_idx, float const* kv_scale_quant_orig_ptr, cudaStream_t stream);
|
||||
|
||||
INSTANTIATE_MLA_CHUNKED_PREFILL_KERNEL(half);
|
||||
INSTANTIATE_MLA_CHUNKED_PREFILL_KERNEL(float);
|
||||
|
||||
@ -37,14 +37,5 @@ template <typename T, typename TCache>
|
||||
void invokeMLALoadChunkedKV(T* output_kv_ptr, T* output_k_pe_ptr, KVBlockArray const& kv_cache, int const num_contexts,
|
||||
int64_t const* cu_ctx_chunked_len, int lora_size, int rope_size, int chunked_size, int chunked_idx,
|
||||
float const* kv_scale_quant_orig_ptr, cudaStream_t stream);
|
||||
|
||||
// output_kv {B, 2, ceil(chunked_size / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}, padding with
|
||||
// zero
|
||||
// kv {total_token, 2, H, uncompressed_h=128} 0 for k and 1 for v, k_pe {total_token, h=1, rope_h}
|
||||
// input kv and k_pe can be cached tokens or uncached tokens
|
||||
template <typename T>
|
||||
void invokeMLASetChunkedKV(T* output_kv, T const* kv, T const* k_pe, int const batch_size, int const max_seq_len,
|
||||
int const num_heads, int uncompressed_head_size, int rope_size, int64_t const* cu_seq_lens,
|
||||
int const kv_cache_tokens_per_block, cudaStream_t stream);
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -121,27 +121,6 @@ struct loadPagedKVKernelTraits
|
||||
static constexpr int kKVThreadPerHead = (kLoraSize * kBytesPerElem) / kBytesPerLoad;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct setPagedKVKernelTraits
|
||||
{
|
||||
static constexpr int kQKNopeSize = 128;
|
||||
static constexpr int kVHeadSize = 128;
|
||||
static_assert(kQKNopeSize == kVHeadSize);
|
||||
static constexpr int kRopeSize = 64;
|
||||
static constexpr int kHeadSize = kQKNopeSize + kRopeSize;
|
||||
using VecT = typename VecType<T>::Type;
|
||||
static constexpr int kBytesPerElem = sizeof(T);
|
||||
static constexpr int kBytesPerLoad = 16;
|
||||
static constexpr int kElemPerLoad = kBytesPerLoad / kBytesPerElem;
|
||||
static_assert((kHeadSize * kBytesPerElem) % kBytesPerLoad == 0,
|
||||
"kHeadSize * kBytesPerElem must be multiple of kBytesPerLoad (16Bytes)");
|
||||
static constexpr int kNumHeads = 128;
|
||||
static constexpr int kThreadPerHead = (kHeadSize * kBytesPerElem) / kBytesPerLoad;
|
||||
static constexpr int kKVThreadPerHead = (kQKNopeSize * kBytesPerElem) / kBytesPerLoad;
|
||||
static constexpr int kCpTokenPerBlock = 16;
|
||||
static constexpr int kBlockSize = kThreadPerHead * kCpTokenPerBlock;
|
||||
};
|
||||
|
||||
template <typename SrcType, int NUM>
|
||||
inline __device__ void quantCopy(
|
||||
__nv_fp8_e4m3* dst_global_ptr, SrcType const* src_fragment_ptr, float const scale_val = 1.f)
|
||||
@ -205,11 +184,10 @@ inline __device__ void dequantCopy(
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_SIZE, int K_DIM, int ROPE_DIM, typename KVCacheBuffer>
|
||||
__global__ void applyMLARopeAndAssignQKVKernelOptContext(T* qkv_output, T const* fuse_buf, KVCacheBuffer kv_cache,
|
||||
__global__ void applyMLARopeAndAssignQKVKernelOptContext(T* q_ptr, T* k_ptr, T const* fuse_buf, KVCacheBuffer kv_cache,
|
||||
float2 const* cos_sin_cache, size_t head_num, int head_size, int c_k, int* cu_q_seqlens,
|
||||
int32_t const* kv_cache_lengths, uint32_t max_input_seq_len, KvCacheDataType cache_type, float* bmm1_scale,
|
||||
float* bmm2_scale, float const* quant_scale_o, float const* quant_scale_kv, float const* dequant_scale_q,
|
||||
float const* dequant_scale_kv, float host_bmm1_scale)
|
||||
int32_t const* kv_cache_lengths, uint32_t max_input_seq_len, KvCacheDataType cache_type,
|
||||
float const* quant_scale_kv)
|
||||
{
|
||||
|
||||
// Constants.
|
||||
@ -232,32 +210,6 @@ __global__ void applyMLARopeAndAssignQKVKernelOptContext(T* qkv_output, T const*
|
||||
size_t const batch_idx = blockIdx.y;
|
||||
size_t const head_idx = blockIdx.z;
|
||||
|
||||
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0)
|
||||
{
|
||||
|
||||
// Calculate bmm scale for FP8 MLA
|
||||
if (cache_type == KvCacheDataType::FP8)
|
||||
{
|
||||
float dequant_scale_q_val = dequant_scale_q ? dequant_scale_q[0] : 1.f;
|
||||
float dequant_scale_kv_val = dequant_scale_kv ? dequant_scale_kv[0] : 1.f;
|
||||
float quant_scale_o_val = quant_scale_o ? quant_scale_o[0] : 1.f;
|
||||
if (bmm1_scale)
|
||||
{
|
||||
// The scale prepared for log2 optimization.
|
||||
constexpr float kLog2e = 1.4426950408889634074f;
|
||||
// The scale after fmha bmm1.
|
||||
float bmm1_scale_val = dequant_scale_q_val * dequant_scale_kv_val * host_bmm1_scale;
|
||||
bmm1_scale[0] = bmm1_scale_val;
|
||||
bmm1_scale[1] = bmm1_scale_val * kLog2e;
|
||||
}
|
||||
if (bmm2_scale)
|
||||
{
|
||||
// The scale after fmha bmm2.
|
||||
bmm2_scale[0] = quant_scale_o_val * dequant_scale_kv_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (head_idx < head_num)
|
||||
{
|
||||
size_t const head_dim_vec_idx = (threadIdx.x % VECS_PER_HEAD);
|
||||
@ -287,11 +239,10 @@ __global__ void applyMLARopeAndAssignQKVKernelOptContext(T* qkv_output, T const*
|
||||
|
||||
VecT q, k;
|
||||
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM) + c_k;
|
||||
auto const src_q_global_offset
|
||||
= static_cast<size_t>(global_token_idx) * head_num * ((head_size + ROPE_DIM) * 2 + head_size)
|
||||
auto const src_q_global_offset = static_cast<size_t>(global_token_idx) * head_num * (head_size + ROPE_DIM)
|
||||
+ (head_size + ROPE_DIM) * head_idx + head_size;
|
||||
|
||||
q = *reinterpret_cast<VecT const*>(&qkv_output[src_q_global_offset + head_dim_idx]);
|
||||
q = *reinterpret_cast<VecT const*>(&q_ptr[src_q_global_offset + head_dim_idx]);
|
||||
k = *reinterpret_cast<VecT const*>(&fuse_buf[src_k_global_offset + head_dim_idx]);
|
||||
|
||||
// Pack two elements into one for gptj rotary embedding.
|
||||
@ -322,14 +273,12 @@ __global__ void applyMLARopeAndAssignQKVKernelOptContext(T* qkv_output, T const*
|
||||
else
|
||||
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = k;
|
||||
}
|
||||
auto const dst_q_idx
|
||||
= static_cast<size_t>(global_token_idx) * head_num * ((head_size + ROPE_DIM) * 2 + head_size)
|
||||
auto const dst_q_idx = static_cast<size_t>(global_token_idx) * head_num * (head_size + ROPE_DIM)
|
||||
+ head_idx * (head_size + ROPE_DIM) + head_size + head_dim_idx;
|
||||
auto const dst_k_idx
|
||||
= static_cast<size_t>(global_token_idx) * head_num * ((head_size + ROPE_DIM) * 2 + head_size)
|
||||
+ head_num * (head_size + ROPE_DIM) + head_idx * (head_size + ROPE_DIM) + head_size + head_dim_idx;
|
||||
reinterpret_cast<VecT*>(qkv_output)[dst_q_idx / ELTS_PER_VEC] = q;
|
||||
reinterpret_cast<VecT*>(qkv_output)[dst_k_idx / ELTS_PER_VEC] = k;
|
||||
auto const dst_k_idx = static_cast<size_t>(global_token_idx) * head_num * (head_size + ROPE_DIM)
|
||||
+ head_idx * (head_size + ROPE_DIM) + head_size + head_dim_idx;
|
||||
reinterpret_cast<VecT*>(q_ptr)[dst_q_idx / ELTS_PER_VEC] = q;
|
||||
reinterpret_cast<VecT*>(k_ptr)[dst_k_idx / ELTS_PER_VEC] = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -712,79 +661,6 @@ __global__ void loadPagedKVCacheForMLAKernel(T* compressed_kv_ptr, T* k_pe_ptr,
|
||||
}
|
||||
}
|
||||
|
||||
// k {total_token, h, d}, v {total_token, h, d}, k_pe {total_token, h=1, d_rope}
|
||||
// output {b, 2, ceil(max_seq / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}
|
||||
template <typename T>
|
||||
__global__ void setPagedKVCacheForMLAKernel(T* output, T const* k_ptr, T const* v_ptr, T const* k_pe_ptr,
|
||||
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, int rope_dim,
|
||||
int kv_cache_tokens_per_block, int64_t kv_token_stride)
|
||||
{
|
||||
using KT = typename tensorrt_llm::kernels::setPagedKVKernelTraits<T>;
|
||||
int const batch_idx = static_cast<int>(blockIdx.y);
|
||||
int const head_idx = static_cast<int>(blockIdx.z);
|
||||
int const head_dim_vec_idx = (threadIdx.x % KT::kThreadPerHead);
|
||||
int const head_dim_idx = head_dim_vec_idx * KT::kElemPerLoad;
|
||||
bool const is_valid_v = head_dim_idx < KT::kVHeadSize;
|
||||
|
||||
size_t const seq_len_loop_end
|
||||
= (max_input_seq_len + KT::kCpTokenPerBlock - 1) / KT::kCpTokenPerBlock * KT::kCpTokenPerBlock;
|
||||
size_t const kv_cache_block_size = num_heads * kv_cache_tokens_per_block * (kv_dim + rope_dim);
|
||||
size_t const kv_cache_block_num = (max_input_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
|
||||
int64_t const global_token_offset = cu_seq_lens[batch_idx];
|
||||
int64_t const cache_kv_len = cu_seq_lens[batch_idx + 1] - cu_seq_lens[batch_idx];
|
||||
|
||||
for (int local_token_idx = (threadIdx.x / KT::kThreadPerHead) + blockIdx.x * KT::kCpTokenPerBlock;
|
||||
local_token_idx < seq_len_loop_end; local_token_idx += KT::kCpTokenPerBlock * gridDim.x)
|
||||
{
|
||||
int token_idx_in_kv_cache = local_token_idx;
|
||||
bool const valid_token = token_idx_in_kv_cache < cache_kv_len;
|
||||
if (valid_token)
|
||||
{
|
||||
// copy k and v
|
||||
if (is_valid_v)
|
||||
{
|
||||
int ld_kv_global_offset = (global_token_offset + local_token_idx) * kv_token_stride + head_idx * kv_dim;
|
||||
int ld_kv_local_offset = head_dim_vec_idx;
|
||||
auto k_data
|
||||
= (reinterpret_cast<typename KT::VecT const*>(k_ptr + ld_kv_global_offset))[ld_kv_local_offset];
|
||||
auto v_data
|
||||
= (reinterpret_cast<typename KT::VecT const*>(v_ptr + ld_kv_global_offset))[ld_kv_local_offset];
|
||||
// {b, 0, token / kv_cache_tokens_per_block, h, token % kv_cache_tokens_per_block, ...}
|
||||
int st_k_global_offset = batch_idx * 2 * kv_cache_block_num * kv_cache_block_size
|
||||
+ local_token_idx / kv_cache_tokens_per_block * kv_cache_block_size
|
||||
+ head_idx * kv_cache_tokens_per_block * (kv_dim + rope_dim)
|
||||
+ (local_token_idx % kv_cache_tokens_per_block) * (kv_dim + rope_dim);
|
||||
// {b, 1, token / kv_cache_tokens_per_block, h, token % kv_cache_tokens_per_block, ...}
|
||||
int st_v_global_offset = st_k_global_offset + kv_cache_block_num * kv_cache_block_size;
|
||||
int st_k_local_offset = head_dim_vec_idx;
|
||||
int st_v_local_offset = head_dim_vec_idx;
|
||||
(reinterpret_cast<typename KT::VecT*>(output + st_k_global_offset))[st_k_local_offset] = k_data;
|
||||
(reinterpret_cast<typename KT::VecT*>(output + st_v_global_offset))[st_v_local_offset] = v_data;
|
||||
}
|
||||
// copy k_pe, only 1 head
|
||||
else
|
||||
{
|
||||
int ld_rope_global_offset = (global_token_offset + local_token_idx) * rope_dim;
|
||||
int ld_rope_local_offset = head_dim_vec_idx - KT::kKVThreadPerHead;
|
||||
auto rope_data = (reinterpret_cast<typename KT::VecT const*>(
|
||||
k_pe_ptr + ld_rope_global_offset))[ld_rope_local_offset];
|
||||
// {b, 0, token / kv_cache_tokens_per_block, h, token % kv_cache_tokens_per_block, ...}
|
||||
int st_rope_global_offset = batch_idx * 2 * kv_cache_block_num * kv_cache_block_size
|
||||
+ local_token_idx / kv_cache_tokens_per_block * kv_cache_block_size
|
||||
+ head_idx * kv_cache_tokens_per_block * (kv_dim + rope_dim)
|
||||
+ (local_token_idx % kv_cache_tokens_per_block) * (kv_dim + rope_dim);
|
||||
int st_rope_local_offset = head_dim_vec_idx;
|
||||
(reinterpret_cast<typename KT::VecT*>(output + st_rope_global_offset))[st_rope_local_offset]
|
||||
= rope_data;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// q {total_uncached_tokens, h, d_nope + d_rope}
|
||||
// latent_cache {total_uncached_tokens, d_k + d_rope}
|
||||
template <typename T, typename TCache, int BLOCK_SIZE, int K_DIM, int ROPE_DIM>
|
||||
@ -941,58 +817,152 @@ __global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T*
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_SIZE, int QK_NOPE_HEAD_DIM, int QK_ROPE_HEAD_DIM, int V_HEAD_DIM>
|
||||
__global__ void quantizeCopyInputToFp8Kernel(T const* q_buf, __nv_fp8_e4m3* quant_q_buf, T const* k_buf,
|
||||
__nv_fp8_e4m3* quant_k_buf, T const* v_buf, __nv_fp8_e4m3* quant_v_buf, int total_q_len, int total_kv_len,
|
||||
float const* quant_scale_qkv_ptr, float* bmm1_scale, float* bmm2_scale, float const* quant_scale_o,
|
||||
float const* dequant_scale_q, float const* dequant_scale_kv, float host_bmm1_scale)
|
||||
{
|
||||
// Constants.
|
||||
using VecT = typename VecType<T>::Type;
|
||||
constexpr auto BYTES_PER_ELT = sizeof(T);
|
||||
constexpr auto BYTES_PER_LOAD = 16;
|
||||
constexpr auto ELTS_PER_VEC = BYTES_PER_LOAD / BYTES_PER_ELT;
|
||||
constexpr auto QK_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM;
|
||||
static_assert(
|
||||
(QK_HEAD_DIM * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "QK head size needs to be multiple of 16 bytes.");
|
||||
static_assert((V_HEAD_DIM * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "V head size needs to be multiple of 16 bytes.");
|
||||
constexpr auto QK_VECS_PER_HEAD = QK_HEAD_DIM * BYTES_PER_ELT / BYTES_PER_LOAD;
|
||||
constexpr auto V_VECS_PER_HEAD = V_HEAD_DIM * BYTES_PER_ELT / BYTES_PER_LOAD;
|
||||
static_assert(BLOCK_SIZE % QK_VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
|
||||
static_assert(BLOCK_SIZE % V_VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
|
||||
constexpr auto QK_TOKENS_PER_BLOCK = BLOCK_SIZE / QK_VECS_PER_HEAD;
|
||||
constexpr auto V_TOKENS_PER_BLOCK = BLOCK_SIZE / V_VECS_PER_HEAD;
|
||||
|
||||
size_t const head_idx = blockIdx.z;
|
||||
size_t const head_num = gridDim.z;
|
||||
|
||||
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0)
|
||||
{
|
||||
// Calculate bmm scale for FP8 MLA
|
||||
float dequant_scale_q_val = dequant_scale_q ? dequant_scale_q[0] : 1.f;
|
||||
float dequant_scale_kv_val = dequant_scale_kv ? dequant_scale_kv[0] : 1.f;
|
||||
float quant_scale_o_val = quant_scale_o ? quant_scale_o[0] : 1.f;
|
||||
if (bmm1_scale)
|
||||
{
|
||||
// The scale prepared for log2 optimization.
|
||||
constexpr float kLog2e = 1.4426950408889634074f;
|
||||
// The scale after fmha bmm1.
|
||||
float bmm1_scale_val = dequant_scale_q_val * dequant_scale_kv_val * host_bmm1_scale;
|
||||
bmm1_scale[0] = bmm1_scale_val;
|
||||
bmm1_scale[1] = bmm1_scale_val * kLog2e;
|
||||
}
|
||||
if (bmm2_scale)
|
||||
{
|
||||
// The scale after fmha bmm2.
|
||||
bmm2_scale[0] = quant_scale_o_val * dequant_scale_kv_val;
|
||||
}
|
||||
}
|
||||
|
||||
size_t const qk_head_dim_vec_idx = (threadIdx.x % QK_VECS_PER_HEAD);
|
||||
size_t const v_head_dim_vec_idx = (threadIdx.x % V_VECS_PER_HEAD);
|
||||
size_t const qk_head_dim_idx = qk_head_dim_vec_idx * ELTS_PER_VEC;
|
||||
size_t const v_head_dim_idx = v_head_dim_vec_idx * ELTS_PER_VEC;
|
||||
|
||||
size_t const q_len_loop_end
|
||||
= size_t((total_q_len + QK_TOKENS_PER_BLOCK - 1) / QK_TOKENS_PER_BLOCK) * QK_TOKENS_PER_BLOCK;
|
||||
size_t const k_len_loop_end
|
||||
= size_t((total_kv_len + QK_TOKENS_PER_BLOCK - 1) / QK_TOKENS_PER_BLOCK) * QK_TOKENS_PER_BLOCK;
|
||||
size_t const v_len_loop_end
|
||||
= size_t((total_kv_len + V_TOKENS_PER_BLOCK - 1) / V_TOKENS_PER_BLOCK) * V_TOKENS_PER_BLOCK;
|
||||
float quant_scale_qkv_val = quant_scale_qkv_ptr ? quant_scale_qkv_ptr[0] : 1.f;
|
||||
|
||||
// Quantize Q, both src and dst are contiguous
|
||||
for (int q_token_idx = (threadIdx.x / QK_VECS_PER_HEAD) + blockIdx.x * QK_TOKENS_PER_BLOCK;
|
||||
q_token_idx < q_len_loop_end; q_token_idx += QK_TOKENS_PER_BLOCK * gridDim.x)
|
||||
{
|
||||
if (q_token_idx < total_q_len)
|
||||
{
|
||||
auto const src_q_idx
|
||||
= static_cast<size_t>(q_token_idx) * QK_HEAD_DIM * head_num + head_idx * QK_HEAD_DIM + qk_head_dim_idx;
|
||||
auto const dst_q_idx = src_q_idx;
|
||||
quantCopy<T, ELTS_PER_VEC>(quant_q_buf + dst_q_idx, &q_buf[src_q_idx], quant_scale_qkv_val);
|
||||
}
|
||||
}
|
||||
|
||||
// Quantize K, both src and dst are contiguous
|
||||
for (int k_token_idx = (threadIdx.x / QK_VECS_PER_HEAD) + blockIdx.x * QK_TOKENS_PER_BLOCK;
|
||||
k_token_idx < k_len_loop_end; k_token_idx += QK_TOKENS_PER_BLOCK * gridDim.x)
|
||||
{
|
||||
if (k_token_idx < total_kv_len)
|
||||
{
|
||||
auto const src_k_idx
|
||||
= static_cast<size_t>(k_token_idx) * QK_HEAD_DIM * head_num + head_idx * QK_HEAD_DIM + qk_head_dim_idx;
|
||||
auto const dst_k_idx = src_k_idx;
|
||||
quantCopy<T, ELTS_PER_VEC>(quant_k_buf + dst_k_idx, &k_buf[src_k_idx], quant_scale_qkv_val);
|
||||
}
|
||||
}
|
||||
|
||||
// Quantize V, dst V is contiguous, but src V is not contiguous, so we need to calculate the stride
|
||||
size_t const src_v_token_stride = (QK_NOPE_HEAD_DIM + V_HEAD_DIM) * head_num;
|
||||
for (int v_token_idx = (threadIdx.x / V_VECS_PER_HEAD) + blockIdx.x * V_TOKENS_PER_BLOCK;
|
||||
v_token_idx < v_len_loop_end; v_token_idx += V_TOKENS_PER_BLOCK * gridDim.x)
|
||||
{
|
||||
if (v_token_idx < total_kv_len)
|
||||
{
|
||||
auto const src_v_idx
|
||||
= static_cast<size_t>(v_token_idx) * src_v_token_stride + head_idx * V_HEAD_DIM + v_head_dim_idx;
|
||||
auto const dst_v_idx
|
||||
= static_cast<size_t>(v_token_idx) * V_HEAD_DIM * head_num + head_idx * V_HEAD_DIM + v_head_dim_idx;
|
||||
quantCopy<T, ELTS_PER_VEC>(quant_v_buf + dst_v_idx, &v_buf[src_v_idx], quant_scale_qkv_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream)
|
||||
{
|
||||
dim3 grid(int(tensorrt_llm::common::divUp(params.max_input_seq_len, 32)), params.batch_size, params.head_num + 8);
|
||||
auto head_size = params.meta.qk_nope_head_dim;
|
||||
applyMLARopeAndAssignQKVKernelOptContext<T, 256, 512, 64, KVCacheBuffer><<<grid, 256, 0, stream>>>(
|
||||
params.attention_input_buf, params.latent_cache, kv_cache_buffer, params.cos_sin_cache, params.head_num,
|
||||
head_size, params.meta.kv_lora_rank, params.cu_q_seqlens, params.cache_seq_lens, params.max_input_seq_len,
|
||||
params.cache_type, params.bmm1_scale, params.bmm2_scale, params.quant_scale_o, params.quant_scale_kv,
|
||||
params.dequant_scale_q, params.dequant_scale_kv, params.host_bmm1_scale);
|
||||
if (params.attention_input_buf != nullptr && params.quant_attention_input_buf != nullptr
|
||||
&& params.cache_type == KvCacheDataType::FP8)
|
||||
applyMLARopeAndAssignQKVKernelOptContext<T, 256, 512, 64, KVCacheBuffer><<<grid, 256, 0, stream>>>(params.q_buf,
|
||||
params.k_buf, params.latent_cache, kv_cache_buffer, params.cos_sin_cache, params.head_num, head_size,
|
||||
params.meta.kv_lora_rank, params.cu_q_seqlens, params.cache_seq_lens, params.max_input_seq_len,
|
||||
params.cache_type, params.quant_scale_kv);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeMLAContextFp8Quantize(MlaParams<T>& params, int total_kv_len, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(params.cache_type == KvCacheDataType::FP8, "MLA Context: cache_type must be FP8");
|
||||
TLLM_CHECK_WITH_INFO(params.q_buf != nullptr, "MLA Context: q_buf must be non-null");
|
||||
TLLM_CHECK_WITH_INFO(params.k_buf != nullptr, "MLA Context: k_buf must be non-null");
|
||||
TLLM_CHECK_WITH_INFO(params.v_buf != nullptr, "MLA Context: v_buf must be non-null");
|
||||
TLLM_CHECK_WITH_INFO(params.quant_q_buf != nullptr, "MLA Context: quant_q_buf must be non-null");
|
||||
TLLM_CHECK_WITH_INFO(params.quant_k_buf != nullptr, "MLA Context: quant_k_buf must be non-null");
|
||||
TLLM_CHECK_WITH_INFO(params.quant_v_buf != nullptr, "MLA Context: quant_v_buf must be non-null");
|
||||
|
||||
TLLM_LOG_DEBUG("MLA RoPE Context: Quantizing separate qkv to FP8");
|
||||
|
||||
if (params.acc_q_len > 0)
|
||||
{
|
||||
TLLM_LOG_DEBUG("MLA RoPE Context: Quantizing attention_input_buf to FP8");
|
||||
constexpr int threads_per_block = 384;
|
||||
dim3 grid(int(tensorrt_llm::common::divUp(total_kv_len, 48)), 1, params.head_num);
|
||||
|
||||
int const dim_q_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim);
|
||||
int const dim_k_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim);
|
||||
int const dim_v_per_head = (params.meta.v_head_dim);
|
||||
TLLM_LOG_DEBUG(
|
||||
"Launching quantizeCopyInputToFp8Kernel with grid_size: (%d, %d, %d), threads_per_block: %d, "
|
||||
"total_kv_len: %d, acc_q_len: %d",
|
||||
grid.x, grid.y, grid.z, threads_per_block, total_kv_len, params.acc_q_len);
|
||||
|
||||
// Total dimension per token across all heads for Q, K, and V components respectively
|
||||
int const total_q_dim_all_heads = params.head_num * dim_q_per_head;
|
||||
int const total_k_dim_all_heads
|
||||
= params.head_num * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
|
||||
int const total_v_dim_all_heads
|
||||
= params.head_num * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
|
||||
|
||||
int const num_total_qkv_elements
|
||||
= params.acc_q_len * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
|
||||
size_t headDim = params.meta.kv_lora_rank + params.meta.qk_rope_head_dim;
|
||||
float const* device_qkv_scale_ptr = params.quant_scale_qkv;
|
||||
|
||||
if (num_total_qkv_elements > 0)
|
||||
{
|
||||
int const threads_per_block = 256;
|
||||
int const num_blocks = (num_total_qkv_elements + threads_per_block - 1) / threads_per_block;
|
||||
|
||||
TLLM_LOG_DEBUG(
|
||||
"Launching QuantizeCopyInputToFp8Kernel with num_blocks: %d, threads_per_block: %d, elements: %d",
|
||||
num_blocks, threads_per_block, num_total_qkv_elements);
|
||||
|
||||
tensorrt_llm::kernels::QuantizeCopyInputToFp8Kernel<T><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
static_cast<T const*>(params.attention_input_buf), // Source
|
||||
static_cast<__nv_fp8_e4m3*>(params.quant_attention_input_buf), // Destination
|
||||
num_total_qkv_elements, device_qkv_scale_ptr);
|
||||
sync_check_cuda_error(stream);
|
||||
|
||||
cudaStreamSynchronize(stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_WARNING("MLA RoPE Context: num_total_qkv_elements is 0, skipping quantization.");
|
||||
}
|
||||
quantizeCopyInputToFp8Kernel<T, threads_per_block, 128, 64, 128>
|
||||
<<<grid, threads_per_block, 0, stream>>>(params.q_buf, static_cast<__nv_fp8_e4m3*>(params.quant_q_buf),
|
||||
params.k_buf, static_cast<__nv_fp8_e4m3*>(params.quant_k_buf), params.v_buf,
|
||||
static_cast<__nv_fp8_e4m3*>(params.quant_v_buf), params.acc_q_len, total_kv_len, params.quant_scale_qkv,
|
||||
params.bmm1_scale, params.bmm2_scale, params.quant_scale_o, params.dequant_scale_q,
|
||||
params.dequant_scale_kv, params.host_bmm1_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_WARNING("MLA RoPE Context: acc_q_len is 0, skipping quantization.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -1017,12 +987,12 @@ void invokeMLARopeGeneration(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance, params.attention_input_buf, params.q_pe, params.latent_cache,
|
||||
params.quant_attention_input_buf, kv_cache_buffer, params.cos_sin_cache, params.head_num,
|
||||
params.meta.kv_lora_rank, params.acc_q_len, seq_len, params.seqQOffset, params.fmha_tile_counter,
|
||||
params.cache_seq_lens, params.cu_kv_seqlens, params.q_pe_ld, params.q_pe_stride, params.cache_type,
|
||||
params.bmm1_scale, params.bmm2_scale, params.quant_scale_o, params.quant_scale_q, params.quant_scale_kv,
|
||||
params.dequant_scale_q, params.dequant_scale_kv, params.host_bmm1_scale);
|
||||
cudaLaunchKernelEx(&config, kernel_instance, params.q_buf, params.q_pe, params.latent_cache, params.quant_q_buf,
|
||||
kv_cache_buffer, params.cos_sin_cache, params.head_num, params.meta.kv_lora_rank, params.acc_q_len, seq_len,
|
||||
params.seqQOffset, params.fmha_tile_counter, params.cache_seq_lens, params.cu_kv_seqlens, params.q_pe_ld,
|
||||
params.q_pe_stride, params.cache_type, params.bmm1_scale, params.bmm2_scale, params.quant_scale_o,
|
||||
params.quant_scale_q, params.quant_scale_kv, params.dequant_scale_q, params.dequant_scale_kv,
|
||||
params.host_bmm1_scale);
|
||||
}
|
||||
|
||||
template <typename T, typename TCache>
|
||||
@ -1040,20 +1010,6 @@ void invokeMLALoadPagedKV(T* compressed_kv_ptr, T* k_pe_ptr, KVBlockArray& kv_ca
|
||||
compressed_kv_ptr, k_pe_ptr, kv_cache, cu_ctx_cached_kv_lens, max_input_seq_len, kv_scale_quant_orig_ptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_pe_ptr, int const num_requests,
|
||||
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, int rope_dim,
|
||||
int kv_cache_tokens_per_block, int64_t kv_token_stride, cudaStream_t stream)
|
||||
{
|
||||
using KT = typename tensorrt_llm::kernels::setPagedKVKernelTraits<T>;
|
||||
TLLM_CHECK_WITH_INFO(kv_dim + rope_dim == KT::kHeadSize, "head dim should be equal to %d", KT::kHeadSize);
|
||||
TLLM_CHECK_WITH_INFO(kv_cache_tokens_per_block % KT::kCpTokenPerBlock == 0,
|
||||
"kv_cache_tokens_per_block should be multiple of %d", KT::kCpTokenPerBlock);
|
||||
dim3 grid(tensorrt_llm::common::divUp(max_input_seq_len, KT::kCpTokenPerBlock), num_requests, num_heads);
|
||||
setPagedKVCacheForMLAKernel<T><<<grid, KT::kBlockSize, 0, stream>>>(output, k_ptr, v_ptr, k_pe_ptr, cu_seq_lens,
|
||||
max_input_seq_len, num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, kv_token_stride);
|
||||
}
|
||||
|
||||
template <typename T, typename TCache>
|
||||
void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* latent_cache_ptr, int const num_requests,
|
||||
int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len,
|
||||
@ -1076,11 +1032,15 @@ INSTANTIATE_MLA_ROPE(float, KVBlockArray);
|
||||
INSTANTIATE_MLA_ROPE(half, KVBlockArray);
|
||||
INSTANTIATE_MLA_ROPE(float, KVLinearBuffer);
|
||||
INSTANTIATE_MLA_ROPE(half, KVLinearBuffer);
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MLA_ROPE(__nv_bfloat16, KVBlockArray);
|
||||
INSTANTIATE_MLA_ROPE(__nv_bfloat16, KVLinearBuffer);
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_MLA_QUANTIZE(T) \
|
||||
template void invokeMLAContextFp8Quantize<T>(MlaParams<T> & params, int total_kv_len, cudaStream_t stream);
|
||||
|
||||
INSTANTIATE_MLA_QUANTIZE(float);
|
||||
INSTANTIATE_MLA_QUANTIZE(half);
|
||||
INSTANTIATE_MLA_QUANTIZE(__nv_bfloat16);
|
||||
|
||||
#define INSTANTIATE_RW_KVCACHE_MLA(T, TCache) \
|
||||
template void invokeMLALoadPagedKV<T, TCache>(T * compressed_kv_ptr, T * k_pe_ptr, KVBlockArray & kv_cache, \
|
||||
@ -1099,26 +1059,6 @@ INSTANTIATE_RW_KVCACHE_MLA(half, __nv_fp8_e4m3);
|
||||
INSTANTIATE_RW_KVCACHE_MLA(__nv_bfloat16, __nv_bfloat16);
|
||||
INSTANTIATE_RW_KVCACHE_MLA(__nv_bfloat16, __nv_fp8_e4m3);
|
||||
|
||||
#define INSTANTIATE_SET_KVCACHE_MLA(T) \
|
||||
template void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_pe_ptr, \
|
||||
int const num_requests, int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, \
|
||||
int rope_dim, int kv_cache_tokens_per_block, int64_t kv_token_stride, cudaStream_t stream);
|
||||
|
||||
INSTANTIATE_SET_KVCACHE_MLA(float);
|
||||
INSTANTIATE_SET_KVCACHE_MLA(half);
|
||||
INSTANTIATE_SET_KVCACHE_MLA(__nv_bfloat16);
|
||||
|
||||
template <typename T_IN>
|
||||
__global__ void QuantizeCopyInputToFp8Kernel(
|
||||
T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr)
|
||||
{
|
||||
uint element_idx = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
if (element_idx < num_total_elements)
|
||||
{
|
||||
float scale_factor = (device_scale_ptr != nullptr) ? *device_scale_ptr : 1.0f;
|
||||
output_fp8_buffer[element_idx] = __nv_fp8_e4m3(static_cast<float>(input_buffer[element_idx]) * scale_factor);
|
||||
}
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -51,9 +51,25 @@ struct MlaMetaParams
|
||||
template <typename T>
|
||||
struct MlaParams
|
||||
{
|
||||
T const* latent_cache; // cKV + k_pe
|
||||
T* attention_input_buf; // [b, s, 3, h, d_h + r]
|
||||
void* quant_attention_input_buf;
|
||||
T const* latent_cache; // cKV + k_pe
|
||||
// Tensor Q for both context and generation MLA, contiguous. Pre-process kernel will apply RoPE and modify it
|
||||
// in-place. For context MLA, shape: [total_q_len, h * (d_nope + d_rope)], stride: [h * (d_nope + d_rope), 1]
|
||||
T* q_buf;
|
||||
// Separate tensor K for context MLA, contiguous. Pre-process kernel will apply RoPE and modify it in-place.
|
||||
// shape: [total_kv_len, h * (d_nope + d_rope)], stride: [h * (d_nope + d_rope), 1]
|
||||
T* k_buf = nullptr;
|
||||
// Separate tensor V for context MLA, NOT contiguous,
|
||||
// shape: [total_kv_len, h * d_v], stride: [h * (d_nope + d_v), 1]
|
||||
T const* v_buf = nullptr;
|
||||
// Tensor quantized Q for both context and generation MLA.
|
||||
// For context MLA, shape: [total_q_len, h * (d_nope + d_rope)], stride: [h * (d_nope + d_rope), 1]
|
||||
void* quant_q_buf = nullptr;
|
||||
// Tensor quantized K for context MLA, contiguous
|
||||
// shape: [total_kv_len, h * (d_nope + d_rope)], stride: [h * (d_nope + d_rope), 1]
|
||||
void* quant_k_buf = nullptr;
|
||||
// Tensor quantized V for context MLA, contiguous
|
||||
// shape: [total_kv_len, h * d_v], stride: [h * d_v, 1]
|
||||
void* quant_v_buf = nullptr;
|
||||
T* context_buf;
|
||||
T* q_pe; // [b, h, d_r], strided
|
||||
|
||||
@ -83,17 +99,16 @@ struct MlaParams
|
||||
float const* dequant_scale_kv;
|
||||
float host_bmm1_scale;
|
||||
|
||||
// for kv cache reuse/chunked context
|
||||
void* context_paged_kv_ptr = nullptr;
|
||||
void* context_kv_cache_block_offsets_ptr = nullptr;
|
||||
int32_t context_paged_kv_max_blocks_per_seq = 0;
|
||||
// for FP8 context qkv quantization
|
||||
// For FP8 context qkv quantization
|
||||
float const* quant_scale_qkv = nullptr;
|
||||
};
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void invokeMLAContextFp8Quantize(MlaParams<T>& params, int total_kv_len, cudaStream_t stream);
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void invokeMLARopeGeneration(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream);
|
||||
|
||||
@ -102,20 +117,10 @@ void invokeMLALoadPagedKV(T* compressed_kv_ptr, T* k_pe_ptr, KVBlockArray& kv_ca
|
||||
int64_t const* cu_ctx_cached_kv_lens, int const max_input_seq_len, int const lora_size, int const rope_size,
|
||||
float const* kv_scale_quant_orig_ptr, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_pe_ptr, int const num_requests,
|
||||
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, int rope_dim,
|
||||
int kv_cache_tokens_per_block, int64_t kv_token_stride, cudaStream_t stream);
|
||||
|
||||
template <typename T, typename TCache>
|
||||
void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* latent_cache_ptr, int const num_requests,
|
||||
int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len,
|
||||
float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, int lora_size,
|
||||
float const* kv_scale_orig_quant_ptr, cudaStream_t stream);
|
||||
|
||||
template <typename T_IN>
|
||||
__global__ void QuantizeCopyInputToFp8Kernel(
|
||||
T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -238,6 +238,8 @@ struct TllmGenFmhaRunnerParams
|
||||
int mHeadDimQk;
|
||||
// Head dimension for V.
|
||||
int mHeadDimV;
|
||||
// Head dimension for Q/K non-RoPE part, only used for MLA now.
|
||||
int mHeadDimQkNope;
|
||||
// Number of heads for Q and K/V.
|
||||
int mNumHeadsQ, mNumHeadsKv, mNumHeadsQPerKv;
|
||||
// The batch size.
|
||||
|
||||
@ -329,7 +329,7 @@ struct KernelParams
|
||||
|
||||
// Compute the strides for K and V.
|
||||
template <class FmhaOptions>
|
||||
static auto makeStrideKv(FmhaOptions const& options, bool isK)
|
||||
static auto makeStrideKv(FmhaOptions const& options, Data_type dtypeKv, bool isK)
|
||||
{
|
||||
|
||||
// The maximum headDim of K and V.
|
||||
@ -357,6 +357,12 @@ struct KernelParams
|
||||
{
|
||||
strideKeysVals = maxHeadDimKv;
|
||||
}
|
||||
else if (isSeparateQkv(options.mQkvLayout) && !isK && options.mHeadDimQkNope > 0 && dtypeKv != DATA_TYPE_E4M3)
|
||||
{
|
||||
// Non-FP8 context MLA: tensor V is not contiguous. The token stride is mNumHeadsKv * (mHeadDimQkNope +
|
||||
// mHeadDimV).
|
||||
strideKeysVals = options.mNumHeadsKv * (options.mHeadDimQkNope + options.mHeadDimV);
|
||||
}
|
||||
|
||||
// The stride between heads.
|
||||
int32_t strideHeads{isK ? options.mHeadDimQk : options.mHeadDimV};
|
||||
@ -397,7 +403,7 @@ struct KernelParams
|
||||
// The shape elements.
|
||||
auto [numKeys, numHeadsQPerKv, batchSize] = makeShapeKv(options, params);
|
||||
// The stride elements.
|
||||
auto [strideKeys, strideHeads, strideBatch] = makeStrideKv(options, isK);
|
||||
auto [strideKeys, strideHeads, strideBatch] = makeStrideKv(options, dtypeKv, isK);
|
||||
|
||||
// The headDim.
|
||||
// Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv.
|
||||
@ -430,12 +436,13 @@ struct KernelParams
|
||||
|
||||
// Create the TMA shape/stride for KV scaling factors.
|
||||
template <class FmhaOptions>
|
||||
static auto makeTmaShapeStrideKvSf(FmhaOptions const& options, KernelParams const& params, bool isK)
|
||||
static auto makeTmaShapeStrideKvSf(
|
||||
FmhaOptions const& options, KernelParams const& params, Data_type dtypeKv, bool isK)
|
||||
{
|
||||
// The shape elements.
|
||||
auto [numKeys, numHeadsQPerKv, batchSize] = makeShapeKv(options, params);
|
||||
// The stride elements.
|
||||
auto [strideKeys, strideHeads, strideBatch] = makeStrideKv(options, isK);
|
||||
auto [strideKeys, strideHeads, strideBatch] = makeStrideKv(options, dtypeKv, isK);
|
||||
|
||||
// The headDim.
|
||||
// Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv.
|
||||
@ -685,7 +692,8 @@ struct KernelParams
|
||||
int32_t NumEltsPerSf = 16;
|
||||
// Compute the shape and stride for SF tensor.
|
||||
// FIXME: assume K and V uses the same shape.
|
||||
auto [shapeKvSf, strideKvSf] = makeTmaShapeStrideKvSf(options, params, /*isK*/ true);
|
||||
auto [shapeKvSf, strideKvSf]
|
||||
= makeTmaShapeStrideKvSf(options, params, kernelMeta.mDataTypeKv, /*isK*/ true);
|
||||
|
||||
// The tileShapes for K/V.
|
||||
std::vector<uint32_t> tileShapeKvSf(shapeKvSf.size(), 1);
|
||||
|
||||
@ -32,8 +32,8 @@ void initBindings(nb::module_& m)
|
||||
// Parameters with default values using std::nullopt for optional arguments
|
||||
nb::arg("q"), nb::arg("k") = std::nullopt, nb::arg("v") = std::nullopt, nb::arg("output"),
|
||||
nb::arg("output_sf") = std::nullopt, nb::arg("out_dtype") = std::nullopt, nb::arg("workspace_") = std::nullopt,
|
||||
nb::arg("sequence_length"), nb::arg("host_past_key_value_lengths"), nb::arg("context_lengths"),
|
||||
nb::arg("host_context_lengths"), nb::arg("host_request_types"),
|
||||
nb::arg("sequence_length"), nb::arg("host_past_key_value_lengths"), nb::arg("host_total_kv_lens"),
|
||||
nb::arg("context_lengths"), nb::arg("host_context_lengths"), nb::arg("host_request_types"),
|
||||
nb::arg("kv_cache_block_offsets") = std::nullopt, nb::arg("host_kv_cache_block_offsets") = std::nullopt,
|
||||
nb::arg("host_kv_cache_pool_pointers") = std::nullopt, nb::arg("host_kv_cache_pool_mapping") = std::nullopt,
|
||||
nb::arg("cache_indirection") = std::nullopt, nb::arg("kv_scale_orig_quant") = std::nullopt,
|
||||
@ -52,7 +52,6 @@ void initBindings(nb::module_& m)
|
||||
nb::arg("kv_lora_rank") = std::nullopt, nb::arg("qk_nope_head_dim") = std::nullopt,
|
||||
nb::arg("qk_rope_head_dim") = std::nullopt, nb::arg("v_head_dim") = std::nullopt,
|
||||
nb::arg("mrope_rotary_cos_sin") = std::nullopt, nb::arg("mrope_position_deltas") = std::nullopt,
|
||||
nb::arg("mla_context_paged_kv") = std::nullopt, nb::arg("mla_context_kv_cache_block_offsets") = std::nullopt,
|
||||
nb::arg("attention_chunk_size") = std::nullopt, nb::arg("softmax_stats_tensor") = std::nullopt,
|
||||
nb::arg("spec_decoding_bool_params"), nb::arg("spec_decoding_tensor_params"), "Multi-head attention operation");
|
||||
}
|
||||
|
||||
@ -1046,6 +1046,7 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
enqueue_params.host_block_offsets = host_block_offsets;
|
||||
enqueue_params.batch_size = batch_size;
|
||||
enqueue_params.mrope_rotary_cos_sin = mrope_rotary_cos_sin;
|
||||
enqueue_params.total_kv_len = enqueue_params.num_tokens;
|
||||
|
||||
if (isCrossAttention())
|
||||
{
|
||||
|
||||
@ -32,8 +32,8 @@ void initBindings(pybind11::module_& m)
|
||||
// Parameters with default values using std::nullopt for optional arguments
|
||||
py::arg("q"), py::arg("k") = std::nullopt, py::arg("v") = std::nullopt, py::arg("output"),
|
||||
py::arg("output_sf") = std::nullopt, py::arg("out_dtype") = std::nullopt, py::arg("workspace_") = std::nullopt,
|
||||
py::arg("sequence_length"), py::arg("host_past_key_value_lengths"), py::arg("context_lengths"),
|
||||
py::arg("host_context_lengths"), py::arg("host_request_types"),
|
||||
py::arg("sequence_length"), py::arg("host_past_key_value_lengths"), py::arg("host_total_kv_lens"),
|
||||
py::arg("context_lengths"), py::arg("host_context_lengths"), py::arg("host_request_types"),
|
||||
py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_block_offsets") = std::nullopt,
|
||||
py::arg("host_kv_cache_pool_pointers") = std::nullopt, py::arg("host_kv_cache_pool_mapping") = std::nullopt,
|
||||
py::arg("cache_indirection") = std::nullopt, py::arg("kv_scale_orig_quant") = std::nullopt,
|
||||
@ -52,7 +52,6 @@ void initBindings(pybind11::module_& m)
|
||||
py::arg("kv_lora_rank") = std::nullopt, py::arg("qk_nope_head_dim") = std::nullopt,
|
||||
py::arg("qk_rope_head_dim") = std::nullopt, py::arg("v_head_dim") = std::nullopt,
|
||||
py::arg("mrope_rotary_cos_sin") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt,
|
||||
py::arg("mla_context_paged_kv") = std::nullopt, py::arg("mla_context_kv_cache_block_offsets") = std::nullopt,
|
||||
py::arg("attention_chunk_size") = std::nullopt, py::arg("softmax_stats_tensor") = std::nullopt,
|
||||
py::arg("spec_decoding_bool_params"), py::arg("spec_decoding_tensor_params"), "Multi-head attention operation");
|
||||
}
|
||||
|
||||
@ -64,10 +64,12 @@ public:
|
||||
virtual int64_t getWorkspaceSize(AttentionOp const& op, int const num_tokens, int const max_attention_window_size,
|
||||
int const num_gen_tokens) const
|
||||
= 0;
|
||||
// typically, we use single qkv input, but for context MLA, we use separate qkv inputs
|
||||
virtual void run(AttentionOp& op, bool const is_context, int32_t const seq_offset, int32_t const num_seqs,
|
||||
int32_t const token_offset, int32_t const num_tokens, int32_t const predicted_tokens_per_seq,
|
||||
torch::Tensor workspace, torch::Tensor output, torch::optional<torch::Tensor> output_sf, torch::Tensor qkv,
|
||||
torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, torch::Tensor context_lengths,
|
||||
torch::Tensor workspace, torch::Tensor output, torch::optional<torch::Tensor> output_sf, torch::Tensor qkv_or_q,
|
||||
torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v, torch::Tensor sequence_length,
|
||||
torch::Tensor host_past_key_value_lengths, int32_t const total_kv_len, torch::Tensor context_lengths,
|
||||
torch::Tensor host_context_lengths, torch::optional<torch::Tensor> kv_cache_block_offsets,
|
||||
torch::optional<torch::Tensor> host_kv_cache_block_offsets,
|
||||
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
|
||||
@ -77,8 +79,6 @@ public:
|
||||
torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache,
|
||||
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
|
||||
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
|
||||
torch::optional<torch::Tensor> mla_context_paged_kv,
|
||||
torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
|
||||
torch::optional<torch::Tensor> softmax_stats_tensor,
|
||||
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
||||
torch::optional<torch::Tensor> attention_sinks) const
|
||||
@ -121,8 +121,9 @@ public:
|
||||
|
||||
void run(AttentionOp& op, bool const is_context, int32_t const seq_offset, int32_t const num_seqs,
|
||||
int32_t const token_offset, int32_t const num_tokens, int32_t const predicted_tokens_per_seq,
|
||||
torch::Tensor workspace, torch::Tensor output, torch::optional<torch::Tensor> output_sf, torch::Tensor qkv,
|
||||
torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, torch::Tensor context_lengths,
|
||||
torch::Tensor workspace, torch::Tensor output, torch::optional<torch::Tensor> output_sf, torch::Tensor qkv_or_q,
|
||||
torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v, torch::Tensor sequence_length,
|
||||
torch::Tensor host_past_key_value_lengths, int32_t const total_kv_len, torch::Tensor context_lengths,
|
||||
torch::Tensor host_context_lengths, torch::optional<torch::Tensor> kv_cache_block_offsets,
|
||||
torch::optional<torch::Tensor> host_kv_cache_block_offsets,
|
||||
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
|
||||
@ -132,14 +133,14 @@ public:
|
||||
torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache,
|
||||
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
|
||||
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
|
||||
torch::optional<torch::Tensor> mla_context_paged_kv,
|
||||
torch::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
|
||||
torch::optional<torch::Tensor> softmax_stats_tensor,
|
||||
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
||||
torch::optional<torch::Tensor> attention_sinks) const override
|
||||
{
|
||||
auto stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
|
||||
T* attention_input = static_cast<T*>(qkv.slice(0, token_offset).data_ptr());
|
||||
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
|
||||
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
|
||||
T* k_ptr = nullptr;
|
||||
T* v_ptr = nullptr;
|
||||
AttentionOutT* context_buf = static_cast<AttentionOutT*>(output.slice(0, token_offset).data_ptr());
|
||||
TORCH_CHECK(!op.mFuseFp4Quant || output_sf.has_value());
|
||||
void* context_buf_sf = op.mFuseFp4Quant ? output_sf->data_ptr() : nullptr;
|
||||
@ -164,23 +165,33 @@ public:
|
||||
[[maybe_unused]] MlaParams<T> mla_params;
|
||||
if (op.isMLAEnabled())
|
||||
{
|
||||
if (is_context && op.mPagedContextFMHA && op.mPagedKVCache)
|
||||
if (is_context)
|
||||
{
|
||||
TORCH_CHECK(mla_context_paged_kv.has_value());
|
||||
TORCH_CHECK(mla_context_kv_cache_block_offsets.has_value());
|
||||
if (latent_cache.has_value())
|
||||
{
|
||||
mla_params.latent_cache = static_cast<T const*>(latent_cache->data_ptr());
|
||||
}
|
||||
else
|
||||
{
|
||||
// kv cache reuse / chunked context cases, latent_cache is not used
|
||||
mla_params.latent_cache = nullptr;
|
||||
}
|
||||
TORCH_CHECK(k.has_value());
|
||||
TORCH_CHECK(v.has_value());
|
||||
TORCH_CHECK(k->dim() == 2);
|
||||
TORCH_CHECK(v->dim() == 2);
|
||||
TORCH_CHECK(k->strides()[1] == 1);
|
||||
TORCH_CHECK(v->strides()[1] == 1);
|
||||
|
||||
mla_params.context_paged_kv_ptr = mla_context_paged_kv->data_ptr();
|
||||
mla_params.context_kv_cache_block_offsets_ptr = mla_context_kv_cache_block_offsets->data_ptr();
|
||||
mla_params.context_paged_kv_max_blocks_per_seq = mla_context_kv_cache_block_offsets->size(-1);
|
||||
k_ptr = static_cast<T*>(k->slice(0, token_offset).data_ptr());
|
||||
v_ptr = static_cast<T*>(v->slice(0, token_offset).data_ptr());
|
||||
mla_params.k_buf = k_ptr;
|
||||
mla_params.v_buf = v_ptr;
|
||||
}
|
||||
else
|
||||
{
|
||||
// assume latent_cache has been written to paged kv cache by the PyTorch backend
|
||||
TORCH_CHECK(latent_cache.has_value());
|
||||
mla_params.latent_cache = static_cast<T const*>(latent_cache->data_ptr());
|
||||
}
|
||||
if (!is_context)
|
||||
{
|
||||
TORCH_CHECK(q_pe.has_value());
|
||||
TORCH_CHECK(q_pe->dim() == 3);
|
||||
TORCH_CHECK(q_pe->strides()[2] == 1);
|
||||
@ -189,7 +200,7 @@ public:
|
||||
mla_params.q_pe_ld = q_pe->strides()[1];
|
||||
mla_params.q_pe_stride = q_pe->strides()[0];
|
||||
}
|
||||
mla_params.attention_input_buf = attention_input;
|
||||
mla_params.q_buf = attention_input;
|
||||
mla_params.context_buf = reinterpret_cast<T*>(context_buf);
|
||||
|
||||
mla_params.cos_sin_cache = rotary_cos_sin_ptr;
|
||||
@ -300,6 +311,7 @@ public:
|
||||
common_enqueue_params.host_primary_pool_pointer = host_primary_pool_pointer;
|
||||
common_enqueue_params.host_secondary_pool_pointer = host_secondary_pool_pointer;
|
||||
common_enqueue_params.num_tokens = num_tokens;
|
||||
common_enqueue_params.total_kv_len = total_kv_len;
|
||||
common_enqueue_params.max_blocks_per_sequence = max_blocks_per_sequence;
|
||||
common_enqueue_params.sequence_lengths = sequence_lengths_ptr;
|
||||
common_enqueue_params.context_lengths = context_lengths_ptr;
|
||||
@ -316,6 +328,8 @@ public:
|
||||
{
|
||||
enqueue_params.softmaxStatsPtr = static_cast<float2*>(softmax_stats_tensor.value().data_ptr());
|
||||
}
|
||||
enqueue_params.k_ptr = k_ptr;
|
||||
enqueue_params.v_ptr = v_ptr;
|
||||
|
||||
if (op.isMLAEnabled())
|
||||
{
|
||||
@ -424,26 +438,26 @@ using torch_ext::trtllm::attention::AttentionInputType;
|
||||
void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
|
||||
std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
|
||||
torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::Tensor host_request_types,
|
||||
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
|
||||
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
|
||||
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
|
||||
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
|
||||
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
|
||||
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
|
||||
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
|
||||
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
|
||||
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
|
||||
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
|
||||
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
|
||||
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
|
||||
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
|
||||
torch::Tensor host_total_kv_lens, torch::Tensor context_lengths, torch::Tensor host_context_lengths,
|
||||
torch::Tensor host_request_types, std::optional<torch::Tensor> kv_cache_block_offsets,
|
||||
std::optional<torch::Tensor> host_kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_pool_pointers,
|
||||
std::optional<torch::Tensor> host_kv_cache_pool_mapping, std::optional<torch::Tensor> cache_indirection,
|
||||
std::optional<torch::Tensor> kv_scale_orig_quant, std::optional<torch::Tensor> kv_scale_quant_orig,
|
||||
std::optional<torch::Tensor> out_scale, std::optional<torch::Tensor> rotary_inv_freq,
|
||||
std::optional<torch::Tensor> rotary_cos_sin, std::optional<torch::Tensor> latent_cache,
|
||||
std::optional<torch::Tensor> q_pe, std::optional<torch::Tensor> block_ids_per_seq,
|
||||
std::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
|
||||
int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads,
|
||||
int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block,
|
||||
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
|
||||
int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode,
|
||||
double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim,
|
||||
double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
|
||||
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
|
||||
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
|
||||
std::optional<int64_t> q_lora_rank, std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
|
||||
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
|
||||
std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas,
|
||||
std::optional<torch::Tensor> mla_context_paged_kv, std::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
|
||||
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor,
|
||||
std::vector<bool> spec_decoding_bool_params, std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params)
|
||||
{
|
||||
@ -452,9 +466,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_block_offsets.has_value()
|
||||
&& host_kv_cache_pool_pointers.has_value() && host_kv_cache_pool_mapping.has_value();
|
||||
|
||||
TLLM_CHECK_WITH_INFO(is_fused_qkv, "Only fused QKV is supported now");
|
||||
TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv, "Only fused QKV is supported for non-MLA attention now");
|
||||
TLLM_CHECK_WITH_INFO(update_kv_cache, "KV cache update cannot be disabled now");
|
||||
auto qkv = q;
|
||||
auto qkv_or_q = q;
|
||||
if (is_fused_qkv)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!k.has_value(), "The k tensor should be null if using fused QKV");
|
||||
@ -466,7 +480,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
TLLM_CHECK_WITH_INFO(v.has_value(), "The v tensor should be provided if updating KV cache with unfused K/V");
|
||||
}
|
||||
|
||||
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(qkv.scalar_type());
|
||||
auto const dtype = tensorrt_llm::runtime::TorchUtils::dataType(qkv_or_q.scalar_type());
|
||||
bool const is_fp8_out = out_dtype.has_value() && out_dtype.value() == torch::kFloat8_e4m3fn;
|
||||
bool const is_fp4_out = out_dtype.has_value() && out_dtype.value() == torch::kUInt8;
|
||||
|
||||
@ -624,9 +638,11 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
++num_contexts;
|
||||
}
|
||||
int32_t const num_generations = num_seqs - num_contexts;
|
||||
int32_t const num_tokens = qkv.size(0);
|
||||
int32_t const num_tokens = qkv_or_q.size(0);
|
||||
int32_t const num_ctx_tokens = host_context_lengths.slice(0, 0, num_contexts).sum().item<int32_t>();
|
||||
int32_t const num_gen_tokens = is_gen_only ? num_tokens : num_tokens - num_ctx_tokens;
|
||||
auto const ctx_total_kv_len = host_total_kv_lens.index({0}).item<int32_t>();
|
||||
auto const gen_total_kv_len = host_total_kv_lens.index({1}).item<int32_t>();
|
||||
|
||||
for (int32_t idx = num_contexts; idx < num_seqs; idx++)
|
||||
{
|
||||
@ -661,7 +677,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
}
|
||||
else
|
||||
{
|
||||
workspace = torch::empty({workspace_size}, torch::dtype(torch::kByte).device(qkv.device()));
|
||||
workspace = torch::empty({workspace_size}, torch::dtype(torch::kByte).device(qkv_or_q.device()));
|
||||
}
|
||||
|
||||
if ((num_contexts > 0) && (attn_input_type != AttentionInputType::GenerationOnly))
|
||||
@ -671,12 +687,12 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
runner->run(*op,
|
||||
/*is_context=*/true, seq_offset,
|
||||
/*num_seqs=*/num_contexts, token_offset,
|
||||
/*num_tokens=*/num_ctx_tokens, predicted_tokens_per_seq, workspace, output, output_sf, qkv, sequence_length,
|
||||
host_past_key_value_lengths, context_lengths, host_context_lengths, kv_cache_block_offsets,
|
||||
host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
|
||||
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
|
||||
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv,
|
||||
mla_context_kv_cache_block_offsets, softmax_stats_tensor, spec_decoding_tensor_params, attention_sinks);
|
||||
/*num_tokens=*/num_ctx_tokens, predicted_tokens_per_seq, workspace, output, output_sf, qkv_or_q, k, v,
|
||||
sequence_length, host_past_key_value_lengths, ctx_total_kv_len, context_lengths, host_context_lengths,
|
||||
kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers,
|
||||
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
|
||||
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
|
||||
mrope_position_deltas, softmax_stats_tensor, spec_decoding_tensor_params, attention_sinks);
|
||||
}
|
||||
|
||||
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
|
||||
@ -687,12 +703,12 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
runner->run(*op,
|
||||
/*is_context=*/false, seq_offset,
|
||||
/*num_seqs=*/num_generations, token_offset,
|
||||
/*num_tokens=*/num_gen_tokens, predicted_tokens_per_seq, workspace, output, output_sf, qkv, sequence_length,
|
||||
host_past_key_value_lengths, context_lengths, host_context_lengths, kv_cache_block_offsets,
|
||||
host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
|
||||
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
|
||||
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_context_paged_kv,
|
||||
mla_context_kv_cache_block_offsets, softmax_stats_tensor, spec_decoding_tensor_params, attention_sinks);
|
||||
/*num_tokens=*/num_gen_tokens, predicted_tokens_per_seq, workspace, output, output_sf, qkv_or_q, k, v,
|
||||
sequence_length, host_past_key_value_lengths, gen_total_kv_len, context_lengths, host_context_lengths,
|
||||
kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers,
|
||||
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
|
||||
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
|
||||
mrope_position_deltas, softmax_stats_tensor, spec_decoding_tensor_params, attention_sinks);
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);
|
||||
|
||||
@ -37,26 +37,26 @@ namespace torch_ext
|
||||
void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
|
||||
std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
|
||||
torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::Tensor host_request_types,
|
||||
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
|
||||
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
|
||||
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
|
||||
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
|
||||
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
|
||||
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
|
||||
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
|
||||
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
|
||||
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
|
||||
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
|
||||
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
|
||||
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
|
||||
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
|
||||
torch::Tensor host_total_kv_lens, torch::Tensor context_lengths, torch::Tensor host_context_lengths,
|
||||
torch::Tensor host_request_types, std::optional<torch::Tensor> kv_cache_block_offsets,
|
||||
std::optional<torch::Tensor> host_kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_pool_pointers,
|
||||
std::optional<torch::Tensor> host_kv_cache_pool_mapping, std::optional<torch::Tensor> cache_indirection,
|
||||
std::optional<torch::Tensor> kv_scale_orig_quant, std::optional<torch::Tensor> kv_scale_quant_orig,
|
||||
std::optional<torch::Tensor> out_scale, std::optional<torch::Tensor> rotary_inv_freq,
|
||||
std::optional<torch::Tensor> rotary_cos_sin, std::optional<torch::Tensor> latent_cache,
|
||||
std::optional<torch::Tensor> q_pe, std::optional<torch::Tensor> block_ids_per_seq,
|
||||
std::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
|
||||
int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads,
|
||||
int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block,
|
||||
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
|
||||
int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode,
|
||||
double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim,
|
||||
double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
|
||||
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
|
||||
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
|
||||
std::optional<int64_t> q_lora_rank, std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
|
||||
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
|
||||
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
|
||||
std::optional<torch::Tensor> mla_context_paged_kv, std::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
|
||||
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor,
|
||||
std::vector<bool> spec_decoding_bool_params, std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params);
|
||||
|
||||
|
||||
@ -62,39 +62,6 @@ void loadChunkedKVCacheForMLAHelper(torch::Tensor& output_kv, torch::Tensor& out
|
||||
kv_scale_quant_orig_ptr, stream);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void setPagedKVCacheForMLAHelper(torch::Tensor& output, torch::Tensor const& k, torch::Tensor const& v,
|
||||
torch::Tensor const& k_pe, int const num_requests, torch::Tensor const& cu_seq_lens, int const max_input_seq_len,
|
||||
int num_heads, int kv_dim, int rope_dim, int kv_cache_tokens_per_block, int64_t kv_token_stride)
|
||||
{
|
||||
auto stream = at::cuda::getCurrentCUDAStream(output.get_device());
|
||||
auto* output_ptr = static_cast<T*>(output.data_ptr());
|
||||
auto const* k_ptr = static_cast<T const*>(k.data_ptr());
|
||||
auto const* v_ptr = static_cast<T const*>(v.data_ptr());
|
||||
auto const* k_pe_ptr = static_cast<T const*>(k_pe.data_ptr());
|
||||
auto const* cu_seq_lens_ptr = cu_seq_lens.data_ptr<int64_t>();
|
||||
|
||||
// cudaMemset is faster than torch::zeros
|
||||
TLLM_CUDA_CHECK(cudaMemsetAsync(output_ptr, 0, output.numel() * torch::elementSize(output.scalar_type()), stream));
|
||||
tensorrt_llm::kernels::invokeMLASetPagedKV<T>(output_ptr, k_ptr, v_ptr, k_pe_ptr, num_requests, cu_seq_lens_ptr,
|
||||
max_input_seq_len, num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, kv_token_stride, stream);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void setChunkedKVCacheForMLAHelper(torch::Tensor& output, torch::Tensor const& kv, torch::Tensor const& k_pe,
|
||||
int const num_requests, torch::Tensor const& cu_seq_lens, int num_heads, int kv_dim, int rope_dim,
|
||||
int kv_cache_tokens_per_block, int max_seq_len)
|
||||
{
|
||||
auto stream = at::cuda::getCurrentCUDAStream(output.get_device());
|
||||
T* output_ptr = static_cast<T*>(output.data_ptr());
|
||||
T* kv_ptr = static_cast<T*>(kv.data_ptr());
|
||||
T* k_pe_ptr = static_cast<T*>(k_pe.data_ptr());
|
||||
auto* cu_seq_lens_ptr = cu_seq_lens.data_ptr<int64_t>();
|
||||
|
||||
tensorrt_llm::kernels::invokeMLASetChunkedKV<T>(output_ptr, kv_ptr, k_pe_ptr, num_requests, max_seq_len, num_heads,
|
||||
kv_dim, rope_dim, cu_seq_lens_ptr, kv_cache_tokens_per_block, stream);
|
||||
}
|
||||
|
||||
template <typename T, typename TCache>
|
||||
void invokeMLARopeAppendPagedKVAssignQHelper(KVBlockArray& kv_cache, torch::Tensor& q, torch::Tensor& latent_cache,
|
||||
int const num_requests, torch::Tensor const& cu_ctx_cached_kv_lens, torch::Tensor const& cu_seq_lens,
|
||||
@ -371,111 +338,6 @@ std::vector<torch::Tensor> loadChunkedKVCacheForMLA(torch::ScalarType out_dtype,
|
||||
return outputs;
|
||||
}
|
||||
|
||||
torch::Tensor setPagedKVCacheForMLA(torch::Tensor& output, torch::Tensor const& k, torch::Tensor const& v,
|
||||
torch::Tensor const& k_pe, int64_t const num_requests, torch::Tensor const& cu_seq_lens,
|
||||
int64_t const max_input_seq_len, int64_t const num_heads, int64_t const kv_dim, int64_t const rope_dim,
|
||||
int64_t const kv_cache_tokens_per_block)
|
||||
{
|
||||
TORCH_CHECK(output.numel() > 0);
|
||||
auto output_dtype = output.scalar_type();
|
||||
TORCH_CHECK(output_dtype == torch::kFloat16 || output_dtype == torch::kFloat32 || output_dtype == torch::kBFloat16);
|
||||
CHECK_TH_CUDA(output);
|
||||
CHECK_CONTIGUOUS(output);
|
||||
|
||||
// k and v can be non-contiguous
|
||||
CHECK_TH_CUDA(k);
|
||||
CHECK_TYPE(k, output_dtype);
|
||||
CHECK_TH_CUDA(v);
|
||||
CHECK_TYPE(v, output_dtype);
|
||||
TORCH_CHECK(k.dim() == 3);
|
||||
TORCH_CHECK(v.dim() == 3);
|
||||
TORCH_CHECK(k.size(0) == v.size(0));
|
||||
TORCH_CHECK(k.size(1) == v.size(1));
|
||||
TORCH_CHECK(k.size(2) == v.size(2));
|
||||
TORCH_CHECK(k.stride(1) == k.size(2));
|
||||
TORCH_CHECK(v.stride(1) == v.size(2));
|
||||
TORCH_CHECK(k.stride(2) == 1);
|
||||
TORCH_CHECK(v.stride(2) == 1);
|
||||
// k and v should have the same token stride
|
||||
int64_t k_token_stride = k.stride(0);
|
||||
int64_t v_token_stride = v.stride(0);
|
||||
TORCH_CHECK(k_token_stride == v_token_stride);
|
||||
|
||||
// k_pe should be contiguous
|
||||
CHECK_INPUT(k_pe, output_dtype);
|
||||
|
||||
CHECK_INPUT(cu_seq_lens, torch::kInt64);
|
||||
TORCH_CHECK(cu_seq_lens.dim() == 1);
|
||||
TORCH_CHECK(cu_seq_lens.size(0) >= num_requests + 1);
|
||||
|
||||
if (output_dtype == torch::kFloat16)
|
||||
{
|
||||
setPagedKVCacheForMLAHelper<half>(output, k, v, k_pe, num_requests, cu_seq_lens, max_input_seq_len, num_heads,
|
||||
kv_dim, rope_dim, kv_cache_tokens_per_block, k_token_stride);
|
||||
}
|
||||
else if (output_dtype == torch::kFloat32)
|
||||
{
|
||||
setPagedKVCacheForMLAHelper<float>(output, k, v, k_pe, num_requests, cu_seq_lens, max_input_seq_len, num_heads,
|
||||
kv_dim, rope_dim, kv_cache_tokens_per_block, k_token_stride);
|
||||
}
|
||||
else if (output_dtype == torch::kBFloat16)
|
||||
{
|
||||
setPagedKVCacheForMLAHelper<__nv_bfloat16>(output, k, v, k_pe, num_requests, cu_seq_lens, max_input_seq_len,
|
||||
num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, k_token_stride);
|
||||
}
|
||||
|
||||
int64_t max_block_num = (max_input_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
|
||||
|
||||
torch::Tensor faked_kv_cache_block_offsets = torch::arange(
|
||||
0, num_requests * 2 * max_block_num, torch::TensorOptions().dtype(torch::kInt32).device(output.device()));
|
||||
|
||||
faked_kv_cache_block_offsets = faked_kv_cache_block_offsets.view({num_requests, 2, max_block_num});
|
||||
|
||||
return faked_kv_cache_block_offsets;
|
||||
}
|
||||
|
||||
torch::Tensor setChunkedKVCacheForMLA(torch::Tensor& output, torch::Tensor const& kv, torch::Tensor const& k_pe,
|
||||
int64_t const num_requests, torch::Tensor const& cu_seq_lens, int64_t const num_heads, int64_t const kv_dim,
|
||||
int64_t const rope_dim, int64_t const kv_cache_tokens_per_block, int64_t const max_seq_len)
|
||||
{
|
||||
TORCH_CHECK(output.numel() > 0);
|
||||
TORCH_CHECK(output.scalar_type() == torch::kFloat16 || output.scalar_type() == torch::kFloat32
|
||||
|| output.scalar_type() == torch::kBFloat16);
|
||||
CHECK_TH_CUDA(output);
|
||||
CHECK_CONTIGUOUS(output);
|
||||
CHECK_INPUT(kv, output.scalar_type());
|
||||
CHECK_INPUT(k_pe, output.scalar_type());
|
||||
CHECK_INPUT(cu_seq_lens, torch::kInt64);
|
||||
TORCH_CHECK(cu_seq_lens.dim() == 1);
|
||||
TORCH_CHECK(cu_seq_lens.size(0) >= num_requests + 1);
|
||||
|
||||
if (output.scalar_type() == torch::kFloat16)
|
||||
{
|
||||
setChunkedKVCacheForMLAHelper<half>(output, kv, k_pe, num_requests, cu_seq_lens, num_heads, kv_dim, rope_dim,
|
||||
kv_cache_tokens_per_block, max_seq_len);
|
||||
}
|
||||
else if (output.scalar_type() == torch::kFloat32)
|
||||
{
|
||||
setChunkedKVCacheForMLAHelper<float>(output, kv, k_pe, num_requests, cu_seq_lens, num_heads, kv_dim, rope_dim,
|
||||
kv_cache_tokens_per_block, max_seq_len);
|
||||
}
|
||||
else if (output.scalar_type() == torch::kBFloat16)
|
||||
{
|
||||
setChunkedKVCacheForMLAHelper<__nv_bfloat16>(output, kv, k_pe, num_requests, cu_seq_lens, num_heads, kv_dim,
|
||||
rope_dim, kv_cache_tokens_per_block, max_seq_len);
|
||||
}
|
||||
|
||||
int64_t max_block_num = (max_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
|
||||
|
||||
// TODO: actually this offset is always the same for all requests and all layers.
|
||||
torch::Tensor faked_kv_cache_block_offsets = torch::arange(
|
||||
0, num_requests * 2 * max_block_num, torch::TensorOptions().dtype(torch::kInt32).device(output.device()));
|
||||
|
||||
faked_kv_cache_block_offsets = faked_kv_cache_block_offsets.view({num_requests, 2, max_block_num});
|
||||
|
||||
return faked_kv_cache_block_offsets;
|
||||
}
|
||||
|
||||
void MLARopeAppendPagedKVAssignQ(torch::Tensor& q, torch::Tensor& latent_cache, int64_t const num_contexts,
|
||||
torch::Tensor const& cu_ctx_cached_kv_lens, torch::Tensor const& cu_seq_lens,
|
||||
int64_t const max_input_uncached_seq_len, torch::Tensor const& cos_sin_cache, int64_t const head_num,
|
||||
@ -664,51 +526,6 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
m.impl("load_chunked_kv_cache_for_mla", &torch_ext::loadChunkedKVCacheForMLA);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"set_paged_kv_cache_for_mla("
|
||||
"Tensor output"
|
||||
", Tensor k"
|
||||
", Tensor v"
|
||||
", Tensor k_pe"
|
||||
", int num_requests"
|
||||
", Tensor cu_seq_lens"
|
||||
", int max_input_seq_len"
|
||||
", int num_heads"
|
||||
", int kv_dim"
|
||||
", int rope_dim"
|
||||
", int kv_cache_tokens_per_block"
|
||||
") -> Tensor");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("set_paged_kv_cache_for_mla", &torch_ext::setPagedKVCacheForMLA);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"set_chunked_kv_cache_for_mla("
|
||||
"Tensor output"
|
||||
", Tensor kv"
|
||||
", Tensor k_pe"
|
||||
", int num_requests"
|
||||
", Tensor cu_seq_lens"
|
||||
", int num_heads"
|
||||
", int kv_dim"
|
||||
", int rope_dim"
|
||||
", int kv_cache_tokens_per_block"
|
||||
", int max_seq_len"
|
||||
") -> Tensor");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("set_chunked_kv_cache_for_mla", &torch_ext::setChunkedKVCacheForMLA);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
|
||||
@ -63,59 +63,6 @@ void loadChunkedKVKernelRef(T* kv_output, T* k_pe_output, tensorrt_llm::kernels:
|
||||
}
|
||||
}
|
||||
|
||||
// kv {total_tokens, 2, h, nope_size}
|
||||
// k_pe {total_tokens, h=1, rope_size}
|
||||
// output {b, 2, ceil(max_seq / cache_tokens_per_block), h, cache_tokens_per_block, (nope_size + rope_size)}
|
||||
// max_seq <= chunk_size
|
||||
template <typename T>
|
||||
void setChunkedKVCacheForMLAKernelRef(T* output, T* kv_ptr, T* k_pe_ptr, int num_contexts, int64_t const* cu_seq_len,
|
||||
int const max_input_seq_len, int num_heads, int nope_size, int rope_size, int cache_tokens_per_block)
|
||||
{
|
||||
int head_size = nope_size + rope_size;
|
||||
int const kv_cache_size_per_block = num_heads * cache_tokens_per_block * head_size;
|
||||
int const kv_cache_block_num_per_seq = (max_input_seq_len + cache_tokens_per_block - 1) / cache_tokens_per_block;
|
||||
for (int b = 0; b < num_contexts; b++)
|
||||
{
|
||||
int const global_token_offset = cu_seq_len[b];
|
||||
int const current_seq_len = cu_seq_len[b + 1] - cu_seq_len[b];
|
||||
for (int s = 0; s < current_seq_len; s++)
|
||||
{
|
||||
int const global_token_idx = global_token_offset + s;
|
||||
int const kv_cache_block_offset_for_k
|
||||
= (b * 2 * kv_cache_block_num_per_seq + s / cache_tokens_per_block) * kv_cache_size_per_block;
|
||||
int const kv_cache_block_offset_for_v
|
||||
= kv_cache_block_offset_for_k + (kv_cache_block_num_per_seq * kv_cache_size_per_block);
|
||||
for (int h = 0; h < num_heads; h++)
|
||||
{
|
||||
int const ld_k_head_offset = (global_token_idx * 2 * num_heads * nope_size) + h * nope_size;
|
||||
int const ld_v_head_offset = ld_k_head_offset + num_heads * nope_size;
|
||||
int const ld_k_pe_head_offset = global_token_idx * rope_size;
|
||||
// copy kv
|
||||
for (int d = 0; d < nope_size; d++)
|
||||
{
|
||||
int const ld_k_idx = ld_k_head_offset + d;
|
||||
int const ld_v_idx = ld_v_head_offset + d;
|
||||
int const st_k_idx = kv_cache_block_offset_for_k + h * cache_tokens_per_block * head_size
|
||||
+ (s % cache_tokens_per_block) * head_size + d;
|
||||
int const st_v_idx = kv_cache_block_offset_for_v + h * cache_tokens_per_block * head_size
|
||||
+ (s % cache_tokens_per_block) * head_size + d;
|
||||
output[st_k_idx] = kv_ptr[ld_k_idx];
|
||||
output[st_v_idx] = kv_ptr[ld_v_idx];
|
||||
}
|
||||
|
||||
// copy k_pe
|
||||
for (int d = 0; d < rope_size; d++)
|
||||
{
|
||||
int const ld_k_pe_idx = ld_k_pe_head_offset + d;
|
||||
int const st_k_pe_idx = kv_cache_block_offset_for_k + h * cache_tokens_per_block * head_size
|
||||
+ (s % cache_tokens_per_block) * head_size + (nope_size + d);
|
||||
output[st_k_pe_idx] = k_pe_ptr[ld_k_pe_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Q {total_q, H, D}
|
||||
// KV {total_kv, 2, H, D}
|
||||
// softmax_sum {total_q, H, 2} // {max/sum}
|
||||
@ -322,9 +269,6 @@ protected:
|
||||
d_k_pe_output{nullptr}, h_compressed_kv_output_ref{nullptr}, h_k_pe_output_ref{nullptr},
|
||||
h_kv_scale_quant_orig{nullptr}, d_kv_scale_quant_orig{nullptr},
|
||||
|
||||
// for kernel 2
|
||||
h_kv_tensor{nullptr}, d_kv_tensor{nullptr}, h_k_pe_tensor{nullptr}, d_k_pe_tensor{nullptr},
|
||||
|
||||
// for merge attn {kv_full_tensor = kv + k_pe}
|
||||
m_h_q_tensor{nullptr}, m_h_kv_full_tensor{nullptr}, m_h_chunked_kv_tensor{nullptr}, m_h_output_tensor{nullptr},
|
||||
m_h_softmax_sum_tensor{nullptr}, m_h_softmax_sum_accum_tensor{nullptr}, m_h_output_tensor_ref{nullptr},
|
||||
@ -621,27 +565,6 @@ protected:
|
||||
cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
// kv, k_pe for invokeMLASetChunkedKV (kernel 2)
|
||||
this->h_kv_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
||||
ITensor::makeShape({this->mBatchSize * this->mChunkSize, 2, this->mNumHeads, this->mNopeSize}), dtype);
|
||||
this->h_k_pe_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
||||
ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mRopeSize}), dtype);
|
||||
this->d_kv_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_kv_tensor->getShape(), dtype);
|
||||
this->d_k_pe_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_k_pe_tensor->getShape(), dtype);
|
||||
{
|
||||
auto* kv_ptr = bufferCast<DataType>(*(this->h_kv_tensor));
|
||||
auto* k_pe_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor));
|
||||
|
||||
fillArrayDataWithMod<DataType>(kv_ptr, h_kv_tensor->getSize());
|
||||
fillArrayDataWithMod<DataType>(k_pe_ptr, h_k_pe_tensor->getSize());
|
||||
|
||||
cudaMemcpyAsync(d_kv_tensor->data(), h_kv_tensor->data(), h_kv_tensor->getSizeInBytes(),
|
||||
cudaMemcpyHostToDevice, mStream->get());
|
||||
cudaMemcpyAsync(d_k_pe_tensor->data(), h_k_pe_tensor->data(), h_k_pe_tensor->getSizeInBytes(),
|
||||
cudaMemcpyHostToDevice, mStream->get());
|
||||
cudaStreamSynchronize(mStream->get());
|
||||
}
|
||||
|
||||
// invokeMergeAttnWithSoftmax, we just ignore rope_size here for simplicity
|
||||
|
||||
this->m_h_q_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
||||
@ -915,39 +838,6 @@ protected:
|
||||
cudaMemcpyDeviceToHost);
|
||||
sync_check_cuda_error(this->mStream->get());
|
||||
}
|
||||
|
||||
void PerformSetChunkedKVRef()
|
||||
{
|
||||
using tensorrt_llm::runtime::bufferCast;
|
||||
auto* kv_ptr = bufferCast<DataType>(*(this->h_kv_tensor));
|
||||
auto* k_pe_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor));
|
||||
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
|
||||
auto* cu_chunked_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_chunk_lens));
|
||||
this->PrepareChunkedLen(0);
|
||||
setChunkedKVCacheForMLAKernelRef(kv_cache_ptr, kv_ptr, k_pe_ptr, this->mBatchSize, cu_chunked_seq_lens_ptr,
|
||||
this->mChunkSize, this->mNumHeads, this->mNopeSize, this->mRopeSize, this->mTokensPerBlock);
|
||||
}
|
||||
|
||||
void PerformSetChunkedKV()
|
||||
{
|
||||
using tensorrt_llm::runtime::bufferCast;
|
||||
auto* kv_ptr = bufferCast<DataType>(*(this->d_kv_tensor));
|
||||
auto* k_pe_ptr = bufferCast<DataType>(*(this->d_k_pe_tensor));
|
||||
auto* kv_cache_ptr = bufferCast<DataType>(*(this->d_kv_cache_tensor));
|
||||
auto* cu_chunked_seq_lens_ptr = bufferCast<int64_t>(*(this->d_cu_chunk_lens));
|
||||
this->PrepareChunkedLen(0);
|
||||
// copy cu chunk lens to device
|
||||
cudaMemcpy(this->d_cu_chunk_lens->data(), this->h_cu_chunk_lens->data(),
|
||||
this->h_cu_chunk_lens->getSizeInBytes(), cudaMemcpyHostToDevice);
|
||||
tensorrt_llm::kernels::invokeMLASetChunkedKV(kv_cache_ptr, kv_ptr, k_pe_ptr, this->mBatchSize, this->mChunkSize,
|
||||
this->mNumHeads, this->mNopeSize, this->mRopeSize, cu_chunked_seq_lens_ptr, this->mTokensPerBlock,
|
||||
mStream->get());
|
||||
cudaStreamSynchronize(this->mStream->get());
|
||||
// copy result back to host
|
||||
cudaMemcpy(this->h_kv_cache_tensor->data(), kv_cache_ptr, this->h_kv_cache_tensor->getSizeInBytes(),
|
||||
cudaMemcpyDeviceToHost);
|
||||
sync_check_cuda_error(this->mStream->get());
|
||||
}
|
||||
};
|
||||
|
||||
using MLATypes
|
||||
@ -1086,41 +976,3 @@ TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedLoad)
|
||||
}
|
||||
ASSERT_TRUE(allEqual);
|
||||
}
|
||||
|
||||
TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedSet)
|
||||
{
|
||||
using tensorrt_llm::runtime::bufferCast;
|
||||
using DataType = typename TestFixture::DataType;
|
||||
using TCache = typename TestFixture::TCache;
|
||||
if constexpr (std::is_same_v<DataType, TCache>)
|
||||
{
|
||||
this->setDefaultParams();
|
||||
this->allocateBuffers();
|
||||
|
||||
sync_check_cuda_error(this->mStream->get());
|
||||
bool allEqual{true};
|
||||
|
||||
this->PerformSetChunkedKVRef();
|
||||
sync_check_cuda_error(this->mStream->get());
|
||||
this->PerformSetChunkedKV();
|
||||
sync_check_cuda_error(this->mStream->get());
|
||||
|
||||
// check result
|
||||
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor));
|
||||
auto* kv_cache_ptr_ref = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
|
||||
|
||||
for (int i = 0; i < this->h_kv_cache_tensor->getSize(); i++)
|
||||
{
|
||||
if (std::abs(static_cast<float>(kv_cache_ptr[i]) - static_cast<float>(kv_cache_ptr_ref[i]))
|
||||
> getTolerance<DataType>(kv_cache_ptr[i]))
|
||||
{
|
||||
std::cout << "KV cache mismatch at index " << i << ": "
|
||||
<< "expected " << static_cast<float>(kv_cache_ptr_ref[i]) << ", got "
|
||||
<< static_cast<float>(kv_cache_ptr[i]) << std::endl;
|
||||
allEqual = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
ASSERT_TRUE(allEqual);
|
||||
}
|
||||
}
|
||||
|
||||
@ -79,60 +79,6 @@ void loadPagedKvKernelRef(T* compressed_kv_output, T* k_pe_output,
|
||||
}
|
||||
}
|
||||
|
||||
// k {total_token, h, uncompressed_h=128}, v {total_token, h, uncompressed_h}, k_pe {total_token, h=1, rope_h}
|
||||
// output {b, 2, ceil(max_seq / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, (uncompressed_h + rope_h)}
|
||||
// copy k, v, k_pe to a continuous memory space (then it will be packed to kv_cache)
|
||||
template <typename T>
|
||||
void setPagedKvCacheForMLAKernelRef(T* output, T* const k_ptr, T* const v_ptr, T* const k_pe_ptr, int num_requests,
|
||||
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int uncompressed_head_size, int rope_size,
|
||||
int kv_cache_tokens_per_block, int64_t kv_token_stride)
|
||||
{
|
||||
int const kv_cache_size_per_block = num_heads * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size);
|
||||
int const kv_cache_block_num_per_seq
|
||||
= (max_input_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
|
||||
for (int b = 0; b < num_requests; b++)
|
||||
{
|
||||
int const global_token_offset = cu_seq_lens[b];
|
||||
int const current_token_len = cu_seq_lens[b + 1] - cu_seq_lens[b];
|
||||
for (int s = 0; s < current_token_len; s++)
|
||||
{
|
||||
int const global_token_idx = global_token_offset + s;
|
||||
int const kv_cache_block_offset_for_k
|
||||
= ((b * 2 * kv_cache_block_num_per_seq) + (s / kv_cache_tokens_per_block)) * kv_cache_size_per_block;
|
||||
int const kv_cache_block_offset_for_v
|
||||
= kv_cache_block_offset_for_k + (kv_cache_block_num_per_seq * kv_cache_size_per_block);
|
||||
for (int h = 0; h < num_heads; h++)
|
||||
{
|
||||
// copy k, v
|
||||
int const ld_kv_head_offset = (global_token_idx * kv_token_stride) + (h * uncompressed_head_size);
|
||||
int const ld_k_pe_head_offset = (global_token_idx * rope_size);
|
||||
for (int d = 0; d < uncompressed_head_size; d++)
|
||||
{
|
||||
int const ld_kv_idx = ld_kv_head_offset + d;
|
||||
int const st_k_idx = kv_cache_block_offset_for_k
|
||||
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
|
||||
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d;
|
||||
int const st_v_idx = kv_cache_block_offset_for_v
|
||||
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
|
||||
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d;
|
||||
output[st_k_idx] = k_ptr[ld_kv_idx];
|
||||
output[st_v_idx] = v_ptr[ld_kv_idx];
|
||||
}
|
||||
// copy k_pe, head_num = 1
|
||||
for (int d = 0; d < rope_size; d++)
|
||||
{
|
||||
int const ld_k_pe_idx = ld_k_pe_head_offset + d;
|
||||
int const st_k_pe_idx = kv_cache_block_offset_for_k
|
||||
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
|
||||
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d
|
||||
+ uncompressed_head_size;
|
||||
output[st_k_pe_idx] = k_pe_ptr[ld_k_pe_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline bool almostEqual(float a, float b, float atol = 1e-2, float rtol = 1e-3)
|
||||
{
|
||||
if (isnan(a) || isnan(b))
|
||||
@ -168,10 +114,7 @@ protected:
|
||||
h_kv_scale_quant_orig{nullptr}, d_kv_scale_quant_orig{nullptr},
|
||||
// for kernel 1
|
||||
d_compressed_kv_output{nullptr}, h_compressed_kv_output{nullptr}, h_compressed_kv_output_ref{nullptr},
|
||||
d_k_pe_output{nullptr}, h_k_pe_output{nullptr}, h_k_pe_output_ref{nullptr},
|
||||
// for kernel 2
|
||||
d_k_tensor{nullptr}, d_v_tensor{nullptr}, d_k_pe_tensor{nullptr}, h_k_tensor{nullptr}, h_v_tensor{nullptr},
|
||||
h_k_pe_tensor{nullptr};
|
||||
d_k_pe_output{nullptr}, h_k_pe_output{nullptr}, h_k_pe_output_ref{nullptr};
|
||||
|
||||
int mNumRequests{};
|
||||
int mMaxSeqLen{};
|
||||
@ -463,33 +406,6 @@ protected:
|
||||
cudaMemcpy(this->d_k_pe_output->data(), this->h_k_pe_output->data(), this->h_k_pe_output->getSizeInBytes(),
|
||||
cudaMemcpyHostToDevice);
|
||||
}
|
||||
// k, v, k_pe for setPagedKvCacheForMLAKernel (kernel 2)
|
||||
this->h_k_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
||||
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
|
||||
this->h_v_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
||||
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
|
||||
this->h_k_pe_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
||||
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
|
||||
this->d_k_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
|
||||
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
|
||||
this->d_v_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
|
||||
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
|
||||
this->d_k_pe_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
|
||||
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
|
||||
{
|
||||
auto* k_ptr = bufferCast<DataType>(*(this->h_k_tensor));
|
||||
auto* v_ptr = bufferCast<DataType>(*(this->h_v_tensor));
|
||||
auto* k_pe_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor));
|
||||
fillArrayDataWithMod(k_ptr, this->h_k_tensor->getSize());
|
||||
fillArrayDataWithMod(v_ptr, this->h_v_tensor->getSize());
|
||||
fillArrayDataWithMod(k_pe_ptr, this->h_k_pe_tensor->getSize());
|
||||
cudaMemcpy(this->d_k_tensor->data(), this->h_k_tensor->data(), this->h_k_tensor->getSizeInBytes(),
|
||||
cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(this->d_v_tensor->data(), this->h_v_tensor->data(), this->h_v_tensor->getSizeInBytes(),
|
||||
cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(this->d_k_pe_tensor->data(), this->h_k_pe_tensor->data(), this->h_k_pe_tensor->getSizeInBytes(),
|
||||
cudaMemcpyHostToDevice);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -539,35 +455,6 @@ protected:
|
||||
cu_ctx_cached_kv_lens_ptr, this->mLoraSize, this->mRopeSize, kv_scale_quant_orig_ptr);
|
||||
}
|
||||
|
||||
void PerformSetPagedKV()
|
||||
{
|
||||
using tensorrt_llm::runtime::bufferCast;
|
||||
auto* k_ptr = bufferCast<DataType>(*(this->d_k_tensor));
|
||||
auto* v_ptr = bufferCast<DataType>(*(this->d_v_tensor));
|
||||
auto* k_pe_ptr = bufferCast<DataType>(*(this->d_k_pe_tensor));
|
||||
auto* kv_cache_ptr = bufferCast<DataType>(*(this->d_kv_cache_tensor));
|
||||
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->d_cu_seq_lens));
|
||||
tensorrt_llm::kernels::invokeMLASetPagedKV<DataType>(kv_cache_ptr, k_ptr, v_ptr, k_pe_ptr, this->mNumRequests,
|
||||
cu_seq_lens_ptr, this->mMaxSeqLen, this->mNumHeadsUncompressed, this->mUncompressedHeadSize,
|
||||
this->mRopeSize, this->mTokensPerBlock, this->mKvTokenStride, this->mStream->get());
|
||||
cudaStreamSynchronize(this->mStream->get());
|
||||
cudaMemcpy(this->h_kv_cache_tensor->data(), this->d_kv_cache_tensor->data(),
|
||||
this->d_kv_cache_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost);
|
||||
}
|
||||
|
||||
void PerformSetPagedKVRef()
|
||||
{
|
||||
using tensorrt_llm::runtime::bufferCast;
|
||||
auto* k_ptr = bufferCast<DataType>(*(this->h_k_tensor));
|
||||
auto* v_ptr = bufferCast<DataType>(*(this->h_v_tensor));
|
||||
auto* k_pe_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor));
|
||||
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
|
||||
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_seq_lens));
|
||||
setPagedKvCacheForMLAKernelRef(kv_cache_ptr, k_ptr, v_ptr, k_pe_ptr, this->mNumRequests, cu_seq_lens_ptr,
|
||||
this->mMaxSeqLen, this->mNumHeadsUncompressed, this->mUncompressedHeadSize, this->mRopeSize,
|
||||
this->mTokensPerBlock, this->mKvTokenStride);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CheckEqual(T const* expected, T const* output, size_t size)
|
||||
{
|
||||
@ -617,14 +504,4 @@ TYPED_TEST(MlaPreprocessTest, MLAPreprocessDefault)
|
||||
allEqual = this->CheckEqual(k_pe_output_ref_ptr, k_pe_output_ptr, this->h_k_pe_output->getSize());
|
||||
EXPECT_TRUE(allEqual);
|
||||
}
|
||||
|
||||
{
|
||||
this->PerformSetPagedKV();
|
||||
sync_check_cuda_error(this->mStream->get());
|
||||
this->PerformSetPagedKVRef();
|
||||
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor));
|
||||
auto* kv_cache_ref_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
|
||||
allEqual = this->CheckEqual(kv_cache_ref_ptr, kv_cache_ptr, this->h_kv_cache_tensor->getSize());
|
||||
EXPECT_TRUE(allEqual);
|
||||
}
|
||||
}
|
||||
|
||||
@ -783,7 +783,7 @@ echo "All processes completed!"
|
||||
The converted checkpoint could be used as `<YOUR_MODEL_DIR>` and consumed by other commands.
|
||||
|
||||
### KV Cache Reuse
|
||||
KV cache reuse is supported for MLA on SM90 and SM100. It is enabled by default. Due to extra operations like memcpy and GEMMs, GPU memory consumption may be higher and the E2E performance may have regression in some cases. Users could pass `KvCacheConfig(enable_block_reuse=False)` to LLM API to disable it.
|
||||
KV cache reuse is supported for MLA on SM90, SM100 and SM120. It is enabled by default. Due to extra operations like memcpy and GEMMs, GPU memory consumption may be higher and the E2E performance may have regression in some cases. Users could pass `KvCacheConfig(enable_block_reuse=False)` to LLM API to disable it.
|
||||
|
||||
### Chunked Prefill
|
||||
Chunked Prefill is supported for MLA only on SM90 and SM100 currently. You should add `--enable_chunked_prefill` to enable it. The GPU memory consumption is highly correlated with `max_num_tokens` and `max_batch_size`. If encountering out-of-memory errors, you may make these values smaller. (`max_num_tokens` must be divisible by kv cache's `tokens_per_block`)
|
||||
|
||||
@ -52,7 +52,7 @@ class AttentionMetadata:
|
||||
mapping: Optional[Mapping] = None
|
||||
|
||||
enable_flash_mla: bool = False
|
||||
enable_paged_context_mla: bool = False
|
||||
enable_context_mla_with_cached_kv: bool = False
|
||||
# Whether CUDA graph is enabled.
|
||||
is_cuda_graph: bool = field(default=False, repr=False)
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ from .interface import (AttentionBackend, AttentionInputType, AttentionMask,
|
||||
class TrtllmAttentionWrapper:
|
||||
sequence_length: torch.Tensor
|
||||
host_past_key_value_lengths: torch.Tensor
|
||||
host_total_kv_lens: torch.Tensor
|
||||
context_lengths: torch.Tensor
|
||||
host_context_lengths: torch.Tensor
|
||||
host_request_types: torch.Tensor
|
||||
@ -67,6 +68,7 @@ class TrtllmAttentionWrapper:
|
||||
qk_nope_head_dim: Optional[int]
|
||||
v_head_dim: Optional[int]
|
||||
attention_chunk_size: Optional[int]
|
||||
softmax_stats_tensor: Optional[torch.Tensor]
|
||||
use_spec_decoding: bool
|
||||
is_spec_dec_tree: bool
|
||||
spec_decoding_position_offsets: Optional[torch.Tensor]
|
||||
@ -154,6 +156,7 @@ class TrtllmAttentionWrapper:
|
||||
beam_width: int = 1,
|
||||
sequence_length: torch.Tensor = ...,
|
||||
host_past_key_value_lengths: torch.Tensor = ...,
|
||||
host_total_kv_lens: torch.Tensor = ...,
|
||||
context_lengths: torch.Tensor = ...,
|
||||
host_context_lengths: torch.Tensor = ...,
|
||||
host_request_types: torch.Tensor = ...,
|
||||
@ -174,8 +177,6 @@ class TrtllmAttentionWrapper:
|
||||
latent_cache: Optional[torch.Tensor] = None,
|
||||
q_pe: Optional[torch.Tensor] = None,
|
||||
mrope_config: Optional[dict] = None,
|
||||
mla_context_paged_kv: Optional[torch.Tensor] = None,
|
||||
mla_context_kv_cache_block_offsets: Optional[torch.Tensor] = None,
|
||||
softmax_stats_tensor: Optional[torch.Tensor] = None,
|
||||
is_spec_decoding_enabled: bool = False,
|
||||
use_spec_decoding: bool = False,
|
||||
@ -201,6 +202,7 @@ class TrtllmAttentionWrapper:
|
||||
beam_width (int): Beam width in beam search.
|
||||
sequence_length (torch.Tensor): The length of each sequence with shape (batch_size) on GPU.
|
||||
host_past_key_value_lengths (torch.Tensor): Same as sequence_length, but on CPU.
|
||||
host_total_kv_lens (torch.Tensor): The tensor to store the total KV lens for context requests and generation requests, with shape (2) on CPU.
|
||||
context_lengths (torch.Tensor): The context-phase sequence length of each request with shape (batch_size) on GPU.
|
||||
host_context_lengths (torch.Tensor): Same as context_lengths, but on CPU.
|
||||
host_request_types (torch.Tensor): The tensor that indicates whether a request is in context or generation phase, with shape (batch_size) on CPU.
|
||||
@ -216,8 +218,6 @@ class TrtllmAttentionWrapper:
|
||||
out_scale_sf (torch.Tensor): The tensor to store the global scale for NVFP4 scaling factors, with shape (1) on GPU.
|
||||
use_paged_context_fmha (bool): Sets the mPagedContextFMHA attribute in the op runner.
|
||||
mrope_config (dict): The dictionary containing the mRope configuration.
|
||||
mla_context_paged_kv (torch.Tensor): The paged KV cache for MLA context, for kv cache reuse/chunked context.
|
||||
mla_context_kv_cache_block_offsets (torch.Tensor): The block offsets for the paged KV cache for MLA context, for kv cache reuse/chunked context.
|
||||
softmax_stats_tensor (torch.Tensor): The tensor to store the softmax statistics (max/sum)
|
||||
attention_sinks (torch.Tensor): The attention sinks (additional value in the denominator of the softmax) with shape of (num_heads_q) on GPU.
|
||||
"""
|
||||
@ -230,6 +230,7 @@ class TrtllmAttentionWrapper:
|
||||
self.beam_width = beam_width
|
||||
self.sequence_length = sequence_length
|
||||
self.host_past_key_value_lengths = host_past_key_value_lengths
|
||||
self.host_total_kv_lens = host_total_kv_lens
|
||||
self.context_lengths = context_lengths
|
||||
self.host_context_lengths = host_context_lengths
|
||||
self.host_request_types = host_request_types
|
||||
@ -253,8 +254,6 @@ class TrtllmAttentionWrapper:
|
||||
self.mrope_position_deltas = mrope_config.get(
|
||||
'mrope_position_deltas') if mrope_config is not None else None
|
||||
self.block_ids_per_seq = block_ids_per_seq
|
||||
self.mla_context_paged_kv = mla_context_paged_kv
|
||||
self.mla_context_kv_cache_block_offsets = mla_context_kv_cache_block_offsets
|
||||
self.softmax_stats_tensor = softmax_stats_tensor
|
||||
self.attention_sinks = attention_sinks
|
||||
|
||||
@ -361,18 +360,12 @@ class TrtllmAttentionWrapper:
|
||||
else:
|
||||
raise ValueError("Unexpected attention mask type")
|
||||
else:
|
||||
assert is_fused_qkv
|
||||
if self.attention_input_type == AttentionInputType.context_only:
|
||||
if self.use_paged_context_fmha:
|
||||
assert self.mla_context_paged_kv is not None
|
||||
assert self.mla_context_kv_cache_block_offsets is not None
|
||||
qkv_hidden_size = self.num_heads * (self.qk_nope_head_dim +
|
||||
self.qk_rope_head_dim)
|
||||
else:
|
||||
qkv_hidden_size = self.num_heads * (
|
||||
2 * (self.qk_nope_head_dim + self.qk_rope_head_dim)
|
||||
) + self.num_kv_heads * self.v_head_dim
|
||||
assert not is_fused_qkv
|
||||
qkv_hidden_size = self.num_heads * (self.qk_nope_head_dim +
|
||||
self.qk_rope_head_dim)
|
||||
elif self.attention_input_type == AttentionInputType.generation_only:
|
||||
assert is_fused_qkv
|
||||
qkv_hidden_size = self.num_heads * (self.kv_lora_rank +
|
||||
self.qk_rope_head_dim)
|
||||
else:
|
||||
@ -430,6 +423,7 @@ class TrtllmAttentionWrapper:
|
||||
self.workspace,
|
||||
self.sequence_length,
|
||||
self.host_past_key_value_lengths,
|
||||
self.host_total_kv_lens,
|
||||
self.context_lengths,
|
||||
self.host_context_lengths,
|
||||
self.host_request_types,
|
||||
@ -479,8 +473,6 @@ class TrtllmAttentionWrapper:
|
||||
self.v_head_dim,
|
||||
self.mrope_rotary_cos_sin,
|
||||
self.mrope_position_deltas,
|
||||
self.mla_context_paged_kv,
|
||||
self.mla_context_kv_cache_block_offsets,
|
||||
self.attention_chunk_size,
|
||||
self.softmax_stats_tensor,
|
||||
spec_decoding_bool_params,
|
||||
@ -623,6 +615,7 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
self.kv_lens = torch.empty_like(self.kv_lens_cuda,
|
||||
device='cpu',
|
||||
pin_memory=True)
|
||||
self.host_total_kv_lens = torch.empty(2, device='cpu', dtype=torch.int)
|
||||
self.host_request_types = torch.empty_like(self.prompt_lens_cpu)
|
||||
|
||||
# For debugging, can use it to call the wrapper's plan function
|
||||
@ -665,7 +658,7 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
dtype=torch.int32,
|
||||
device='cuda',
|
||||
)
|
||||
if self.enable_paged_context_mla:
|
||||
if self.enable_context_mla_with_cached_kv:
|
||||
# for kv cache reuse/chunked context in MLA
|
||||
self.ctx_cached_token_indptr = torch.zeros(
|
||||
(self.max_num_requests + 1, ),
|
||||
@ -746,12 +739,16 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
kv_lens + self.kv_cache_params.num_extra_kv_tokens)
|
||||
self.kv_lens_cuda[:self.num_seqs].copy_(
|
||||
kv_lens[:self.num_seqs].pin_memory(), non_blocking=True)
|
||||
# total kv lens for context requests and generation requests, without extra tokens
|
||||
self.host_total_kv_lens[0] = kv_lens[:self.num_contexts].sum().item()
|
||||
self.host_total_kv_lens[1] = kv_lens[self.num_contexts:self.
|
||||
num_seqs].sum().item()
|
||||
self.host_request_types[:self.num_contexts].fill_(0)
|
||||
self.host_request_types[self.num_contexts:self.num_seqs].fill_(1)
|
||||
|
||||
# prepare for kv cache reuse/chunked context in MLA
|
||||
if self.enable_paged_context_mla:
|
||||
self.prepare_paged_context_mla(cached_token_lens, kv_lens)
|
||||
if self.enable_context_mla_with_cached_kv:
|
||||
self.prepare_context_mla_with_cached_kv(cached_token_lens, kv_lens)
|
||||
|
||||
# kv block offsets
|
||||
assert self.request_ids is not None
|
||||
@ -838,8 +835,9 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
else:
|
||||
merge_op_tensor[chunked_loop_num, s] = 1 # merge
|
||||
|
||||
def prepare_paged_context_mla(self, cached_token_lens: torch.Tensor,
|
||||
kv_lens: torch.Tensor) -> None:
|
||||
def prepare_context_mla_with_cached_kv(self,
|
||||
cached_token_lens: torch.Tensor,
|
||||
kv_lens: torch.Tensor) -> None:
|
||||
if self.num_contexts > 0:
|
||||
self.num_ctx_cached_tokens = cached_token_lens[:self.
|
||||
num_contexts].sum(
|
||||
@ -1101,8 +1099,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
q_pe: Optional[torch.Tensor] = None,
|
||||
mrope_config: Optional[dict] = None,
|
||||
attention_window_size: Optional[int] = None,
|
||||
mla_context_paged_kv: Optional[torch.Tensor] = None,
|
||||
mla_context_kv_cache_block_offsets: Optional[torch.Tensor] = None,
|
||||
softmax_stats_tensor: Optional[torch.Tensor] = None,
|
||||
enable_attn_nvfp4_output: bool = True,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
@ -1123,9 +1119,8 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
) if metadata.runtime_features else False
|
||||
|
||||
if self.is_mla_enable:
|
||||
# for MLA, we only use paged_context_fmha when there is cached kv
|
||||
use_paged_context_fmha = use_paged_context_fmha and self.has_cached_kv_for_mla_context(
|
||||
metadata)
|
||||
# Context MLA uses separate qkv instead of paged_context_fmha
|
||||
use_paged_context_fmha = False
|
||||
|
||||
use_nvfp4_output = False
|
||||
if enable_attn_nvfp4_output and self.has_nvfp4 and self.support_nvfp4_output(
|
||||
@ -1149,6 +1144,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
beam_width=metadata.beam_width,
|
||||
sequence_length=metadata.kv_lens_cuda_runtime,
|
||||
host_past_key_value_lengths=metadata.kv_lens_runtime,
|
||||
host_total_kv_lens=metadata.host_total_kv_lens,
|
||||
context_lengths=metadata.prompt_lens_cuda_runtime,
|
||||
host_context_lengths=metadata.prompt_lens_cpu_runtime,
|
||||
host_request_types=metadata.host_request_types_runtime,
|
||||
@ -1169,9 +1165,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
latent_cache=latent_cache,
|
||||
q_pe=q_pe,
|
||||
mrope_config=mrope_config,
|
||||
mla_context_paged_kv=mla_context_paged_kv,
|
||||
mla_context_kv_cache_block_offsets=
|
||||
mla_context_kv_cache_block_offsets,
|
||||
softmax_stats_tensor=softmax_stats_tensor,
|
||||
is_spec_decoding_enabled=metadata.is_spec_decoding_enabled,
|
||||
use_spec_decoding=metadata.use_spec_decoding,
|
||||
@ -1232,7 +1225,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
metadata: TrtllmAttentionMetadata,
|
||||
) -> bool:
|
||||
return (self.is_mla_enable and metadata.kv_cache_manager is not None
|
||||
and metadata.enable_paged_context_mla
|
||||
and metadata.enable_context_mla_with_cached_kv
|
||||
and metadata.num_ctx_cached_tokens > 0)
|
||||
|
||||
def is_chunked_prefill_for_mla_context(
|
||||
@ -1240,7 +1233,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
metadata: TrtllmAttentionMetadata,
|
||||
) -> bool:
|
||||
return (self.is_mla_enable and metadata.kv_cache_manager is not None
|
||||
and metadata.enable_paged_context_mla
|
||||
and metadata.enable_context_mla_with_cached_kv
|
||||
and metadata.num_ctx_cached_tokens > 0
|
||||
and metadata.runtime_features.chunked_prefill)
|
||||
|
||||
@ -1248,7 +1241,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
self,
|
||||
metadata: TrtllmAttentionMetadata,
|
||||
out_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert out_dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert self.is_mla_enable and self.mla_params is not None
|
||||
assert metadata.kv_cache_manager is not None
|
||||
@ -1290,15 +1283,19 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
num_ctx_cached_tokens: int,
|
||||
cu_chunked_seq_len: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert out_dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert self.is_mla_enable and self.mla_params is not None
|
||||
assert metadata.kv_cache_manager is not None
|
||||
|
||||
if metadata.max_ctx_cached_token_len == 0:
|
||||
return torch.empty((0, metadata.kv_cache_manager.head_dim),
|
||||
dtype=out_dtype,
|
||||
device=cu_chunked_seq_len.device)
|
||||
empty_kv = torch.empty((0, self.mla_params.kv_lora_rank),
|
||||
dtype=out_dtype,
|
||||
device=cu_chunked_seq_len.device)
|
||||
empty_k_pe = torch.empty((0, self.mla_params.qk_rope_head_dim),
|
||||
dtype=out_dtype,
|
||||
device=cu_chunked_seq_len.device)
|
||||
return empty_kv, empty_k_pe
|
||||
|
||||
sink_token_length = 0
|
||||
beam_width = 1
|
||||
@ -1326,86 +1323,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
)
|
||||
return output_kv, output_k_pe
|
||||
|
||||
def set_paged_kv_cache_for_mla(
|
||||
self,
|
||||
paged_kv: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
metadata: TrtllmAttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert self.is_mla_enable and self.mla_params is not None
|
||||
assert self.mla_params.qk_nope_head_dim == self.mla_params.v_head_dim
|
||||
assert metadata.kv_cache_manager is not None
|
||||
assert paged_kv.shape[0] == metadata.num_contexts
|
||||
assert paged_kv.is_contiguous()
|
||||
|
||||
num_contexts = metadata.num_contexts
|
||||
max_seq_len = metadata.max_ctx_kv_len
|
||||
tokens_per_block = metadata.kv_cache_manager.tokens_per_block
|
||||
|
||||
paged_kv_offsets = torch.ops.trtllm.set_paged_kv_cache_for_mla(
|
||||
paged_kv,
|
||||
k,
|
||||
v,
|
||||
k_pe,
|
||||
num_contexts,
|
||||
metadata.ctx_kv_indptr,
|
||||
max_seq_len,
|
||||
self.num_heads,
|
||||
self.mla_params.qk_nope_head_dim,
|
||||
self.mla_params.qk_rope_head_dim,
|
||||
tokens_per_block,
|
||||
)
|
||||
|
||||
max_block_num = (max_seq_len + tokens_per_block - 1) // tokens_per_block
|
||||
assert paged_kv_offsets.shape == (num_contexts, 2, max_block_num)
|
||||
return paged_kv_offsets
|
||||
|
||||
def set_chunked_kv_cache_for_mla(
|
||||
self,
|
||||
paged_kv: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
cu_chunked_seq_len: torch.Tensor,
|
||||
cached: bool,
|
||||
metadata: TrtllmAttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert self.is_mla_enable and self.mla_params is not None
|
||||
assert self.mla_params.qk_nope_head_dim == self.mla_params.v_head_dim
|
||||
assert metadata.kv_cache_manager is not None
|
||||
assert paged_kv.shape[0] == metadata.num_contexts
|
||||
assert paged_kv.is_contiguous()
|
||||
|
||||
kv = kv.contiguous()
|
||||
k_pe = k_pe.contiguous()
|
||||
|
||||
num_contexts = metadata.num_contexts
|
||||
tokens_per_block = metadata.kv_cache_manager.tokens_per_block
|
||||
if cached:
|
||||
# this indptr is the fake.
|
||||
cu_seq_len = cu_chunked_seq_len
|
||||
max_seq_len = metadata.runtime_features.chunk_size
|
||||
else:
|
||||
cu_seq_len = metadata.ctx_uncached_token_indptr
|
||||
max_seq_len = metadata.max_ctx_seq_len
|
||||
paged_kv_offsets = torch.ops.trtllm.set_chunked_kv_cache_for_mla(
|
||||
paged_kv,
|
||||
kv,
|
||||
k_pe,
|
||||
num_contexts,
|
||||
cu_seq_len,
|
||||
self.num_heads,
|
||||
self.mla_params.qk_nope_head_dim,
|
||||
self.mla_params.qk_rope_head_dim,
|
||||
metadata.kv_cache_manager.tokens_per_block,
|
||||
max_seq_len,
|
||||
)
|
||||
|
||||
max_block_num = (max_seq_len + tokens_per_block - 1) // tokens_per_block
|
||||
assert paged_kv_offsets.shape == (num_contexts, 2, max_block_num)
|
||||
return paged_kv_offsets
|
||||
|
||||
def mla_rope_append_paged_kv_assign_q(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@ -28,7 +28,7 @@ class KVCacheParams:
|
||||
# The number of sink tokens for each layer.
|
||||
host_sink_token_length: Optional[torch.Tensor] = None
|
||||
# The number of extra kv for draft tokens
|
||||
num_extra_kv_tokens: Optional[List[int]] = 0
|
||||
num_extra_kv_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class CacheType(Enum):
|
||||
|
||||
@ -764,7 +764,6 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
enable_allreduce=not (self.disable_attn_allreduce)),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(self.mlp, Deepseekv3MoE):
|
||||
return self.forward_MoE(
|
||||
hidden_states=hidden_states,
|
||||
|
||||
@ -835,7 +835,6 @@ class MLA(nn.Module):
|
||||
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
|
||||
|
||||
self.rope_fusion = self.mha.support_fused_rope()
|
||||
self.support_fused_qkv = self.mha.support_fused_qkv()
|
||||
self.rotary_emb = None
|
||||
self.apply_rotary_emb = not self.rope_fusion
|
||||
if self.apply_rotary_emb:
|
||||
@ -1054,12 +1053,6 @@ class MLA(nn.Module):
|
||||
attn_output_gen = None
|
||||
return output
|
||||
|
||||
def _maybe_concat_qkv(self, q, k, v):
|
||||
if k is not None and v is not None and self.support_fused_qkv:
|
||||
qkv = torch.concat([q, k, v], dim=-1)
|
||||
q, k, v = qkv, None, None
|
||||
return q, k, v
|
||||
|
||||
def forward_context_default(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@ -1085,9 +1078,6 @@ class MLA(nn.Module):
|
||||
self.qk_rope_head_dim)
|
||||
k = k.view(-1, self.num_heads * self.qk_head_dim)
|
||||
|
||||
# May concat q(including q_pe), k + k_pe, v together
|
||||
q, k, v = self._maybe_concat_qkv(q, k, v)
|
||||
|
||||
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
|
||||
out_scale = None # Currently we use BF16 MHA for context phase
|
||||
|
||||
@ -1139,50 +1129,33 @@ class MLA(nn.Module):
|
||||
],
|
||||
-1,
|
||||
)
|
||||
|
||||
full_k_nope = full_k_nope.view(-1, self.num_heads,
|
||||
self.qk_nope_head_dim)
|
||||
full_v = full_v.view(-1, self.num_heads, self.v_head_dim)
|
||||
|
||||
# build paged_full_kv
|
||||
tokens_per_block = attn_metadata.kv_cache_manager.tokens_per_block
|
||||
# paged_full_kv will be initialized to 0 in the kernel to avoid NaN
|
||||
paged_full_kv = torch.empty([
|
||||
attn_metadata.num_contexts, 2,
|
||||
(attn_metadata.max_ctx_kv_len + tokens_per_block - 1) //
|
||||
tokens_per_block, self.num_heads, tokens_per_block,
|
||||
max(self.qk_nope_head_dim + self.qk_rope_head_dim, self.v_head_dim)
|
||||
],
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
mla_context_kv_cache_block_offsets = trtllm_attention.set_paged_kv_cache_for_mla(
|
||||
paged_full_kv,
|
||||
full_k_nope,
|
||||
full_v,
|
||||
full_k_pe,
|
||||
attn_metadata,
|
||||
)
|
||||
full_k_pe = full_k_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||
full_k = torch.cat(
|
||||
(full_k_nope, full_k_pe.expand(-1, self.num_heads, -1)), dim=-1)
|
||||
full_k = full_k.view(-1, self.num_heads * self.qk_head_dim)
|
||||
|
||||
# release pytorch activation memory
|
||||
full_compressed_kv = None
|
||||
full_k_pe = None
|
||||
full_kv = None
|
||||
full_k_nope = None
|
||||
full_v = None
|
||||
|
||||
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
|
||||
out_scale = None # Currently we use BF16 MHA for context phase
|
||||
|
||||
# latent_cache must be None to differentiate from normal context phase,
|
||||
# so that we can skip applying RoPE and appending KV cache inside attention op
|
||||
attn_output = self.mha.forward(
|
||||
q,
|
||||
None,
|
||||
None,
|
||||
full_k,
|
||||
full_v,
|
||||
attn_metadata,
|
||||
attention_input_type=AttentionInputType.context_only,
|
||||
latent_cache=None,
|
||||
out_scale=out_scale,
|
||||
mla_context_paged_kv=paged_full_kv,
|
||||
mla_context_kv_cache_block_offsets=
|
||||
mla_context_kv_cache_block_offsets,
|
||||
output=output,
|
||||
)
|
||||
|
||||
@ -1204,7 +1177,6 @@ class MLA(nn.Module):
|
||||
|
||||
# determine the number of loop
|
||||
# currently we assume that the chunk size is the same as the max_num_tokens
|
||||
chunk_size = attn_metadata.runtime_features.chunk_size
|
||||
chunked_loop_num = attn_metadata.chunked_loop_num
|
||||
|
||||
# [toal_token_q, num_heads, 2] -> [toal_token_q, num_heads] float2
|
||||
@ -1229,6 +1201,7 @@ class MLA(nn.Module):
|
||||
# use fake cached_cu_seq_len for chunked loop
|
||||
origin_kv_lens_cuda_runtime = attn_metadata.kv_lens_cuda_runtime
|
||||
origin_kv_lens_runtime = attn_metadata.kv_lens_runtime
|
||||
origin_ctx_total_kv_len = attn_metadata.host_total_kv_lens[0]
|
||||
|
||||
for loop_idx in range(chunked_loop_num):
|
||||
# {b, chunked_unit_size, h, kv_lora_rank + qk_rope_head_dim} zero padded
|
||||
@ -1246,45 +1219,48 @@ class MLA(nn.Module):
|
||||
# up proj to uncompressed kv
|
||||
# [tokens, 2, h, kv_dim], without rope_dim
|
||||
chunked_kv = self.kv_b_proj(chunked_compressed_kv)
|
||||
chunked_k_nope, chunked_v = chunked_kv.split(
|
||||
[
|
||||
self.num_heads * self.qk_nope_head_dim,
|
||||
self.num_heads * self.v_head_dim
|
||||
],
|
||||
-1,
|
||||
)
|
||||
|
||||
# build full_kv
|
||||
# full_kv {B, 2, chunk_size / tokens_per_block, h, tokens_per_block, kv_dim + rope_dim}
|
||||
tokens_per_block = attn_metadata.kv_cache_manager.tokens_per_block
|
||||
full_kv = torch.zeros([
|
||||
attn_metadata.num_contexts, 2,
|
||||
(chunk_size + tokens_per_block - 1) // tokens_per_block,
|
||||
self.num_heads, tokens_per_block,
|
||||
max(self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||
self.v_head_dim)
|
||||
],
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
mla_kv_cache_block_offsets = trtllm_attention.set_chunked_kv_cache_for_mla(
|
||||
full_kv,
|
||||
chunked_kv,
|
||||
chunked_k_pe,
|
||||
cu_chunked_seq_len=temp_cu_chunked_seq_len,
|
||||
cached=True,
|
||||
metadata=attn_metadata)
|
||||
chunked_k_nope = chunked_k_nope.view(-1, self.num_heads,
|
||||
self.qk_nope_head_dim)
|
||||
chunked_k_pe = chunked_k_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||
chunked_k = torch.cat(
|
||||
(chunked_k_nope, chunked_k_pe.expand(-1, self.num_heads, -1)),
|
||||
dim=-1)
|
||||
chunked_k = chunked_k.view(-1, self.num_heads * self.qk_head_dim)
|
||||
|
||||
# release pytorch activation memory
|
||||
chunked_compressed_kv = None
|
||||
chunked_k_pe = None
|
||||
chunked_kv = None
|
||||
chunked_k_nope = None
|
||||
|
||||
# copy chunked_seq_len to replace kv_lens_runtime
|
||||
attn_metadata.kv_lens_runtime = attn_metadata.host_chunked_seq_len[
|
||||
loop_idx]
|
||||
attn_metadata.kv_lens_cuda_runtime = attn_metadata.chunked_seq_len[
|
||||
loop_idx]
|
||||
attn_metadata.host_total_kv_lens[0] = total_ctx_chunked_tokens
|
||||
|
||||
out_scale = None
|
||||
# do not apply mask for attention within loop
|
||||
# latent_cache must be None to differentiate from normal context phase,
|
||||
# so that we can skip applying RoPE and appending KV cache inside attention op
|
||||
temp_attn_output = self.mha.forward(
|
||||
q,
|
||||
None,
|
||||
None,
|
||||
chunked_k,
|
||||
chunked_v,
|
||||
attn_metadata,
|
||||
attention_input_type=AttentionInputType.context_only,
|
||||
latent_cache=None,
|
||||
out_scale=out_scale,
|
||||
attention_mask=PredefinedAttentionMask.FULL,
|
||||
mla_context_paged_kv=full_kv,
|
||||
mla_context_kv_cache_block_offsets=mla_kv_cache_block_offsets,
|
||||
softmax_stats_tensor=self.temp_softmax_stats_tensor,
|
||||
output=temp_attn_output,
|
||||
)
|
||||
@ -1299,41 +1275,42 @@ class MLA(nn.Module):
|
||||
_, k_pe = latent_cache.view([
|
||||
-1, self.kv_lora_rank + self.qk_rope_head_dim
|
||||
]).split([self.kv_lora_rank, self.qk_rope_head_dim], -1)
|
||||
k_pe = k_pe.contiguous()
|
||||
# final round of attention
|
||||
|
||||
k_nope, v = kv.split(
|
||||
[
|
||||
self.num_heads * self.qk_nope_head_dim,
|
||||
self.num_heads * self.v_head_dim
|
||||
],
|
||||
-1,
|
||||
)
|
||||
|
||||
k_nope = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)
|
||||
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||
k = torch.cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), dim=-1)
|
||||
k = k.view(-1, self.num_heads * self.qk_head_dim)
|
||||
|
||||
# copy q_lens to replace kv_lens_runtime
|
||||
attn_metadata.kv_lens_runtime = attn_metadata.prompt_lens_cpu_runtime
|
||||
attn_metadata.kv_lens_cuda_runtime = attn_metadata.prompt_lens_cuda_runtime
|
||||
attn_metadata.host_total_kv_lens[
|
||||
0] = attn_metadata.prompt_lens_cpu_runtime[:attn_metadata.
|
||||
num_contexts].sum().item(
|
||||
)
|
||||
|
||||
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
|
||||
out_scale = None # Currently we use BF16 MHA for context phase
|
||||
|
||||
tokens_per_block = attn_metadata.kv_cache_manager.tokens_per_block
|
||||
full_kv = torch.zeros([
|
||||
attn_metadata.num_contexts, 2,
|
||||
(attn_metadata.max_ctx_seq_len + tokens_per_block - 1) //
|
||||
tokens_per_block, self.num_heads, tokens_per_block,
|
||||
max(self.qk_nope_head_dim + self.qk_rope_head_dim, self.v_head_dim)
|
||||
],
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
mla_kv_cache_block_offsets = trtllm_attention.set_chunked_kv_cache_for_mla(
|
||||
full_kv,
|
||||
kv,
|
||||
k_pe,
|
||||
cu_chunked_seq_len=None,
|
||||
cached=False,
|
||||
metadata=attn_metadata)
|
||||
# copy q_lens to replace kv_lens_runtime
|
||||
attn_metadata.kv_lens_runtime = attn_metadata.prompt_lens_cpu_runtime
|
||||
attn_metadata.kv_lens_cuda_runtime = attn_metadata.prompt_lens_cuda_runtime
|
||||
# latent_cache must be None to differentiate from normal context phase,
|
||||
# so that we can skip applying RoPE and appending KV cache inside attention op
|
||||
temp_attn_output = self.mha.forward(
|
||||
q,
|
||||
None,
|
||||
None,
|
||||
k,
|
||||
v,
|
||||
attn_metadata,
|
||||
attention_input_type=AttentionInputType.context_only,
|
||||
latent_cache=None,
|
||||
out_scale=out_scale,
|
||||
mla_context_paged_kv=full_kv,
|
||||
mla_context_kv_cache_block_offsets=mla_kv_cache_block_offsets,
|
||||
softmax_stats_tensor=self.temp_softmax_stats_tensor,
|
||||
output=temp_attn_output,
|
||||
)
|
||||
@ -1345,6 +1322,7 @@ class MLA(nn.Module):
|
||||
# copy back kv_lens_runtime and kv_lens_cuda_runtime
|
||||
attn_metadata.kv_lens_runtime = origin_kv_lens_runtime
|
||||
attn_metadata.kv_lens_cuda_runtime = origin_kv_lens_cuda_runtime
|
||||
attn_metadata.host_total_kv_lens[0] = origin_ctx_total_kv_len
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
@ -823,7 +823,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.enable_spec_decode = self.is_spec_decode
|
||||
|
||||
def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
|
||||
enable_paged_context_mla = is_mla(
|
||||
enable_context_mla_with_cached_kv = is_mla(
|
||||
self.model.model_config.pretrained_config) and (
|
||||
self.attn_runtime_features.cache_reuse
|
||||
or self.attn_runtime_features.chunked_prefill)
|
||||
@ -837,7 +837,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
mapping=self.mapping,
|
||||
runtime_features=self.attn_runtime_features,
|
||||
enable_flash_mla=self.model.model_config.enable_flash_mla,
|
||||
enable_paged_context_mla=enable_paged_context_mla,
|
||||
enable_context_mla_with_cached_kv=
|
||||
enable_context_mla_with_cached_kv,
|
||||
cache_indirection=cache_indirection)
|
||||
|
||||
if self.attn_metadata is not None:
|
||||
@ -854,7 +855,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
mapping=self.mapping,
|
||||
runtime_features=self.attn_runtime_features,
|
||||
enable_flash_mla=self.model.model_config.enable_flash_mla,
|
||||
enable_paged_context_mla=enable_paged_context_mla,
|
||||
enable_context_mla_with_cached_kv=enable_context_mla_with_cached_kv,
|
||||
cache_indirection=cache_indirection)
|
||||
|
||||
return self.attn_metadata
|
||||
|
||||
@ -290,11 +290,13 @@ def create_py_executor(
|
||||
f"Change tokens_per_block to: {executor_config.tokens_per_block} for using FlashMLA"
|
||||
)
|
||||
|
||||
if executor_config.kv_cache_config.enable_block_reuse and not (
|
||||
get_sm_version() >= 90 and get_sm_version() <= 100):
|
||||
sm_version = get_sm_version()
|
||||
if executor_config.kv_cache_config.enable_block_reuse and sm_version not in [
|
||||
90, 100, 120
|
||||
]:
|
||||
logger.warning(
|
||||
f"KV cache reuse for MLA can only be enabled on SM90/SM100, "
|
||||
f"disable enable_block_reuse for SM{get_sm_version()}")
|
||||
f"KV cache reuse for MLA can only be enabled on SM90/SM100/SM120, "
|
||||
f"disable enable_block_reuse for SM{sm_version}")
|
||||
executor_config.kv_cache_config.enable_block_reuse = False
|
||||
|
||||
kv_cache_quant_algo = model_engine.model.model_config.quant_config.kv_cache_quant_algo
|
||||
@ -306,11 +308,12 @@ def create_py_executor(
|
||||
f"disable enable_block_reuse for KV cache quant algorithm: {kv_cache_quant_algo}"
|
||||
)
|
||||
executor_config.kv_cache_config.enable_block_reuse = False
|
||||
if executor_config.enable_chunked_context and not (
|
||||
get_sm_version() == 100 or get_sm_version() == 90):
|
||||
if executor_config.enable_chunked_context and sm_version not in [
|
||||
90, 100
|
||||
]:
|
||||
logger.warning(
|
||||
"Chunked Prefill for MLA can only be enabled on SM90/100, "
|
||||
f"disable enable_block_reuse for SM{get_sm_version()}")
|
||||
"Chunked Prefill for MLA can only be enabled on SM90/SM100, "
|
||||
f"disable enable_chunked_context for SM{sm_version}")
|
||||
executor_config.enable_chunked_context = False
|
||||
model_engine.attn_runtime_features.chunked_prefill = False
|
||||
if draft_model_engine is not None:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user