TensorRT-LLMs/cpp/tensorrt_llm/kernels/mlaKernels.cu
zhhuang-nv a891013e3c
[feat] Optimize KV Cache Reuse for MLA (#4869)
Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
2025-06-13 11:03:05 +08:00

1039 lines
49 KiB
Plaintext

/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/common/mathUtils.h"
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/mlaKernels.h"
#include <cstdint>
#include <cub/cub.cuh>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
using namespace tensorrt_llm::common;
namespace tensorrt_llm
{
namespace kernels
{
// A stateful callback functor that maintains the running sum between consecutive scans.
struct BlockPrefixCallbackOp
{
// Running prefix
int mRunningTotal;
// Constructor
__device__ BlockPrefixCallbackOp(int runningTotal)
: mRunningTotal(runningTotal)
{
}
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__ int operator()(int blockAggregate)
{
int oldPrefix = mRunningTotal;
mRunningTotal += blockAggregate;
return oldPrefix;
}
};
template <typename T>
struct VecType
{
using Type = T;
using GPTJEltType = T;
};
template <>
struct VecType<float>
{
using Type = float4;
using GPTJEltType = float2;
};
template <>
struct VecType<half>
{
using Type = uint4;
using GPTJEltType = uint32_t;
};
template <>
struct VecType<__nv_bfloat16>
{
using Type = mmha::bf16_8_t;
using GPTJEltType = __nv_bfloat162;
};
struct __align__(16) fp8_16_t
{
__nv_fp8x4_e4m3 x;
__nv_fp8x4_e4m3 y;
__nv_fp8x4_e4m3 z;
__nv_fp8x4_e4m3 w;
};
template <>
struct VecType<__nv_fp8_e4m3>
{
using Type = fp8_16_t;
using GPTJEltType = __nv_fp8x2_e4m3;
};
template <typename T>
struct loadPagedKVKernelTraits
{
static constexpr int kLoraSize = 512;
static constexpr int kRopeSize = 64;
static constexpr int kHeadSize = kLoraSize + kRopeSize;
using VecT = typename VecType<T>::Type;
static constexpr int kBytesPerElem = sizeof(T);
static constexpr int kBytesPerLoad = 16;
static constexpr int kElemPerLoad = kBytesPerLoad / kBytesPerElem;
static_assert((kHeadSize * kBytesPerElem) % kBytesPerLoad == 0,
"kHeadSize * kBytesPerElem must be multiple of kBytesPerLoad (16Bytes)");
static constexpr int kVecPerHead = (kHeadSize * kBytesPerElem) / kBytesPerLoad;
static constexpr int kThreadPerHead = kVecPerHead; // for each head, we use kThreadPerHead threads to fetch all the
// kv cache data, each thread read kv cache only once.
static constexpr int kTokenPerBlock
= std::is_same_v<T, float> ? 4 : 8; // for each block, we fetch 4 tokens for fp32, 8 tokens for other types.
static constexpr int kBlockSize = kThreadPerHead * kTokenPerBlock;
static constexpr int kKVThreadPerHead = (kLoraSize * kBytesPerElem) / kBytesPerLoad;
};
template <typename T>
struct setPagedKVKernelTraits
{
static constexpr int kQKNopeSize = 128;
static constexpr int kVHeadSize = 128;
static_assert(kQKNopeSize == kVHeadSize);
static constexpr int kRopeSize = 64;
static constexpr int kHeadSize = kQKNopeSize + kRopeSize;
using VecT = typename VecType<T>::Type;
static constexpr int kBytesPerElem = sizeof(T);
static constexpr int kBytesPerLoad = 16;
static constexpr int kElemPerLoad = kBytesPerLoad / kBytesPerElem;
static_assert((kHeadSize * kBytesPerElem) % kBytesPerLoad == 0,
"kHeadSize * kBytesPerElem must be multiple of kBytesPerLoad (16Bytes)");
static constexpr int kNumHeads = 128;
static constexpr int kThreadPerHead = (kHeadSize * kBytesPerElem) / kBytesPerLoad;
static constexpr int kKVThreadPerHead = (kQKNopeSize * kBytesPerElem) / kBytesPerLoad;
static constexpr int kCpTokenPerBlock = 16;
static constexpr int kBlockSize = kThreadPerHead * kCpTokenPerBlock;
};
template <typename SrcType, int NUM>
inline __device__ void quantCopy(
__nv_fp8_e4m3* dst_global_ptr, SrcType const* src_fragment_ptr, float const scale_val = 1.f)
{
using DstVecType = typename std::conditional<sizeof(SrcType) == 2, float2, float>::type;
using SrcType2 =
typename std::conditional<sizeof(SrcType) == 2, typename TypeConverter<SrcType>::Type, float2>::type;
static constexpr int COPY_SIZE = sizeof(DstVecType);
static constexpr int TOTAL_COPY_SIZE = NUM * sizeof(__nv_fp8_e4m3);
static constexpr int LOOP_NUM = TOTAL_COPY_SIZE / COPY_SIZE;
static_assert(TOTAL_COPY_SIZE % COPY_SIZE == 0);
static constexpr int CVT_NUM = COPY_SIZE / sizeof(__nv_fp8_e4m3) / 2;
static_assert(COPY_SIZE % (sizeof(__nv_fp8_e4m3) * 2) == 0);
DstVecType fragment;
int offset = 0;
#pragma unroll
for (int i = 0; i < LOOP_NUM; ++i)
{
#pragma unroll
for (int j = 0; j < CVT_NUM; ++j)
{
float2 val2 = cuda_cast<float2>(reinterpret_cast<SrcType2 const*>(src_fragment_ptr)[j + offset]);
val2.x *= scale_val;
val2.y *= scale_val;
reinterpret_cast<__nv_fp8x2_e4m3*>(&fragment)[j] = __nv_fp8x2_e4m3(val2);
}
reinterpret_cast<DstVecType*>(dst_global_ptr)[i] = fragment;
offset += CVT_NUM;
}
}
template <typename DstType, int NUM>
inline __device__ void dequantCopy(
DstType* dst_global_ptr, __nv_fp8_e4m3 const* src_fragment_ptr, float const scale_val = 1.f)
{
using DstVecType = typename VecType<DstType>::Type;
using DstType2 =
typename std::conditional<sizeof(DstType) == 2, typename TypeConverter<DstType>::Type, float2>::type;
static constexpr int COPY_SIZE = sizeof(DstVecType);
static constexpr int TOTAL_COPY_SIZE = NUM * sizeof(DstType);
static constexpr int LOOP_NUM = TOTAL_COPY_SIZE / COPY_SIZE;
static_assert(TOTAL_COPY_SIZE % COPY_SIZE == 0);
static constexpr int CVT_NUM = COPY_SIZE / sizeof(DstType) / 2;
static_assert(COPY_SIZE % (sizeof(DstType) * 2) == 0);
DstVecType fragment;
int offset = 0;
#pragma unroll
for (int i = 0; i < LOOP_NUM; ++i)
{
#pragma unroll
for (int j = 0; j < CVT_NUM; ++j)
{
float2 val2 = cuda_cast<float2>(reinterpret_cast<__nv_fp8x2_e4m3 const*>(src_fragment_ptr)[j + offset]);
val2.x *= scale_val;
val2.y *= scale_val;
reinterpret_cast<DstType2*>(&fragment)[j] = cuda_cast<DstType2>(val2);
}
reinterpret_cast<DstVecType*>(dst_global_ptr)[i] = fragment;
offset += CVT_NUM;
}
}
template <typename T, int BLOCK_SIZE, int K_DIM, int ROPE_DIM, typename KVCacheBuffer>
__global__ void applyMLARopeAndAssignQKVKernelOptContext(T* qkv_output, T const* fuse_buf, KVCacheBuffer kv_cache,
float2 const* cos_sin_cache, size_t head_num, int head_size, int c_k, int* cu_q_seqlens,
int32_t const* kv_cache_lengths, uint32_t max_input_seq_len, KvCacheDataType cache_type,
float const* quant_scale_kv)
{
// Constants.
using VecT = typename VecType<T>::Type;
using GPTJEltT = typename VecType<T>::GPTJEltType;
constexpr auto HEAD_SIZE = ROPE_DIM;
constexpr auto K_HEAD_SIZE = K_DIM;
constexpr auto BYTES_PER_ELT = sizeof(T);
constexpr auto BYTES_PER_LOAD = 16;
constexpr auto ELTS_PER_VEC = BYTES_PER_LOAD / BYTES_PER_ELT;
static_assert((HEAD_SIZE * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "Head size needs to be multiple of 16 bytes.");
constexpr auto VECS_PER_HEAD = HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
constexpr auto K_VECS_PER_HEAD = K_HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
static_assert(BLOCK_SIZE % VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
constexpr auto TOKENS_PER_BLOCK = BLOCK_SIZE / VECS_PER_HEAD;
constexpr auto K_TOKENS_PER_BLOCK = BLOCK_SIZE / K_VECS_PER_HEAD;
constexpr auto TOTAL_VECS_PER_HEAD = VECS_PER_HEAD + K_VECS_PER_HEAD;
// Block/Head idx.
size_t const batch_idx = blockIdx.y;
size_t const head_idx = blockIdx.z;
if (head_idx < head_num)
{
size_t const head_dim_vec_idx = (threadIdx.x % VECS_PER_HEAD);
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
size_t const seq_len_loop_end
= size_t((max_input_seq_len + TOKENS_PER_BLOCK - 1) / TOKENS_PER_BLOCK) * TOKENS_PER_BLOCK;
float quant_scale_kv_val = quant_scale_kv ? quant_scale_kv[0] : 1.f;
// Mainloop.
for (int local_token_idx = (threadIdx.x / VECS_PER_HEAD) + blockIdx.x * TOKENS_PER_BLOCK;
local_token_idx < seq_len_loop_end; local_token_idx += TOKENS_PER_BLOCK * gridDim.x)
{
int const global_token_offset = cu_q_seqlens[batch_idx];
int const cache_seq_len = kv_cache_lengths[batch_idx];
int token_idx_in_kv_cache = local_token_idx;
bool const valid_token = token_idx_in_kv_cache < cache_seq_len;
// Limit the token_idx to cache seq length (we need all threads in this block to be involved).
token_idx_in_kv_cache = std::min(token_idx_in_kv_cache, cache_seq_len - 1);
local_token_idx = std::min(local_token_idx, cache_seq_len - 1);
int const global_token_idx = local_token_idx + global_token_offset;
auto const position_id = local_token_idx;
float2 const* rotary_coef_cache_buffer
= cos_sin_cache + static_cast<size_t>(ROPE_DIM) * position_id + (head_dim_idx / 2);
VecT q, k;
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM) + c_k;
auto const src_q_global_offset
= static_cast<size_t>(global_token_idx) * head_num * ((head_size + ROPE_DIM) * 2 + head_size)
+ (head_size + ROPE_DIM) * head_idx + head_size;
q = *reinterpret_cast<VecT const*>(&qkv_output[src_q_global_offset + head_dim_idx]);
k = *reinterpret_cast<VecT const*>(&fuse_buf[src_k_global_offset + head_dim_idx]);
// Pack two elements into one for gptj rotary embedding.
#pragma unroll
for (int elt_id = 0; elt_id < ELTS_PER_VEC / 2; elt_id++)
{
GPTJEltT& q_ = reinterpret_cast<GPTJEltT*>(&q)[elt_id];
GPTJEltT& k_ = reinterpret_cast<GPTJEltT*>(&k)[elt_id];
float2 rotary_coef_cache = rotary_coef_cache_buffer[elt_id];
mmha::apply_rotary_embedding_gptj(q_, k_, rotary_coef_cache);
}
// do sync
__syncwarp();
if (valid_token)
{
if (head_idx == 0)
{
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
auto inBlockIdx = kv_cache.getKVLocalIdx(
token_idx_in_kv_cache, 0, TOTAL_VECS_PER_HEAD, K_VECS_PER_HEAD + head_dim_vec_idx);
if (cache_type == KvCacheDataType::FP8)
{
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
reinterpret_cast<T const*>(&k), quant_scale_kv_val);
}
else
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = k;
}
auto const dst_q_idx
= static_cast<size_t>(global_token_idx) * head_num * ((head_size + ROPE_DIM) * 2 + head_size)
+ head_idx * (head_size + ROPE_DIM) + head_size + head_dim_idx;
auto const dst_k_idx
= static_cast<size_t>(global_token_idx) * head_num * ((head_size + ROPE_DIM) * 2 + head_size)
+ head_num * (head_size + ROPE_DIM) + head_idx * (head_size + ROPE_DIM) + head_size + head_dim_idx;
reinterpret_cast<VecT*>(qkv_output)[dst_q_idx / ELTS_PER_VEC] = q;
reinterpret_cast<VecT*>(qkv_output)[dst_k_idx / ELTS_PER_VEC] = k;
}
}
}
else
{
int block_dim = gridDim.z - head_num;
int block_id = head_idx - head_num;
size_t const head_dim_vec_idx = (threadIdx.x % K_VECS_PER_HEAD);
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
size_t const seq_len_loop_end
= size_t((max_input_seq_len + K_TOKENS_PER_BLOCK - 1) / K_TOKENS_PER_BLOCK) * K_TOKENS_PER_BLOCK;
float quant_scale_kv_val = quant_scale_kv ? quant_scale_kv[0] : 1.f;
// Mainloop.
for (int local_token_idx = (threadIdx.x / K_VECS_PER_HEAD) + gridDim.x * K_TOKENS_PER_BLOCK * block_id
+ blockIdx.x * K_TOKENS_PER_BLOCK;
local_token_idx < seq_len_loop_end; local_token_idx += block_dim * K_TOKENS_PER_BLOCK * gridDim.x)
{
int const global_token_offset = cu_q_seqlens[batch_idx];
int const cache_seq_len = kv_cache_lengths[batch_idx];
int token_idx_in_kv_cache = local_token_idx;
bool const valid_token = token_idx_in_kv_cache < cache_seq_len;
// Limit the token_idx to cache seq length (we need all threads in this block to be involved).
token_idx_in_kv_cache = std::min(token_idx_in_kv_cache, cache_seq_len - 1);
local_token_idx = std::min(local_token_idx, cache_seq_len - 1);
int const global_token_idx = local_token_idx + global_token_offset;
if (valid_token)
{
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM);
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
auto inBlockIdx
= kv_cache.getKVLocalIdx(token_idx_in_kv_cache, 0, TOTAL_VECS_PER_HEAD, head_dim_vec_idx);
if (cache_type == KvCacheDataType::FP8)
{
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
fuse_buf + src_k_global_offset + head_dim_idx, quant_scale_kv_val);
}
else
reinterpret_cast<VecT*>(kDst)[inBlockIdx]
= *reinterpret_cast<VecT const*>(&fuse_buf[src_k_global_offset + head_dim_idx]);
}
}
}
}
template <typename T, int BLOCK_SIZE, int K_DIM, int ROPE_DIM, typename KVCacheBuffer>
__global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe, T const* fuse_buf, void* quant_q,
KVCacheBuffer kv_cache, float2 const* cos_sin_cache, size_t head_num, int c_k, int total_s_len, int seq_len,
int* seqQOffset, uint32_t* fmha_tile_counter, int32_t const* kv_cache_lengths, int* seqKVOffsets, int q_pe_ld,
int q_pe_stride, KvCacheDataType cache_type, float* bmm1_scale, float* bmm2_scale, float const* quant_scale_o,
float const* quant_scale_q, float const* quant_scale_kv, float const* dequant_scale_q,
float const* dequant_scale_kv, float host_bmm1_scale)
{
// Constants.
using VecT = typename VecType<T>::Type;
using GPTJEltT = typename VecType<T>::GPTJEltType;
constexpr auto HEAD_SIZE = ROPE_DIM;
constexpr auto K_HEAD_SIZE = K_DIM;
constexpr auto BYTES_PER_ELT = sizeof(T);
constexpr auto BYTES_PER_LOAD = 16;
constexpr auto ELTS_PER_VEC = BYTES_PER_LOAD / BYTES_PER_ELT;
static_assert((HEAD_SIZE * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "Head size needs to be multiple of 16 bytes.");
constexpr auto VECS_PER_HEAD = HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
constexpr auto K_VECS_PER_HEAD = K_HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
static_assert(BLOCK_SIZE % VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
constexpr auto TOKENS_PER_BLOCK = BLOCK_SIZE / VECS_PER_HEAD;
constexpr auto K_TOKENS_PER_BLOCK = BLOCK_SIZE / K_VECS_PER_HEAD;
constexpr auto TOTAL_VEC_PER_HEAD = VECS_PER_HEAD + K_VECS_PER_HEAD;
// Block/Head idx.
size_t const head_idx = blockIdx.y;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0)
{
fmha_tile_counter[0] = 0;
seqQOffset[0] = 0;
// Calculate bmm scale for FP8 MLA
if (cache_type == KvCacheDataType::FP8)
{
float dequant_scale_q_val = dequant_scale_q ? dequant_scale_q[0] : 1.f;
float dequant_scale_kv_val = dequant_scale_kv ? dequant_scale_kv[0] : 1.f;
float quant_scale_o_val = quant_scale_o ? quant_scale_o[0] : 1.f;
if (bmm1_scale)
{
// The scale prepared for log2 optimization.
constexpr float kLog2e = 1.4426950408889634074f;
// The scale after fmha bmm1.
float bmm1_scale_val = dequant_scale_q_val * dequant_scale_kv_val * host_bmm1_scale;
bmm1_scale[0] = bmm1_scale_val;
bmm1_scale[1] = bmm1_scale_val * kLog2e;
}
if (bmm2_scale)
{
// The scale after fmha bmm2.
bmm2_scale[0] = quant_scale_o_val * dequant_scale_kv_val;
}
}
}
if (head_idx <= head_num)
{
size_t const head_dim_vec_idx = (threadIdx.x % VECS_PER_HEAD);
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
int const seq_len_loop_end = size_t((total_s_len + TOKENS_PER_BLOCK - 1) / TOKENS_PER_BLOCK) * TOKENS_PER_BLOCK;
float const quant_scale_q_val = quant_scale_q ? quant_scale_q[0] : 1.0f;
float const quant_scale_kv_val = quant_scale_kv ? quant_scale_kv[0] : 1.0f;
// Mainloop.
for (int global_token_idx = (threadIdx.x / VECS_PER_HEAD) + blockIdx.x * TOKENS_PER_BLOCK;
global_token_idx < seq_len_loop_end; global_token_idx += TOKENS_PER_BLOCK * gridDim.x)
{
auto batch_idx = global_token_idx / seq_len;
auto local_token_idx = global_token_idx % seq_len;
bool const valid_token = global_token_idx < total_s_len;
VecT data;
if (valid_token)
{
auto const position_id = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;
float2 const* rotary_coef_cache_buffer
= cos_sin_cache + static_cast<size_t>(ROPE_DIM) * position_id + (head_dim_idx / 2);
if (head_idx == head_num)
{
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM) + c_k;
data = *reinterpret_cast<VecT const*>(&fuse_buf[src_k_global_offset + head_dim_idx]);
}
else
{
auto const src_q_global_offset
= static_cast<size_t>(global_token_idx) * q_pe_stride + q_pe_ld * head_idx;
data = *reinterpret_cast<VecT const*>(&q_pe[src_q_global_offset + head_dim_idx]);
}
// Pack two elements into one for gptj rotary embedding.
#pragma unroll
for (int elt_id = 0; elt_id < ELTS_PER_VEC / 2; elt_id++)
{
GPTJEltT& data_ = reinterpret_cast<GPTJEltT*>(&data)[elt_id];
float2 rotary_coef_cache = rotary_coef_cache_buffer[elt_id];
data_ = mmha::rotary_embedding_transform(data_, rotary_coef_cache);
}
}
__syncwarp();
if (valid_token)
{
if (head_idx == head_num)
{
auto const token_kv_idx = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;
{
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_kv_idx));
auto inBlockIdx = kv_cache.getKVLocalIdx(
token_kv_idx, 0, TOTAL_VEC_PER_HEAD, K_VECS_PER_HEAD + head_dim_vec_idx);
if (cache_type == KvCacheDataType::FP8)
{
quantCopy<T, ELTS_PER_VEC>(
reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
reinterpret_cast<T const*>(&data), quant_scale_kv_val);
}
else
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = data;
}
}
else
{
auto const dst_q_idx = static_cast<size_t>(global_token_idx) * head_num * (c_k + ROPE_DIM)
+ head_idx * (c_k + ROPE_DIM) + c_k + head_dim_idx;
if (cache_type == KvCacheDataType::FP8)
{
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(quant_q) + dst_q_idx,
reinterpret_cast<T const*>(&data), quant_scale_q_val);
}
else
reinterpret_cast<VecT*>(qkv_output)[dst_q_idx / ELTS_PER_VEC] = data;
}
}
}
}
else if (head_idx <= head_num + 8)
{
int block_dim = gridDim.y - head_num - 1;
int block_id = head_idx - head_num - 1;
size_t const head_dim_vec_idx = (threadIdx.x % K_VECS_PER_HEAD);
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
size_t const seq_len_loop_end
= size_t((total_s_len + K_TOKENS_PER_BLOCK - 1) / K_TOKENS_PER_BLOCK) * K_TOKENS_PER_BLOCK;
float quant_scale_kv_val = quant_scale_kv ? quant_scale_kv[0] : 1.0f;
// Mainloop.
for (int global_token_idx = (threadIdx.x / K_VECS_PER_HEAD) + gridDim.x * K_TOKENS_PER_BLOCK * block_id
+ blockIdx.x * K_TOKENS_PER_BLOCK;
global_token_idx < seq_len_loop_end; global_token_idx += block_dim * K_TOKENS_PER_BLOCK * gridDim.x)
{
auto batch_idx = global_token_idx / seq_len;
auto local_token_idx = global_token_idx % seq_len;
bool valid_token = global_token_idx < total_s_len;
if (valid_token)
{
if (head_dim_vec_idx == 0)
{
seqQOffset[batch_idx + 1] = head_num * seq_len * (batch_idx + 1);
}
auto const token_kv_idx = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;
auto const src_kv_global_offset = static_cast<size_t>(global_token_idx) * (c_k + ROPE_DIM);
{
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_kv_idx));
auto inBlockIdx = kv_cache.getKVLocalIdx(token_kv_idx, 0, TOTAL_VEC_PER_HEAD, head_dim_vec_idx);
if (cache_type == KvCacheDataType::FP8)
{
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
fuse_buf + src_kv_global_offset + head_dim_idx, quant_scale_kv_val);
}
else
reinterpret_cast<VecT*>(kDst)[inBlockIdx]
= *reinterpret_cast<VecT const*>(&fuse_buf[src_kv_global_offset + head_dim_idx]);
}
}
}
}
else
{
if (cache_type == KvCacheDataType::FP8)
{
int block_dim = gridDim.y - head_num - 1 - 8;
int block_id = head_idx - head_num - 1 - 8;
size_t const head_dim_vec_idx = (threadIdx.x % K_VECS_PER_HEAD);
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
size_t const head_num_idx = (block_id % head_num) * (K_HEAD_SIZE + HEAD_SIZE);
size_t const seq_len_loop_end
= size_t((total_s_len + K_TOKENS_PER_BLOCK - 1) / K_TOKENS_PER_BLOCK) * K_TOKENS_PER_BLOCK;
float quant_scale_q_val = quant_scale_q ? quant_scale_q[0] : 1.0f;
// Mainloop.
for (int global_token_idx = (threadIdx.x / K_VECS_PER_HEAD)
+ (block_id / head_num) * gridDim.x * K_TOKENS_PER_BLOCK + blockIdx.x * K_TOKENS_PER_BLOCK;
global_token_idx < seq_len_loop_end;
global_token_idx += (block_dim / head_num) * gridDim.x * K_TOKENS_PER_BLOCK)
{
if (global_token_idx < total_s_len)
{
size_t const load_idx
= global_token_idx * head_num * (K_HEAD_SIZE + HEAD_SIZE) + head_num_idx + head_dim_idx;
quantCopy<T, ELTS_PER_VEC>(
reinterpret_cast<__nv_fp8_e4m3*>(quant_q) + load_idx, qkv_output + load_idx, quant_scale_q_val);
}
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
// The implementation of the parallel scan in the thread block (see CUB for details).
using BlockScan = cub::BlockScan<int, BLOCK_SIZE>;
// Allocate storage in shared memory to do the scan.
__shared__ typename BlockScan::TempStorage tempKVStorage;
BlockPrefixCallbackOp prefixKVOp(0);
if (blockIdx.x == 0 && blockIdx.y == 0)
{
int const batchSizeBound = total_s_len / seq_len;
for (int batchOffset = 0; batchOffset <= batchSizeBound; batchOffset += BLOCK_SIZE)
{
// The index of the batch.
int batchIdx = batchOffset + threadIdx.x;
int seqKVLength = 0;
if (batchIdx < batchSizeBound)
{
seqKVLength = kv_cache_lengths[batchIdx];
}
int seqKVOffset;
BlockScan(tempKVStorage).ExclusiveSum(seqKVLength, seqKVOffset, prefixKVOp);
if (batchIdx <= batchSizeBound)
{
seqKVOffsets[batchIdx] = seqKVOffset;
}
}
}
}
template <typename T, typename TCache>
__global__ void loadPagedKVCacheForMLAKernel(T* compressed_kv_ptr, T* k_pe_ptr,
tensorrt_llm::kernels::KVBlockArray const kv_cache, int64_t const* cu_ctx_cached_kv_lens, int max_input_seq_len,
float const* kv_scale_quant_orig_ptr)
{
static_assert(std::is_same_v<T, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>,
"TCache must be either the same type as T or __nv_fp8_e4m3");
using KT = typename tensorrt_llm::kernels::loadPagedKVKernelTraits<TCache>;
int const batch_idx = static_cast<int>(blockIdx.y);
float const kv_scale_quant_orig = kv_scale_quant_orig_ptr ? kv_scale_quant_orig_ptr[0] : 1.0f;
size_t const head_dim_vec_idx = (threadIdx.x % KT::kVecPerHead);
size_t const head_dim_idx = head_dim_vec_idx * KT::kElemPerLoad;
bool const is_valid_kv = head_dim_vec_idx < KT::kKVThreadPerHead;
size_t const seq_len_loop_end
= (max_input_seq_len + KT::kTokenPerBlock - 1) / KT::kTokenPerBlock * KT::kTokenPerBlock;
int64_t const global_token_offset = cu_ctx_cached_kv_lens[batch_idx];
int64_t const cache_kv_len = cu_ctx_cached_kv_lens[batch_idx + 1] - cu_ctx_cached_kv_lens[batch_idx];
for (int local_token_idx = (threadIdx.x / KT::kThreadPerHead) + blockIdx.x * KT::kTokenPerBlock;
local_token_idx < seq_len_loop_end; local_token_idx += KT::kTokenPerBlock * gridDim.x)
{
int token_idx_in_kv_cache = local_token_idx;
bool const valid_token = token_idx_in_kv_cache < cache_kv_len;
if (valid_token)
{
auto* kvSrc = reinterpret_cast<TCache*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
// head_idx === 0
auto kvBlockIdx
= kv_cache.getKVLocalIdx(token_idx_in_kv_cache, 0, KT::kVecPerHead, static_cast<int>(head_dim_vec_idx));
auto src_data = reinterpret_cast<typename KT::VecT*>(kvSrc)[kvBlockIdx];
int const global_token_idx = local_token_idx + global_token_offset;
if (is_valid_kv)
{
// compressed_kv {total_token, lora_size}
int const dstIdx = global_token_idx * KT::kLoraSize + head_dim_idx;
// copy back to compressed_kv
if constexpr (std::is_same_v<TCache, T>)
{
*reinterpret_cast<typename KT::VecT*>(compressed_kv_ptr + dstIdx) = src_data;
}
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
dequantCopy<T, KT::kElemPerLoad>(compressed_kv_ptr + dstIdx,
reinterpret_cast<__nv_fp8_e4m3 const*>(&src_data), kv_scale_quant_orig);
}
}
else
{
// k_pe {total_token, rope_size}
int const dstIdx = global_token_idx * KT::kRopeSize + (head_dim_idx - KT::kLoraSize);
// copy back to k_pe
if constexpr (std::is_same_v<TCache, T>)
{
*reinterpret_cast<typename KT::VecT*>(k_pe_ptr + dstIdx) = src_data;
}
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
dequantCopy<T, KT::kElemPerLoad>(
k_pe_ptr + dstIdx, reinterpret_cast<__nv_fp8_e4m3 const*>(&src_data), kv_scale_quant_orig);
}
}
}
}
}
// k {total_token, h, d}, v {total_token, h, d}, k_pe {total_token, h=1, d_rope}
// output {b, 2, ceil(max_seq / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, d}
template <typename T>
__global__ void setPagedKVCacheForMLAKernel(T* output, T const* k_ptr, T const* v_ptr, T const* k_pe_ptr,
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, int rope_dim,
int kv_cache_tokens_per_block, int64_t kv_token_stride)
{
using KT = typename tensorrt_llm::kernels::setPagedKVKernelTraits<T>;
int const batch_idx = static_cast<int>(blockIdx.y);
int const head_idx = static_cast<int>(blockIdx.z);
int const head_dim_vec_idx = (threadIdx.x % KT::kThreadPerHead);
int const head_dim_idx = head_dim_vec_idx * KT::kElemPerLoad;
bool const is_valid_v = head_dim_idx < KT::kVHeadSize;
size_t const seq_len_loop_end
= (max_input_seq_len + KT::kCpTokenPerBlock - 1) / KT::kCpTokenPerBlock * KT::kCpTokenPerBlock;
size_t const kv_cache_block_size = num_heads * kv_cache_tokens_per_block * (kv_dim + rope_dim);
size_t const kv_cache_block_num = (max_input_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
int64_t const global_token_offset = cu_seq_lens[batch_idx];
int64_t const cache_kv_len = cu_seq_lens[batch_idx + 1] - cu_seq_lens[batch_idx];
for (int local_token_idx = (threadIdx.x / KT::kThreadPerHead) + blockIdx.x * KT::kCpTokenPerBlock;
local_token_idx < seq_len_loop_end; local_token_idx += KT::kCpTokenPerBlock * gridDim.x)
{
int token_idx_in_kv_cache = local_token_idx;
bool const valid_token = token_idx_in_kv_cache < cache_kv_len;
if (valid_token)
{
// copy k and v
if (is_valid_v)
{
int ld_kv_global_offset = (global_token_offset + local_token_idx) * kv_token_stride + head_idx * kv_dim;
int ld_kv_local_offset = head_dim_vec_idx;
auto k_data
= (reinterpret_cast<typename KT::VecT const*>(k_ptr + ld_kv_global_offset))[ld_kv_local_offset];
auto v_data
= (reinterpret_cast<typename KT::VecT const*>(v_ptr + ld_kv_global_offset))[ld_kv_local_offset];
// {b, 0, token / kv_cache_tokens_per_block, h, token % kv_cache_tokens_per_block, ...}
int st_k_global_offset = batch_idx * 2 * kv_cache_block_num * kv_cache_block_size
+ local_token_idx / kv_cache_tokens_per_block * kv_cache_block_size
+ head_idx * kv_cache_tokens_per_block * (kv_dim + rope_dim)
+ (local_token_idx % kv_cache_tokens_per_block) * (kv_dim + rope_dim);
// {b, 1, token / kv_cache_tokens_per_block, h, token % kv_cache_tokens_per_block, ...}
int st_v_global_offset = st_k_global_offset + kv_cache_block_num * kv_cache_block_size;
int st_k_local_offset = head_dim_vec_idx;
int st_v_local_offset = head_dim_vec_idx;
(reinterpret_cast<typename KT::VecT*>(output + st_k_global_offset))[st_k_local_offset] = k_data;
(reinterpret_cast<typename KT::VecT*>(output + st_v_global_offset))[st_v_local_offset] = v_data;
}
// copy k_pe, only 1 head
else
{
int ld_rope_global_offset = (global_token_offset + local_token_idx) * rope_dim;
int ld_rope_local_offset = head_dim_vec_idx - KT::kKVThreadPerHead;
auto rope_data = (reinterpret_cast<typename KT::VecT const*>(
k_pe_ptr + ld_rope_global_offset))[ld_rope_local_offset];
// {b, 0, token / kv_cache_tokens_per_block, h, token % kv_cache_tokens_per_block, ...}
int st_rope_global_offset = batch_idx * 2 * kv_cache_block_num * kv_cache_block_size
+ local_token_idx / kv_cache_tokens_per_block * kv_cache_block_size
+ head_idx * kv_cache_tokens_per_block * (kv_dim + rope_dim)
+ (local_token_idx % kv_cache_tokens_per_block) * (kv_dim + rope_dim);
int st_rope_local_offset = head_dim_vec_idx;
(reinterpret_cast<typename KT::VecT*>(output + st_rope_global_offset))[st_rope_local_offset]
= rope_data;
}
}
else
{
break;
}
}
}
// q {total_uncached_tokens, h, d_nope + d_rope}
// latent_cache {total_uncached_tokens, d_k + d_rope}
template <typename T, typename TCache, int BLOCK_SIZE, int K_DIM, int ROPE_DIM>
__global__ void applyMLARopeAppendPagedKVAssignQKernel(KVBlockArray kv_cache, T* q_ptr, T const* latent_cache_ptr,
int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens, int const max_input_uncached_seq_len,
float2 const* cos_sin_cache, size_t head_num, int nope_size, float const* kv_scale_orig_quant_ptr)
{
static_assert(std::is_same_v<T, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>,
"TCache must be either the same type as T or __nv_fp8_e4m3");
// Constants.
using VecT = typename VecType<T>::Type;
using GPTJEltT = typename VecType<T>::GPTJEltType;
constexpr auto HEAD_SIZE = ROPE_DIM;
constexpr auto K_HEAD_SIZE = K_DIM;
constexpr auto BYTES_PER_ELT = sizeof(T);
constexpr auto BYTES_PER_LOAD = 16;
constexpr auto ELTS_PER_VEC = BYTES_PER_LOAD / BYTES_PER_ELT;
static_assert((HEAD_SIZE * BYTES_PER_ELT) % BYTES_PER_LOAD == 0, "Head size needs to be multiple of 16 bytes.");
constexpr auto VECS_PER_HEAD = HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
constexpr auto K_VECS_PER_HEAD = K_HEAD_SIZE * BYTES_PER_ELT / BYTES_PER_LOAD;
static_assert(BLOCK_SIZE % VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
constexpr auto TOKENS_PER_BLOCK = BLOCK_SIZE / VECS_PER_HEAD;
constexpr auto K_TOKENS_PER_BLOCK = BLOCK_SIZE / K_VECS_PER_HEAD;
constexpr auto TOTAL_VECS_PER_HEAD = VECS_PER_HEAD + K_VECS_PER_HEAD;
// Block/Head idx.
size_t const batch_idx = blockIdx.y;
size_t const head_idx = blockIdx.z;
int64_t const global_token_offset = cu_seq_lens[batch_idx] - cu_ctx_cached_kv_lens[batch_idx];
int64_t const cached_kv_len = cu_ctx_cached_kv_lens[batch_idx + 1] - cu_ctx_cached_kv_lens[batch_idx];
int64_t const uncached_kv_len = cu_seq_lens[batch_idx + 1] - cu_seq_lens[batch_idx] - cached_kv_len;
if (head_idx <= head_num)
{
size_t const head_dim_vec_idx = (threadIdx.x % VECS_PER_HEAD);
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
size_t const seq_len_loop_end
= size_t((max_input_uncached_seq_len + TOKENS_PER_BLOCK - 1) / TOKENS_PER_BLOCK) * TOKENS_PER_BLOCK;
float quant_scale_kv_val = kv_scale_orig_quant_ptr ? kv_scale_orig_quant_ptr[0] : 1.f;
// Mainloop.
for (int local_token_idx = (threadIdx.x / VECS_PER_HEAD) + blockIdx.x * TOKENS_PER_BLOCK;
local_token_idx < seq_len_loop_end; local_token_idx += TOKENS_PER_BLOCK * gridDim.x)
{
int token_idx_in_kv_cache = local_token_idx + cached_kv_len;
bool valid_token = local_token_idx < uncached_kv_len;
int const global_token_idx = local_token_idx + global_token_offset;
VecT data;
if (valid_token)
{
auto const position_id = token_idx_in_kv_cache;
float2 const* rotary_coef_cache_buffer
= cos_sin_cache + static_cast<size_t>(ROPE_DIM) * position_id + (head_dim_idx / 2);
if (head_idx == head_num)
{
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (K_DIM + ROPE_DIM) + K_DIM;
data = *reinterpret_cast<VecT const*>(&latent_cache_ptr[src_k_global_offset + head_dim_idx]);
}
else
{
auto const src_q_global_offset
= static_cast<size_t>(global_token_idx) * head_num * (nope_size + ROPE_DIM)
+ (nope_size + ROPE_DIM) * head_idx + nope_size;
data = *reinterpret_cast<VecT const*>(&q_ptr[src_q_global_offset + head_dim_idx]);
}
// Pack two elements into one for gptj rotary embedding.
#pragma unroll
for (int elt_id = 0; elt_id < ELTS_PER_VEC / 2; elt_id++)
{
GPTJEltT& data_ = reinterpret_cast<GPTJEltT*>(&data)[elt_id];
float2 rotary_coef_cache = rotary_coef_cache_buffer[elt_id];
data_ = mmha::rotary_embedding_transform(data_, rotary_coef_cache);
}
}
// do sync
__syncwarp();
if (valid_token)
{
if (head_idx == head_num)
{
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
auto inBlockIdx = kv_cache.getKVLocalIdx(
token_idx_in_kv_cache, 0, TOTAL_VECS_PER_HEAD, K_VECS_PER_HEAD + head_dim_vec_idx);
if constexpr (std::is_same_v<TCache, T>)
{
reinterpret_cast<VecT*>(kDst)[inBlockIdx] = data;
}
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
reinterpret_cast<T const*>(&data), quant_scale_kv_val);
}
}
else
{
auto const dst_q_idx = static_cast<size_t>(global_token_idx) * head_num * (nope_size + ROPE_DIM)
+ head_idx * (nope_size + ROPE_DIM) + nope_size + head_dim_idx;
reinterpret_cast<VecT*>(q_ptr)[dst_q_idx / ELTS_PER_VEC] = data;
}
}
}
}
else
{
int block_dim = gridDim.z - head_num - 1;
int block_id = head_idx - head_num - 1;
size_t const head_dim_vec_idx = (threadIdx.x % K_VECS_PER_HEAD);
size_t const head_dim_idx = head_dim_vec_idx * ELTS_PER_VEC;
size_t const seq_len_loop_end
= size_t((max_input_uncached_seq_len + K_TOKENS_PER_BLOCK - 1) / K_TOKENS_PER_BLOCK) * K_TOKENS_PER_BLOCK;
float quant_scale_kv_val = kv_scale_orig_quant_ptr ? kv_scale_orig_quant_ptr[0] : 1.f;
// Mainloop.
for (int local_token_idx = (threadIdx.x / K_VECS_PER_HEAD) + gridDim.x * K_TOKENS_PER_BLOCK * block_id
+ blockIdx.x * K_TOKENS_PER_BLOCK;
local_token_idx < seq_len_loop_end; local_token_idx += block_dim * K_TOKENS_PER_BLOCK * gridDim.x)
{
int token_idx_in_kv_cache = local_token_idx + cached_kv_len;
bool valid_token = local_token_idx < uncached_kv_len;
int const global_token_idx = local_token_idx + global_token_offset;
if (valid_token)
{
auto const src_k_global_offset = static_cast<size_t>(global_token_idx) * (K_DIM + ROPE_DIM);
auto kDst = reinterpret_cast<T*>(kv_cache.getKBlockPtr(batch_idx, token_idx_in_kv_cache));
auto inBlockIdx
= kv_cache.getKVLocalIdx(token_idx_in_kv_cache, 0, TOTAL_VECS_PER_HEAD, head_dim_vec_idx);
if constexpr (std::is_same_v<TCache, T>)
{
reinterpret_cast<VecT*>(kDst)[inBlockIdx]
= *reinterpret_cast<VecT const*>(&latent_cache_ptr[src_k_global_offset + head_dim_idx]);
}
else if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
quantCopy<T, ELTS_PER_VEC>(reinterpret_cast<__nv_fp8_e4m3*>(kDst) + inBlockIdx * ELTS_PER_VEC,
latent_cache_ptr + src_k_global_offset + head_dim_idx, quant_scale_kv_val);
}
}
}
}
}
template <typename T, typename KVCacheBuffer>
void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream)
{
dim3 grid(int(tensorrt_llm::common::divUp(params.max_input_seq_len, 32)), params.batch_size, params.head_num + 8);
auto head_size = params.meta.qk_nope_head_dim;
applyMLARopeAndAssignQKVKernelOptContext<T, 256, 512, 64, KVCacheBuffer>
<<<grid, 256, 0, stream>>>(params.attention_input_buf, params.latent_cache, kv_cache_buffer,
params.cos_sin_cache, params.head_num, head_size, params.meta.kv_lora_rank, params.cu_q_seqlens,
params.cache_seq_lens, params.max_input_seq_len, params.cache_type, params.quant_scale_kv);
}
template <typename T, typename KVCacheBuffer>
void invokeMLARopeGeneration(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream)
{
dim3 grid(int(tensorrt_llm::common::divUp(params.acc_q_len, 32)), params.head_num + 1 + 8);
if (params.cache_type == KvCacheDataType::FP8)
grid.y += params.head_num * 8;
TLLM_CHECK_WITH_INFO(params.acc_q_len % params.batch_size == 0,
"MLA can only support input sequences with the same sequence length.");
auto seq_len = params.acc_q_len / params.batch_size;
auto* kernel_instance = &applyMLARopeAndAssignQKVKernelGeneration<T, 256, 512, 64, KVCacheBuffer>;
cudaLaunchConfig_t config;
config.gridDim = grid;
config.blockDim = 256;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance, params.attention_input_buf, params.q_pe, params.latent_cache,
params.quant_attention_input_buf, kv_cache_buffer, params.cos_sin_cache, params.head_num,
params.meta.kv_lora_rank, params.acc_q_len, seq_len, params.seqQOffset, params.fmha_tile_counter,
params.cache_seq_lens, params.cu_kv_seqlens, params.q_pe_ld, params.q_pe_stride, params.cache_type,
params.bmm1_scale, params.bmm2_scale, params.quant_scale_o, params.quant_scale_q, params.quant_scale_kv,
params.dequant_scale_q, params.dequant_scale_kv, params.host_bmm1_scale);
}
template <typename T, typename TCache>
void invokeMLALoadPagedKV(T* compressed_kv_ptr, T* k_pe_ptr, KVBlockArray& kv_cache, int const num_contexts,
int64_t const* cu_ctx_cached_kv_lens, int const max_input_seq_len, int const lora_size, int const rope_size,
float const* kv_scale_quant_orig_ptr, cudaStream_t stream)
{
using KT = typename tensorrt_llm::kernels::loadPagedKVKernelTraits<TCache>;
// {seq_len / token_per_block, batch_size, head_num}
TLLM_CHECK_WITH_INFO(lora_size == KT::kLoraSize, "lora_size should be equal to %d", KT::kLoraSize);
TLLM_CHECK_WITH_INFO(rope_size == KT::kRopeSize, "rope_size should be equal to %d", KT::kRopeSize);
TLLM_CHECK_WITH_INFO(lora_size + rope_size == KT::kHeadSize, "head dim should be equal to %d", KT::kHeadSize);
dim3 grid(static_cast<int>(tensorrt_llm::common::divUp(max_input_seq_len, KT::kTokenPerBlock)), num_contexts, 1);
loadPagedKVCacheForMLAKernel<T, TCache><<<grid, KT::kBlockSize, 0, stream>>>(
compressed_kv_ptr, k_pe_ptr, kv_cache, cu_ctx_cached_kv_lens, max_input_seq_len, kv_scale_quant_orig_ptr);
}
template <typename T>
void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_pe_ptr, int const num_requests,
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, int rope_dim,
int kv_cache_tokens_per_block, int64_t kv_token_stride, cudaStream_t stream)
{
using KT = typename tensorrt_llm::kernels::setPagedKVKernelTraits<T>;
TLLM_CHECK_WITH_INFO(kv_dim + rope_dim == KT::kHeadSize, "head dim should be equal to %d", KT::kHeadSize);
TLLM_CHECK_WITH_INFO(kv_cache_tokens_per_block % KT::kCpTokenPerBlock == 0,
"kv_cache_tokens_per_block should be multiple of %d", KT::kCpTokenPerBlock);
dim3 grid(tensorrt_llm::common::divUp(max_input_seq_len, KT::kCpTokenPerBlock), num_requests, num_heads);
setPagedKVCacheForMLAKernel<T><<<grid, KT::kBlockSize, 0, stream>>>(output, k_ptr, v_ptr, k_pe_ptr, cu_seq_lens,
max_input_seq_len, num_heads, kv_dim, rope_dim, kv_cache_tokens_per_block, kv_token_stride);
}
template <typename T, typename TCache>
void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T const* latent_cache_ptr,
int const num_requests, int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens,
int const max_input_uncached_seq_len, float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size,
int lora_size, float const* kv_scale_orig_quant_ptr, cudaStream_t stream)
{
dim3 grid(int(tensorrt_llm::common::divUp(max_input_uncached_seq_len, 32)), num_requests, head_num + 1 + 8);
TLLM_CHECK_WITH_INFO(lora_size == 512, "lora_size should be equal to %d", 512);
TLLM_CHECK_WITH_INFO(rope_size == 64, "rope_size should be equal to %d", 64);
applyMLARopeAppendPagedKVAssignQKernel<T, TCache, 256, 512, 64><<<grid, 256, 0, stream>>>(kv_cache, q_ptr,
latent_cache_ptr, cu_ctx_cached_kv_lens, cu_seq_lens, max_input_uncached_seq_len, cos_sin_cache, head_num,
nope_size, kv_scale_orig_quant_ptr);
}
#define INSTANTIATE_MLA_ROPE(T, KVCacheBuffer) \
template void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream); \
template void invokeMLARopeGeneration(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, cudaStream_t stream);
INSTANTIATE_MLA_ROPE(float, KVBlockArray);
INSTANTIATE_MLA_ROPE(half, KVBlockArray);
INSTANTIATE_MLA_ROPE(float, KVLinearBuffer);
INSTANTIATE_MLA_ROPE(half, KVLinearBuffer);
#ifdef ENABLE_BF16
INSTANTIATE_MLA_ROPE(__nv_bfloat16, KVBlockArray);
INSTANTIATE_MLA_ROPE(__nv_bfloat16, KVLinearBuffer);
#endif
#define INSTANTIATE_RW_KVCACHE_MLA(T, TCache) \
template void invokeMLALoadPagedKV<T, TCache>(T * compressed_kv_ptr, T * k_pe_ptr, KVBlockArray & kv_cache, \
int const num_contexts, int64_t const* cu_ctx_cached_kv_lens, int const max_input_seq_len, \
int const lora_size, int const rope_size, float const* kv_scale_quant_orig_ptr, cudaStream_t stream); \
template void invokeMLARopeAppendPagedKVAssignQ<T, TCache>(KVBlockArray & kv_cache, T * q_ptr, \
T const* latent_cache_ptr, int const num_requests, int64_t const* cu_ctx_cached_kv_lens, \
int64_t const* cu_seq_lens, int const max_input_uncached_seq_len, float2 const* cos_sin_cache, \
size_t head_num, int nope_size, int rope_size, int lora_size, float const* kv_scale_orig_quant_ptr, \
cudaStream_t stream);
INSTANTIATE_RW_KVCACHE_MLA(float, float);
INSTANTIATE_RW_KVCACHE_MLA(float, __nv_fp8_e4m3);
INSTANTIATE_RW_KVCACHE_MLA(half, half);
INSTANTIATE_RW_KVCACHE_MLA(half, __nv_fp8_e4m3);
INSTANTIATE_RW_KVCACHE_MLA(__nv_bfloat16, __nv_bfloat16);
INSTANTIATE_RW_KVCACHE_MLA(__nv_bfloat16, __nv_fp8_e4m3);
#define INSTANTIATE_SET_KVCACHE_MLA(T) \
template void invokeMLASetPagedKV(T* output, T const* k_ptr, T const* v_ptr, T const* k_pe_ptr, \
int const num_requests, int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int kv_dim, \
int rope_dim, int kv_cache_tokens_per_block, int64_t kv_token_stride, cudaStream_t stream);
INSTANTIATE_SET_KVCACHE_MLA(float);
INSTANTIATE_SET_KVCACHE_MLA(half);
INSTANTIATE_SET_KVCACHE_MLA(__nv_bfloat16);
} // namespace kernels
} // namespace tensorrt_llm