TensorRT-LLMs/cpp/tests/unit_tests/kernels/mlaPreprocessTest.cu
zhhuang-nv 8452775db8
[TRTLLM-5070][feat] Support FP8 KV Cache Reuse for MLA (#4535)
* optimize kv cache reuse workflow for MLA

write kv cache first and only call up-projection GEMM once
relax contiguous requirements of k/v for setting paged kv cache
return two contiguous tensors when loading MLA KV Cache

Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>

* support fp8 kv cache for MLA kv cache reuse

Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>

* resolve comments

Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>

---------

Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
2025-05-23 19:47:50 +08:00

954 lines
53 KiB
Plaintext

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <cstdint>
#include <gtest/gtest.h>
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/kernels/kvCacheUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/kernels/mlaKernels.h"
#include <random>
namespace
{
// copy matched kv cache data to compressed_kv_output and k_pe_output
// compressed_kv_output {total_cached_token, lora_size}
// k_pe_output {total_cached_token, rope_size}
// compressed_kv_cache {batch, 1 (ignore v), max_seq_len / tokens_per_block, num_head, tokens_per_block, (lora_size +
// rope_size)}
template <typename T, typename TCache>
void loadPagedKvKernelRef(T* compressed_kv_output, T* k_pe_output,
tensorrt_llm::kernels::KVBlockArray const& compressed_kv_cache, int num_contexts,
int64_t const* cu_ctx_cached_kv_lens, int const lora_size, int const rope_size,
float const* kv_scale_quant_orig_ptr)
{
static_assert(std::is_same_v<T, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>,
"TCache must be either the same type as T or __nv_fp8_e4m3");
int const head_dim = lora_size + rope_size;
float const kv_scale_quant_orig = kv_scale_quant_orig_ptr ? kv_scale_quant_orig_ptr[0] : 1.0f;
for (int b = 0; b < num_contexts; b++)
{
int const global_token_offset = cu_ctx_cached_kv_lens[b];
int const current_token_len = cu_ctx_cached_kv_lens[b + 1] - cu_ctx_cached_kv_lens[b];
for (int s = 0; s < current_token_len; s++)
{
int const global_token_idx = global_token_offset + s;
for (int d = 0; d < head_dim; d++)
{
auto const* kv_src = reinterpret_cast<TCache const*>(compressed_kv_cache.getKBlockPtr(b, s));
auto kv_block_idx = compressed_kv_cache.getKVLocalIdx(s, 0, head_dim, d);
auto src_data = kv_src[kv_block_idx];
T data;
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
data = T(float(src_data) * kv_scale_quant_orig);
}
else
{
data = src_data;
}
if (d < lora_size)
{
compressed_kv_output[global_token_idx * lora_size + d] = data;
}
else
{
k_pe_output[global_token_idx * rope_size + (d - lora_size)] = data;
}
}
}
}
}
// k {total_token, h, uncompressed_h=128}, v {total_token, h, uncompressed_h}, k_pe {total_token, h=1, rope_h}
// output {b, 2, ceil(max_seq / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, (uncompressed_h + rope_h)}
// copy k, v, k_pe to a continuous memory space (then it will be packed to kv_cache)
template <typename T>
void setPagedKvCacheForMLAKernelRef(T* output, T* const k_ptr, T* const v_ptr, T* const k_pe_ptr, int num_requests,
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int uncompressed_head_size, int rope_size,
int kv_cache_tokens_per_block, int64_t kv_token_stride)
{
int const kv_cache_size_per_block = num_heads * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size);
int const kv_cache_block_num_per_seq
= (max_input_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
for (int b = 0; b < num_requests; b++)
{
int const global_token_offset = cu_seq_lens[b];
int const current_token_len = cu_seq_lens[b + 1] - cu_seq_lens[b];
for (int s = 0; s < current_token_len; s++)
{
int const global_token_idx = global_token_offset + s;
int const kv_cache_block_offset_for_k
= ((b * 2 * kv_cache_block_num_per_seq) + (s / kv_cache_tokens_per_block)) * kv_cache_size_per_block;
int const kv_cache_block_offset_for_v
= kv_cache_block_offset_for_k + (kv_cache_block_num_per_seq * kv_cache_size_per_block);
for (int h = 0; h < num_heads; h++)
{
// copy k, v
int const ld_kv_head_offset = (global_token_idx * kv_token_stride) + (h * uncompressed_head_size);
int const ld_k_pe_head_offset = (global_token_idx * rope_size);
for (int d = 0; d < uncompressed_head_size; d++)
{
int const ld_kv_idx = ld_kv_head_offset + d;
int const st_k_idx = kv_cache_block_offset_for_k
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d;
int const st_v_idx = kv_cache_block_offset_for_v
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d;
output[st_k_idx] = k_ptr[ld_kv_idx];
output[st_v_idx] = v_ptr[ld_kv_idx];
}
// copy k_pe, head_num = 1
for (int d = 0; d < rope_size; d++)
{
int const ld_k_pe_idx = ld_k_pe_head_offset + d;
int const st_k_pe_idx = kv_cache_block_offset_for_k
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d
+ uncompressed_head_size;
output[st_k_pe_idx] = k_pe_ptr[ld_k_pe_idx];
}
}
}
}
}
// ck or cv {total_cached_token, h, uncompressed_h=128}, ck_pe {total_cached_token, h=1, rope_h}
// uk or uv {total_uncached_token, h, uncompressed_h}, uk_pe {total_uncached_token, h=1, rope_h}
// output {b, 2, ceil(max_seq / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, (uncompressed_h + rope_h)}
// copy k, v, k_pe to a continuous memory space (then it will be packed to kv_cache)
template <typename T>
void setPagedKvCacheForMLAKernelRefV2(T* output, T* const ck_ptr, T* const cv_ptr, T* const ck_pe_ptr, T* const nk_ptr,
T* const nv_ptr, T* const nk_pe_ptr, int num_requests, int64_t const* cu_ctx_cached_kv_lens,
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int uncompressed_head_size, int rope_size,
int kv_cache_tokens_per_block)
{
int const kv_cache_size_per_block = num_heads * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size);
int const kv_cache_block_num_per_seq
= (max_input_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
for (int b = 0; b < num_requests; b++)
{
int const global_cached_token_offset = cu_ctx_cached_kv_lens[b];
int const global_unchached_token_offset = cu_seq_lens[b] - cu_ctx_cached_kv_lens[b];
int const current_token_len = cu_seq_lens[b + 1] - cu_seq_lens[b];
int const current_cached_token_len = cu_ctx_cached_kv_lens[b + 1] - cu_ctx_cached_kv_lens[b];
// int const current_uncached_token_len = current_token_len - current_cached_token_len;
for (int s = 0; s < current_token_len; s++)
{
bool const is_cached = (s < current_cached_token_len);
int const global_token_idx = is_cached ? global_cached_token_offset + s
: global_unchached_token_offset + (s - current_cached_token_len);
int const kv_cache_block_offset_for_k
= ((b * 2 * kv_cache_block_num_per_seq) + (s / kv_cache_tokens_per_block)) * kv_cache_size_per_block;
int const kv_cache_block_offset_for_v
= kv_cache_block_offset_for_k + (kv_cache_block_num_per_seq * kv_cache_size_per_block);
auto const k_ptr = is_cached ? ck_ptr : nk_ptr;
auto const v_ptr = is_cached ? cv_ptr : nv_ptr;
auto const k_pe_ptr = is_cached ? ck_pe_ptr : nk_pe_ptr;
for (int h = 0; h < num_heads; h++)
{
// copy k, v
int const ld_kv_head_offset
= (global_token_idx * num_heads * uncompressed_head_size) + (h * uncompressed_head_size);
int const ld_k_pe_head_offset = (global_token_idx * rope_size);
for (int d = 0; d < uncompressed_head_size; d++)
{
int const ld_kv_idx = ld_kv_head_offset + d;
int const st_k_idx = kv_cache_block_offset_for_k
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d;
int const st_v_idx = kv_cache_block_offset_for_v
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d;
output[st_k_idx] = k_ptr[ld_kv_idx];
output[st_v_idx] = v_ptr[ld_kv_idx];
}
// copy k_pe, head_num = 1
for (int d = 0; d < rope_size; d++)
{
int const ld_k_pe_idx = ld_k_pe_head_offset + d;
int const st_k_pe_idx = kv_cache_block_offset_for_k
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d
+ uncompressed_head_size;
output[st_k_pe_idx] = k_pe_ptr[ld_k_pe_idx];
}
}
}
}
}
// compressed_kv_cache {batch, 1 (ignore v), max_seq_len / tokens_per_block, num_head=1, tokens_per_block, (lora_size +
// rope_size)}
// kv {total_uncached_tokens, h_k=1, lora_d}, k_pe {total_uncached_tokens, h_kpe=1, rope_d}
template <typename T, typename TCache>
void appendPagedKvForMLAKernelRef(tensorrt_llm::kernels::KVBlockArray& kv_cache, T* const compressed_kv_ptr,
T* const k_pe_ptr, int const num_requests, int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens,
int k_pe_head_num, int lora_size, int rope_size, float const* kv_scale_orig_quant_ptr)
{
static_assert(std::is_same_v<T, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>,
"TCache must be either the same type as T or __nv_fp8_e4m3");
assert(k_pe_head_num == 1);
float const kv_scale_orig_quant = kv_scale_orig_quant_ptr ? kv_scale_orig_quant_ptr[0] : 1.0f;
for (int b = 0; b < num_requests; b++)
{
int const global_token_offset = cu_seq_lens[b] - cu_ctx_cached_kv_lens[b];
int const cached_kv_len = cu_ctx_cached_kv_lens[b + 1] - cu_ctx_cached_kv_lens[b];
int const uncached_token_len = cu_seq_lens[b + 1] - cu_seq_lens[b] - cached_kv_len;
for (int s = 0; s < uncached_token_len; s++)
{
int const ld_kv_offset = (global_token_offset + s) * lora_size;
int const ld_k_pe_offset = (global_token_offset + s) * k_pe_head_num * rope_size;
auto* kv_cache_ptr = reinterpret_cast<TCache*>(kv_cache.getKBlockPtr(b, cached_kv_len + s));
// copy kv
for (int d = 0; d < lora_size; d++)
{
int const ld_kv_idx = ld_kv_offset + d;
int const kv_cache_idx_in_block
= kv_cache.getKVLocalIdx(cached_kv_len + s, 0, lora_size + rope_size, d);
auto src_data = compressed_kv_ptr[ld_kv_idx];
TCache data;
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
data = static_cast<__nv_fp8_e4m3>(static_cast<float>(src_data) * kv_scale_orig_quant);
}
else
{
data = src_data;
}
kv_cache_ptr[kv_cache_idx_in_block] = data;
}
// copy k_pe (we only copy the first head)
for (int d = 0; d < rope_size; d++)
{
int const ld_k_pe_idx = ld_k_pe_offset + d;
int const kv_cache_idx_in_block
= kv_cache.getKVLocalIdx(cached_kv_len + s, 0, lora_size + rope_size, d + lora_size);
auto src_data = k_pe_ptr[ld_k_pe_idx];
TCache data;
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
data = static_cast<__nv_fp8_e4m3>(static_cast<float>(src_data) * kv_scale_orig_quant);
}
else
{
data = src_data;
}
kv_cache_ptr[kv_cache_idx_in_block] = data;
}
}
}
}
inline bool almostEqual(float a, float b, float atol = 1e-2, float rtol = 1e-3)
{
if (isnan(a) || isnan(b))
{
return false;
}
return fabs(a - b) <= (atol + rtol * fabs(b));
}
} // namespace
template <typename Typepair>
class MlaPreprocessTest : public testing::Test
{
protected:
using DataType = typename Typepair::first_type;
using TCache = typename Typepair::second_type;
static_assert(std::is_same_v<DataType, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>,
"TCache must be either the same type as DataType or __nv_fp8_e4m3");
std::shared_ptr<tensorrt_llm::runtime::BufferManager> mBufferManager;
std::shared_ptr<tensorrt_llm::runtime::CudaStream> mStream;
// kv_cache shape {batch, 2(k or v), max_seq_len / tokens_per_block, num_head, tokens_per_block, head_size}
// k, v, k_pe shape {total_token, num_head, head_size(lora_size or rope_size, or uncompressed_head_size)}
// offset shape {batch, 2, max_seq_len / tokens_per_block}
// for KVBlockArray, we only allocate primary pool.
// you can infer the allocateBuffers function for more details.
tensorrt_llm::runtime::BufferManager::ITensorPtr h_kv_cache_tensor{nullptr}, h_kv_cache_tensor_ref{nullptr},
d_kv_cache_tensor{nullptr}, d_compressed_kv_cache_tensor{nullptr}, d_compressed_kv_cache_tensor_ref{nullptr},
h_compressed_kv_cache_tensor{nullptr}, h_compressed_kv_cache_tensor_ref{nullptr}, d_offset_tensor{nullptr},
d_compressed_offset_tensor{nullptr}, d_cu_ctx_cached_kv_lens{nullptr}, d_cu_seq_lens{nullptr},
h_offset_tensor{nullptr}, h_compressed_offset_tensor{nullptr}, h_cu_ctx_cached_kv_lens{nullptr},
h_cu_seq_lens{nullptr}, h_kv_scale_orig_quant{nullptr}, d_kv_scale_orig_quant{nullptr},
h_kv_scale_quant_orig{nullptr}, d_kv_scale_quant_orig{nullptr},
// for kernel 1
d_compressed_kv_output{nullptr}, h_compressed_kv_output{nullptr}, h_compressed_kv_output_ref{nullptr},
d_k_pe_output{nullptr}, h_k_pe_output{nullptr}, h_k_pe_output_ref{nullptr},
// for kernel 2
d_k_tensor{nullptr}, d_v_tensor{nullptr}, d_k_pe_tensor{nullptr}, h_k_tensor{nullptr}, h_v_tensor{nullptr},
h_k_pe_tensor{nullptr},
// for kernel 2 (new)
d_k_tensor_cached{nullptr}, d_v_tensor_cached{nullptr}, d_k_pe_tensor_cached{nullptr},
d_k_tensor_uncached{nullptr}, d_v_tensor_uncached{nullptr}, d_k_pe_tensor_uncached{nullptr},
h_k_tensor_cached{nullptr}, h_v_tensor_cached{nullptr}, h_k_pe_tensor_cached{nullptr},
h_k_tensor_uncached{nullptr}, h_v_tensor_uncached{nullptr}, h_k_pe_tensor_uncached{nullptr},
// for kernel 3
d_compressed_kv_tensor{nullptr}, d_k_pe_one_head_tensor{nullptr}, h_compressed_kv_tensor{nullptr},
h_k_pe_one_head_tensor{nullptr};
int mNumRequests{};
int mMaxSeqLen{};
int mMaxCachedSeqLen{};
int mMaxUncachedSeqLen{};
int mMaxBlockPerSeq{};
int mTokensPerBlock{};
int mNumHeadsCompressed{};
int mNumHeadsUncompressed{};
int mTotalTokens{};
int mTotalCachedTokens{};
int mTotalUncachedTokens{};
int mLoraSize{};
int mRopeSize{};
int mUncompressedHeadSize{};
int64_t mKvTokenStride{};
std::mt19937 gen;
void SetUp() override
{
if (shouldSkip())
{
GTEST_SKIP() << "Skipping mla preprocess test";
}
mStream = std::make_shared<tensorrt_llm::runtime::CudaStream>();
mBufferManager = std::make_shared<tensorrt_llm::runtime::BufferManager>(mStream);
gen.seed(42U);
}
static bool shouldSkip()
{
return false;
}
void setDefaultParams()
{
this->mTokensPerBlock = 64;
this->mNumHeadsCompressed = 1;
this->mNumHeadsUncompressed = 128;
this->mLoraSize = 512;
this->mRopeSize = 64;
this->mUncompressedHeadSize = 128;
this->mMaxSeqLen = 0;
this->mMaxCachedSeqLen = 0;
this->mMaxUncachedSeqLen = 0;
this->mKvTokenStride = this->mNumHeadsUncompressed * this->mUncompressedHeadSize;
}
template <typename T>
void fillKVOffsetData(T* arr, size_t size, bool use_both_kv = true, int max_block_per_seq = 0)
{
if (use_both_kv)
{
for (int i = 0; i < size; i++)
{
arr[i] = static_cast<T>(i);
}
}
else
{
int temp_idx = 0;
for (int i = 0; i < size; i++)
{
bool is_v = (((i / max_block_per_seq) % 2) == 1);
if (is_v)
{
arr[i] = static_cast<T>(0);
}
else
{
arr[i] = static_cast<T>(temp_idx);
temp_idx++;
}
}
}
}
template <typename T>
void fillArrayDataWithMod(T* arr, size_t size)
{
for (int i = 0; i < size; i++)
{
arr[i] = static_cast<T>(i % 448);
}
}
int generateRandomSizeSmallerThan(int a)
{
if (a <= 0)
{
return 0;
}
std::uniform_int_distribution<> distrib(0, a - 1);
// Generate and return the random number
return int{distrib(gen)};
}
template <typename T>
void memsetZeroDevice(T* ptr, size_t size)
{
cudaMemset(ptr, 0, size * sizeof(T));
}
template <typename T>
void memsetZeroHost(T* ptr, size_t size)
{
std::memset(ptr, 0, size * sizeof(T));
}
bool allocateBuffers()
{
using tensorrt_llm::runtime::BufferManager;
using tensorrt_llm::runtime::CudaStream;
using tensorrt_llm::runtime::ITensor;
using tensorrt_llm::runtime::bufferCast;
auto dtype = nvinfer1::DataType::kHALF;
if constexpr (std::is_same_v<DataType, float>)
{
dtype = nvinfer1::DataType::kFLOAT;
}
else if constexpr (std::is_same_v<DataType, half>)
{
dtype = nvinfer1::DataType::kHALF;
}
else if constexpr (std::is_same_v<DataType, __nv_bfloat16>)
{
dtype = nvinfer1::DataType::kBF16;
}
else
{
return false;
}
auto cache_dtype = dtype;
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
cache_dtype = nvinfer1::DataType::kFP8;
this->h_kv_scale_orig_quant
= tensorrt_llm::runtime::BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT);
this->d_kv_scale_orig_quant
= tensorrt_llm::runtime::BufferManager::gpuSync(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT);
this->h_kv_scale_quant_orig
= tensorrt_llm::runtime::BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT);
this->d_kv_scale_quant_orig
= tensorrt_llm::runtime::BufferManager::gpuSync(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT);
auto* kv_scale_orig_quant_ptr = bufferCast<float>(*(this->h_kv_scale_orig_quant));
auto* kv_scale_quant_orig_ptr = bufferCast<float>(*(this->h_kv_scale_quant_orig));
float kv_scale_orig_quant = 2.0f;
kv_scale_orig_quant_ptr[0] = kv_scale_orig_quant;
kv_scale_quant_orig_ptr[0] = 1.0 / kv_scale_orig_quant;
cudaMemcpy(this->d_kv_scale_orig_quant->data(), this->h_kv_scale_orig_quant->data(),
this->h_kv_scale_orig_quant->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_kv_scale_quant_orig->data(), this->h_kv_scale_quant_orig->data(),
this->h_kv_scale_quant_orig->getSizeInBytes(), cudaMemcpyHostToDevice);
}
else
{
static_assert(std::is_same_v<DataType, TCache>, "TCache must be the same type as DataType");
}
this->h_cu_seq_lens = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mNumRequests + 1}), nvinfer1::DataType::kINT64);
this->h_cu_ctx_cached_kv_lens = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mNumRequests + 1}), nvinfer1::DataType::kINT64);
this->d_cu_seq_lens = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mNumRequests + 1}), nvinfer1::DataType::kINT64);
this->d_cu_ctx_cached_kv_lens = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mNumRequests + 1}), nvinfer1::DataType::kINT64);
{
// set random sequence length
auto* cu_seq_lens_temp_ptr = bufferCast<int64_t>(*(this->h_cu_seq_lens));
auto* cu_ctx_cached_kv_lens_temp_ptr = bufferCast<int64_t>(*(this->h_cu_ctx_cached_kv_lens));
cu_seq_lens_temp_ptr[0] = 0;
cu_ctx_cached_kv_lens_temp_ptr[0] = 0;
for (int i = 1; i <= this->mNumRequests; i++)
{
int temp_seq_len = generateRandomSizeSmallerThan(512);
if (temp_seq_len <= 0)
{
temp_seq_len = 1; // at least 1 token
}
int cached_seq_len = generateRandomSizeSmallerThan(temp_seq_len);
this->mMaxSeqLen = std::max(temp_seq_len, this->mMaxSeqLen);
this->mMaxCachedSeqLen = std::max(cached_seq_len, this->mMaxCachedSeqLen);
this->mMaxUncachedSeqLen = std::max(temp_seq_len - cached_seq_len, this->mMaxUncachedSeqLen);
this->mTotalTokens += temp_seq_len;
this->mTotalCachedTokens += cached_seq_len;
this->mTotalUncachedTokens += temp_seq_len - cached_seq_len;
cu_seq_lens_temp_ptr[i] = cu_seq_lens_temp_ptr[i - 1] + temp_seq_len;
cu_ctx_cached_kv_lens_temp_ptr[i] = cu_ctx_cached_kv_lens_temp_ptr[i - 1] + cached_seq_len;
// std::cout << "batch " << i << "seq len: " << temp_seq_len << ", cached len: " << cached_seq_len
// << ", uncached len: " << temp_seq_len - cached_seq_len << std::endl;
}
cudaMemcpy(this->d_cu_seq_lens->data(), this->h_cu_seq_lens->data(), this->h_cu_seq_lens->getSizeInBytes(),
cudaMemcpyHostToDevice);
cudaMemcpy(this->d_cu_ctx_cached_kv_lens->data(), this->h_cu_ctx_cached_kv_lens->data(),
this->h_cu_ctx_cached_kv_lens->getSizeInBytes(), cudaMemcpyHostToDevice);
}
// malloc kv_cache
this->mMaxBlockPerSeq = (this->mMaxSeqLen + this->mTokensPerBlock - 1) / this->mTokensPerBlock;
this->h_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mNumRequests, 2, this->mMaxBlockPerSeq, this->mNumHeadsUncompressed,
this->mTokensPerBlock, this->mUncompressedHeadSize + this->mRopeSize}),
dtype);
this->h_kv_cache_tensor_ref = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mNumRequests, 2, this->mMaxBlockPerSeq, this->mNumHeadsUncompressed,
this->mTokensPerBlock, this->mUncompressedHeadSize + this->mRopeSize}),
dtype);
this->h_compressed_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mNumRequests, 1, this->mMaxBlockPerSeq, this->mNumHeadsCompressed,
this->mTokensPerBlock, this->mLoraSize + this->mRopeSize}),
cache_dtype);
this->h_compressed_kv_cache_tensor_ref = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mNumRequests, 1, this->mMaxBlockPerSeq, this->mNumHeadsCompressed,
this->mTokensPerBlock, this->mLoraSize + this->mRopeSize}),
cache_dtype);
this->h_offset_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mNumRequests, 2, this->mMaxBlockPerSeq}), nvinfer1::DataType::kINT32);
this->h_compressed_offset_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mNumRequests, 2, this->mMaxBlockPerSeq}), nvinfer1::DataType::kINT32);
this->d_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mNumRequests, 2, this->mMaxBlockPerSeq, this->mNumHeadsUncompressed,
this->mTokensPerBlock, this->mUncompressedHeadSize + this->mRopeSize}),
dtype);
this->d_compressed_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mNumRequests, 1, this->mMaxBlockPerSeq, this->mNumHeadsCompressed,
this->mTokensPerBlock, this->mLoraSize + this->mRopeSize}),
cache_dtype);
this->d_compressed_kv_cache_tensor_ref = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mNumRequests, 1, this->mMaxBlockPerSeq, this->mNumHeadsCompressed,
this->mTokensPerBlock, this->mLoraSize + this->mRopeSize}),
cache_dtype);
this->d_offset_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mNumRequests, 2, this->mMaxBlockPerSeq}), nvinfer1::DataType::kINT32);
this->d_compressed_offset_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mNumRequests, 2, this->mMaxBlockPerSeq}), nvinfer1::DataType::kINT32);
{
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor));
auto* kv_cache_ref_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->h_compressed_kv_cache_tensor));
auto* compressed_kv_cache_ref_ptr = bufferCast<TCache>(*(this->h_compressed_kv_cache_tensor_ref));
auto* offset_ptr = bufferCast<int32_t>(*(this->h_offset_tensor));
auto* compressed_offset_ptr = bufferCast<int32_t>(*(this->h_compressed_offset_tensor));
fillArrayDataWithMod(compressed_kv_cache_ptr, this->h_compressed_kv_cache_tensor->getSize());
fillArrayDataWithMod(compressed_kv_cache_ref_ptr, this->h_compressed_kv_cache_tensor_ref->getSize());
memsetZeroHost<DataType>(kv_cache_ptr, this->h_kv_cache_tensor->getSize());
memsetZeroHost<DataType>(kv_cache_ref_ptr, this->h_kv_cache_tensor_ref->getSize());
cudaMemcpy(this->d_kv_cache_tensor->data(), this->h_kv_cache_tensor->data(),
this->h_kv_cache_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
// fillArrayDataWithMod(offset_ptr, this->offset_tensor->getSize());
fillKVOffsetData(
compressed_offset_ptr, this->h_compressed_offset_tensor->getSize(), false, this->mMaxBlockPerSeq);
cudaMemcpy(this->d_compressed_kv_cache_tensor->data(), this->h_compressed_kv_cache_tensor->data(),
this->h_compressed_kv_cache_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_compressed_kv_cache_tensor_ref->data(), this->h_compressed_kv_cache_tensor_ref->data(),
this->h_compressed_kv_cache_tensor_ref->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_compressed_offset_tensor->data(), this->h_compressed_offset_tensor->data(),
this->h_compressed_offset_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_offset_tensor->data(), this->h_offset_tensor->data(),
this->h_offset_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
}
// compressed_kv_output + k_pe_output for loadPagedKvKernel (kernel 1)
this->h_compressed_kv_output = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalCachedTokens, this->mLoraSize}), dtype);
this->h_compressed_kv_output_ref = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalCachedTokens, this->mLoraSize}), dtype);
this->d_compressed_kv_output = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalCachedTokens, this->mLoraSize}), dtype);
this->h_k_pe_output = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalCachedTokens, this->mRopeSize}), dtype);
this->h_k_pe_output_ref = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalCachedTokens, this->mRopeSize}), dtype);
this->d_k_pe_output = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalCachedTokens, this->mRopeSize}), dtype);
{
auto* compressed_kv_output_ptr = bufferCast<DataType>(*(this->h_compressed_kv_output));
auto* compressed_kv_output_ref_ptr = bufferCast<DataType>(*(this->h_compressed_kv_output_ref));
memsetZeroHost<DataType>(compressed_kv_output_ptr, this->h_compressed_kv_output->getSize());
memsetZeroHost<DataType>(compressed_kv_output_ref_ptr, this->h_compressed_kv_output_ref->getSize());
cudaMemcpy(this->d_compressed_kv_output->data(), this->h_compressed_kv_output->data(),
this->h_compressed_kv_output->getSizeInBytes(), cudaMemcpyHostToDevice);
auto* k_pe_output_ptr = bufferCast<DataType>(*(this->h_k_pe_output));
auto* k_pe_output_ref_ptr = bufferCast<DataType>(*(this->h_k_pe_output_ref));
memsetZeroHost<DataType>(k_pe_output_ptr, this->h_k_pe_output->getSize());
memsetZeroHost<DataType>(k_pe_output_ref_ptr, this->h_k_pe_output_ref->getSize());
cudaMemcpy(this->d_k_pe_output->data(), this->h_k_pe_output->data(), this->h_k_pe_output->getSizeInBytes(),
cudaMemcpyHostToDevice);
}
// k, v, k_pe for setPagedKvCacheForMLAKernel (kernel 2)
this->h_k_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
this->h_v_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
this->h_k_pe_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
this->d_k_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
this->d_v_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
this->d_k_pe_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
{
auto* k_ptr = bufferCast<DataType>(*(this->h_k_tensor));
auto* v_ptr = bufferCast<DataType>(*(this->h_v_tensor));
auto* k_pe_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor));
fillArrayDataWithMod(k_ptr, this->h_k_tensor->getSize());
fillArrayDataWithMod(v_ptr, this->h_v_tensor->getSize());
fillArrayDataWithMod(k_pe_ptr, this->h_k_pe_tensor->getSize());
cudaMemcpy(this->d_k_tensor->data(), this->h_k_tensor->data(), this->h_k_tensor->getSizeInBytes(),
cudaMemcpyHostToDevice);
cudaMemcpy(this->d_v_tensor->data(), this->h_v_tensor->data(), this->h_v_tensor->getSizeInBytes(),
cudaMemcpyHostToDevice);
cudaMemcpy(this->d_k_pe_tensor->data(), this->h_k_pe_tensor->data(), this->h_k_pe_tensor->getSizeInBytes(),
cudaMemcpyHostToDevice);
}
// ck, cv, ck_pe, uk, uc, uk_pe for setPagedKvCacheForMLAKernelV2 (kernel 2)
this->h_k_tensor_cached = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalCachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
dtype);
this->h_v_tensor_cached = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalCachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
dtype);
this->h_k_pe_tensor_cached = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalCachedTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
this->h_k_tensor_uncached = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalUncachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
dtype);
this->h_v_tensor_uncached = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalUncachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
dtype);
this->h_k_pe_tensor_uncached = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalUncachedTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
this->d_k_tensor_cached = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalCachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
dtype);
this->d_v_tensor_cached = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalCachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
dtype);
this->d_k_pe_tensor_cached = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalCachedTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
this->d_k_tensor_uncached = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalUncachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
dtype);
this->d_v_tensor_uncached = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalUncachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
dtype);
this->d_k_pe_tensor_uncached = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalUncachedTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
{
auto* k_cached_ptr = bufferCast<DataType>(*(this->h_k_tensor_cached));
auto* v_cached_ptr = bufferCast<DataType>(*(this->h_v_tensor_cached));
auto* k_pe_cached_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor_cached));
auto* k_uncached_ptr = bufferCast<DataType>(*(this->h_k_tensor_uncached));
auto* v_uncached_ptr = bufferCast<DataType>(*(this->h_v_tensor_uncached));
auto* k_pe_uncached_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor_uncached));
fillArrayDataWithMod(k_cached_ptr, this->h_k_tensor_cached->getSize());
fillArrayDataWithMod(v_cached_ptr, this->h_v_tensor_cached->getSize());
fillArrayDataWithMod(k_pe_cached_ptr, this->h_k_pe_tensor_cached->getSize());
fillArrayDataWithMod(k_uncached_ptr, this->h_k_tensor_uncached->getSize());
fillArrayDataWithMod(v_uncached_ptr, this->h_v_tensor_uncached->getSize());
fillArrayDataWithMod(k_pe_uncached_ptr, this->h_k_pe_tensor_uncached->getSize());
cudaMemcpy(this->d_k_tensor_cached->data(), this->h_k_tensor_cached->data(),
this->h_k_tensor_cached->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_v_tensor_cached->data(), this->h_v_tensor_cached->data(),
this->h_v_tensor_cached->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_k_pe_tensor_cached->data(), this->h_k_pe_tensor_cached->data(),
this->h_k_pe_tensor_cached->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_k_tensor_uncached->data(), this->h_k_tensor_uncached->data(),
this->h_k_tensor_uncached->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_v_tensor_uncached->data(), this->h_v_tensor_uncached->data(),
this->h_v_tensor_uncached->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_k_pe_tensor_uncached->data(), this->h_k_pe_tensor_uncached->data(),
this->h_k_pe_tensor_uncached->getSizeInBytes(), cudaMemcpyHostToDevice);
}
// compressed_kv, k_pe_one_head for appendPagedKvForMLAKernel (kernel 3)
this->h_compressed_kv_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalUncachedTokens, 1, this->mLoraSize}), dtype);
this->h_k_pe_one_head_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalUncachedTokens, 1, this->mRopeSize}), dtype);
this->d_compressed_kv_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalUncachedTokens, 1, this->mLoraSize}), dtype);
this->d_k_pe_one_head_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
ITensor::makeShape({this->mTotalUncachedTokens, 1, this->mRopeSize}), dtype);
{
auto* compressed_kv_ptr = bufferCast<DataType>(*(this->h_compressed_kv_tensor));
auto* k_pe_one_head_ptr = bufferCast<DataType>(*(this->h_k_pe_one_head_tensor));
fillArrayDataWithMod(compressed_kv_ptr, this->h_compressed_kv_tensor->getSize());
fillArrayDataWithMod(k_pe_one_head_ptr, this->h_k_pe_one_head_tensor->getSize());
cudaMemcpy(this->d_compressed_kv_tensor->data(), this->h_compressed_kv_tensor->data(),
this->h_compressed_kv_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_k_pe_one_head_tensor->data(), this->h_k_pe_one_head_tensor->data(),
this->h_k_pe_one_head_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
}
return true;
}
void PerformLoadPagedKV()
{
using tensorrt_llm::runtime::bufferCast;
auto* compressed_kv_output_ptr = bufferCast<DataType>(*(this->d_compressed_kv_output));
auto* k_pe_output_ptr = bufferCast<DataType>(*(this->d_k_pe_output));
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->d_compressed_kv_cache_tensor));
auto* offset_ptr = bufferCast<int32_t>(*(this->d_compressed_offset_tensor));
auto* cu_ctx_cached_kv_lens_ptr = bufferCast<int64_t>(*(this->d_cu_ctx_cached_kv_lens));
float* kv_scale_quant_orig_ptr = nullptr;
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
kv_scale_quant_orig_ptr = bufferCast<float>(*(this->d_kv_scale_quant_orig));
}
tensorrt_llm::kernels::KVBlockArray kv_cache(this->mNumRequests, this->mMaxBlockPerSeq, this->mTokensPerBlock,
sizeof(TCache) * 1 * (this->mLoraSize + this->mRopeSize), 0, 0, 0, 0, compressed_kv_cache_ptr, nullptr,
reinterpret_cast<tensorrt_llm::kernels::KVBlockArrayForContextFMHA::DataType*>(offset_ptr));
tensorrt_llm::kernels::invokeMLALoadPagedKV<DataType, TCache>(compressed_kv_output_ptr, k_pe_output_ptr,
kv_cache, this->mNumRequests, cu_ctx_cached_kv_lens_ptr, this->mMaxCachedSeqLen, this->mLoraSize,
this->mRopeSize, kv_scale_quant_orig_ptr, this->mStream->get());
cudaStreamSynchronize(this->mStream->get());
cudaMemcpy(this->h_compressed_kv_output->data(), this->d_compressed_kv_output->data(),
this->d_compressed_kv_output->getSizeInBytes(), cudaMemcpyDeviceToHost);
cudaMemcpy(this->h_k_pe_output->data(), this->d_k_pe_output->data(), this->d_k_pe_output->getSizeInBytes(),
cudaMemcpyDeviceToHost);
}
void PerformLoadPagedKVRef()
{
using tensorrt_llm::runtime::bufferCast;
auto* compressed_kv_output_ptr = bufferCast<DataType>(*(this->h_compressed_kv_output_ref));
auto* k_pe_output_ptr = bufferCast<DataType>(*(this->h_k_pe_output_ref));
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->h_compressed_kv_cache_tensor));
auto* offset_ptr = bufferCast<int32_t>(*(this->h_compressed_offset_tensor));
auto* cu_ctx_cached_kv_lens_ptr = bufferCast<int64_t>(*(this->h_cu_ctx_cached_kv_lens));
float* kv_scale_quant_orig_ptr = nullptr;
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
kv_scale_quant_orig_ptr = bufferCast<float>(*(this->h_kv_scale_quant_orig));
}
tensorrt_llm::kernels::KVBlockArray kv_cache(this->mNumRequests, this->mMaxBlockPerSeq, this->mTokensPerBlock,
sizeof(TCache) * 1 * (this->mLoraSize + this->mRopeSize), 0, 0, 0, 0, compressed_kv_cache_ptr, nullptr,
reinterpret_cast<tensorrt_llm::kernels::KVBlockArrayForContextFMHA::DataType*>(offset_ptr));
loadPagedKvKernelRef<DataType, TCache>(compressed_kv_output_ptr, k_pe_output_ptr, kv_cache, this->mNumRequests,
cu_ctx_cached_kv_lens_ptr, this->mLoraSize, this->mRopeSize, kv_scale_quant_orig_ptr);
}
void PerformSetPagedKV()
{
using tensorrt_llm::runtime::bufferCast;
auto* k_ptr = bufferCast<DataType>(*(this->d_k_tensor));
auto* v_ptr = bufferCast<DataType>(*(this->d_v_tensor));
auto* k_pe_ptr = bufferCast<DataType>(*(this->d_k_pe_tensor));
auto* kv_cache_ptr = bufferCast<DataType>(*(this->d_kv_cache_tensor));
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->d_cu_seq_lens));
tensorrt_llm::kernels::invokeMLASetPagedKV<DataType>(kv_cache_ptr, k_ptr, v_ptr, k_pe_ptr, this->mNumRequests,
cu_seq_lens_ptr, this->mMaxSeqLen, this->mNumHeadsUncompressed, this->mUncompressedHeadSize,
this->mRopeSize, this->mTokensPerBlock, this->mKvTokenStride, this->mStream->get());
cudaStreamSynchronize(this->mStream->get());
cudaMemcpy(this->h_kv_cache_tensor->data(), this->d_kv_cache_tensor->data(),
this->d_kv_cache_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost);
}
void PerformSetPagedKVRef()
{
using tensorrt_llm::runtime::bufferCast;
auto* k_ptr = bufferCast<DataType>(*(this->h_k_tensor));
auto* v_ptr = bufferCast<DataType>(*(this->h_v_tensor));
auto* k_pe_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor));
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_seq_lens));
setPagedKvCacheForMLAKernelRef(kv_cache_ptr, k_ptr, v_ptr, k_pe_ptr, this->mNumRequests, cu_seq_lens_ptr,
this->mMaxSeqLen, this->mNumHeadsUncompressed, this->mUncompressedHeadSize, this->mRopeSize,
this->mTokensPerBlock, this->mKvTokenStride);
}
void PerformSetPagedKVV2()
{
using tensorrt_llm::runtime::bufferCast;
auto* k_cached_ptr = bufferCast<DataType>(*(this->d_k_tensor_cached));
auto* v_cached_ptr = bufferCast<DataType>(*(this->d_v_tensor_cached));
auto* k_pe_cached_ptr = bufferCast<DataType>(*(this->d_k_pe_tensor_cached));
auto* k_uncached_ptr = bufferCast<DataType>(*(this->d_k_tensor_uncached));
auto* v_uncached_ptr = bufferCast<DataType>(*(this->d_v_tensor_uncached));
auto* k_pe_uncached_ptr = bufferCast<DataType>(*(this->d_k_pe_tensor_uncached));
auto* cu_ctx_cached_kv_lens_ptr = bufferCast<int64_t>(*(this->d_cu_ctx_cached_kv_lens));
auto* kv_cache_ptr = bufferCast<DataType>(*(this->d_kv_cache_tensor));
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->d_cu_seq_lens));
tensorrt_llm::kernels::invokeMLASetPagedKVV2<DataType>(kv_cache_ptr, k_cached_ptr, v_cached_ptr,
k_pe_cached_ptr, k_uncached_ptr, v_uncached_ptr, k_pe_uncached_ptr, this->mNumRequests,
cu_ctx_cached_kv_lens_ptr, cu_seq_lens_ptr, this->mMaxSeqLen, this->mNumHeadsUncompressed,
this->mUncompressedHeadSize, this->mRopeSize, this->mTokensPerBlock, this->mStream->get());
cudaStreamSynchronize(this->mStream->get());
cudaMemcpy(this->h_kv_cache_tensor->data(), this->d_kv_cache_tensor->data(),
this->d_kv_cache_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost);
}
void PerformSetPagedKVV2Ref()
{
using tensorrt_llm::runtime::bufferCast;
auto* k_cached_ptr = bufferCast<DataType>(*(this->h_k_tensor_cached));
auto* v_cached_ptr = bufferCast<DataType>(*(this->h_v_tensor_cached));
auto* k_pe_cached_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor_cached));
auto* k_uncached_ptr = bufferCast<DataType>(*(this->h_k_tensor_uncached));
auto* v_uncached_ptr = bufferCast<DataType>(*(this->h_v_tensor_uncached));
auto* k_pe_uncached_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor_uncached));
auto* cu_ctx_cached_kv_lens_ptr = bufferCast<int64_t>(*(this->h_cu_ctx_cached_kv_lens));
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_seq_lens));
setPagedKvCacheForMLAKernelRefV2(kv_cache_ptr, k_cached_ptr, v_cached_ptr, k_pe_cached_ptr, k_uncached_ptr,
v_uncached_ptr, k_pe_uncached_ptr, this->mNumRequests, cu_ctx_cached_kv_lens_ptr, cu_seq_lens_ptr,
this->mMaxSeqLen, this->mNumHeadsUncompressed, this->mUncompressedHeadSize, this->mRopeSize,
this->mTokensPerBlock);
}
void PerformAppendPagedKV()
{
using tensorrt_llm::runtime::bufferCast;
auto* compressed_kv_ptr = bufferCast<DataType>(*(this->d_compressed_kv_tensor));
auto* k_pe_one_head_ptr = bufferCast<DataType>(*(this->d_k_pe_one_head_tensor));
auto* offset_ptr = bufferCast<int32_t>(*(this->d_compressed_offset_tensor));
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->d_compressed_kv_cache_tensor));
auto* cu_ctx_cached_kv_lens_ptr = bufferCast<int64_t>(*(this->d_cu_ctx_cached_kv_lens));
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->d_cu_seq_lens));
float* kv_scale_orig_quant_ptr = nullptr;
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
kv_scale_orig_quant_ptr = bufferCast<float>(*(this->d_kv_scale_orig_quant));
}
tensorrt_llm::kernels::KVBlockArray kv_cache(this->mNumRequests, this->mMaxBlockPerSeq, this->mTokensPerBlock,
sizeof(TCache) * 1 * (this->mLoraSize + this->mRopeSize), 0, 0, 0, 0, compressed_kv_cache_ptr, nullptr,
reinterpret_cast<tensorrt_llm::kernels::KVBlockArrayForContextFMHA::DataType*>(offset_ptr));
tensorrt_llm::kernels::invokeMLAAppendPagedKV<DataType, TCache>(kv_cache, compressed_kv_ptr, k_pe_one_head_ptr,
this->mNumRequests, cu_ctx_cached_kv_lens_ptr, cu_seq_lens_ptr, this->mMaxUncachedSeqLen,
this->mLoraSize + this->mRopeSize, kv_scale_orig_quant_ptr, this->mStream->get());
cudaStreamSynchronize(this->mStream->get());
cudaMemcpy(this->h_compressed_kv_cache_tensor->data(), this->d_compressed_kv_cache_tensor->data(),
this->d_compressed_kv_cache_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost);
}
void PerformAppendPagedKVRef()
{
using tensorrt_llm::runtime::bufferCast;
auto* compressed_kv_ptr = bufferCast<DataType>(*(this->h_compressed_kv_tensor));
auto* k_pe_one_head_ptr = bufferCast<DataType>(*(this->h_k_pe_one_head_tensor));
auto* offset_ptr = bufferCast<int32_t>(*(this->h_compressed_offset_tensor));
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->h_compressed_kv_cache_tensor_ref));
auto* cu_ctx_cached_kv_lens_ptr = bufferCast<int64_t>(*(this->h_cu_ctx_cached_kv_lens));
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_seq_lens));
float* kv_scale_orig_quant_ptr = nullptr;
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
kv_scale_orig_quant_ptr = bufferCast<float>(*(this->h_kv_scale_orig_quant));
}
tensorrt_llm::kernels::KVBlockArray kv_cache(this->mNumRequests, this->mMaxBlockPerSeq, this->mTokensPerBlock,
sizeof(TCache) * 1 * (this->mLoraSize + this->mRopeSize), 0, 0, 0, 0, compressed_kv_cache_ptr, nullptr,
reinterpret_cast<tensorrt_llm::kernels::KVBlockArrayForContextFMHA::DataType*>(offset_ptr));
// currently k_pe_head_num = 1
appendPagedKvForMLAKernelRef<DataType, TCache>(kv_cache, compressed_kv_ptr, k_pe_one_head_ptr,
this->mNumRequests, cu_ctx_cached_kv_lens_ptr, cu_seq_lens_ptr, 1, this->mLoraSize, this->mRopeSize,
kv_scale_orig_quant_ptr);
}
template <typename T>
bool CheckEqual(T const* expected, T const* output, size_t size)
{
for (int i = 0; i < size; i++)
{
auto e = static_cast<float>(expected[i]);
auto o = static_cast<float>(output[i]);
if (!almostEqual(e, o, 1e-3, 1e-3))
{
TLLM_LOG_ERROR(
"Mismatch input value. Position of inputs: %d, expected value: %f, output value: %f", i, e, o);
return false;
}
}
return true;
}
};
using MLATypes
= ::testing::Types<std::pair<half, half>, std::pair<__nv_bfloat16, __nv_bfloat16>, std::pair<float, float>,
std::pair<half, __nv_fp8_e4m3>, std::pair<__nv_bfloat16, __nv_fp8_e4m3>, std::pair<float, __nv_fp8_e4m3>>;
TYPED_TEST_SUITE(MlaPreprocessTest, MLATypes);
TYPED_TEST(MlaPreprocessTest, MLAPreprocessDefault)
{
using tensorrt_llm::runtime::bufferCast;
using DataType = typename TestFixture::DataType;
using TCache = typename TestFixture::TCache;
this->mNumRequests = 8;
this->setDefaultParams();
EXPECT_TRUE(this->allocateBuffers());
sync_check_cuda_error(this->mStream->get());
bool allEqual{true};
{
this->PerformLoadPagedKV();
sync_check_cuda_error(this->mStream->get());
this->PerformLoadPagedKVRef();
auto* compressed_kv_output_ptr = bufferCast<DataType>(*(this->h_compressed_kv_output));
auto* k_pe_output_ptr = bufferCast<DataType>(*(this->h_k_pe_output));
auto* compressed_kv_output_ref_ptr = bufferCast<DataType>(*(this->h_compressed_kv_output_ref));
auto* k_pe_output_ref_ptr = bufferCast<DataType>(*(this->h_k_pe_output_ref));
allEqual = this->CheckEqual(
compressed_kv_output_ref_ptr, compressed_kv_output_ptr, this->h_compressed_kv_output->getSize());
EXPECT_TRUE(allEqual);
allEqual = this->CheckEqual(k_pe_output_ref_ptr, k_pe_output_ptr, this->h_k_pe_output->getSize());
EXPECT_TRUE(allEqual);
}
{
this->PerformSetPagedKV();
sync_check_cuda_error(this->mStream->get());
this->PerformSetPagedKVRef();
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor));
auto* kv_cache_ref_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
allEqual = this->CheckEqual(kv_cache_ref_ptr, kv_cache_ptr, this->h_kv_cache_tensor->getSize());
EXPECT_TRUE(allEqual);
}
{
this->PerformSetPagedKVV2();
sync_check_cuda_error(this->mStream->get());
this->PerformSetPagedKVV2Ref();
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor));
auto* kv_cache_ref_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
allEqual = this->CheckEqual(kv_cache_ref_ptr, kv_cache_ptr, this->h_kv_cache_tensor->getSize());
EXPECT_TRUE(allEqual);
}
{
this->PerformAppendPagedKV();
sync_check_cuda_error(this->mStream->get());
this->PerformAppendPagedKVRef();
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->h_compressed_kv_cache_tensor));
auto* compressed_kv_cache_ref_ptr = bufferCast<TCache>(*(this->h_compressed_kv_cache_tensor_ref));
allEqual = this->CheckEqual(
compressed_kv_cache_ref_ptr, compressed_kv_cache_ptr, this->h_compressed_kv_cache_tensor->getSize());
EXPECT_TRUE(allEqual);
}
}