mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* optimize kv cache reuse workflow for MLA write kv cache first and only call up-projection GEMM once relax contiguous requirements of k/v for setting paged kv cache return two contiguous tensors when loading MLA KV Cache Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com> * support fp8 kv cache for MLA kv cache reuse Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com> * resolve comments Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com> --------- Signed-off-by: Zhen Huang <145532724+zhhuang-nv@users.noreply.github.com>
954 lines
53 KiB
Plaintext
954 lines
53 KiB
Plaintext
/*
|
|
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include <algorithm>
|
|
#include <cstdint>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include "tensorrt_llm/common/cudaUtils.h"
|
|
#include "tensorrt_llm/common/quantization.h"
|
|
#include "tensorrt_llm/kernels/decodingCommon.h"
|
|
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
|
#include "tensorrt_llm/runtime/bufferManager.h"
|
|
|
|
#include "tensorrt_llm/kernels/mlaKernels.h"
|
|
#include <random>
|
|
|
|
namespace
|
|
{
|
|
|
|
// copy matched kv cache data to compressed_kv_output and k_pe_output
|
|
// compressed_kv_output {total_cached_token, lora_size}
|
|
// k_pe_output {total_cached_token, rope_size}
|
|
// compressed_kv_cache {batch, 1 (ignore v), max_seq_len / tokens_per_block, num_head, tokens_per_block, (lora_size +
|
|
// rope_size)}
|
|
template <typename T, typename TCache>
|
|
void loadPagedKvKernelRef(T* compressed_kv_output, T* k_pe_output,
|
|
tensorrt_llm::kernels::KVBlockArray const& compressed_kv_cache, int num_contexts,
|
|
int64_t const* cu_ctx_cached_kv_lens, int const lora_size, int const rope_size,
|
|
float const* kv_scale_quant_orig_ptr)
|
|
{
|
|
static_assert(std::is_same_v<T, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>,
|
|
"TCache must be either the same type as T or __nv_fp8_e4m3");
|
|
int const head_dim = lora_size + rope_size;
|
|
float const kv_scale_quant_orig = kv_scale_quant_orig_ptr ? kv_scale_quant_orig_ptr[0] : 1.0f;
|
|
for (int b = 0; b < num_contexts; b++)
|
|
{
|
|
int const global_token_offset = cu_ctx_cached_kv_lens[b];
|
|
int const current_token_len = cu_ctx_cached_kv_lens[b + 1] - cu_ctx_cached_kv_lens[b];
|
|
for (int s = 0; s < current_token_len; s++)
|
|
{
|
|
int const global_token_idx = global_token_offset + s;
|
|
for (int d = 0; d < head_dim; d++)
|
|
{
|
|
auto const* kv_src = reinterpret_cast<TCache const*>(compressed_kv_cache.getKBlockPtr(b, s));
|
|
auto kv_block_idx = compressed_kv_cache.getKVLocalIdx(s, 0, head_dim, d);
|
|
|
|
auto src_data = kv_src[kv_block_idx];
|
|
T data;
|
|
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
|
|
{
|
|
data = T(float(src_data) * kv_scale_quant_orig);
|
|
}
|
|
else
|
|
{
|
|
data = src_data;
|
|
}
|
|
if (d < lora_size)
|
|
{
|
|
compressed_kv_output[global_token_idx * lora_size + d] = data;
|
|
}
|
|
else
|
|
{
|
|
k_pe_output[global_token_idx * rope_size + (d - lora_size)] = data;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// k {total_token, h, uncompressed_h=128}, v {total_token, h, uncompressed_h}, k_pe {total_token, h=1, rope_h}
|
|
// output {b, 2, ceil(max_seq / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, (uncompressed_h + rope_h)}
|
|
// copy k, v, k_pe to a continuous memory space (then it will be packed to kv_cache)
|
|
template <typename T>
|
|
void setPagedKvCacheForMLAKernelRef(T* output, T* const k_ptr, T* const v_ptr, T* const k_pe_ptr, int num_requests,
|
|
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int uncompressed_head_size, int rope_size,
|
|
int kv_cache_tokens_per_block, int64_t kv_token_stride)
|
|
{
|
|
int const kv_cache_size_per_block = num_heads * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size);
|
|
int const kv_cache_block_num_per_seq
|
|
= (max_input_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
|
|
for (int b = 0; b < num_requests; b++)
|
|
{
|
|
int const global_token_offset = cu_seq_lens[b];
|
|
int const current_token_len = cu_seq_lens[b + 1] - cu_seq_lens[b];
|
|
for (int s = 0; s < current_token_len; s++)
|
|
{
|
|
int const global_token_idx = global_token_offset + s;
|
|
int const kv_cache_block_offset_for_k
|
|
= ((b * 2 * kv_cache_block_num_per_seq) + (s / kv_cache_tokens_per_block)) * kv_cache_size_per_block;
|
|
int const kv_cache_block_offset_for_v
|
|
= kv_cache_block_offset_for_k + (kv_cache_block_num_per_seq * kv_cache_size_per_block);
|
|
for (int h = 0; h < num_heads; h++)
|
|
{
|
|
// copy k, v
|
|
int const ld_kv_head_offset = (global_token_idx * kv_token_stride) + (h * uncompressed_head_size);
|
|
int const ld_k_pe_head_offset = (global_token_idx * rope_size);
|
|
for (int d = 0; d < uncompressed_head_size; d++)
|
|
{
|
|
int const ld_kv_idx = ld_kv_head_offset + d;
|
|
int const st_k_idx = kv_cache_block_offset_for_k
|
|
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
|
|
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d;
|
|
int const st_v_idx = kv_cache_block_offset_for_v
|
|
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
|
|
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d;
|
|
output[st_k_idx] = k_ptr[ld_kv_idx];
|
|
output[st_v_idx] = v_ptr[ld_kv_idx];
|
|
}
|
|
// copy k_pe, head_num = 1
|
|
for (int d = 0; d < rope_size; d++)
|
|
{
|
|
int const ld_k_pe_idx = ld_k_pe_head_offset + d;
|
|
int const st_k_pe_idx = kv_cache_block_offset_for_k
|
|
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
|
|
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d
|
|
+ uncompressed_head_size;
|
|
output[st_k_pe_idx] = k_pe_ptr[ld_k_pe_idx];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ck or cv {total_cached_token, h, uncompressed_h=128}, ck_pe {total_cached_token, h=1, rope_h}
|
|
// uk or uv {total_uncached_token, h, uncompressed_h}, uk_pe {total_uncached_token, h=1, rope_h}
|
|
// output {b, 2, ceil(max_seq / kv_cache_tokens_per_block), h, kv_cache_tokens_per_block, (uncompressed_h + rope_h)}
|
|
// copy k, v, k_pe to a continuous memory space (then it will be packed to kv_cache)
|
|
template <typename T>
|
|
void setPagedKvCacheForMLAKernelRefV2(T* output, T* const ck_ptr, T* const cv_ptr, T* const ck_pe_ptr, T* const nk_ptr,
|
|
T* const nv_ptr, T* const nk_pe_ptr, int num_requests, int64_t const* cu_ctx_cached_kv_lens,
|
|
int64_t const* cu_seq_lens, int const max_input_seq_len, int num_heads, int uncompressed_head_size, int rope_size,
|
|
int kv_cache_tokens_per_block)
|
|
{
|
|
int const kv_cache_size_per_block = num_heads * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size);
|
|
int const kv_cache_block_num_per_seq
|
|
= (max_input_seq_len + kv_cache_tokens_per_block - 1) / kv_cache_tokens_per_block;
|
|
for (int b = 0; b < num_requests; b++)
|
|
{
|
|
int const global_cached_token_offset = cu_ctx_cached_kv_lens[b];
|
|
int const global_unchached_token_offset = cu_seq_lens[b] - cu_ctx_cached_kv_lens[b];
|
|
int const current_token_len = cu_seq_lens[b + 1] - cu_seq_lens[b];
|
|
int const current_cached_token_len = cu_ctx_cached_kv_lens[b + 1] - cu_ctx_cached_kv_lens[b];
|
|
// int const current_uncached_token_len = current_token_len - current_cached_token_len;
|
|
|
|
for (int s = 0; s < current_token_len; s++)
|
|
{
|
|
bool const is_cached = (s < current_cached_token_len);
|
|
int const global_token_idx = is_cached ? global_cached_token_offset + s
|
|
: global_unchached_token_offset + (s - current_cached_token_len);
|
|
int const kv_cache_block_offset_for_k
|
|
= ((b * 2 * kv_cache_block_num_per_seq) + (s / kv_cache_tokens_per_block)) * kv_cache_size_per_block;
|
|
int const kv_cache_block_offset_for_v
|
|
= kv_cache_block_offset_for_k + (kv_cache_block_num_per_seq * kv_cache_size_per_block);
|
|
auto const k_ptr = is_cached ? ck_ptr : nk_ptr;
|
|
auto const v_ptr = is_cached ? cv_ptr : nv_ptr;
|
|
auto const k_pe_ptr = is_cached ? ck_pe_ptr : nk_pe_ptr;
|
|
for (int h = 0; h < num_heads; h++)
|
|
{
|
|
// copy k, v
|
|
int const ld_kv_head_offset
|
|
= (global_token_idx * num_heads * uncompressed_head_size) + (h * uncompressed_head_size);
|
|
int const ld_k_pe_head_offset = (global_token_idx * rope_size);
|
|
for (int d = 0; d < uncompressed_head_size; d++)
|
|
{
|
|
int const ld_kv_idx = ld_kv_head_offset + d;
|
|
int const st_k_idx = kv_cache_block_offset_for_k
|
|
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
|
|
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d;
|
|
int const st_v_idx = kv_cache_block_offset_for_v
|
|
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
|
|
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d;
|
|
output[st_k_idx] = k_ptr[ld_kv_idx];
|
|
output[st_v_idx] = v_ptr[ld_kv_idx];
|
|
}
|
|
// copy k_pe, head_num = 1
|
|
for (int d = 0; d < rope_size; d++)
|
|
{
|
|
int const ld_k_pe_idx = ld_k_pe_head_offset + d;
|
|
int const st_k_pe_idx = kv_cache_block_offset_for_k
|
|
+ h * kv_cache_tokens_per_block * (uncompressed_head_size + rope_size)
|
|
+ (s % kv_cache_tokens_per_block) * (uncompressed_head_size + rope_size) + d
|
|
+ uncompressed_head_size;
|
|
output[st_k_pe_idx] = k_pe_ptr[ld_k_pe_idx];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// compressed_kv_cache {batch, 1 (ignore v), max_seq_len / tokens_per_block, num_head=1, tokens_per_block, (lora_size +
|
|
// rope_size)}
|
|
// kv {total_uncached_tokens, h_k=1, lora_d}, k_pe {total_uncached_tokens, h_kpe=1, rope_d}
|
|
template <typename T, typename TCache>
|
|
void appendPagedKvForMLAKernelRef(tensorrt_llm::kernels::KVBlockArray& kv_cache, T* const compressed_kv_ptr,
|
|
T* const k_pe_ptr, int const num_requests, int64_t const* cu_ctx_cached_kv_lens, int64_t const* cu_seq_lens,
|
|
int k_pe_head_num, int lora_size, int rope_size, float const* kv_scale_orig_quant_ptr)
|
|
{
|
|
static_assert(std::is_same_v<T, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>,
|
|
"TCache must be either the same type as T or __nv_fp8_e4m3");
|
|
assert(k_pe_head_num == 1);
|
|
float const kv_scale_orig_quant = kv_scale_orig_quant_ptr ? kv_scale_orig_quant_ptr[0] : 1.0f;
|
|
for (int b = 0; b < num_requests; b++)
|
|
{
|
|
int const global_token_offset = cu_seq_lens[b] - cu_ctx_cached_kv_lens[b];
|
|
int const cached_kv_len = cu_ctx_cached_kv_lens[b + 1] - cu_ctx_cached_kv_lens[b];
|
|
int const uncached_token_len = cu_seq_lens[b + 1] - cu_seq_lens[b] - cached_kv_len;
|
|
for (int s = 0; s < uncached_token_len; s++)
|
|
{
|
|
int const ld_kv_offset = (global_token_offset + s) * lora_size;
|
|
int const ld_k_pe_offset = (global_token_offset + s) * k_pe_head_num * rope_size;
|
|
auto* kv_cache_ptr = reinterpret_cast<TCache*>(kv_cache.getKBlockPtr(b, cached_kv_len + s));
|
|
// copy kv
|
|
for (int d = 0; d < lora_size; d++)
|
|
{
|
|
int const ld_kv_idx = ld_kv_offset + d;
|
|
int const kv_cache_idx_in_block
|
|
= kv_cache.getKVLocalIdx(cached_kv_len + s, 0, lora_size + rope_size, d);
|
|
auto src_data = compressed_kv_ptr[ld_kv_idx];
|
|
TCache data;
|
|
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
|
|
{
|
|
data = static_cast<__nv_fp8_e4m3>(static_cast<float>(src_data) * kv_scale_orig_quant);
|
|
}
|
|
else
|
|
{
|
|
data = src_data;
|
|
}
|
|
kv_cache_ptr[kv_cache_idx_in_block] = data;
|
|
}
|
|
// copy k_pe (we only copy the first head)
|
|
for (int d = 0; d < rope_size; d++)
|
|
{
|
|
int const ld_k_pe_idx = ld_k_pe_offset + d;
|
|
int const kv_cache_idx_in_block
|
|
= kv_cache.getKVLocalIdx(cached_kv_len + s, 0, lora_size + rope_size, d + lora_size);
|
|
auto src_data = k_pe_ptr[ld_k_pe_idx];
|
|
TCache data;
|
|
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
|
|
{
|
|
data = static_cast<__nv_fp8_e4m3>(static_cast<float>(src_data) * kv_scale_orig_quant);
|
|
}
|
|
else
|
|
{
|
|
data = src_data;
|
|
}
|
|
kv_cache_ptr[kv_cache_idx_in_block] = data;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
inline bool almostEqual(float a, float b, float atol = 1e-2, float rtol = 1e-3)
|
|
{
|
|
if (isnan(a) || isnan(b))
|
|
{
|
|
return false;
|
|
}
|
|
return fabs(a - b) <= (atol + rtol * fabs(b));
|
|
}
|
|
|
|
} // namespace
|
|
|
|
template <typename Typepair>
|
|
class MlaPreprocessTest : public testing::Test
|
|
{
|
|
protected:
|
|
using DataType = typename Typepair::first_type;
|
|
using TCache = typename Typepair::second_type;
|
|
static_assert(std::is_same_v<DataType, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>,
|
|
"TCache must be either the same type as DataType or __nv_fp8_e4m3");
|
|
std::shared_ptr<tensorrt_llm::runtime::BufferManager> mBufferManager;
|
|
std::shared_ptr<tensorrt_llm::runtime::CudaStream> mStream;
|
|
// kv_cache shape {batch, 2(k or v), max_seq_len / tokens_per_block, num_head, tokens_per_block, head_size}
|
|
// k, v, k_pe shape {total_token, num_head, head_size(lora_size or rope_size, or uncompressed_head_size)}
|
|
// offset shape {batch, 2, max_seq_len / tokens_per_block}
|
|
// for KVBlockArray, we only allocate primary pool.
|
|
// you can infer the allocateBuffers function for more details.
|
|
tensorrt_llm::runtime::BufferManager::ITensorPtr h_kv_cache_tensor{nullptr}, h_kv_cache_tensor_ref{nullptr},
|
|
d_kv_cache_tensor{nullptr}, d_compressed_kv_cache_tensor{nullptr}, d_compressed_kv_cache_tensor_ref{nullptr},
|
|
h_compressed_kv_cache_tensor{nullptr}, h_compressed_kv_cache_tensor_ref{nullptr}, d_offset_tensor{nullptr},
|
|
d_compressed_offset_tensor{nullptr}, d_cu_ctx_cached_kv_lens{nullptr}, d_cu_seq_lens{nullptr},
|
|
h_offset_tensor{nullptr}, h_compressed_offset_tensor{nullptr}, h_cu_ctx_cached_kv_lens{nullptr},
|
|
h_cu_seq_lens{nullptr}, h_kv_scale_orig_quant{nullptr}, d_kv_scale_orig_quant{nullptr},
|
|
h_kv_scale_quant_orig{nullptr}, d_kv_scale_quant_orig{nullptr},
|
|
// for kernel 1
|
|
d_compressed_kv_output{nullptr}, h_compressed_kv_output{nullptr}, h_compressed_kv_output_ref{nullptr},
|
|
d_k_pe_output{nullptr}, h_k_pe_output{nullptr}, h_k_pe_output_ref{nullptr},
|
|
// for kernel 2
|
|
d_k_tensor{nullptr}, d_v_tensor{nullptr}, d_k_pe_tensor{nullptr}, h_k_tensor{nullptr}, h_v_tensor{nullptr},
|
|
h_k_pe_tensor{nullptr},
|
|
// for kernel 2 (new)
|
|
d_k_tensor_cached{nullptr}, d_v_tensor_cached{nullptr}, d_k_pe_tensor_cached{nullptr},
|
|
d_k_tensor_uncached{nullptr}, d_v_tensor_uncached{nullptr}, d_k_pe_tensor_uncached{nullptr},
|
|
h_k_tensor_cached{nullptr}, h_v_tensor_cached{nullptr}, h_k_pe_tensor_cached{nullptr},
|
|
h_k_tensor_uncached{nullptr}, h_v_tensor_uncached{nullptr}, h_k_pe_tensor_uncached{nullptr},
|
|
// for kernel 3
|
|
d_compressed_kv_tensor{nullptr}, d_k_pe_one_head_tensor{nullptr}, h_compressed_kv_tensor{nullptr},
|
|
h_k_pe_one_head_tensor{nullptr};
|
|
|
|
int mNumRequests{};
|
|
int mMaxSeqLen{};
|
|
int mMaxCachedSeqLen{};
|
|
int mMaxUncachedSeqLen{};
|
|
int mMaxBlockPerSeq{};
|
|
int mTokensPerBlock{};
|
|
int mNumHeadsCompressed{};
|
|
int mNumHeadsUncompressed{};
|
|
int mTotalTokens{};
|
|
int mTotalCachedTokens{};
|
|
int mTotalUncachedTokens{};
|
|
int mLoraSize{};
|
|
int mRopeSize{};
|
|
int mUncompressedHeadSize{};
|
|
int64_t mKvTokenStride{};
|
|
|
|
std::mt19937 gen;
|
|
|
|
void SetUp() override
|
|
{
|
|
if (shouldSkip())
|
|
{
|
|
GTEST_SKIP() << "Skipping mla preprocess test";
|
|
}
|
|
mStream = std::make_shared<tensorrt_llm::runtime::CudaStream>();
|
|
mBufferManager = std::make_shared<tensorrt_llm::runtime::BufferManager>(mStream);
|
|
gen.seed(42U);
|
|
}
|
|
|
|
static bool shouldSkip()
|
|
{
|
|
return false;
|
|
}
|
|
|
|
void setDefaultParams()
|
|
{
|
|
this->mTokensPerBlock = 64;
|
|
this->mNumHeadsCompressed = 1;
|
|
this->mNumHeadsUncompressed = 128;
|
|
this->mLoraSize = 512;
|
|
this->mRopeSize = 64;
|
|
this->mUncompressedHeadSize = 128;
|
|
this->mMaxSeqLen = 0;
|
|
this->mMaxCachedSeqLen = 0;
|
|
this->mMaxUncachedSeqLen = 0;
|
|
this->mKvTokenStride = this->mNumHeadsUncompressed * this->mUncompressedHeadSize;
|
|
}
|
|
|
|
template <typename T>
|
|
void fillKVOffsetData(T* arr, size_t size, bool use_both_kv = true, int max_block_per_seq = 0)
|
|
{
|
|
if (use_both_kv)
|
|
{
|
|
for (int i = 0; i < size; i++)
|
|
{
|
|
arr[i] = static_cast<T>(i);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
int temp_idx = 0;
|
|
for (int i = 0; i < size; i++)
|
|
{
|
|
bool is_v = (((i / max_block_per_seq) % 2) == 1);
|
|
if (is_v)
|
|
{
|
|
arr[i] = static_cast<T>(0);
|
|
}
|
|
else
|
|
{
|
|
arr[i] = static_cast<T>(temp_idx);
|
|
temp_idx++;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void fillArrayDataWithMod(T* arr, size_t size)
|
|
{
|
|
for (int i = 0; i < size; i++)
|
|
{
|
|
arr[i] = static_cast<T>(i % 448);
|
|
}
|
|
}
|
|
|
|
int generateRandomSizeSmallerThan(int a)
|
|
{
|
|
if (a <= 0)
|
|
{
|
|
return 0;
|
|
}
|
|
std::uniform_int_distribution<> distrib(0, a - 1);
|
|
// Generate and return the random number
|
|
return int{distrib(gen)};
|
|
}
|
|
|
|
template <typename T>
|
|
void memsetZeroDevice(T* ptr, size_t size)
|
|
{
|
|
cudaMemset(ptr, 0, size * sizeof(T));
|
|
}
|
|
|
|
template <typename T>
|
|
void memsetZeroHost(T* ptr, size_t size)
|
|
{
|
|
std::memset(ptr, 0, size * sizeof(T));
|
|
}
|
|
|
|
bool allocateBuffers()
|
|
{
|
|
using tensorrt_llm::runtime::BufferManager;
|
|
using tensorrt_llm::runtime::CudaStream;
|
|
using tensorrt_llm::runtime::ITensor;
|
|
using tensorrt_llm::runtime::bufferCast;
|
|
|
|
auto dtype = nvinfer1::DataType::kHALF;
|
|
if constexpr (std::is_same_v<DataType, float>)
|
|
{
|
|
dtype = nvinfer1::DataType::kFLOAT;
|
|
}
|
|
else if constexpr (std::is_same_v<DataType, half>)
|
|
{
|
|
dtype = nvinfer1::DataType::kHALF;
|
|
}
|
|
else if constexpr (std::is_same_v<DataType, __nv_bfloat16>)
|
|
{
|
|
dtype = nvinfer1::DataType::kBF16;
|
|
}
|
|
else
|
|
{
|
|
return false;
|
|
}
|
|
auto cache_dtype = dtype;
|
|
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
|
|
{
|
|
cache_dtype = nvinfer1::DataType::kFP8;
|
|
this->h_kv_scale_orig_quant
|
|
= tensorrt_llm::runtime::BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT);
|
|
this->d_kv_scale_orig_quant
|
|
= tensorrt_llm::runtime::BufferManager::gpuSync(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT);
|
|
this->h_kv_scale_quant_orig
|
|
= tensorrt_llm::runtime::BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT);
|
|
this->d_kv_scale_quant_orig
|
|
= tensorrt_llm::runtime::BufferManager::gpuSync(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT);
|
|
auto* kv_scale_orig_quant_ptr = bufferCast<float>(*(this->h_kv_scale_orig_quant));
|
|
auto* kv_scale_quant_orig_ptr = bufferCast<float>(*(this->h_kv_scale_quant_orig));
|
|
float kv_scale_orig_quant = 2.0f;
|
|
kv_scale_orig_quant_ptr[0] = kv_scale_orig_quant;
|
|
kv_scale_quant_orig_ptr[0] = 1.0 / kv_scale_orig_quant;
|
|
cudaMemcpy(this->d_kv_scale_orig_quant->data(), this->h_kv_scale_orig_quant->data(),
|
|
this->h_kv_scale_orig_quant->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
cudaMemcpy(this->d_kv_scale_quant_orig->data(), this->h_kv_scale_quant_orig->data(),
|
|
this->h_kv_scale_quant_orig->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
}
|
|
else
|
|
{
|
|
static_assert(std::is_same_v<DataType, TCache>, "TCache must be the same type as DataType");
|
|
}
|
|
this->h_cu_seq_lens = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mNumRequests + 1}), nvinfer1::DataType::kINT64);
|
|
this->h_cu_ctx_cached_kv_lens = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mNumRequests + 1}), nvinfer1::DataType::kINT64);
|
|
this->d_cu_seq_lens = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mNumRequests + 1}), nvinfer1::DataType::kINT64);
|
|
this->d_cu_ctx_cached_kv_lens = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mNumRequests + 1}), nvinfer1::DataType::kINT64);
|
|
{
|
|
// set random sequence length
|
|
auto* cu_seq_lens_temp_ptr = bufferCast<int64_t>(*(this->h_cu_seq_lens));
|
|
auto* cu_ctx_cached_kv_lens_temp_ptr = bufferCast<int64_t>(*(this->h_cu_ctx_cached_kv_lens));
|
|
cu_seq_lens_temp_ptr[0] = 0;
|
|
cu_ctx_cached_kv_lens_temp_ptr[0] = 0;
|
|
for (int i = 1; i <= this->mNumRequests; i++)
|
|
{
|
|
int temp_seq_len = generateRandomSizeSmallerThan(512);
|
|
if (temp_seq_len <= 0)
|
|
{
|
|
temp_seq_len = 1; // at least 1 token
|
|
}
|
|
int cached_seq_len = generateRandomSizeSmallerThan(temp_seq_len);
|
|
this->mMaxSeqLen = std::max(temp_seq_len, this->mMaxSeqLen);
|
|
this->mMaxCachedSeqLen = std::max(cached_seq_len, this->mMaxCachedSeqLen);
|
|
this->mMaxUncachedSeqLen = std::max(temp_seq_len - cached_seq_len, this->mMaxUncachedSeqLen);
|
|
this->mTotalTokens += temp_seq_len;
|
|
this->mTotalCachedTokens += cached_seq_len;
|
|
this->mTotalUncachedTokens += temp_seq_len - cached_seq_len;
|
|
cu_seq_lens_temp_ptr[i] = cu_seq_lens_temp_ptr[i - 1] + temp_seq_len;
|
|
cu_ctx_cached_kv_lens_temp_ptr[i] = cu_ctx_cached_kv_lens_temp_ptr[i - 1] + cached_seq_len;
|
|
// std::cout << "batch " << i << "seq len: " << temp_seq_len << ", cached len: " << cached_seq_len
|
|
// << ", uncached len: " << temp_seq_len - cached_seq_len << std::endl;
|
|
}
|
|
cudaMemcpy(this->d_cu_seq_lens->data(), this->h_cu_seq_lens->data(), this->h_cu_seq_lens->getSizeInBytes(),
|
|
cudaMemcpyHostToDevice);
|
|
cudaMemcpy(this->d_cu_ctx_cached_kv_lens->data(), this->h_cu_ctx_cached_kv_lens->data(),
|
|
this->h_cu_ctx_cached_kv_lens->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
}
|
|
|
|
// malloc kv_cache
|
|
this->mMaxBlockPerSeq = (this->mMaxSeqLen + this->mTokensPerBlock - 1) / this->mTokensPerBlock;
|
|
this->h_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mNumRequests, 2, this->mMaxBlockPerSeq, this->mNumHeadsUncompressed,
|
|
this->mTokensPerBlock, this->mUncompressedHeadSize + this->mRopeSize}),
|
|
dtype);
|
|
this->h_kv_cache_tensor_ref = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mNumRequests, 2, this->mMaxBlockPerSeq, this->mNumHeadsUncompressed,
|
|
this->mTokensPerBlock, this->mUncompressedHeadSize + this->mRopeSize}),
|
|
dtype);
|
|
this->h_compressed_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mNumRequests, 1, this->mMaxBlockPerSeq, this->mNumHeadsCompressed,
|
|
this->mTokensPerBlock, this->mLoraSize + this->mRopeSize}),
|
|
cache_dtype);
|
|
this->h_compressed_kv_cache_tensor_ref = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mNumRequests, 1, this->mMaxBlockPerSeq, this->mNumHeadsCompressed,
|
|
this->mTokensPerBlock, this->mLoraSize + this->mRopeSize}),
|
|
cache_dtype);
|
|
this->h_offset_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mNumRequests, 2, this->mMaxBlockPerSeq}), nvinfer1::DataType::kINT32);
|
|
this->h_compressed_offset_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mNumRequests, 2, this->mMaxBlockPerSeq}), nvinfer1::DataType::kINT32);
|
|
this->d_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mNumRequests, 2, this->mMaxBlockPerSeq, this->mNumHeadsUncompressed,
|
|
this->mTokensPerBlock, this->mUncompressedHeadSize + this->mRopeSize}),
|
|
dtype);
|
|
this->d_compressed_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mNumRequests, 1, this->mMaxBlockPerSeq, this->mNumHeadsCompressed,
|
|
this->mTokensPerBlock, this->mLoraSize + this->mRopeSize}),
|
|
cache_dtype);
|
|
this->d_compressed_kv_cache_tensor_ref = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mNumRequests, 1, this->mMaxBlockPerSeq, this->mNumHeadsCompressed,
|
|
this->mTokensPerBlock, this->mLoraSize + this->mRopeSize}),
|
|
cache_dtype);
|
|
this->d_offset_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mNumRequests, 2, this->mMaxBlockPerSeq}), nvinfer1::DataType::kINT32);
|
|
this->d_compressed_offset_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mNumRequests, 2, this->mMaxBlockPerSeq}), nvinfer1::DataType::kINT32);
|
|
{
|
|
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor));
|
|
auto* kv_cache_ref_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
|
|
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->h_compressed_kv_cache_tensor));
|
|
auto* compressed_kv_cache_ref_ptr = bufferCast<TCache>(*(this->h_compressed_kv_cache_tensor_ref));
|
|
auto* offset_ptr = bufferCast<int32_t>(*(this->h_offset_tensor));
|
|
auto* compressed_offset_ptr = bufferCast<int32_t>(*(this->h_compressed_offset_tensor));
|
|
fillArrayDataWithMod(compressed_kv_cache_ptr, this->h_compressed_kv_cache_tensor->getSize());
|
|
fillArrayDataWithMod(compressed_kv_cache_ref_ptr, this->h_compressed_kv_cache_tensor_ref->getSize());
|
|
memsetZeroHost<DataType>(kv_cache_ptr, this->h_kv_cache_tensor->getSize());
|
|
memsetZeroHost<DataType>(kv_cache_ref_ptr, this->h_kv_cache_tensor_ref->getSize());
|
|
cudaMemcpy(this->d_kv_cache_tensor->data(), this->h_kv_cache_tensor->data(),
|
|
this->h_kv_cache_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
// fillArrayDataWithMod(offset_ptr, this->offset_tensor->getSize());
|
|
fillKVOffsetData(
|
|
compressed_offset_ptr, this->h_compressed_offset_tensor->getSize(), false, this->mMaxBlockPerSeq);
|
|
cudaMemcpy(this->d_compressed_kv_cache_tensor->data(), this->h_compressed_kv_cache_tensor->data(),
|
|
this->h_compressed_kv_cache_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
cudaMemcpy(this->d_compressed_kv_cache_tensor_ref->data(), this->h_compressed_kv_cache_tensor_ref->data(),
|
|
this->h_compressed_kv_cache_tensor_ref->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
cudaMemcpy(this->d_compressed_offset_tensor->data(), this->h_compressed_offset_tensor->data(),
|
|
this->h_compressed_offset_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
cudaMemcpy(this->d_offset_tensor->data(), this->h_offset_tensor->data(),
|
|
this->h_offset_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
}
|
|
|
|
// compressed_kv_output + k_pe_output for loadPagedKvKernel (kernel 1)
|
|
this->h_compressed_kv_output = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalCachedTokens, this->mLoraSize}), dtype);
|
|
this->h_compressed_kv_output_ref = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalCachedTokens, this->mLoraSize}), dtype);
|
|
this->d_compressed_kv_output = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mTotalCachedTokens, this->mLoraSize}), dtype);
|
|
this->h_k_pe_output = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalCachedTokens, this->mRopeSize}), dtype);
|
|
this->h_k_pe_output_ref = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalCachedTokens, this->mRopeSize}), dtype);
|
|
this->d_k_pe_output = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mTotalCachedTokens, this->mRopeSize}), dtype);
|
|
{
|
|
auto* compressed_kv_output_ptr = bufferCast<DataType>(*(this->h_compressed_kv_output));
|
|
auto* compressed_kv_output_ref_ptr = bufferCast<DataType>(*(this->h_compressed_kv_output_ref));
|
|
memsetZeroHost<DataType>(compressed_kv_output_ptr, this->h_compressed_kv_output->getSize());
|
|
memsetZeroHost<DataType>(compressed_kv_output_ref_ptr, this->h_compressed_kv_output_ref->getSize());
|
|
cudaMemcpy(this->d_compressed_kv_output->data(), this->h_compressed_kv_output->data(),
|
|
this->h_compressed_kv_output->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
|
|
auto* k_pe_output_ptr = bufferCast<DataType>(*(this->h_k_pe_output));
|
|
auto* k_pe_output_ref_ptr = bufferCast<DataType>(*(this->h_k_pe_output_ref));
|
|
memsetZeroHost<DataType>(k_pe_output_ptr, this->h_k_pe_output->getSize());
|
|
memsetZeroHost<DataType>(k_pe_output_ref_ptr, this->h_k_pe_output_ref->getSize());
|
|
cudaMemcpy(this->d_k_pe_output->data(), this->h_k_pe_output->data(), this->h_k_pe_output->getSizeInBytes(),
|
|
cudaMemcpyHostToDevice);
|
|
}
|
|
// k, v, k_pe for setPagedKvCacheForMLAKernel (kernel 2)
|
|
this->h_k_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
|
|
this->h_v_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
|
|
this->h_k_pe_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
|
|
this->d_k_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
|
|
this->d_v_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}), dtype);
|
|
this->d_k_pe_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mTotalTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
|
|
{
|
|
auto* k_ptr = bufferCast<DataType>(*(this->h_k_tensor));
|
|
auto* v_ptr = bufferCast<DataType>(*(this->h_v_tensor));
|
|
auto* k_pe_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor));
|
|
fillArrayDataWithMod(k_ptr, this->h_k_tensor->getSize());
|
|
fillArrayDataWithMod(v_ptr, this->h_v_tensor->getSize());
|
|
fillArrayDataWithMod(k_pe_ptr, this->h_k_pe_tensor->getSize());
|
|
cudaMemcpy(this->d_k_tensor->data(), this->h_k_tensor->data(), this->h_k_tensor->getSizeInBytes(),
|
|
cudaMemcpyHostToDevice);
|
|
cudaMemcpy(this->d_v_tensor->data(), this->h_v_tensor->data(), this->h_v_tensor->getSizeInBytes(),
|
|
cudaMemcpyHostToDevice);
|
|
cudaMemcpy(this->d_k_pe_tensor->data(), this->h_k_pe_tensor->data(), this->h_k_pe_tensor->getSizeInBytes(),
|
|
cudaMemcpyHostToDevice);
|
|
}
|
|
// ck, cv, ck_pe, uk, uc, uk_pe for setPagedKvCacheForMLAKernelV2 (kernel 2)
|
|
this->h_k_tensor_cached = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalCachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
|
|
dtype);
|
|
this->h_v_tensor_cached = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalCachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
|
|
dtype);
|
|
this->h_k_pe_tensor_cached = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalCachedTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
|
|
this->h_k_tensor_uncached = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalUncachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
|
|
dtype);
|
|
this->h_v_tensor_uncached = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalUncachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
|
|
dtype);
|
|
this->h_k_pe_tensor_uncached = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalUncachedTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
|
|
this->d_k_tensor_cached = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mTotalCachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
|
|
dtype);
|
|
this->d_v_tensor_cached = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mTotalCachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
|
|
dtype);
|
|
this->d_k_pe_tensor_cached = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mTotalCachedTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
|
|
this->d_k_tensor_uncached = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mTotalUncachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
|
|
dtype);
|
|
this->d_v_tensor_uncached = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mTotalUncachedTokens, this->mNumHeadsUncompressed, this->mUncompressedHeadSize}),
|
|
dtype);
|
|
this->d_k_pe_tensor_uncached = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mTotalUncachedTokens, this->mNumHeadsCompressed, this->mRopeSize}), dtype);
|
|
{
|
|
auto* k_cached_ptr = bufferCast<DataType>(*(this->h_k_tensor_cached));
|
|
auto* v_cached_ptr = bufferCast<DataType>(*(this->h_v_tensor_cached));
|
|
auto* k_pe_cached_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor_cached));
|
|
auto* k_uncached_ptr = bufferCast<DataType>(*(this->h_k_tensor_uncached));
|
|
auto* v_uncached_ptr = bufferCast<DataType>(*(this->h_v_tensor_uncached));
|
|
auto* k_pe_uncached_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor_uncached));
|
|
fillArrayDataWithMod(k_cached_ptr, this->h_k_tensor_cached->getSize());
|
|
fillArrayDataWithMod(v_cached_ptr, this->h_v_tensor_cached->getSize());
|
|
fillArrayDataWithMod(k_pe_cached_ptr, this->h_k_pe_tensor_cached->getSize());
|
|
fillArrayDataWithMod(k_uncached_ptr, this->h_k_tensor_uncached->getSize());
|
|
fillArrayDataWithMod(v_uncached_ptr, this->h_v_tensor_uncached->getSize());
|
|
fillArrayDataWithMod(k_pe_uncached_ptr, this->h_k_pe_tensor_uncached->getSize());
|
|
cudaMemcpy(this->d_k_tensor_cached->data(), this->h_k_tensor_cached->data(),
|
|
this->h_k_tensor_cached->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
cudaMemcpy(this->d_v_tensor_cached->data(), this->h_v_tensor_cached->data(),
|
|
this->h_v_tensor_cached->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
cudaMemcpy(this->d_k_pe_tensor_cached->data(), this->h_k_pe_tensor_cached->data(),
|
|
this->h_k_pe_tensor_cached->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
cudaMemcpy(this->d_k_tensor_uncached->data(), this->h_k_tensor_uncached->data(),
|
|
this->h_k_tensor_uncached->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
cudaMemcpy(this->d_v_tensor_uncached->data(), this->h_v_tensor_uncached->data(),
|
|
this->h_v_tensor_uncached->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
cudaMemcpy(this->d_k_pe_tensor_uncached->data(), this->h_k_pe_tensor_uncached->data(),
|
|
this->h_k_pe_tensor_uncached->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
}
|
|
// compressed_kv, k_pe_one_head for appendPagedKvForMLAKernel (kernel 3)
|
|
this->h_compressed_kv_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalUncachedTokens, 1, this->mLoraSize}), dtype);
|
|
this->h_k_pe_one_head_tensor = tensorrt_llm::runtime::BufferManager::pinned(
|
|
ITensor::makeShape({this->mTotalUncachedTokens, 1, this->mRopeSize}), dtype);
|
|
this->d_compressed_kv_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mTotalUncachedTokens, 1, this->mLoraSize}), dtype);
|
|
this->d_k_pe_one_head_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
|
|
ITensor::makeShape({this->mTotalUncachedTokens, 1, this->mRopeSize}), dtype);
|
|
|
|
{
|
|
auto* compressed_kv_ptr = bufferCast<DataType>(*(this->h_compressed_kv_tensor));
|
|
auto* k_pe_one_head_ptr = bufferCast<DataType>(*(this->h_k_pe_one_head_tensor));
|
|
fillArrayDataWithMod(compressed_kv_ptr, this->h_compressed_kv_tensor->getSize());
|
|
fillArrayDataWithMod(k_pe_one_head_ptr, this->h_k_pe_one_head_tensor->getSize());
|
|
cudaMemcpy(this->d_compressed_kv_tensor->data(), this->h_compressed_kv_tensor->data(),
|
|
this->h_compressed_kv_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
cudaMemcpy(this->d_k_pe_one_head_tensor->data(), this->h_k_pe_one_head_tensor->data(),
|
|
this->h_k_pe_one_head_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void PerformLoadPagedKV()
|
|
{
|
|
using tensorrt_llm::runtime::bufferCast;
|
|
auto* compressed_kv_output_ptr = bufferCast<DataType>(*(this->d_compressed_kv_output));
|
|
auto* k_pe_output_ptr = bufferCast<DataType>(*(this->d_k_pe_output));
|
|
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->d_compressed_kv_cache_tensor));
|
|
auto* offset_ptr = bufferCast<int32_t>(*(this->d_compressed_offset_tensor));
|
|
auto* cu_ctx_cached_kv_lens_ptr = bufferCast<int64_t>(*(this->d_cu_ctx_cached_kv_lens));
|
|
float* kv_scale_quant_orig_ptr = nullptr;
|
|
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
|
|
{
|
|
kv_scale_quant_orig_ptr = bufferCast<float>(*(this->d_kv_scale_quant_orig));
|
|
}
|
|
tensorrt_llm::kernels::KVBlockArray kv_cache(this->mNumRequests, this->mMaxBlockPerSeq, this->mTokensPerBlock,
|
|
sizeof(TCache) * 1 * (this->mLoraSize + this->mRopeSize), 0, 0, 0, 0, compressed_kv_cache_ptr, nullptr,
|
|
reinterpret_cast<tensorrt_llm::kernels::KVBlockArrayForContextFMHA::DataType*>(offset_ptr));
|
|
tensorrt_llm::kernels::invokeMLALoadPagedKV<DataType, TCache>(compressed_kv_output_ptr, k_pe_output_ptr,
|
|
kv_cache, this->mNumRequests, cu_ctx_cached_kv_lens_ptr, this->mMaxCachedSeqLen, this->mLoraSize,
|
|
this->mRopeSize, kv_scale_quant_orig_ptr, this->mStream->get());
|
|
cudaStreamSynchronize(this->mStream->get());
|
|
cudaMemcpy(this->h_compressed_kv_output->data(), this->d_compressed_kv_output->data(),
|
|
this->d_compressed_kv_output->getSizeInBytes(), cudaMemcpyDeviceToHost);
|
|
cudaMemcpy(this->h_k_pe_output->data(), this->d_k_pe_output->data(), this->d_k_pe_output->getSizeInBytes(),
|
|
cudaMemcpyDeviceToHost);
|
|
}
|
|
|
|
void PerformLoadPagedKVRef()
|
|
{
|
|
using tensorrt_llm::runtime::bufferCast;
|
|
auto* compressed_kv_output_ptr = bufferCast<DataType>(*(this->h_compressed_kv_output_ref));
|
|
auto* k_pe_output_ptr = bufferCast<DataType>(*(this->h_k_pe_output_ref));
|
|
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->h_compressed_kv_cache_tensor));
|
|
auto* offset_ptr = bufferCast<int32_t>(*(this->h_compressed_offset_tensor));
|
|
auto* cu_ctx_cached_kv_lens_ptr = bufferCast<int64_t>(*(this->h_cu_ctx_cached_kv_lens));
|
|
float* kv_scale_quant_orig_ptr = nullptr;
|
|
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
|
|
{
|
|
kv_scale_quant_orig_ptr = bufferCast<float>(*(this->h_kv_scale_quant_orig));
|
|
}
|
|
tensorrt_llm::kernels::KVBlockArray kv_cache(this->mNumRequests, this->mMaxBlockPerSeq, this->mTokensPerBlock,
|
|
sizeof(TCache) * 1 * (this->mLoraSize + this->mRopeSize), 0, 0, 0, 0, compressed_kv_cache_ptr, nullptr,
|
|
reinterpret_cast<tensorrt_llm::kernels::KVBlockArrayForContextFMHA::DataType*>(offset_ptr));
|
|
loadPagedKvKernelRef<DataType, TCache>(compressed_kv_output_ptr, k_pe_output_ptr, kv_cache, this->mNumRequests,
|
|
cu_ctx_cached_kv_lens_ptr, this->mLoraSize, this->mRopeSize, kv_scale_quant_orig_ptr);
|
|
}
|
|
|
|
void PerformSetPagedKV()
|
|
{
|
|
using tensorrt_llm::runtime::bufferCast;
|
|
auto* k_ptr = bufferCast<DataType>(*(this->d_k_tensor));
|
|
auto* v_ptr = bufferCast<DataType>(*(this->d_v_tensor));
|
|
auto* k_pe_ptr = bufferCast<DataType>(*(this->d_k_pe_tensor));
|
|
auto* kv_cache_ptr = bufferCast<DataType>(*(this->d_kv_cache_tensor));
|
|
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->d_cu_seq_lens));
|
|
tensorrt_llm::kernels::invokeMLASetPagedKV<DataType>(kv_cache_ptr, k_ptr, v_ptr, k_pe_ptr, this->mNumRequests,
|
|
cu_seq_lens_ptr, this->mMaxSeqLen, this->mNumHeadsUncompressed, this->mUncompressedHeadSize,
|
|
this->mRopeSize, this->mTokensPerBlock, this->mKvTokenStride, this->mStream->get());
|
|
cudaStreamSynchronize(this->mStream->get());
|
|
cudaMemcpy(this->h_kv_cache_tensor->data(), this->d_kv_cache_tensor->data(),
|
|
this->d_kv_cache_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost);
|
|
}
|
|
|
|
void PerformSetPagedKVRef()
|
|
{
|
|
using tensorrt_llm::runtime::bufferCast;
|
|
auto* k_ptr = bufferCast<DataType>(*(this->h_k_tensor));
|
|
auto* v_ptr = bufferCast<DataType>(*(this->h_v_tensor));
|
|
auto* k_pe_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor));
|
|
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
|
|
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_seq_lens));
|
|
setPagedKvCacheForMLAKernelRef(kv_cache_ptr, k_ptr, v_ptr, k_pe_ptr, this->mNumRequests, cu_seq_lens_ptr,
|
|
this->mMaxSeqLen, this->mNumHeadsUncompressed, this->mUncompressedHeadSize, this->mRopeSize,
|
|
this->mTokensPerBlock, this->mKvTokenStride);
|
|
}
|
|
|
|
void PerformSetPagedKVV2()
|
|
{
|
|
using tensorrt_llm::runtime::bufferCast;
|
|
auto* k_cached_ptr = bufferCast<DataType>(*(this->d_k_tensor_cached));
|
|
auto* v_cached_ptr = bufferCast<DataType>(*(this->d_v_tensor_cached));
|
|
auto* k_pe_cached_ptr = bufferCast<DataType>(*(this->d_k_pe_tensor_cached));
|
|
auto* k_uncached_ptr = bufferCast<DataType>(*(this->d_k_tensor_uncached));
|
|
auto* v_uncached_ptr = bufferCast<DataType>(*(this->d_v_tensor_uncached));
|
|
auto* k_pe_uncached_ptr = bufferCast<DataType>(*(this->d_k_pe_tensor_uncached));
|
|
auto* cu_ctx_cached_kv_lens_ptr = bufferCast<int64_t>(*(this->d_cu_ctx_cached_kv_lens));
|
|
auto* kv_cache_ptr = bufferCast<DataType>(*(this->d_kv_cache_tensor));
|
|
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->d_cu_seq_lens));
|
|
tensorrt_llm::kernels::invokeMLASetPagedKVV2<DataType>(kv_cache_ptr, k_cached_ptr, v_cached_ptr,
|
|
k_pe_cached_ptr, k_uncached_ptr, v_uncached_ptr, k_pe_uncached_ptr, this->mNumRequests,
|
|
cu_ctx_cached_kv_lens_ptr, cu_seq_lens_ptr, this->mMaxSeqLen, this->mNumHeadsUncompressed,
|
|
this->mUncompressedHeadSize, this->mRopeSize, this->mTokensPerBlock, this->mStream->get());
|
|
cudaStreamSynchronize(this->mStream->get());
|
|
cudaMemcpy(this->h_kv_cache_tensor->data(), this->d_kv_cache_tensor->data(),
|
|
this->d_kv_cache_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost);
|
|
}
|
|
|
|
void PerformSetPagedKVV2Ref()
|
|
{
|
|
using tensorrt_llm::runtime::bufferCast;
|
|
auto* k_cached_ptr = bufferCast<DataType>(*(this->h_k_tensor_cached));
|
|
auto* v_cached_ptr = bufferCast<DataType>(*(this->h_v_tensor_cached));
|
|
auto* k_pe_cached_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor_cached));
|
|
auto* k_uncached_ptr = bufferCast<DataType>(*(this->h_k_tensor_uncached));
|
|
auto* v_uncached_ptr = bufferCast<DataType>(*(this->h_v_tensor_uncached));
|
|
auto* k_pe_uncached_ptr = bufferCast<DataType>(*(this->h_k_pe_tensor_uncached));
|
|
auto* cu_ctx_cached_kv_lens_ptr = bufferCast<int64_t>(*(this->h_cu_ctx_cached_kv_lens));
|
|
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
|
|
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_seq_lens));
|
|
setPagedKvCacheForMLAKernelRefV2(kv_cache_ptr, k_cached_ptr, v_cached_ptr, k_pe_cached_ptr, k_uncached_ptr,
|
|
v_uncached_ptr, k_pe_uncached_ptr, this->mNumRequests, cu_ctx_cached_kv_lens_ptr, cu_seq_lens_ptr,
|
|
this->mMaxSeqLen, this->mNumHeadsUncompressed, this->mUncompressedHeadSize, this->mRopeSize,
|
|
this->mTokensPerBlock);
|
|
}
|
|
|
|
void PerformAppendPagedKV()
|
|
{
|
|
using tensorrt_llm::runtime::bufferCast;
|
|
auto* compressed_kv_ptr = bufferCast<DataType>(*(this->d_compressed_kv_tensor));
|
|
auto* k_pe_one_head_ptr = bufferCast<DataType>(*(this->d_k_pe_one_head_tensor));
|
|
auto* offset_ptr = bufferCast<int32_t>(*(this->d_compressed_offset_tensor));
|
|
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->d_compressed_kv_cache_tensor));
|
|
auto* cu_ctx_cached_kv_lens_ptr = bufferCast<int64_t>(*(this->d_cu_ctx_cached_kv_lens));
|
|
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->d_cu_seq_lens));
|
|
float* kv_scale_orig_quant_ptr = nullptr;
|
|
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
|
|
{
|
|
kv_scale_orig_quant_ptr = bufferCast<float>(*(this->d_kv_scale_orig_quant));
|
|
}
|
|
tensorrt_llm::kernels::KVBlockArray kv_cache(this->mNumRequests, this->mMaxBlockPerSeq, this->mTokensPerBlock,
|
|
sizeof(TCache) * 1 * (this->mLoraSize + this->mRopeSize), 0, 0, 0, 0, compressed_kv_cache_ptr, nullptr,
|
|
reinterpret_cast<tensorrt_llm::kernels::KVBlockArrayForContextFMHA::DataType*>(offset_ptr));
|
|
tensorrt_llm::kernels::invokeMLAAppendPagedKV<DataType, TCache>(kv_cache, compressed_kv_ptr, k_pe_one_head_ptr,
|
|
this->mNumRequests, cu_ctx_cached_kv_lens_ptr, cu_seq_lens_ptr, this->mMaxUncachedSeqLen,
|
|
this->mLoraSize + this->mRopeSize, kv_scale_orig_quant_ptr, this->mStream->get());
|
|
cudaStreamSynchronize(this->mStream->get());
|
|
cudaMemcpy(this->h_compressed_kv_cache_tensor->data(), this->d_compressed_kv_cache_tensor->data(),
|
|
this->d_compressed_kv_cache_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost);
|
|
}
|
|
|
|
void PerformAppendPagedKVRef()
|
|
{
|
|
using tensorrt_llm::runtime::bufferCast;
|
|
auto* compressed_kv_ptr = bufferCast<DataType>(*(this->h_compressed_kv_tensor));
|
|
auto* k_pe_one_head_ptr = bufferCast<DataType>(*(this->h_k_pe_one_head_tensor));
|
|
auto* offset_ptr = bufferCast<int32_t>(*(this->h_compressed_offset_tensor));
|
|
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->h_compressed_kv_cache_tensor_ref));
|
|
auto* cu_ctx_cached_kv_lens_ptr = bufferCast<int64_t>(*(this->h_cu_ctx_cached_kv_lens));
|
|
auto* cu_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_seq_lens));
|
|
float* kv_scale_orig_quant_ptr = nullptr;
|
|
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
|
|
{
|
|
kv_scale_orig_quant_ptr = bufferCast<float>(*(this->h_kv_scale_orig_quant));
|
|
}
|
|
tensorrt_llm::kernels::KVBlockArray kv_cache(this->mNumRequests, this->mMaxBlockPerSeq, this->mTokensPerBlock,
|
|
sizeof(TCache) * 1 * (this->mLoraSize + this->mRopeSize), 0, 0, 0, 0, compressed_kv_cache_ptr, nullptr,
|
|
reinterpret_cast<tensorrt_llm::kernels::KVBlockArrayForContextFMHA::DataType*>(offset_ptr));
|
|
// currently k_pe_head_num = 1
|
|
appendPagedKvForMLAKernelRef<DataType, TCache>(kv_cache, compressed_kv_ptr, k_pe_one_head_ptr,
|
|
this->mNumRequests, cu_ctx_cached_kv_lens_ptr, cu_seq_lens_ptr, 1, this->mLoraSize, this->mRopeSize,
|
|
kv_scale_orig_quant_ptr);
|
|
}
|
|
|
|
template <typename T>
|
|
bool CheckEqual(T const* expected, T const* output, size_t size)
|
|
{
|
|
for (int i = 0; i < size; i++)
|
|
{
|
|
auto e = static_cast<float>(expected[i]);
|
|
auto o = static_cast<float>(output[i]);
|
|
if (!almostEqual(e, o, 1e-3, 1e-3))
|
|
{
|
|
TLLM_LOG_ERROR(
|
|
"Mismatch input value. Position of inputs: %d, expected value: %f, output value: %f", i, e, o);
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
};
|
|
|
|
using MLATypes
|
|
= ::testing::Types<std::pair<half, half>, std::pair<__nv_bfloat16, __nv_bfloat16>, std::pair<float, float>,
|
|
std::pair<half, __nv_fp8_e4m3>, std::pair<__nv_bfloat16, __nv_fp8_e4m3>, std::pair<float, __nv_fp8_e4m3>>;
|
|
TYPED_TEST_SUITE(MlaPreprocessTest, MLATypes);
|
|
|
|
TYPED_TEST(MlaPreprocessTest, MLAPreprocessDefault)
|
|
{
|
|
using tensorrt_llm::runtime::bufferCast;
|
|
using DataType = typename TestFixture::DataType;
|
|
using TCache = typename TestFixture::TCache;
|
|
this->mNumRequests = 8;
|
|
this->setDefaultParams();
|
|
EXPECT_TRUE(this->allocateBuffers());
|
|
|
|
sync_check_cuda_error(this->mStream->get());
|
|
bool allEqual{true};
|
|
|
|
{
|
|
this->PerformLoadPagedKV();
|
|
sync_check_cuda_error(this->mStream->get());
|
|
this->PerformLoadPagedKVRef();
|
|
auto* compressed_kv_output_ptr = bufferCast<DataType>(*(this->h_compressed_kv_output));
|
|
auto* k_pe_output_ptr = bufferCast<DataType>(*(this->h_k_pe_output));
|
|
auto* compressed_kv_output_ref_ptr = bufferCast<DataType>(*(this->h_compressed_kv_output_ref));
|
|
auto* k_pe_output_ref_ptr = bufferCast<DataType>(*(this->h_k_pe_output_ref));
|
|
allEqual = this->CheckEqual(
|
|
compressed_kv_output_ref_ptr, compressed_kv_output_ptr, this->h_compressed_kv_output->getSize());
|
|
EXPECT_TRUE(allEqual);
|
|
allEqual = this->CheckEqual(k_pe_output_ref_ptr, k_pe_output_ptr, this->h_k_pe_output->getSize());
|
|
EXPECT_TRUE(allEqual);
|
|
}
|
|
|
|
{
|
|
this->PerformSetPagedKV();
|
|
sync_check_cuda_error(this->mStream->get());
|
|
this->PerformSetPagedKVRef();
|
|
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor));
|
|
auto* kv_cache_ref_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
|
|
allEqual = this->CheckEqual(kv_cache_ref_ptr, kv_cache_ptr, this->h_kv_cache_tensor->getSize());
|
|
EXPECT_TRUE(allEqual);
|
|
}
|
|
|
|
{
|
|
this->PerformSetPagedKVV2();
|
|
sync_check_cuda_error(this->mStream->get());
|
|
this->PerformSetPagedKVV2Ref();
|
|
auto* kv_cache_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor));
|
|
auto* kv_cache_ref_ptr = bufferCast<DataType>(*(this->h_kv_cache_tensor_ref));
|
|
allEqual = this->CheckEqual(kv_cache_ref_ptr, kv_cache_ptr, this->h_kv_cache_tensor->getSize());
|
|
EXPECT_TRUE(allEqual);
|
|
}
|
|
|
|
{
|
|
this->PerformAppendPagedKV();
|
|
sync_check_cuda_error(this->mStream->get());
|
|
this->PerformAppendPagedKVRef();
|
|
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->h_compressed_kv_cache_tensor));
|
|
auto* compressed_kv_cache_ref_ptr = bufferCast<TCache>(*(this->h_compressed_kv_cache_tensor_ref));
|
|
allEqual = this->CheckEqual(
|
|
compressed_kv_cache_ref_ptr, compressed_kv_cache_ptr, this->h_compressed_kv_cache_tensor->getSize());
|
|
EXPECT_TRUE(allEqual);
|
|
}
|
|
}
|